题意: 有$n$种球,每种有无限个,同时第$i$种球有一个代价$c_i$,你要拿不超过$w$个球。如果最后第$i$种球你拿了$k_i$个,那么你会获得$\prod_{1\leq i\leq n}{k_i^{c_i}}$的权值,求所有合法方案的权值和。$n\leq 1e5,\sum{c_i}\leq 1e5,w\leq 10^{18}$
#include <cinttypes>
#include <cstring>
#include <fstream>
#include <iostream>
#include <random>
using int_t = int;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 2e6;
using i64 = int64_t;
#ifdef NTTCNT
std::ofstream nttcnt("nttcnt.txt");
#endif
const int_t mod = 998244353;
const int_t g = 3;
int_t power(int_t b, int_t i) {
int_t r = 1;
if (i < 0)
i = ((i64)i % (mod - 1) + mod - 1) % (mod - 1);
while (i) {
if (i & 1)
r = (i64)r * b % mod;
b = (i64)b * b % mod;
i >>= 1;
}
return r;
}
void makeflip(int_t* arr, int_t size2) {
int_t len = (1 << size2);
arr[0] = 0;
for (int_t i = 1; i < len; i++) {
arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1));
}
}
int_t upper2n(int_t x) {
int_t r = 0;
while ((1 << r) < x)
r++;
return r;
}
template <int_t arg = 1>
void transform(int_t* A, int_t size2, int_t* flip) {
// #define int_t i64
int_t len = (1 << size2);
#ifdef NTTCNT
nttcnt << len << endl;
#endif
for (int_t i = 0; i < len; i++) {
int_t r = flip[i];
if (r > i)
std::swap(A[i], A[r]);
}
for (int_t i = 2; i <= len; i *= 2) {
int_t mr = power(g, (i64)arg * (mod - 1) / i);
for (int_t j = 0; j < len; j += i) {
int_t curr = 1;
for (int_t k = 0; k < i / 2; k++) {
int_t u = A[j + k], v = (i64)curr * A[j + k + i / 2] % mod;
A[j + k] = ((i64)u + v) % mod;
A[j + k + i / 2] = ((i64)u - v + mod) % mod;
curr = (i64)curr * mr % mod;
}
}
}
// #undef int_t
}
void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) {
/*
计算n次多项式A与m次多项式B的乘积
*/
int_t size2 = upper2n(n + m + 1);
int_t len = 1 << size2;
static int_t T1[LARGE], T2[LARGE];
for (int_t i = 0; i < len; i++) {
if (i <= n)
T1[i] = A[i];
else
T1[i] = 0;
if (i <= m)
T2[i] = B[i];
else
T2[i] = 0;
}
static int_t fliparr[LARGE];
makeflip(fliparr, size2);
transform(T1, size2, fliparr);
transform(T2, size2, fliparr);
for (int_t i = 0; i < len; i++)
T1[i] = (i64)T1[i] * T2[i] % mod;
transform<-1>(T1, size2, fliparr);
int_t inv = power(len, -1);
for (int_t i = 0; i <= n + m; i++)
C[i] = (i64)T1[i] * inv % mod;
}
const int_t FLARGE = 2e5 + 10;
int_t fact[FLARGE + 1], inv[FLARGE + 1], factInv[FLARGE + 1];
void euler_num(int_t n, int_t* out) {
static int_t P1[LARGE];
static int_t P2[LARGE];
for (int_t i = 0; i <= n; i++) {
P1[i] = (i64)((i % 2) ? (mod - 1) : 1) * factInv[i] % mod *
factInv[n + 1 - i] % mod;
P2[i] = power(i, n);
}
poly_mul(P1, n, P2, n, out);
for (int_t i = 0; i <= n; i++)
out[i] = (i64)out[i] * fact[n + 1] % mod;
}
using poly_t = std::vector<int_t>;
poly_t poly_dcmul(const poly_t* P, int_t left, int_t right) {
if (left == right)
return P[left];
int_t mid = (left + right) / 2;
auto lret = poly_dcmul(P, left, mid);
auto rret = poly_dcmul(P, mid + 1, right);
poly_t ret;
ret.resize(lret.size() - 1 + rret.size() - 1 + 1);
poly_mul(&lret[0], lret.size() - 1, &rret[0], rret.size() - 1, &ret[0]);
while (!ret.empty() && ret.back() == 0)
ret.pop_back();
return ret;
}
int_t C(int_t n, int_t m) {
return (i64)fact[n] * factInv[m] % mod * factInv[n - m] % mod;
}
i64 getC(i64 n, int_t m) {
if (n < 0 || n < m)
return 0;
i64 prod = 1;
for (int_t i = 0; i < m; i++)
prod = prod * (n % mod - i + mod) % mod;
return prod * factInv[m] % mod;
}
int main() {
std::ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
{
fact[0] = fact[1] = inv[1] = factInv[0] = factInv[1] = 1;
for (int_t i = 2; i <= FLARGE; i++) {
fact[i] = (i64)fact[i - 1] * i % mod;
inv[i] = (i64)(mod - mod / i) * inv[mod % i] % mod;
factInv[i] = (i64)factInv[i - 1] * inv[i] % mod;
}
}
int_t n;
i64 w;
cin >> n >> w;
static poly_t up[int_t(1e5 + 10)];
static poly_t cache[int_t(1e5 + 10)];
int_t sum = 0;
for (int_t i = 1; i <= n; i++) {
int_t c;
cin >> c;
sum += c + 1;
poly_t& curr = up[i];
if (!cache[c].empty()) {
curr = cache[c];
continue;
}
curr.resize(2 * c + 3);
euler_num(c, &curr[0]);
curr.resize(c + 1);
cache[c] = curr;
#ifdef DEBUG
cout << "up " << i << " = ";
for (auto t : curr)
cout << t << " ";
cout << endl;
#endif
}
sum++;
// for (int_t i = 0; i <= sum; i++) {
// down[i] = (i64)((i % 2) ? (mod - 1) : 1) * C(sum, i) % mod;
// }
auto upprod = poly_dcmul(up, 1, n);
// cout << upprod.size() << endl;
upprod.resize(sum + 1);
#ifdef DEBUG
cout << "up prod = ";
for (auto t : upprod)
cout << t << " ";
cout << endl;
cout << "sum = " << sum << endl;
#endif
// cout << "down size = " << sum << endl;
// w -= cutcount;
// if (w < 0) {
// cout << 0 << endl;
// return 0;
// }
// int_t ret = poly_divat(&upprod[0], down, upprod.size() - 1, sum, w);
i64 ret = 0;
{
i64 b = sum;
i64 n = w;
i64 curr = getC(n - 0 + b - 1, b - 1);
// C(n-i+b-1,b-1)到C(n-(i+1)+b-1,b-1)
// getC(n - i + b - 1, b - 1)
#ifdef DEBUG
cout << "init curr = " << curr << endl;
#endif
for (i64 i = 0; i < upprod.size(); i++) {
ret = (ret + upprod[i] * curr % mod) % mod;
curr = curr * (n % mod - i + mod) % mod *
power((n % mod - i + b - 1 + (i64)3 * mod) % mod, -1) % mod;
if (curr == 0 || n <= i - n + 1)
break;
#ifdef DEBUG
cout << "curr = " << curr << " at i = " << i << " with n = " << n
<< endl;
#endif
}
// int_t curr = getC(w - b + b - 1, b - 1);
// for (i64 i = std::max<i64>(0, w - sum); i < upprod.size(); i++) {
// ret = (ret + curr * upprod[i] % mod) % mod;
// curr = curr * (n % mod + n % mod - sum % mod + mod + i) % mod *
// power(n + 1, -1) % mod;
// }
// i64 n = w;
// i64 t = sum - 1;
// i64 prod = 1;
// // i64 curr = (t + n) % mod;
// for (int_t i = 0; i < t; i++)
// prod = (prod * (t + n % mod - i) % mod);
// for (int_t a = 0; a < upprod.size(); a++) {
// if (n >= a) {
// ret = (ret + prod * factInv[t] % mod * upprod[a] % mod) %
// mod;
// }
// prod = prod * (t - a + n % mod) % mod *
// power((n % mod - a + mod) % mod, -1) % mod;
// }
}
cout << ret << endl;
return 0;
}