洛谷4705 玩游戏

$$ \text{原式}=\frac{1}{nm}\sum_{1\le k\le t}{x^k\sum_{1\le i\le n}{\sum_{1\le j\le m}{\left( a_i+b_j \right) ^k}}} \\ =\frac{1}{nm}\sum_{1\le k\le t}{x^k\sum_{1\le i\le n}{\sum_{1\le j\le m}{\sum_{0\le y\le k}{\left( \begin{array}{c} k\\ y\\ \end{array} \right) a_{i}^{y}b_{j}^{k-y}}}}} \\ =\frac{1}{nm}\sum_{1\le k\le t}{x^kk!\sum_{0\le y\le k}{\frac{\sum_{1\le i\le n}{a_{i}^{y}}}{y!}\times \frac{\sum_{1\le j\le m}{b_{j}^{k-y}}}{\left( k-y \right) !}}} \\ \\ \text{考虑如何计算 }\sum_{1\le i\le n}{a_{i}^{y}} \\ \text{令}F\left( x \right) =\prod_{1\le i\le n}{\left( 1+a_ix \right)},\text{显然}F\left( x \right) \text{可以在}O\left( n\log ^2n \right) \text{的时间内计算出。} \\ \ln F\left( x \right) =\sum_{1\le i\le n}{\ln \left( 1+a_ix \right)},\text{把}\ln \left( x \right) \text{在}x=\text{1处展开可得} \\ \ln F\left( x \right) =\sum_{1\le i\le n}{\sum_{j\ge 1}{\frac{\left( -1 \right) ^{j-1}}{j}\left( a_ix \right) ^j}} \\ =\sum_{j\ge 1}{x^j\frac{\left( -1 \right) ^{j-1}}{j}\sum_{1\le i\le n}{a_i^j}} \\ \text{即}i\text{次项前的系数除以}\frac{\left( -1 \right) ^{i-1}}{i}\text{即}i\text{次幂的和}. $$

#include <cstring>
#include <iostream>
#include <vector>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;

const int_t mod = 998244353;
const int_t g = 3;
const int_t LARGE = (1 << 20);

void transform(int_t* A, int_t size2, int_t arg);
void poly_inv(const int_t* A, int_t n, int_t* result);
std::vector<int> process(const std::vector<int>& vec);
void poly_log(const int_t* A, int_t n, int_t* result);
int_t power(int_t base, int_t index) {
    const auto phi = mod - 1;
    if (index < 0) index = (index % phi + phi) % phi;
    int_t result = 1;
    while (index) {
        if (index & 1) result = result * base % mod;
        index >>= 1;
        base = base * base % mod;
    }
    return result;
}
std::vector<int> process(const std::vector<int>& vec) {
    if (vec.size() == 1) {
        return std::vector<int>{1, vec[0]};
    }
    std::vector<int> left, right;
    for (int i = 0; i < vec.size(); i++) {
        if (i < vec.size() / 2)
            left.push_back(vec[i]);
        else
            right.push_back(vec[i]);
    }
    auto leftRet = process(std::move(left));
    auto rightRet = process(std::move(right));
    static int_t A[LARGE], B[LARGE];
    int size2 = 0;
    while ((1 << size2) < leftRet.size() + rightRet.size()) size2++;
    std::copy(leftRet.begin(), leftRet.end(), A);
    std::copy(rightRet.begin(), rightRet.end(), B);
    transform(A, size2, 1);
    transform(B, size2, 1);
    for (int i = 0; i < (1 << size2); i++) A[i] = A[i] * B[i] % mod;
    transform(A, size2, -1);
    const auto inv = power(1 << size2, -1);
    std::vector<int> result;
    for (int i = 0; i < leftRet.size() + rightRet.size() - 1; i++)
        result.push_back(A[i] * inv % mod);
    std::fill(A, A + (1 << size2), 0);
    std::fill(B, B + (1 << size2), 0);
    return result;
}
int main() {
    std::vector<int> polyA;
    std::vector<int> polyB;
    static int_t fact[LARGE + 1], factInv[LARGE + 1];
    int n, m, t;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        polyA.push_back(0);
        scanf("%d", &polyA.back());
    }
    for (int i = 1; i <= m; i++) {
        polyB.push_back(0);
        scanf("%d", &polyB.back());
    }
    scanf("%d", &t);
    fact[0] = factInv[0] = fact[1] = factInv[1] = 1;
    for (int_t i = 2; i <= LARGE; i++) {
        fact[i] = fact[i - 1] * i % mod;
        factInv[i] = (mod - mod / i) * factInv[mod % i] % mod;
    }
    for (int_t i = 2; i <= LARGE; i++)
        factInv[i] = factInv[i - 1] * factInv[i] % mod;
    const auto retA = process(polyA);
    const auto retB = process(polyB);
    static int_t A[LARGE], B[LARGE], Ax[LARGE], Bx[LARGE];
    std::copy(retA.begin(), retA.end(), A);
    std::copy(retB.begin(), retB.end(), B);
    poly_log(A, t + 1, Ax);
    poly_log(B, t + 1, Bx);
    memset(A, 0, sizeof(A));
    memset(B, 0, sizeof(B));
    A[0] = n;
    B[0] = m;
    for (int i = 1; i <= t; i++) {
        A[i] =
            Ax[i] * i % mod * ((i % 2) ? 1 : mod - 1) % mod * factInv[i] % mod;
        B[i] =
            Bx[i] * i % mod * ((i % 2) ? 1 : mod - 1) % mod * factInv[i] % mod;
    }
    int size2 = 0, size = 1;
    while ((1 << size2) < 2 * t) size2++;
    size = (1 << size2);
    transform(A, size2, 1);
    transform(B, size2, 1);
    for (int i = 0; i < size; i++) A[i] = A[i] * B[i] % mod;
    transform(A, size2, -1);
    const auto inv = power((int_t)size * n % mod * m % mod, -1);
    for (int i = 1; i <= t; i++) {
        printf("%lld\n", A[i] * inv % mod * fact[i] % mod);
    }

    return 0;
}

void poly_log(const int_t* A, int_t n, int_t* result) {
    int_t size = 1, size2 = 0;
    while ((1 << size2) < n * 2) size2++;
    size = (1 << size2);
    static int_t under[LARGE + 1], Ax[LARGE + 1];
    for (int_t i = 0; i < size; i++) {
        if (i < n)
            Ax[i] = A[i + 1] * (i + 1) % mod;
        else
            under[i] = Ax[i] = 0;
        under[i] = 0;
    }
    poly_inv(A, n, under);
    transform(under, size2, 1);
    transform(Ax, size2, 1);
    for (int i = 0; i < size; i++) Ax[i] = Ax[i] * under[i] % mod;
    transform(Ax, size2, -1);
    const int_t inv = power(size, -1);
    for (int i = 1; i < n; i++)
        result[i] = Ax[i - 1] * power(i, -1) % mod * inv % mod;
    result[0] = 0;
}
// C(x)=2B(x)-B(x)^2A(x)
void poly_inv(const int_t* A, int_t n, int_t* result) {
    if (n == 1) {
        result[0] = power(A[0], -1);
        return;
    }
    poly_inv(A, n / 2 + n % 2, result);
    static int_t Ax[LARGE + 1];
    int_t size2 = 0;
    while ((1 << size2) < 3 * n) size2++;
    for (int_t i = 0; i < (1 << size2); i++) {
        if (i < n)
            Ax[i] = A[i];
        else
            Ax[i] = 0;
    }

    transform(Ax, size2, 1);
    transform(result, size2, 1);
    for (int_t i = 0; i < (1 << size2); i++)
        result[i] = (2 * result[i] % mod -
                     result[i] * result[i] % mod * Ax[i] % mod + mod) %
                    mod;
    transform(result, size2, -1);
    const int_t inv = power(1 << size2, -1);
    for (int_t i = 0; i < (1 << size2); i++)
        if (i < n)
            result[i] = result[i] * inv % mod;
        else
            result[i] = 0;
}

void transform(int_t* A, int_t size2, int_t arg) {
    const auto bitRev = [&](int_t x) {
        int_t result = 0;
        for (int_t i = 1; i < size2; i++) {
            result |= (x & 1);
            result <<= 1;
            x >>= 1;
        }
        return result | (x & 1);
    };
    for (int_t i = 0; i < (1 << size2); i++) {
        int_t x = bitRev(i);
        if (x > i) std::swap(A[i], A[x]);
    }
    for (int_t i = 2; i <= (1 << size2); i *= 2) {
        int_t mr = power(g, (mod - 1) / i * arg);
        for (int_t j = 0; j < (1 << size2); j += i) {
            int_t curr = 1;
            for (int_t k = 0; k < i / 2; k++) {
                int_t u = A[j + k], t = A[j + k + i / 2] * curr % mod;
                A[j + k] = (u + t) % mod;
                A[j + k + i / 2] = (u - t + mod) % mod;
                curr = curr * mr % mod;
            }
        }
    }
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理