CometOJ 37F 真实无妄她们的人生之路

第一道非板子的多点求值题。

题意:

有$n$个物品,每个物品使用后有$p_i$的概率升一级(初始时等级为0),有$1-p_i$的概率不升级。

等级为$i$时的攻击力为$a_i$。

现在要求对于每一个$i\in [1,n]$,依次求出不使用第$i$件物品,但是使用了剩下的$n-1$件物品时,攻击力的期望。

$n\leq 10^5$,所有运算在模$998244353$意义下进行。

时限$10s$(Comet OJ的评测机非常烂)。

题解

对于第$i$个物品,构造多项式$C_i(x)=1-p_i+p_ix$,显然不使用第$i$个物品,但是使用了剩下所有物品的概率生成函数为$P_i(x)=\frac {\prod_{1\leq j\leq n}C_j(x)}{C_j(x)}$,计算这个多项式对序列${a_i}$的点积即不使用第$i$个物品的答案。

但是很显然这样只能做到$O(n^2+n\log n)$,考虑其他做法。

$$ \text{构造一个多项式函数}g\left( F \right) =\sum_{0\le i\le n-1}{a_i\times \left[ x^i \right] F\left( x \right)} \\ \text{显然则有}g\left( F+G \right) =g\left( F \right) +g\left( G \right) ,\text{其中}F,G\text{均为多项式。} \\ g\left( cF \right) =cg\left( F \right) ,\text{其中}c\text{是常数。} \\ \text{令}F\left( x \right) =\prod_{1\le i\le n}{\left( 1-p_i+p_ix \right)}, \\ \text{我们要求的结果则为所有的}g\left( \frac{F\left( x \right)}{1-p_i+p_ix} \right) ,\text{即}\frac{1}{1-p_i}g\left( \frac{F\left( x \right)}{1+\frac{p_i}{1-p_i}x} \right) \text{。} \\ \text{如果}p_i=\text{1,那么显然可以比较便捷的计算出答案,以下假设}p_i\ne 1 \\ \text{由于}\frac{1}{1-x}=\sum_{i\ge 0}{x^i} \\ \text{故}\frac{1}{x}=\sum_{i\ge 0}{\left( 1-x \right) ^i} \\ \text{故}\frac{1}{1-p_i}g\left( \frac{F\left( x \right)}{1+\frac{p_i}{1-p_i}x} \right) =\frac{1}{1-p_i}g\left( \sum_{j\ge 0}{F\left( x \right) \left( -\frac{p_i}{1-p_i}x \right) ^j} \right) \\ \text{又由于}g\left( A+B \right) =g\left( A \right) +g\left( B \right) \\ \text{故}\frac{1}{1-p_i}\sum_{j\ge 0}{g\left( F\left( x \right) \left( -\frac{p_i}{1-p_i} \right) ^j\times x^j \right)} \\ \text{又由于}g\left( cA \right) =cg\left( A \right) \\ \text{可得}\frac{1}{1-p_i}\sum_{j\ge 0}{\left( -\frac{p_i}{1-p_i} \right) ^jg\left( F\left( x \right) x^j \right)} \\ \text{令多项式}G\left( y \right) =\sum_{j\ge 0}{y^jg\left( F\left( x \right) x^j \right)} \\ \text{则答案可以通过计算出多项式}G\left( y \right) \text{后通过多点求值}G\left( -\frac{p_1}{1-p_1} \right) ,G\left( -\frac{p_2}{1-p_2} \right) …G\left( -\frac{p_n}{1-p_n} \right) \text{得到。} \\ \text{显然在}j\ge n\text{时有}g\left( F\left( x \right) x^j \right) \text{为0,所以}G\left( y \right) \text{的次数为}n \\ \text{考虑}g\left( F\left( x \right) x^j \right) \text{如何计算,}g\left( F\left( x \right) x^j \right) =\sum_{0\le i\le n-1}{a_i\times \left[ x^i \right] F\left( x \right) x^j}=\sum_{0\le i\le n-1}{a_i\times \left[ x^{i-j} \right] F\left( x \right)} \\ \text{构造多项式}F_0\left( x \right) =\sum_{0\le i\le n-1}{a_ix^i},\text{然后计算}F_0\left( x \right) \left( xF^R\left( x \right) \right) ,\text{其}i+n\text{次项前的系数即为}g\left( F\left( x \right) x^i \right) \\ $$

#pragma GCC optimize("O3")
#include <assert.h>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>
// #include
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 5e5;
const int_t mod = 998244353;
const int_t g = 3;

void transformX(int_t* A, int_t len, int_t g, int_t mod);
void transformNTT(int_t* A, int_t size2, int_t arg);
std::vector<int_t> poly_dc_mul(const std::vector<int_t>& A);
void poly_inv(const int_t* A, int_t n, int_t* result);
void poly_div(const int_t* A, int_t n, const int_t* B, int_t m, int_t* R);
std::vector<int_t> poly_eval(const std::vector<int_t>& poly,
                             const std::vector<int_t>& vals);

void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* A0);
int_t power(int_t base, int_t index) {
    int_t result = 1;
    base = (base % mod + mod) % mod;
    index = (index % (mod - 1) + mod - 1) % (mod - 1);
    while (index) {
        if (index & 1) result = (int_t)result * base % mod;
        index >>= 1;
        base = (int_t)base * base % mod;
    }
    assert(result >= 0);
    return result % mod;
}

int flips[20][LARGE];
const auto flip = [=](int x, int size2) {
    int result = 0;
    for (int i = 1; i < size2; i++) {
        result |= (x & 1);
        x >>= 1;
        result <<= 1;
    }
    return result | (x & 1);
};
std::vector<int_t> poly_dc_mul2(const std::vector<int_t>& A) {
    if (A.size() == 1) {
        return std::vector<int_t>{(int_t)((1 - A[0] + mod) % mod), A[0]};
    }
    std::vector<int_t> left, right;
    for (int i = 0; i < A.size(); i++) {
        if (i < A.size() / 2)
            left.push_back(A[i]);
        else
            right.push_back(A[i]);
    }
    auto leftr = poly_dc_mul2(std::move(left)),
         rightr = poly_dc_mul2(std::move(right));
    std::vector<int_t> result(leftr.size() + rightr.size() + 5);
    poly_mul(&leftr.front(), leftr.size() - 1, &rightr.front(),
             rightr.size() - 1, &result.front());
    while (result.empty() == false && result.back() == 0) result.pop_back();
    return result;
}

int main() {
    for (int i = 1; i < 19; i++) {
        for (int j = 1; j < (1 << i); j++) flips[i][j] = flip(j, i);
    }
    int_t n;
    static int_t pros[LARGE + 1], as[LARGE + 1], Fx[LARGE + 1];
    std::vector<int_t> F;
    scanf("%lld", &n);
    for (int i = 0; i < n; i++) {
        scanf("%lld", &as[i]);
    }
    for (int i = 1; i <= n; i++) {
        int_t x, y;
        scanf("%lld%lld", &x, &y);
        pros[i] = ((int_t)x * ((int_t)power(y, -1) % mod)) % mod;
        assert(pros[i] < mod);
        F.push_back(pros[i]);
    }
    F = poly_dc_mul2(F);
    int_t oneans = 0;
    for (int i = 1; i <= n; i++) {
        oneans = (oneans + (int_t)F[i] * as[i - 1] % mod) % mod;
    }
    F.push_back(0);
    std::reverse(F.begin(), F.end());
    poly_mul(as, F.size() - 1, &F.front(), F.size() - 1, Fx);
    std::copy(Fx + F.size() - 1, Fx + F.size() - 1 + n, Fx);
    std::vector<int_t> poly(Fx, Fx + n);
    std::vector<int_t> point_ts;
    for (int i = 1; i <= n; i++) {
        point_ts.push_back(
            (mod - (int_t)pros[i] * (int_t)power(1 - pros[i] + mod, -1) % mod) %
            mod);
    }
    auto res = poly_eval(poly, point_ts);
    for (int i = 1; i <= n; i++) {
        if (pros[i] == 1) {
            printf("%lld ", oneans);
        } else {
            printf("%lld ",
                   (int_t)res[i - 1] * power(1 - pros[i] + mod, -1) % mod);
        }
    }
    return 0;
}

void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* A0) {
    static int_t Ax[LARGE + 1], Bx[LARGE + 1];
    int size2 = 0;
    while ((1 << size2) < n + m + 1) size2++;
    int len = 1 << size2;
    for (int i = 0; i < len; i++) {
        if (i <= n)
            Ax[i] = A[i];
        else
            Ax[i] = 0;
        if (i <= m)
            Bx[i] = B[i];
        else
            Bx[i] = 0;
    }
    transformNTT(Ax, size2, 1);
    transformNTT(Bx, size2, 1);
    for (int i = 0; i < len; i++) Ax[i] = (int_t)Ax[i] * Bx[i] % mod;
    transformNTT(Ax, size2, -1);
    const int_t inv = power(len, mod - 2);
    for (int i = 0; i <= n + m; i++) A0[i] = (int_t)Ax[i] * inv % mod;
}
void transformNTT(int_t* A, int_t size2, int_t arg) {
    for (int i = 0; i < (1 << size2); i++) {
        int x = flips[size2][i];
        if (x > i) std::swap(A[i], A[x]);
    }
    for (int i = 2; i <= (1 << size2); i *= 2) {
        int_t mr = power(g, (mod - 1) + (mod - 1) / i * arg);
        for (int j = 0; j < (1 << size2); j += i) {
            int_t curr = 1;
            for (int k = 0; k < i / 2; k++) {
                int_t u = A[j + k] % mod,
                      t = (int_t)A[j + k + i / 2] * curr % mod;
                A[j + k] = (u + t) % mod;
                A[j + k + i / 2] = (int_t)((int_t)u - t + mod) % mod;
                curr = (int_t)curr * mr % mod;
            }
        }
    }
}
void poly_inv(const int_t* A, int_t n, int_t* result) {
    if (n == 1) {
        result[0] = power(A[0], mod - 2);
        return;
    }
    poly_inv(A, n / 2 + n % 2, result);
    static int_t temp[LARGE + 1];
    std::fill(temp, temp + 2 * n, 0);
    int prev = n / 2 + n % 2;
    poly_mul(result, prev - 1, result, prev - 1, temp);
    poly_mul(A, n - 1, temp, n - 1, temp);
    for (int i = 0; i < n; i++)
        result[i] = (2 * result[i] % mod - temp[i] + 2 * mod) % mod;
}
//计算n次多项式A整除B次多项式m的余数
void poly_div(const int_t* A, int_t n, const int_t* B, int_t m, int_t* R) {
    while (n >= 0 && A[n] == 0) n--;
    while (m >= 0 && B[m] == 0) m--;
    if (n < m) {
        std::copy(A, A + n + 1, R);
        return;
    }
    static int_t AR[LARGE + 1], BR[LARGE + 1];
    for (int i = 0; i <= n; i++) AR[i] = A[n - i];
    for (int i = 0; i <= m; i++) BR[i] = B[m - i];
    for (int i = m + 1; i <= n - m; i++) BR[i] = 0;
    static int_t inv[LARGE + 1], Q[LARGE + 1], C[LARGE + 1];
    std::fill(inv, inv + n - m + 1, 0);
    poly_inv(BR, n - m + 1, inv);
    poly_mul(AR, n - m, inv, n - m, Q);
    std::reverse(Q, Q + n - m + 1);
    poly_mul(Q, n - m, B, m, C);
    for (int_t i = 0; i < m; i++) R[i] = (A[i] - C[i] + mod) % mod;
}
std::vector<int_t> poly_dc_mul(const std::vector<int_t>& A) {
    if (A.size() == 1) {
        return std::vector<int_t>{(int_t)mod - A[0], 1};
    }
    std::vector<int_t> left, right;
    for (int i = 0; i < A.size(); i++) {
        if (i < A.size() / 2)
            left.push_back(A[i]);
        else
            right.push_back(A[i]);
    }
    auto leftr = poly_dc_mul(std::move(left)),
         rightr = poly_dc_mul(std::move(right));
    std::vector<int_t> result(leftr.size() + rightr.size() + 5);
    poly_mul(&leftr.front(), leftr.size() - 1, &rightr.front(),
             rightr.size() - 1, &result.front());
    while (result.empty() == false && result.back() == 0) result.pop_back();
    return result;
}

// vals是求值点
std::vector<int_t> poly_eval(const std::vector<int_t>& poly,
                             const std::vector<int_t>& vals) {
    if (vals.size() <= 300) {
        std::vector<int_t> result;
        for (auto x : vals) {
            int_t curr = 0, pow = 1;
            for (auto y : poly) {
                curr = ((int_t)curr + (int_t)y * pow % mod) % mod;
                pow = (int_t)pow * x % mod;
            }
            result.push_back(curr);
        }
        return result;
    }
    std::vector<int_t> left, right;
    for (int i = 0; i < vals.size(); i++)
        if (i < vals.size() / 2)
            left.push_back(vals[i]);
        else
            right.push_back(vals[i]);
    auto leftres = poly_dc_mul(left);
    auto rightres = poly_dc_mul(right);
    std::vector<int_t> pleft(leftres.size() - 1), pright(rightres.size() - 1);
    poly_div(&poly.front(), poly.size() - 1, &leftres.front(),
             leftres.size() - 1, &pleft.front());
    poly_div(&poly.front(), poly.size() - 1, &rightres.front(),
             rightres.size() - 1, &pright.front());
    auto leftx = poly_eval(pleft, left);
    auto rightx = poly_eval(pright, right);
    // assert(leftres.size())
    for (auto x : rightx) leftx.push_back(x);
    return leftx;
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理