题意: 有$n$种球,每种有无限个,同时第$i$种球有一个代价$c_i$,你要拿不超过$w$个球。如果最后第$i$种球你拿了$k_i$个,那么你会获得$\prod_{1\leq i\leq n}{k_i^{c_i}}$的权值,求所有合法方案的权值和。$n\leq 1e5,\sum{c_i}\leq 1e5,w\leq 10^{18}$
$$ \text{考虑对于价值是}c_i\text{的球,构造生成函数} \\ F_{c_i}\left( x \right) =\sum_{n\ge 0}{n^{c_i}x^n} \\ \text{这样}\frac{\prod_i{F_{c_i}\left( x \right)}}{1-x}\text{的}w\text{次项即为答案} \\ \text{设}F_k\left( x \right) =\sum_{n\ge 0}{n^kx^n},\text{显然可得}F_k\left( x \right) =x\frac{\mathrm{d}F_{k-1}\left( x \right)}{\mathrm{d}x} \\ \text{进一步递推可得}, F_k\left( x \right) =\frac{\sum_{0\le i\le k}{T\left( k,i \right) x^i}}{\left( 1-x \right) ^{k+1}},\text{其中}T\left( k,i \right) \text{表示欧拉数} \\ \text{考虑如何快速计算欧拉数} \\ \text{首先由具体数学可得} \\ \sum_{0\le i\le k}{T\left( k,i \right) \times \left( z+1 \right) ^i}=\sum_{0\le i\le k}{S_2\left( k,i \right) \times i!\times z^{k-i}} \\ \text{进一步推导得欧拉数的通项公式}T\left( n,k \right) =\sum_{0\le i\le k}{\begin{array}{c} \binom{n+1}{i}\left( -1 \right) ^i\left( k+1-i \right) ^n\\ \end{array}} \\ \text{构造卷积形式}T\left( n,k \right) =\left( n+1 \right) !\sum_{0\le i\le k}{\frac{\left( -1 \right) ^i}{i!\left( n+1-i \right) !}}\times \left( k+1-i \right) ^n \\ \text{卷积即可求出一行欧拉数。} \\ \text{现在我们可以对每一个}c_i\text{计算出}F_{c_i}\left( x \right) \text{的分子,并且得到他们分子的乘积,设为}F\left( x \right) \\ \text{设他们分母的乘积,再乘个}1-x\text{为}\left( 1-x \right) ^k \\ \text{则答案即为}\frac{F\left( x \right)}{\left( 1-x \right) ^k}\text{的}w\text{次项系数} \\ \text{即}\sum_{0\le i\le n}{\left[ x^i \right] F\left( x \right) \binom{w-i+k-1}{w-i}}\left( \text{广义二项式定理展开分母} \right) \\ \text{令}a_i=\left[ x^i \right] F\left( x \right) \\ \text{即}\sum_{0\le i\le n}{a_i\binom{w-i+k-1}{k-1}} \\ \text{后者可以通过递推快速转移。} \\ \\ \\ \frac{x^a}{\left( 1-x \right) ^b}=x^a\left( \sum_{n\ge 0}{\binom{-b}{n}\left( -1 \right) ^n}x^n \right) \\ =x^a\left( \sum_{n\ge 0}{\frac{\left( -b \right) ^{\underline{n}}}{n!}\left( -1 \right) ^n}x^n \right) \\ =x^a\left( \sum_{n\ge 0}{\frac{\left( -b-0 \right) \left( -b-1 \right) \left( -b-2 \right) ..\left( -b-\left( n-1 \right) \right)}{n!}\left( -1 \right) ^n}x^n \right) \\ =x^a\left( \sum_{n\ge 0}{\frac{\left( b+0 \right) \left( b+1 \right) \left( b+2 \right) ..\left( b+\left( n-1 \right) \right)}{n!}}x^n \right) \\ =x^a\left( \sum_{n\ge 0}{\frac{\left( b+n-1 \right) ^{\underline{n}}}{n!}}x^n \right) \\ =\sum_{n\ge 0}{\frac{\left( b+n-1 \right) ^{\underline{n}}}{n!}x^{n+a}} \\ =\sum_{n\ge 0}{\binom{b+n-1}{n}x^{n+a}} \\ =\sum_{n\ge a}{\binom{b+n-a-1}{n-a}x^n} \\ =\sum_{n\ge a}{\binom{b+n-a-1}{b-1}x^n} \\ =\sum_{n\ge a}{\begin{array}{c} \frac{\left( b-a+n-1 \right) ^{\underline{b-1}}}{\left( b-1 \right) !}x^n\\ \end{array}} \\ \text{令}t=b-1 \\ \text{则} \\ \sum_{n\ge a}{\begin{array}{c} \frac{\left( t-a+n \right) ^{\underline{t}}}{t!}x^n\\ \end{array}} \\ \text{从}a\text{推进到}a+1 \\ \frac{\left( t-a+n \right) ^{\underline{t}}}{\left( t-a+n-1 \right) ^{\underline{t}}}=\frac{t-a+n}{n-a} \\ \\ 10*9*8/\left( 9*8*7 \right) =10/7 \\ \\ \left( \frac{x^a}{\left( 1-x \right) ^b}\text{的}n\text{次项} \right) \\ \\ \binom{n+b-1}{n}=\binom{n+b-1}{b-1} \\ \binom{n+b-1}{n}\rightarrow \binom{n+b}{n+1} \\ \\ \frac{\binom{n+b}{n+1}}{\binom{n+b-1}{n}}=\frac{n+b}{\left( n+1 \right)} \\ \\ \frac{\left( n-2+b-i \right) !\left( b-1 \right) !\left( n-i \right) !}{\left( b-1 \right) !\left( n-i-1 \right) !\left( n-i+b-1 \right) !} \\ \\ \frac{\left( n-i+b-2 \right) !\left( n-i \right)}{\left( n-i+b-1 \right) !} \\ \frac{\left( n-i \right)}{\left( n-i+b-1 \right)} \\ \\ $$
#include <cinttypes> #include <cstring> #include <fstream> #include <iostream> #include <random> using int_t = int; using std::cin; using std::cout; using std::endl; const int_t LARGE = 2e6; using i64 = int64_t; #ifdef NTTCNT std::ofstream nttcnt("nttcnt.txt"); #endif const int_t mod = 998244353; const int_t g = 3; int_t power(int_t b, int_t i) { int_t r = 1; if (i < 0) i = ((i64)i % (mod - 1) + mod - 1) % (mod - 1); while (i) { if (i & 1) r = (i64)r * b % mod; b = (i64)b * b % mod; i >>= 1; } return r; } void makeflip(int_t* arr, int_t size2) { int_t len = (1 << size2); arr[0] = 0; for (int_t i = 1; i < len; i++) { arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1)); } } int_t upper2n(int_t x) { int_t r = 0; while ((1 << r) < x) r++; return r; } template <int_t arg = 1> void transform(int_t* A, int_t size2, int_t* flip) { // #define int_t i64 int_t len = (1 << size2); #ifdef NTTCNT nttcnt << len << endl; #endif for (int_t i = 0; i < len; i++) { int_t r = flip[i]; if (r > i) std::swap(A[i], A[r]); } for (int_t i = 2; i <= len; i *= 2) { int_t mr = power(g, (i64)arg * (mod - 1) / i); for (int_t j = 0; j < len; j += i) { int_t curr = 1; for (int_t k = 0; k < i / 2; k++) { int_t u = A[j + k], v = (i64)curr * A[j + k + i / 2] % mod; A[j + k] = ((i64)u + v) % mod; A[j + k + i / 2] = ((i64)u - v + mod) % mod; curr = (i64)curr * mr % mod; } } } // #undef int_t } void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) { /* 计算n次多项式A与m次多项式B的乘积 */ int_t size2 = upper2n(n + m + 1); int_t len = 1 << size2; static int_t T1[LARGE], T2[LARGE]; for (int_t i = 0; i < len; i++) { if (i <= n) T1[i] = A[i]; else T1[i] = 0; if (i <= m) T2[i] = B[i]; else T2[i] = 0; } static int_t fliparr[LARGE]; makeflip(fliparr, size2); transform(T1, size2, fliparr); transform(T2, size2, fliparr); for (int_t i = 0; i < len; i++) T1[i] = (i64)T1[i] * T2[i] % mod; transform<-1>(T1, size2, fliparr); int_t inv = power(len, -1); for (int_t i = 0; i <= n + m; i++) C[i] = (i64)T1[i] * inv % mod; } const int_t FLARGE = 2e5 + 10; int_t fact[FLARGE + 1], inv[FLARGE + 1], factInv[FLARGE + 1]; void euler_num(int_t n, int_t* out) { static int_t P1[LARGE]; static int_t P2[LARGE]; for (int_t i = 0; i <= n; i++) { P1[i] = (i64)((i % 2) ? (mod - 1) : 1) * factInv[i] % mod * factInv[n + 1 - i] % mod; P2[i] = power(i, n); } poly_mul(P1, n, P2, n, out); for (int_t i = 0; i <= n; i++) out[i] = (i64)out[i] * fact[n + 1] % mod; } using poly_t = std::vector<int_t>; poly_t poly_dcmul(const poly_t* P, int_t left, int_t right) { if (left == right) return P[left]; int_t mid = (left + right) / 2; auto lret = poly_dcmul(P, left, mid); auto rret = poly_dcmul(P, mid + 1, right); poly_t ret; ret.resize(lret.size() - 1 + rret.size() - 1 + 1); poly_mul(&lret[0], lret.size() - 1, &rret[0], rret.size() - 1, &ret[0]); while (!ret.empty() && ret.back() == 0) ret.pop_back(); return ret; } int_t C(int_t n, int_t m) { return (i64)fact[n] * factInv[m] % mod * factInv[n - m] % mod; } i64 getC(i64 n, int_t m) { if (n < 0 || n < m) return 0; i64 prod = 1; for (int_t i = 0; i < m; i++) prod = prod * (n % mod - i + mod) % mod; return prod * factInv[m] % mod; } int main() { std::ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); { fact[0] = fact[1] = inv[1] = factInv[0] = factInv[1] = 1; for (int_t i = 2; i <= FLARGE; i++) { fact[i] = (i64)fact[i - 1] * i % mod; inv[i] = (i64)(mod - mod / i) * inv[mod % i] % mod; factInv[i] = (i64)factInv[i - 1] * inv[i] % mod; } } int_t n; i64 w; cin >> n >> w; static poly_t up[int_t(1e5 + 10)]; static poly_t cache[int_t(1e5 + 10)]; int_t sum = 0; for (int_t i = 1; i <= n; i++) { int_t c; cin >> c; sum += c + 1; poly_t& curr = up[i]; if (!cache[c].empty()) { curr = cache[c]; continue; } curr.resize(2 * c + 3); euler_num(c, &curr[0]); curr.resize(c + 1); cache[c] = curr; #ifdef DEBUG cout << "up " << i << " = "; for (auto t : curr) cout << t << " "; cout << endl; #endif } sum++; // for (int_t i = 0; i <= sum; i++) { // down[i] = (i64)((i % 2) ? (mod - 1) : 1) * C(sum, i) % mod; // } auto upprod = poly_dcmul(up, 1, n); // cout << upprod.size() << endl; upprod.resize(sum + 1); #ifdef DEBUG cout << "up prod = "; for (auto t : upprod) cout << t << " "; cout << endl; cout << "sum = " << sum << endl; #endif // cout << "down size = " << sum << endl; // w -= cutcount; // if (w < 0) { // cout << 0 << endl; // return 0; // } // int_t ret = poly_divat(&upprod[0], down, upprod.size() - 1, sum, w); i64 ret = 0; { i64 b = sum; i64 n = w; i64 curr = getC(n - 0 + b - 1, b - 1); // C(n-i+b-1,b-1)到C(n-(i+1)+b-1,b-1) // getC(n - i + b - 1, b - 1) #ifdef DEBUG cout << "init curr = " << curr << endl; #endif for (i64 i = 0; i < upprod.size(); i++) { ret = (ret + upprod[i] * curr % mod) % mod; curr = curr * (n % mod - i + mod) % mod * power((n % mod - i + b - 1 + (i64)3 * mod) % mod, -1) % mod; if (curr == 0 || n <= i - n + 1) break; #ifdef DEBUG cout << "curr = " << curr << " at i = " << i << " with n = " << n << endl; #endif } // int_t curr = getC(w - b + b - 1, b - 1); // for (i64 i = std::max<i64>(0, w - sum); i < upprod.size(); i++) { // ret = (ret + curr * upprod[i] % mod) % mod; // curr = curr * (n % mod + n % mod - sum % mod + mod + i) % mod * // power(n + 1, -1) % mod; // } // i64 n = w; // i64 t = sum - 1; // i64 prod = 1; // // i64 curr = (t + n) % mod; // for (int_t i = 0; i < t; i++) // prod = (prod * (t + n % mod - i) % mod); // for (int_t a = 0; a < upprod.size(); a++) { // if (n >= a) { // ret = (ret + prod * factInv[t] % mod * upprod[a] % mod) % // mod; // } // prod = prod * (t - a + n % mod) % mod * // power((n % mod - a + mod) % mod, -1) % mod; // } } cout << ret << endl; return 0; }
发表回复