CFGYM 102823 CCPC2018 桂林 B题 Array Modify

题倒是还行,达标找规律找出来结论,然后就想到了一个多项式快速幂的做法。

一开始写的两个log的快速幂,TLE。

然后尝试改成exp+log,跑得看起来快了,但是遇到多组数据(n的大小递增的时候)就挂掉了?

仔细调查后发现源于自己一直以来exp函数内一直存在的错误:

exp内调用了log,log内调用了inv,exp内给G0(用以存储log返回值的数组)预留了upper2n(2*n)的空间,但是inv内却使用了upper2n(3*n+1)的空间!

(调试方法:数组切换为动态分配并开启内存检查)

对于只计算exp一次而言,这个错误不会造成影响,但是计算多次的时候由于空间污染会导致错误结果。

解决方法:在exp内也预留upper2n(3*n)的空间,通过本题。

另外小测试了一下,两个log的多项式快速幂跑本题极限数据,NTT变换的多项式总长度为“27262976”,而exp+log实现的快速幂,NTT变换的多项式总长度为 19922408,我原本以为因为常数原因,exp+log会跑得更慢,没想到事实并非如此。

 

// luogu-judger-enable-o2
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;
const int mod = 998244353;
const int g = 3;
const int LARGE = 1 << 20;
int revs[21][LARGE + 1];

int power(int base, int index);
void transform(int* A, int len, int arg);
void poly_inv(const int* A, int n, int* result);
void poly_log(const int* A, int n, int* result);
void poly_exp(const int* A, int n, int* result);

std::vector transform_count;
#ifdef TDEBUG
#define TRANSFORM_DEBUG(n) \
    { transform_count.push_back(n); }
#else
#define TRANSFORM_DEBUG(n)
#endif
int bitRev(int base, int size2) {
    int result = 0;
    for (int i = 1; i < size2; i++) {
        result |= (base & 1);
        base >>= 1;
        result <<= 1;
    }
    result |= (base & 1);
    return result;
}
int upper2n(int x) {
    int result = 1;
    while (result < x)
        result *= 2;
    return result;
}
int main() {
    std::ios::sync_with_stdio(false);

    for (int i = 0; (1 << i) <= LARGE; i++) {
        for (int j = 0; j < LARGE; j++) {
            revs[i][j] = bitRev(j, i);
        }
    }
    static int F[LARGE + 1], G[LARGE + 1];
    static int arr[LARGE + 1];
    int T;
    cin >> T;
    for (int_t _i = 1; _i <= T; _i++) {
        memset(F, 0, sizeof(F));
        memset(G, 0, sizeof(G));
        memset(arr, 0, sizeof(arr));
        int_t n, L, m;
        cin >> n >> L >> m;

        for (int_t i = 1; i <= n; i++)
            cin >> arr[i];
        int size2 = upper2n(4 * (n + 1));
        for (int i = 0; i < size2; i++) {
            if (i < L)
                F[i] = 1;
            else
                F[i] = 0;
            G[i] = 0;
        }
#ifdef DEBUG
        cout << "init = " << endl;
        cout << "F = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << F[i] << " ";
        cout << endl;

        cout << "G = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << G[i] << " ";
        cout << endl;
#endif

        poly_log(F, n, G);
#ifdef DEBUG
        cout << "after log = " << endl;
        cout << "G = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << G[i] << " ";
        cout << endl;
#endif
        for (int i = 0; i < n; i++)
            G[i] = G[i] * m % mod;
        for (int i = 0; i < size2; i++)
            F[i] = 0;
#ifdef DEBUG
        cout << "before exp = " << endl;
        cout << "G = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << G[i] << " ";
        cout << endl;
#endif
        poly_exp(G, n, F);
#ifdef DEBUG
        cout << "after exp = " << endl;
        cout << "F = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << F[i] << " ";
        cout << endl;
#endif

        for (int i = 0; i < size2; i++) {
            if (i >= n)
                F[i] = 0;
            if (i >= n && i < 2 * n)
                G[i] = arr[i - n + 1];
            else
                G[i] = 0;
        }
        std::reverse(G + n, G + 2 * n);

#ifdef DEBUG
        cout << "F = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << F[i] << " ";
        cout << endl;

        cout << "G = ";
        for (int_t i = 0; i < 2 * n; i++)
            cout << G[i] << " ";
        cout << endl;
#endif

        transform(F, size2, 1);
        transform(G, size2, 1);
        for (int i = 0; i < size2; i++)
            F[i] = (int_t)F[i] * G[i] % mod;
        transform(F, size2, -1);
        int inv = power(size2, -1);
        std::reverse(F + n, F + 2 * n);
        cout << "Case " << _i << ": ";

        for (int i = n; i < 2 * n; i++) {
            cout << (int_t)F[i] * inv % mod << " ";
        }
        cout << endl;
#ifdef DEBUG
        for (int i = 0; i < n; i++) {
            cout << "F " << i << " = " << F[i] << endl;
        }
#endif
#ifdef TDEBUG
        for (auto x : transform_count) {
            cout << x << " ";
        }
        cout << endl;
#endif
    }
    return 0;
}

int power(int base, int index) {
    const int phi = mod - 1;
    index = (index % phi + phi) % phi;
    base = (base % mod + mod) % mod;
    int result = 1;
    while (index) {
        if (index & 1)
            result = (int_t)result * base % mod;
        base = (int_t)base * base % mod;
        index >>= 1;
    }
    return result;
}
void transform(int* A, int len, int arg) {
    TRANSFORM_DEBUG(len);
    int size2 = int(log2(len) + 0.5);
    for (int i = 0; i < len; i++) {
        int x = revs[size2][i];
        if (x > i)
            std::swap(A[i], A[x]);
    }
    for (int i = 2; i <= len; i *= 2) {
        int mr = power(g, (int_t)arg * (mod - 1) / i);
        for (int j = 0; j < len; j += i) {
            int curr = 1;
            for (int k = 0; k < i / 2; k++) {
                int u = A[j + k];
                int t = (int_t)A[j + k + i / 2] * curr % mod;
                A[j + k] = (u + t) % mod;
                A[j + k + i / 2] = (u - t + mod) % mod;
                curr = (int_t)curr * mr % mod;
            }
        }
    }
}
//计算多项式A在模x^n下的逆元
// C(x)<-2B(x)-A(x)B^2(x)
void poly_inv(const int* A, int n, int* result) {
    if (n == 1) {
        result[0] = power(A[0], -1);
        return;
    }
    poly_inv(A, n / 2 + n % 2, result);
    static int temp[LARGE + 1];
    int size2 = upper2n(3 * n + 1);
    for (int i = 0; i < size2; i++) {
        if (i < n)
            temp[i] = A[i];
        else
            temp[i] = 0;
    }
    transform(temp, size2, 1);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++) {
        result[i] =
            ((int_t)2 * result[i] % mod -
             (int_t)temp[i] * result[i] % mod * result[i] % mod + 2 * mod) %
            mod;
    }
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++) {
        if (i < n) {
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        } else {
            result[i] = 0;
        }
    }
}
void poly_log(const int* A, int n, int* result) {
    int size2 = upper2n(3 * n + 1);
    static int Ad[LARGE + 1];
    for (int i = 0; i < size2; i++) {
        if (i < n) {
            Ad[i] = (int_t)(i + 1) * A[i + 1] % mod;
        } else {
            Ad[i] = 0;
        }
        result[i] = 0;
    }
    transform(Ad, size2, 1);
    poly_inv(A, n, result);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++) {
        result[i] = (int_t)result[i] * Ad[i] % mod;
    }
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++)
        if (i < n)
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        else
            result[i] = 0;
    for (int i = n - 1; i >= 1; i--) {
        result[i] = (int_t)result[i - 1] * power(i, -1) % mod;
    }
    result[0] = 0;
}
void poly_exp(const int* A, int n, int* result) {
    if (n == 1) {
        assert(A[0] == 0);
        result[0] = 1;
        return;
    }
    poly_exp(A, n / 2 + n % 2, result);
    static int Ax[LARGE + 1];
    // memset(G0, 0, sizeof(G0));
    // memset(Ax,0,sizeof(Ax));
    int size2 = upper2n(3 * n + 1);
    // int* G0 = new int[size2];
    static int G0[LARGE + 1];
    for (int_t i = 0; i < size2; i++) {
        if (i < n)
            Ax[i] = A[i];
        else
            Ax[i] = 0;
        G0[i] = 0;
    }
    poly_log(result, n, G0);

    transform(Ax, size2, 1);
    transform(G0, size2, 1);
    transform(result, size2, 1);
    for (int i = 0; i < size2; i++)
        result[i] =
            (result[i] - (int_t)result[i] * (G0[i] - Ax[i] + mod) % mod + mod) %
            mod;
    transform(result, size2, -1);
    for (int i = 0; i < size2; i++)
        if (i < n)
            result[i] = (int_t)result[i] * power(size2, -1) % mod;
        else
            result[i] = 0;
    // delete[] G0;
}

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理