LOJ150 YT2sOJ21 挑战多项式

毕竟是板子集合,不管常数了。

#include <assert.h>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <ctime>
#include <iostream>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;
const int mod = 998244353;
const int g = 3;
const int LARGE = 1 << 19;
int revs[20][LARGE + 1];

int power(int base, int index);
void transform(int* A, int len, int arg);
void poly_inv(const int* A, int n, int* result);
void poly_sqrt(const int* A, int n, int* result);
void poly_log(const int* A, int n, int* result);
void poly_exp(const int* A, int n, int* result);
void poly_power(const int* A, int n, int index, int* result);
int modSqrt(int a);
int bitRev(int base, int size2) {
    int result = 0;
    for (int i = 1; i < size2; i++) {
        result |= (base & 1);
        base >>= 1;
        result <<= 1;
    }
    result |= (base & 1);
    return result;
}
int upper2n(int x) {
    int result = 1;
    while (result < x) result *= 2;
    return result;
}
int main() {
#ifdef TIME
    auto begin = clock();
#endif
    for (int i = 0; (1 << i) <= LARGE; i++) {
        for (int j = 0; j < LARGE; j++) {
            revs[i][j] = bitRev(j, i);
        }
    }
    static int F[LARGE + 1], A[LARGE + 1], B[LARGE + 1];
    int n, k;

    scanf("%d%d", &n, &k);
    for (int i = 0; i <= n; i++) {
        scanf("%d", &A[i]);
        F[i] = A[i];
    }
    poly_sqrt(A, n + 1, B);

#ifdef DEBUG
    cout << "sqrt ";
    for (int_t i = 0; i <= n; i++) cout << B[i] << " ";
    cout << endl;
#endif
    memset(A, 0, sizeof(A));
    poly_inv(B, n + 1, A);
#ifdef DEBUG
    cout << "sqrt inv ";
    for (int_t i = 0; i <= n; i++) cout << A[i] << " ";
    cout << endl;
#endif
    for (int i = n; i >= 1; i--) B[i] = (int_t)A[i - 1] * power(i, -1) % mod;
    B[0] = 0;
#ifdef DEBUG
    cout << "sqrt inv integrate ";
    for (int_t i = 0; i <= n; i++) cout << B[i] << " ";
    cout << endl;
#endif
    memset(A, 0, sizeof(A));
    poly_exp(B, n + 1, A);
#ifdef DEBUG
    cout << "sqrt inv integrate exp ";
    for (int_t i = 0; i <= n; i++) cout << A[i] << " ";
    cout << endl;
#endif
    for (int i = 0; i < n + 1; i++) {
        A[i] = (F[i] - A[i] + mod) % mod;
    }
    A[0] = (A[0] + 2 - F[0] + mod) % mod;
#ifdef DEBUG
    cout << "sqrt inv integrate exp process ";
    for (int_t i = 0; i <= n; i++) cout << A[i] << " ";
    cout << endl;
#endif
    // memset(B, 0, sizeof(B));
    poly_log(A, n + 1, B);

    B[0] = 1;
#ifdef DEBUG
    cout << "sqrt inv integrate exp process log +1 ";
    for (int_t i = 0; i <= n; i++) cout << B[i] << " ";
    cout << endl;
#endif
    // memset(A, 0, sizeof());
    poly_power(B, n + 1, k, A);
#ifdef DEBUG
    cout << "sqrt inv integrate exp process power ";
    for (int_t i = 0; i <= n; i++) cout << A[i] << " ";
    cout << endl;
#endif
    for (int i = 0; i < n; i++) {
        A[i] = (int_t)A[i + 1] * (i + 1) % mod;
        printf("%d ", A[i]);
    }
#ifdef TIME
    auto end = clock();
    cout << 1.0 * (end - begin) / CLOCKS_PER_SEC << endl;
#endif
    return 0;
}
//计算A(x)^index mod x^n
void poly_power(const int* A, int n, int index, int* result) {
    int size2 = upper2n(2 * n);
    static int base[LARGE + 1];
    for (int i = 0; i < size2; i++) {
        if (i < n)
            base[i] = A[i];
        else
            base[i] = 0;
        result[i] = 0;
    }
    result[0] = 1;
    while (index) {
        transform(base, size2, 1);
        if (index & 1) {
            transform(result, size2, 1);
            for (int i = 0; i < size2; i++) {
                result[i] = (int_t)result[i] * base[i] % mod;
            }
            transform(result, size2, -1);
            for (int i = 0; i < size2; i++) {
                if (i < n)
                    result[i] = (int_t)result[i] * power(size2, -1) % mod;
                else
                    result[i] = 0;
            }
        }
        for (int i = 0; i < size2; i++)
            base[i] = (int_t)base[i] * base[i] % mod;
        transform(base, size2, -1);
        for (int i = 0; i < size2; i++) {
            if (i < n)
                base[i] = (int_t)base[i] * power(size2, -1) % mod;
            else
                base[i] = 0;
        }
        index >>= 1;
    }
}
int power(int base, int index) {
    const int phi = mod - 1;
    index = (index % phi + phi) % phi;
    base = (base % mod + mod) % mod;
    int result = 1;
    while (index) {
        if (index & 1) result = (int_t)result * base % mod;
        base = (int_t)base * base % mod;
        index >>= 1;
    }
    return result;
}
void transform(int* A, int len, int arg) {
    int size2 = log2(len);
    for (int i = 0; i < len; i++) {
        int x = revs[size2][i];
        if (x > i) std::swap(A[i], A[x]);
    }
    for (int i = 2; i <= len; i *= 2) {
        int mr = power(g, (int_t)arg * (mod - 1) / i);
        for (int j = 0; j < len; j += i) {
            int curr = 1;
            for (int k = 0; k < i / 2; k++) {
                int u = A[j + k];
                int t = (int_t)A[j + k + i / 2] * curr % mod;
                A[j + k] = (u + t) % mod;
                A[j + k + i / 2] = (u - t + mod) % mod;
                curr = (int_t)curr * mr % mod;
            }
        }
    }
}
//计算多项式A在模x^n下的逆元
// C(x)<-2B(x)-A(x)B^2(x)
void poly_inv(const int* A, int n, int* result) {
    if (n == 1) {
        result[0] = power(A[0], -1);
        return;
    }
    poly_inv(A, n / 2 + n % 2, result);
    static int temp[LARGE + 1];
    int size2 = upper2n(3 * n + 1);
    for (int i = 0; i < size2; i++) {
        if (i < n)
            temp[i] = A[i];
        else
            temp[i] = 0;
    }
    transform(temp, size2, 1);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++) {
        result[i] =
            ((int_t)2 * result[i] % mod -
             (int_t)temp[i] * result[i] % mod * result[i] % mod + 2 * mod) %
            mod;
    }
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++) {
        if (i < n) {
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        } else {
            result[i] = 0;
        }
    }
}
int modSqrt(int a) {
    const int_t b = 3;
    // mod-1=2^s*t
    int s = 23, t = 119;
    int x = power(a, (t + 1) / 2);
    int e = power(a, t);
    int k = 1;
    while (k < s) {
        if (power(e, 1 << (s - k - 1)) != 1) {
            x = (int_t)x * power(b, (1 << (k - 1)) * t) % mod;
        }
        e = (int_t)power(a, -1) * x % mod * x % mod;
        k++;
    }
    return x;
}
//计算多项式开根
void poly_sqrt(const int* A, int n, int* result) {
    if (n == 1) {
        int p = modSqrt(A[0]);
        result[0] = std::min(p, mod - p);
        return;
    }
    poly_sqrt(A, n / 2 + n % 2, result);
    int size2 = upper2n(3 * n);
    static int Ax[LARGE + 1], pInv[LARGE];
    for (int i = 0; i < size2; i++) {
        if (i < n) {
            Ax[i] = A[i];
        } else {
            Ax[i] = 0;
        }
        pInv[i] = 0;
    }
    poly_inv(result, n, pInv);
    transform(Ax, size2, 1);
    transform(result, size2, 1);
    transform(pInv, size2, 1);
    const int inv2 = power(2, -1);
    for (int i = 0; i < size2; i++) {
        result[i] = (result[i] -
                     ((int_t)result[i] * result[i] % mod - Ax[i] + mod) % mod *
                         inv2 % mod * pInv[i] % mod +
                     mod) %
                    mod;
    }
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++) {
        if (i < n)
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        else
            result[i] = 0;
    }
}
void poly_log(const int* A, int n, int* result) {
    int size2 = upper2n(2 * n);
    static int Ad[LARGE + 1];
    for (int i = 0; i < size2; i++) {
        if (i < n) {
            Ad[i] = (int_t)(i + 1) * A[i + 1] % mod;

        } else {
            Ad[i] = 0;
        }
        result[i] = 0;
    }
    transform(Ad, size2, 1);
    poly_inv(A, n, result);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++) {
        result[i] = (int_t)result[i] * Ad[i] % mod;
    }
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++)
        if (i < n)
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        else
            result[i] = 0;
    for (int i = n - 1; i >= 1; i--) {
        result[i] = (int_t)result[i - 1] * power(i, -1) % mod;
    }
    result[0] = 0;
}
void poly_exp(const int* A, int n, int* result) {
    if (n == 1) {
        assert(A[0] == 0);
        result[0] = 1;
        return;
    }
    poly_exp(A, n / 2 + n % 2, result);
    static int G0[LARGE + 1], Ax[LARGE + 1];
    int size2 = upper2n(2 * n);
    poly_log(result, n, G0);
    for (int_t i = 0; i < size2; i++) {
        if (i < n)
            Ax[i] = A[i];
        else
            Ax[i] = 0;
    }
    transform(Ax, size2, 1);
    transform(G0, size2, 1);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++)
        result[i] =
            (result[i] - (int_t)result[i] * (G0[i] - Ax[i] + mod) % mod + mod) %
            mod;
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++)
        if (i < n)
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        else
            result[i] = 0;
}

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理