毕竟是板子集合,不管常数了。
#include <assert.h> #include <algorithm> #include <cmath> #include <cstring> #include <ctime> #include <iostream> using int_t = long long int; using std::cin; using std::cout; using std::endl; const int mod = 998244353; const int g = 3; const int LARGE = 1 << 19; int revs[20][LARGE + 1]; int power(int base, int index); void transform(int* A, int len, int arg); void poly_inv(const int* A, int n, int* result); void poly_sqrt(const int* A, int n, int* result); void poly_log(const int* A, int n, int* result); void poly_exp(const int* A, int n, int* result); void poly_power(const int* A, int n, int index, int* result); int modSqrt(int a); int bitRev(int base, int size2) { int result = 0; for (int i = 1; i < size2; i++) { result |= (base & 1); base >>= 1; result <<= 1; } result |= (base & 1); return result; } int upper2n(int x) { int result = 1; while (result < x) result *= 2; return result; } int main() { #ifdef TIME auto begin = clock(); #endif for (int i = 0; (1 << i) <= LARGE; i++) { for (int j = 0; j < LARGE; j++) { revs[i][j] = bitRev(j, i); } } static int F[LARGE + 1], A[LARGE + 1], B[LARGE + 1]; int n, k; scanf("%d%d", &n, &k); for (int i = 0; i <= n; i++) { scanf("%d", &A[i]); F[i] = A[i]; } poly_sqrt(A, n + 1, B); #ifdef DEBUG cout << "sqrt "; for (int_t i = 0; i <= n; i++) cout << B[i] << " "; cout << endl; #endif memset(A, 0, sizeof(A)); poly_inv(B, n + 1, A); #ifdef DEBUG cout << "sqrt inv "; for (int_t i = 0; i <= n; i++) cout << A[i] << " "; cout << endl; #endif for (int i = n; i >= 1; i--) B[i] = (int_t)A[i - 1] * power(i, -1) % mod; B[0] = 0; #ifdef DEBUG cout << "sqrt inv integrate "; for (int_t i = 0; i <= n; i++) cout << B[i] << " "; cout << endl; #endif memset(A, 0, sizeof(A)); poly_exp(B, n + 1, A); #ifdef DEBUG cout << "sqrt inv integrate exp "; for (int_t i = 0; i <= n; i++) cout << A[i] << " "; cout << endl; #endif for (int i = 0; i < n + 1; i++) { A[i] = (F[i] - A[i] + mod) % mod; } A[0] = (A[0] + 2 - F[0] + mod) % mod; #ifdef DEBUG cout << "sqrt inv integrate exp process "; for (int_t i = 0; i <= n; i++) cout << A[i] << " "; cout << endl; #endif // memset(B, 0, sizeof(B)); poly_log(A, n + 1, B); B[0] = 1; #ifdef DEBUG cout << "sqrt inv integrate exp process log +1 "; for (int_t i = 0; i <= n; i++) cout << B[i] << " "; cout << endl; #endif // memset(A, 0, sizeof()); poly_power(B, n + 1, k, A); #ifdef DEBUG cout << "sqrt inv integrate exp process power "; for (int_t i = 0; i <= n; i++) cout << A[i] << " "; cout << endl; #endif for (int i = 0; i < n; i++) { A[i] = (int_t)A[i + 1] * (i + 1) % mod; printf("%d ", A[i]); } #ifdef TIME auto end = clock(); cout << 1.0 * (end - begin) / CLOCKS_PER_SEC << endl; #endif return 0; } //计算A(x)^index mod x^n void poly_power(const int* A, int n, int index, int* result) { int size2 = upper2n(2 * n); static int base[LARGE + 1]; for (int i = 0; i < size2; i++) { if (i < n) base[i] = A[i]; else base[i] = 0; result[i] = 0; } result[0] = 1; while (index) { transform(base, size2, 1); if (index & 1) { transform(result, size2, 1); for (int i = 0; i < size2; i++) { result[i] = (int_t)result[i] * base[i] % mod; } transform(result, size2, -1); for (int i = 0; i < size2; i++) { if (i < n) result[i] = (int_t)result[i] * power(size2, -1) % mod; else result[i] = 0; } } for (int i = 0; i < size2; i++) base[i] = (int_t)base[i] * base[i] % mod; transform(base, size2, -1); for (int i = 0; i < size2; i++) { if (i < n) base[i] = (int_t)base[i] * power(size2, -1) % mod; else base[i] = 0; } index >>= 1; } } int power(int base, int index) { const int phi = mod - 1; index = (index % phi + phi) % phi; base = (base % mod + mod) % mod; int result = 1; while (index) { if (index & 1) result = (int_t)result * base % mod; base = (int_t)base * base % mod; index >>= 1; } return result; } void transform(int* A, int len, int arg) { int size2 = log2(len); for (int i = 0; i < len; i++) { int x = revs[size2][i]; if (x > i) std::swap(A[i], A[x]); } for (int i = 2; i <= len; i *= 2) { int mr = power(g, (int_t)arg * (mod - 1) / i); for (int j = 0; j < len; j += i) { int curr = 1; for (int k = 0; k < i / 2; k++) { int u = A[j + k]; int t = (int_t)A[j + k + i / 2] * curr % mod; A[j + k] = (u + t) % mod; A[j + k + i / 2] = (u - t + mod) % mod; curr = (int_t)curr * mr % mod; } } } } //计算多项式A在模x^n下的逆元 // C(x)<-2B(x)-A(x)B^2(x) void poly_inv(const int* A, int n, int* result) { if (n == 1) { result[0] = power(A[0], -1); return; } poly_inv(A, n / 2 + n % 2, result); static int temp[LARGE + 1]; int size2 = upper2n(3 * n + 1); for (int i = 0; i < size2; i++) { if (i < n) temp[i] = A[i]; else temp[i] = 0; } transform(temp, size2, 1); transform(result, size2, 1); for (int i = 0; i < size2; i++) { result[i] = ((int_t)2 * result[i] % mod - (int_t)temp[i] * result[i] % mod * result[i] % mod + 2 * mod) % mod; } transform(result, size2, -1); for (int i = 0; i < size2; i++) { if (i < n) { result[i] = (int_t)result[i] * power(size2, -1) % mod; } else { result[i] = 0; } } } int modSqrt(int a) { const int_t b = 3; // mod-1=2^s*t int s = 23, t = 119; int x = power(a, (t + 1) / 2); int e = power(a, t); int k = 1; while (k < s) { if (power(e, 1 << (s - k - 1)) != 1) { x = (int_t)x * power(b, (1 << (k - 1)) * t) % mod; } e = (int_t)power(a, -1) * x % mod * x % mod; k++; } return x; } //计算多项式开根 void poly_sqrt(const int* A, int n, int* result) { if (n == 1) { int p = modSqrt(A[0]); result[0] = std::min(p, mod - p); return; } poly_sqrt(A, n / 2 + n % 2, result); int size2 = upper2n(3 * n); static int Ax[LARGE + 1], pInv[LARGE]; for (int i = 0; i < size2; i++) { if (i < n) { Ax[i] = A[i]; } else { Ax[i] = 0; } pInv[i] = 0; } poly_inv(result, n, pInv); transform(Ax, size2, 1); transform(result, size2, 1); transform(pInv, size2, 1); const int inv2 = power(2, -1); for (int i = 0; i < size2; i++) { result[i] = (result[i] - ((int_t)result[i] * result[i] % mod - Ax[i] + mod) % mod * inv2 % mod * pInv[i] % mod + mod) % mod; } transform(result, size2, -1); for (int i = 0; i < size2; i++) { if (i < n) result[i] = (int_t)result[i] * power(size2, -1) % mod; else result[i] = 0; } } void poly_log(const int* A, int n, int* result) { int size2 = upper2n(2 * n); static int Ad[LARGE + 1]; for (int i = 0; i < size2; i++) { if (i < n) { Ad[i] = (int_t)(i + 1) * A[i + 1] % mod; } else { Ad[i] = 0; } result[i] = 0; } transform(Ad, size2, 1); poly_inv(A, n, result); transform(result, size2, 1); for (int i = 0; i < size2; i++) { result[i] = (int_t)result[i] * Ad[i] % mod; } transform(result, size2, -1); for (int i = 0; i < size2; i++) if (i < n) result[i] = (int_t)result[i] * power(size2, -1) % mod; else result[i] = 0; for (int i = n - 1; i >= 1; i--) { result[i] = (int_t)result[i - 1] * power(i, -1) % mod; } result[0] = 0; } void poly_exp(const int* A, int n, int* result) { if (n == 1) { assert(A[0] == 0); result[0] = 1; return; } poly_exp(A, n / 2 + n % 2, result); static int G0[LARGE + 1], Ax[LARGE + 1]; int size2 = upper2n(2 * n); poly_log(result, n, G0); for (int_t i = 0; i < size2; i++) { if (i < n) Ax[i] = A[i]; else Ax[i] = 0; } transform(Ax, size2, 1); transform(G0, size2, 1); transform(result, size2, 1); for (int i = 0; i < size2; i++) result[i] = (result[i] - (int_t)result[i] * (G0[i] - Ax[i] + mod) % mod + mod) % mod; transform(result, size2, -1); for (int i = 0; i < size2; i++) if (i < n) result[i] = (int_t)result[i] * power(size2, -1) % mod; else result[i] = 0; }
发表回复