题意: 有$n$种球,每种有无限个,同时第$i$种球有一个代价$c_i$,你要拿不超过$w$个球。如果最后第$i$种球你拿了$k_i$个,那么你会获得$\prod_{1\leq i\leq n}{k_i^{c_i}}$的权值,求所有合法方案的权值和。$n\leq 1e5,\sum{c_i}\leq 1e5,w\leq 10^{18}$
#include <cinttypes> #include <cstring> #include <fstream> #include <iostream> #include <random> using int_t = int; using std::cin; using std::cout; using std::endl; const int_t LARGE = 2e6; using i64 = int64_t; #ifdef NTTCNT std::ofstream nttcnt("nttcnt.txt"); #endif const int_t mod = 998244353; const int_t g = 3; int_t power(int_t b, int_t i) { int_t r = 1; if (i < 0) i = ((i64)i % (mod - 1) + mod - 1) % (mod - 1); while (i) { if (i & 1) r = (i64)r * b % mod; b = (i64)b * b % mod; i >>= 1; } return r; } void makeflip(int_t* arr, int_t size2) { int_t len = (1 << size2); arr[0] = 0; for (int_t i = 1; i < len; i++) { arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1)); } } int_t upper2n(int_t x) { int_t r = 0; while ((1 << r) < x) r++; return r; } template <int_t arg = 1> void transform(int_t* A, int_t size2, int_t* flip) { // #define int_t i64 int_t len = (1 << size2); #ifdef NTTCNT nttcnt << len << endl; #endif for (int_t i = 0; i < len; i++) { int_t r = flip[i]; if (r > i) std::swap(A[i], A[r]); } for (int_t i = 2; i <= len; i *= 2) { int_t mr = power(g, (i64)arg * (mod - 1) / i); for (int_t j = 0; j < len; j += i) { int_t curr = 1; for (int_t k = 0; k < i / 2; k++) { int_t u = A[j + k], v = (i64)curr * A[j + k + i / 2] % mod; A[j + k] = ((i64)u + v) % mod; A[j + k + i / 2] = ((i64)u - v + mod) % mod; curr = (i64)curr * mr % mod; } } } // #undef int_t } void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) { /* 计算n次多项式A与m次多项式B的乘积 */ int_t size2 = upper2n(n + m + 1); int_t len = 1 << size2; static int_t T1[LARGE], T2[LARGE]; for (int_t i = 0; i < len; i++) { if (i <= n) T1[i] = A[i]; else T1[i] = 0; if (i <= m) T2[i] = B[i]; else T2[i] = 0; } static int_t fliparr[LARGE]; makeflip(fliparr, size2); transform(T1, size2, fliparr); transform(T2, size2, fliparr); for (int_t i = 0; i < len; i++) T1[i] = (i64)T1[i] * T2[i] % mod; transform<-1>(T1, size2, fliparr); int_t inv = power(len, -1); for (int_t i = 0; i <= n + m; i++) C[i] = (i64)T1[i] * inv % mod; } const int_t FLARGE = 2e5 + 10; int_t fact[FLARGE + 1], inv[FLARGE + 1], factInv[FLARGE + 1]; void euler_num(int_t n, int_t* out) { static int_t P1[LARGE]; static int_t P2[LARGE]; for (int_t i = 0; i <= n; i++) { P1[i] = (i64)((i % 2) ? (mod - 1) : 1) * factInv[i] % mod * factInv[n + 1 - i] % mod; P2[i] = power(i, n); } poly_mul(P1, n, P2, n, out); for (int_t i = 0; i <= n; i++) out[i] = (i64)out[i] * fact[n + 1] % mod; } using poly_t = std::vector<int_t>; poly_t poly_dcmul(const poly_t* P, int_t left, int_t right) { if (left == right) return P[left]; int_t mid = (left + right) / 2; auto lret = poly_dcmul(P, left, mid); auto rret = poly_dcmul(P, mid + 1, right); poly_t ret; ret.resize(lret.size() - 1 + rret.size() - 1 + 1); poly_mul(&lret[0], lret.size() - 1, &rret[0], rret.size() - 1, &ret[0]); while (!ret.empty() && ret.back() == 0) ret.pop_back(); return ret; } int_t C(int_t n, int_t m) { return (i64)fact[n] * factInv[m] % mod * factInv[n - m] % mod; } i64 getC(i64 n, int_t m) { if (n < 0 || n < m) return 0; i64 prod = 1; for (int_t i = 0; i < m; i++) prod = prod * (n % mod - i + mod) % mod; return prod * factInv[m] % mod; } int main() { std::ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); { fact[0] = fact[1] = inv[1] = factInv[0] = factInv[1] = 1; for (int_t i = 2; i <= FLARGE; i++) { fact[i] = (i64)fact[i - 1] * i % mod; inv[i] = (i64)(mod - mod / i) * inv[mod % i] % mod; factInv[i] = (i64)factInv[i - 1] * inv[i] % mod; } } int_t n; i64 w; cin >> n >> w; static poly_t up[int_t(1e5 + 10)]; static poly_t cache[int_t(1e5 + 10)]; int_t sum = 0; for (int_t i = 1; i <= n; i++) { int_t c; cin >> c; sum += c + 1; poly_t& curr = up[i]; if (!cache[c].empty()) { curr = cache[c]; continue; } curr.resize(2 * c + 3); euler_num(c, &curr[0]); curr.resize(c + 1); cache[c] = curr; #ifdef DEBUG cout << "up " << i << " = "; for (auto t : curr) cout << t << " "; cout << endl; #endif } sum++; // for (int_t i = 0; i <= sum; i++) { // down[i] = (i64)((i % 2) ? (mod - 1) : 1) * C(sum, i) % mod; // } auto upprod = poly_dcmul(up, 1, n); // cout << upprod.size() << endl; upprod.resize(sum + 1); #ifdef DEBUG cout << "up prod = "; for (auto t : upprod) cout << t << " "; cout << endl; cout << "sum = " << sum << endl; #endif // cout << "down size = " << sum << endl; // w -= cutcount; // if (w < 0) { // cout << 0 << endl; // return 0; // } // int_t ret = poly_divat(&upprod[0], down, upprod.size() - 1, sum, w); i64 ret = 0; { i64 b = sum; i64 n = w; i64 curr = getC(n - 0 + b - 1, b - 1); // C(n-i+b-1,b-1)到C(n-(i+1)+b-1,b-1) // getC(n - i + b - 1, b - 1) #ifdef DEBUG cout << "init curr = " << curr << endl; #endif for (i64 i = 0; i < upprod.size(); i++) { ret = (ret + upprod[i] * curr % mod) % mod; curr = curr * (n % mod - i + mod) % mod * power((n % mod - i + b - 1 + (i64)3 * mod) % mod, -1) % mod; if (curr == 0 || n <= i - n + 1) break; #ifdef DEBUG cout << "curr = " << curr << " at i = " << i << " with n = " << n << endl; #endif } // int_t curr = getC(w - b + b - 1, b - 1); // for (i64 i = std::max<i64>(0, w - sum); i < upprod.size(); i++) { // ret = (ret + curr * upprod[i] % mod) % mod; // curr = curr * (n % mod + n % mod - sum % mod + mod + i) % mod * // power(n + 1, -1) % mod; // } // i64 n = w; // i64 t = sum - 1; // i64 prod = 1; // // i64 curr = (t + n) % mod; // for (int_t i = 0; i < t; i++) // prod = (prod * (t + n % mod - i) % mod); // for (int_t a = 0; a < upprod.size(); a++) { // if (n >= a) { // ret = (ret + prod * factInv[t] % mod * upprod[a] % mod) % // mod; // } // prod = prod * (t - a + n % mod) % mod * // power((n % mod - a + mod) % mod, -1) % mod; // } } cout << ret << endl; return 0; }