作者: officeyutong

  • CUDA (.cu文件) 通过runtime API加载PTX binary的过程分析

    背景

    我们经常会用nvcc直接编译一个.cu文件到binary来执行。执行时,nvcc生成的其他代码帮我们做了很多事情来实现加载PTX ELF binary等事情。现在我们需要拦截nvcc生成的代码所做的这些事情,所以写一篇文章分析一下。

    细节

    #include <cstddef>
    #include <cstdint>
    #include <cstdio>
    #include <iostream>
    #include <ostream>
    #include <stdio.h>
    #include <stdlib.h>
    #include <cuda_runtime.h>
    #include <vector>
    
    __global__ void sumArray(uint64_t *input, uint64_t *output, size_t size)
    {
    	for (size_t i = 0; i < size; i++)
    		*output += input[i];
    	printf("From device side: sum = %lu\n", (unsigned long)*output);
    }
    
    constexpr size_t ARR_SIZE = 1000;
    
    int main()
    {
    	std::vector<uint64_t> arr;
    	for (size_t i = 1; i <= ARR_SIZE; i++) {
    		arr.push_back(i);
    	}
    	auto data_size = sizeof(arr[0]) * arr.size();
    
    	uint64_t *d_input, *d_output;
    	cudaMalloc(&d_input, data_size);
    	cudaMalloc(&d_output, sizeof(arr[0]));
    	cudaMemcpy(d_input, arr.data(), data_size, cudaMemcpyHostToDevice);
    	sumArray<<<1, 1, 1>>>(d_input, d_output, arr.size());
    	uint64_t host_sum;
    	cudaMemcpy(&host_sum, d_output, sizeof(arr[0]), cudaMemcpyDeviceToHost);
        cudaDeviceSynchronize();
    	std::cout << "Sum is " << host_sum << std::endl;
    	cudaFree(d_input);
    	cudaFree(d_output);
    
    	return 0;
    }
    

    这是一个很简单的CUDA程序。传一个数组到device上,device求和并写回内存,而后host读取结果。

    使用 nvcc -cuda victim.cu -o victim.cpp 来生成对应展开后的cpp文件。生成后的代码过大,这里就不粘贴了。

    初始化过程

    nvcc生成的代码里调用了大量内部API来实现加载。这些API的工作逻辑和driverAPI并不一致。

    
    static void __sti____cudaRegisterAll(void) __attribute__((__constructor__));
    static void __sti____cudaRegisterAll(void)
    {
    	__cudaFatCubinHandle =
    		__cudaRegisterFatBinary((void *)&__fatDeviceText);
    	{
    		void (*callback_fp)(void **) =
    			(void (*)(void **))(__nv_cudaEntityRegisterCallback);
    		(*callback_fp)(__cudaFatCubinHandle);
    		__cudaRegisterFatBinaryEnd(__cudaFatCubinHandle);
    	}
    	atexit(__cudaUnregisterBinaryUtil);
    }
    

    可以看到,nvcc生成的代码里给__sti____cudaRegisterAll打了constructor标记,也就是会先执行这个初始化binary的函数。我们可以看到依次调用了这些函数来初始化binary

    • __cudaRegisterFatBinary:加载一份binary。binary的内容是用内联汇编嵌入在源代码里的ELF
    • __nv_cudaEntityRegisterCallback:注册使用到的kernel函数。这个函数接下来会解释。
    • __cudaRegisterFatBinaryEnd:结束函数注册?(我猜的,这是没有文档的东西

    同时注册了一个atexit函数__cudaUnregisterBinaryUtil

    核函数注册

    让我们看__nv_cudaEntityRegisterCallback。

    static void __nv_cudaEntityRegisterCallback(void **__T5)
    {
    	{
    		volatile static void **__ref __attribute__((unused));
    		__ref = (volatile void **)__T5;
    	};
    	__nv_save_fatbinhandle_for_managed_rt(__T5);
    	__cudaRegisterFunction(__T5,
    			       (const char *)((void (*)(uint64_t *, uint64_t *,
    							size_t))sumArray),
    			       (char *)"_Z8sumArrayPmS_m", "_Z8sumArrayPmS_m",
    			       -1, (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0,
    			       (int *)0);
    }

    注意到这里会调用__cudaRegisterFunction来将kernel函数的符号名(定义在PTX ELF binary中)和host上的包装函数的地址关联起来。

    host上的包装函数

    void __device_stub__Z8sumArrayPmS_m(uint64_t *__par0, uint64_t *__par1,
    				    size_t __par2)
    {
    	void *__args_arr[3];
    	int __args_idx = 0;
    	__args_arr[__args_idx] = (void *)(char *)&__par0;
    	++__args_idx;
    	__args_arr[__args_idx] = (void *)(char *)&__par1;
    	++__args_idx;
    	__args_arr[__args_idx] = (void *)(char *)&__par2;
    	++__args_idx;
    	{
    		volatile static char *__f __attribute__((unused));
    		__f = ((char *)((
    			void (*)(uint64_t *, uint64_t *, size_t))sumArray));
    		dim3 __gridDim, __blockDim;
    		size_t __sharedMem;
    		cudaStream_t __stream;
    		if (__cudaPopCallConfiguration(&__gridDim, &__blockDim,
    					       &__sharedMem,
    					       &__stream) != cudaSuccess)
    			return;
    		if (__args_idx == 0) {
    			(void)cudaLaunchKernel(
    				((char *)((void (*)(uint64_t *, uint64_t *,
    						    size_t))sumArray)),
    				__gridDim, __blockDim, &__args_arr[__args_idx],
    				__sharedMem, __stream);
    		} else {
    			(void)cudaLaunchKernel(
    				((char *)((void (*)(uint64_t *, uint64_t *,
    						    size_t))sumArray)),
    				__gridDim, __blockDim, &__args_arr[0],
    				__sharedMem, __stream);
    		}
    	};
    }
    
    void sumArray(uint64_t *__cuda_0, uint64_t *__cuda_1, size_t __cuda_2)
    {
    	__device_stub__Z8sumArrayPmS_m(__cuda_0, __cuda_1, __cuda_2);
    }

    包装函数内的核心是cudaLaunchKernel,直接传入了host上的函数地址来执行kernel。由于先前已经关联过了host上的包装函数的地址和kernel的符号名,所以只需要提供host上的包装函数的地址即可。

    main函数

    int main()
    {
    	std::vector<unsigned long> arr;
    	for (size_t i = (1); i <= ARR_SIZE; i++) {
    		arr.push_back(i);
    	}
    	auto data_size = sizeof arr[0] * arr.size();
    	uint64_t *d_input, *d_output;
    	cudaMalloc(&d_input, data_size);
    	cudaMalloc(&d_output, sizeof arr[0]);
    	cudaMemcpy(d_input, arr.data(), data_size, cudaMemcpyHostToDevice);
    	(__cudaPushCallConfiguration(1, 1, 1)) ?
    		(void)0 :
    		sumArray(d_input, d_output, arr.size());
    	uint64_t host_sum;
    	cudaMemcpy(&host_sum, d_output, sizeof arr[0], cudaMemcpyDeviceToHost);
    	cudaDeviceSynchronize();
    	(((((std::cout << ("Sum is "))) << host_sum)) << (std::endl));
    	cudaFree(d_input);
    	cudaFree(d_output);
    	return 0;
    }

    可以看到,main函数直接执行host上的包装函数即可完成核函数运行。

    反注册binary

    注意到我们之前提到的atexit里面的退出函数。

    static void __cudaUnregisterBinaryUtil(void)
    {
    	____nv_dummy_param_ref((void *)&__cudaFatCubinHandle);
    	__cudaUnregisterFatBinary(__cudaFatCubinHandle);
    }
    

    这里调用一个内部函数来销毁binary handle。

  • CFGYM103415A CCPC2021广州A Math Ball

    题意: 有$n$种球,每种有无限个,同时第$i$种球有一个代价$c_i$,你要拿不超过$w$个球。如果最后第$i$种球你拿了$k_i$个,那么你会获得$\prod_{1\leq i\leq n}{k_i^{c_i}}$的权值,求所有合法方案的权值和。$n\leq 1e5,\sum{c_i}\leq 1e5,w\leq 10^{18}$

    $$  \text{考虑对于价值是}c_i\text{的球,构造生成函数}  \\  F_{c_i}\left( x \right) =\sum_{n\ge 0}{n^{c_i}x^n}  \\  \text{这样}\frac{\prod_i{F_{c_i}\left( x \right)}}{1-x}\text{的}w\text{次项即为答案}  \\  \text{设}F_k\left( x \right) =\sum_{n\ge 0}{n^kx^n},\text{显然可得}F_k\left( x \right) =x\frac{\mathrm{d}F_{k-1}\left( x \right)}{\mathrm{d}x}  \\  \text{进一步递推可得}, F_k\left( x \right) =\frac{\sum_{0\le i\le k}{T\left( k,i \right) x^i}}{\left( 1-x \right) ^{k+1}},\text{其中}T\left( k,i \right) \text{表示欧拉数}  \\  \text{考虑如何快速计算欧拉数}  \\  \text{首先由具体数学可得}  \\  \sum_{0\le i\le k}{T\left( k,i \right) \times \left( z+1 \right) ^i}=\sum_{0\le i\le k}{S_2\left( k,i \right) \times i!\times z^{k-i}}  \\  \text{进一步推导得欧拉数的通项公式}T\left( n,k \right) =\sum_{0\le i\le k}{\begin{array}{c}    \binom{n+1}{i}\left( -1 \right) ^i\left( k+1-i \right) ^n\\  \end{array}}  \\  \text{构造卷积形式}T\left( n,k \right) =\left( n+1 \right) !\sum_{0\le i\le k}{\frac{\left( -1 \right) ^i}{i!\left( n+1-i \right) !}}\times \left( k+1-i \right) ^n  \\  \text{卷积即可求出一行欧拉数。}  \\  \text{现在我们可以对每一个}c_i\text{计算出}F_{c_i}\left( x \right) \text{的分子,并且得到他们分子的乘积,设为}F\left( x \right)   \\  \text{设他们分母的乘积,再乘个}1-x\text{为}\left( 1-x \right) ^k  \\  \text{则答案即为}\frac{F\left( x \right)}{\left( 1-x \right) ^k}\text{的}w\text{次项系数}  \\  \text{即}\sum_{0\le i\le n}{\left[ x^i \right] F\left( x \right) \binom{w-i+k-1}{w-i}}\left( \text{广义二项式定理展开分母} \right)   \\  \text{令}a_i=\left[ x^i \right] F\left( x \right)   \\  \text{即}\sum_{0\le i\le n}{a_i\binom{w-i+k-1}{k-1}}  \\  \text{后者可以通过递推快速转移。}  \\    \\    \\  \frac{x^a}{\left( 1-x \right) ^b}=x^a\left( \sum_{n\ge 0}{\binom{-b}{n}\left( -1 \right) ^n}x^n \right)   \\  =x^a\left( \sum_{n\ge 0}{\frac{\left( -b \right) ^{\underline{n}}}{n!}\left( -1 \right) ^n}x^n \right)   \\  =x^a\left( \sum_{n\ge 0}{\frac{\left( -b-0 \right) \left( -b-1 \right) \left( -b-2 \right) ..\left( -b-\left( n-1 \right) \right)}{n!}\left( -1 \right) ^n}x^n \right)   \\  =x^a\left( \sum_{n\ge 0}{\frac{\left( b+0 \right) \left( b+1 \right) \left( b+2 \right) ..\left( b+\left( n-1 \right) \right)}{n!}}x^n \right)   \\  =x^a\left( \sum_{n\ge 0}{\frac{\left( b+n-1 \right) ^{\underline{n}}}{n!}}x^n \right)   \\  =\sum_{n\ge 0}{\frac{\left( b+n-1 \right) ^{\underline{n}}}{n!}x^{n+a}}  \\  =\sum_{n\ge 0}{\binom{b+n-1}{n}x^{n+a}}  \\  =\sum_{n\ge a}{\binom{b+n-a-1}{n-a}x^n}  \\  =\sum_{n\ge a}{\binom{b+n-a-1}{b-1}x^n}  \\  =\sum_{n\ge a}{\begin{array}{c}     \frac{\left( b-a+n-1 \right) ^{\underline{b-1}}}{\left( b-1 \right) !}x^n\\  \end{array}}  \\  \text{令}t=b-1  \\  \text{则}  \\  \sum_{n\ge a}{\begin{array}{c}      \frac{\left( t-a+n \right) ^{\underline{t}}}{t!}x^n\\  \end{array}}  \\  \text{从}a\text{推进到}a+1  \\  \frac{\left( t-a+n \right) ^{\underline{t}}}{\left( t-a+n-1 \right) ^{\underline{t}}}=\frac{t-a+n}{n-a}  \\    \\  10*9*8/\left( 9*8*7 \right) =10/7  \\    \\  \left( \frac{x^a}{\left( 1-x \right) ^b}\text{的}n\text{次项} \right)   \\    \\  \binom{n+b-1}{n}=\binom{n+b-1}{b-1}  \\  \binom{n+b-1}{n}\rightarrow \binom{n+b}{n+1}  \\    \\  \frac{\binom{n+b}{n+1}}{\binom{n+b-1}{n}}=\frac{n+b}{\left( n+1 \right)}  \\    \\  \frac{\left( n-2+b-i \right) !\left( b-1 \right) !\left( n-i \right) !}{\left( b-1 \right) !\left( n-i-1 \right) !\left( n-i+b-1 \right) !}  \\    \\  \frac{\left( n-i+b-2 \right) !\left( n-i \right)}{\left( n-i+b-1 \right) !}  \\  \frac{\left( n-i \right)}{\left( n-i+b-1 \right)}  \\    \\    $$
    #include <cinttypes>
    #include <cstring>
    #include <fstream>
    #include <iostream>
    #include <random>
    using int_t = int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 2e6;
    
    using i64 = int64_t;
    #ifdef NTTCNT
    std::ofstream nttcnt("nttcnt.txt");
    #endif
    const int_t mod = 998244353;
    const int_t g = 3;
    
    int_t power(int_t b, int_t i) {
        int_t r = 1;
        if (i < 0)
            i = ((i64)i % (mod - 1) + mod - 1) % (mod - 1);
        while (i) {
            if (i & 1)
                r = (i64)r * b % mod;
            b = (i64)b * b % mod;
            i >>= 1;
        }
        return r;
    }
    
    void makeflip(int_t* arr, int_t size2) {
        int_t len = (1 << size2);
        arr[0] = 0;
        for (int_t i = 1; i < len; i++) {
            arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1));
        }
    }
    
    int_t upper2n(int_t x) {
        int_t r = 0;
        while ((1 << r) < x)
            r++;
        return r;
    }
    template <int_t arg = 1>
    void transform(int_t* A, int_t size2, int_t* flip) {
        // #define int_t i64
        int_t len = (1 << size2);
    #ifdef NTTCNT
        nttcnt << len << endl;
    #endif
        for (int_t i = 0; i < len; i++) {
            int_t r = flip[i];
            if (r > i)
                std::swap(A[i], A[r]);
        }
        for (int_t i = 2; i <= len; i *= 2) {
            int_t mr = power(g, (i64)arg * (mod - 1) / i);
            for (int_t j = 0; j < len; j += i) {
                int_t curr = 1;
                for (int_t k = 0; k < i / 2; k++) {
                    int_t u = A[j + k], v = (i64)curr * A[j + k + i / 2] % mod;
                    A[j + k] = ((i64)u + v) % mod;
                    A[j + k + i / 2] = ((i64)u - v + mod) % mod;
                    curr = (i64)curr * mr % mod;
                }
            }
        }
        // #undef int_t
    }
    void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) {
        /*
           计算n次多项式A与m次多项式B的乘积
        */
        int_t size2 = upper2n(n + m + 1);
        int_t len = 1 << size2;
        static int_t T1[LARGE], T2[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = A[i];
            else
                T1[i] = 0;
            if (i <= m)
                T2[i] = B[i];
            else
                T2[i] = 0;
        }
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform(T1, size2, fliparr);
        transform(T2, size2, fliparr);
        for (int_t i = 0; i < len; i++)
            T1[i] = (i64)T1[i] * T2[i] % mod;
        transform<-1>(T1, size2, fliparr);
        int_t inv = power(len, -1);
        for (int_t i = 0; i <= n + m; i++)
            C[i] = (i64)T1[i] * inv % mod;
    }
    
    const int_t FLARGE = 2e5 + 10;
    
    int_t fact[FLARGE + 1], inv[FLARGE + 1], factInv[FLARGE + 1];
    void euler_num(int_t n, int_t* out) {
        static int_t P1[LARGE];
        static int_t P2[LARGE];
        for (int_t i = 0; i <= n; i++) {
            P1[i] = (i64)((i % 2) ? (mod - 1) : 1) * factInv[i] % mod *
                    factInv[n + 1 - i] % mod;
            P2[i] = power(i, n);
        }
        poly_mul(P1, n, P2, n, out);
        for (int_t i = 0; i <= n; i++)
            out[i] = (i64)out[i] * fact[n + 1] % mod;
    }
    
    using poly_t = std::vector<int_t>;
    poly_t poly_dcmul(const poly_t* P, int_t left, int_t right) {
        if (left == right)
            return P[left];
        int_t mid = (left + right) / 2;
        auto lret = poly_dcmul(P, left, mid);
        auto rret = poly_dcmul(P, mid + 1, right);
        poly_t ret;
        ret.resize(lret.size() - 1 + rret.size() - 1 + 1);
        poly_mul(&lret[0], lret.size() - 1, &rret[0], rret.size() - 1, &ret[0]);
        while (!ret.empty() && ret.back() == 0)
            ret.pop_back();
        return ret;
    }
    int_t C(int_t n, int_t m) {
        return (i64)fact[n] * factInv[m] % mod * factInv[n - m] % mod;
    }
    i64 getC(i64 n, int_t m) {
        if (n < 0 || n < m)
            return 0;
        i64 prod = 1;
        for (int_t i = 0; i < m; i++)
            prod = prod * (n % mod - i + mod) % mod;
        return prod * factInv[m] % mod;
    }
    int main() {
        std::ios::sync_with_stdio(false);
        cin.tie(0), cout.tie(0);
        {
            fact[0] = fact[1] = inv[1] = factInv[0] = factInv[1] = 1;
            for (int_t i = 2; i <= FLARGE; i++) {
                fact[i] = (i64)fact[i - 1] * i % mod;
                inv[i] = (i64)(mod - mod / i) * inv[mod % i] % mod;
                factInv[i] = (i64)factInv[i - 1] * inv[i] % mod;
            }
        }
        int_t n;
        i64 w;
        cin >> n >> w;
    
        static poly_t up[int_t(1e5 + 10)];
        static poly_t cache[int_t(1e5 + 10)];
        int_t sum = 0;
        for (int_t i = 1; i <= n; i++) {
            int_t c;
            cin >> c;
            sum += c + 1;
            poly_t& curr = up[i];
            if (!cache[c].empty()) {
                curr = cache[c];
                continue;
            }
            curr.resize(2 * c + 3);
            euler_num(c, &curr[0]);
            curr.resize(c + 1);
            cache[c] = curr;
    #ifdef DEBUG
            cout << "up " << i << " = ";
            for (auto t : curr)
                cout << t << " ";
            cout << endl;
    #endif
        }
        sum++;
        // for (int_t i = 0; i <= sum; i++) {
        //     down[i] = (i64)((i % 2) ? (mod - 1) : 1) * C(sum, i) % mod;
        // }
        auto upprod = poly_dcmul(up, 1, n);
        // cout << upprod.size() << endl;
        upprod.resize(sum + 1);
    #ifdef DEBUG
        cout << "up prod = ";
        for (auto t : upprod)
            cout << t << " ";
        cout << endl;
        cout << "sum = " << sum << endl;
    #endif
        // cout << "down size = " << sum << endl;
        // w -= cutcount;
        // if (w < 0) {
        //     cout << 0 << endl;
        //     return 0;
        // }
        // int_t ret = poly_divat(&upprod[0], down, upprod.size() - 1, sum, w);
        i64 ret = 0;
        {
            i64 b = sum;
            i64 n = w;
            i64 curr = getC(n - 0 + b - 1, b - 1);
    // C(n-i+b-1,b-1)到C(n-(i+1)+b-1,b-1)
    // getC(n - i + b - 1, b - 1)
    #ifdef DEBUG
            cout << "init curr = " << curr << endl;
    #endif
            for (i64 i = 0; i < upprod.size(); i++) {
                ret = (ret + upprod[i] * curr % mod) % mod;
                curr = curr * (n % mod - i + mod) % mod *
                       power((n % mod - i + b - 1 + (i64)3 * mod) % mod, -1) % mod;
                if (curr == 0 || n <= i - n + 1)
                    break;
    #ifdef DEBUG
                cout << "curr = " << curr << " at i = " << i << " with n = " << n
                     << endl;
    #endif
            }
            // int_t curr = getC(w - b + b - 1, b - 1);
            // for (i64 i = std::max<i64>(0, w - sum); i < upprod.size(); i++) {
            //     ret = (ret + curr * upprod[i] % mod) % mod;
            //     curr = curr * (n % mod + n % mod - sum % mod + mod + i) % mod *
            //            power(n + 1, -1) % mod;
            // }
            // i64 n = w;
            // i64 t = sum - 1;
            // i64 prod = 1;
            // // i64 curr = (t + n) % mod;
            // for (int_t i = 0; i < t; i++)
            //     prod = (prod * (t + n % mod - i) % mod);
            // for (int_t a = 0; a < upprod.size(); a++) {
            //     if (n >= a) {
            //         ret = (ret + prod * factInv[t] % mod * upprod[a] % mod) %
            //         mod;
            //     }
            //     prod = prod * (t - a + n % mod) % mod *
            //            power((n % mod - a + mod) % mod, -1) % mod;
            // }
        }
        cout << ret << endl;
        return 0;
    }

     

  • 跑的比较快的多项式板子

    #include <cstring>
    #include <iostream>
    #include <random>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 2e6;
    
    const int_t mod = 998244353;
    const int_t g = 3;
    int_t power(int_t b, int_t i) {
        int_t r = 1;
        if (i < 0)
            i = (i % (mod - 1) + mod - 1) % (mod - 1);
        while (i) {
            if (i & 1)
                r = r * b % mod;
            b = b * b % mod;
            i >>= 1;
        }
        return r;
    }
    
    void makeflip(int_t* arr, int_t size2) {
        int_t len = (1 << size2);
        arr[0] = 0;
        for (int_t i = 1; i < len; i++) {
            arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1));
        }
    }
    
    int_t upper2n(int_t x) {
        int_t r = 0;
        while ((1 << r) < x)
            r++;
        return r;
    }
    template <int_t arg = 1>
    void transform(int_t* A, int_t size2, int_t* flip) {
        int_t len = (1 << size2);
        for (int_t i = 0; i < len; i++) {
            int_t r = flip[i];
            if (r > i)
                std::swap(A[i], A[r]);
        }
        for (int_t i = 2; i <= len; i *= 2) {
            int_t mr = power(g, arg * (mod - 1) / i);
            for (int_t j = 0; j < len; j += i) {
                int_t curr = 1;
                for (int_t k = 0; k < i / 2; k++) {
                    int_t u = A[j + k], v = curr * A[j + k + i / 2] % mod;
                    A[j + k] = (u + v) % mod;
                    A[j + k + i / 2] = (u - v + mod) % mod;
                    curr = curr * mr % mod;
                }
            }
        }
    }
    void poly_inv(const int_t* A, int_t n, int_t* result) {
        /*
        这里的n是模x^n
        计算B(x)*A(x) = 1 mod x^n, 其中A(x)已知
        假设已知A(x)*B(x) = 1 mod x^{ceil(n/2)}
        假设C(x)*A(x) = 1 mod x^n
        (A(x)B(x)-1)^2 = A^2(x)B^2(x)-2A(x)B(x)+1= 0
        A(x)B^2(x)-2B(x)+C(x) = 0
        C(x) = 2B(x)-A(x)B^2(x)
        */
        // #ifdef DEBUG
        //     cout << "at " << n << endl;
        // #endif
        if (n == 1) {
            result[0] = power(A[0], -1);
            return;
        }
        int_t next = n / 2 + n % 2;
        poly_inv(A, next, result);
        //次数不要选错了,应该用n次的A和B去卷
        int_t size2 = upper2n(n + 2 * next + 1);
        static int_t X[LARGE];
        static int_t Y[LARGE];
        int_t len = (1 << size2);
        // 写错设置范围了!
        memset(X + n, 0, sizeof(X[0]) * (len - n));
        memset(Y + next, 0, sizeof(Y[0]) * (len - next));
        memcpy(X, A, sizeof(A[0]) * n);
        memcpy(Y, result, sizeof(result[0]) * next);
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform<1>(X, size2, fliparr);
        transform<1>(Y, size2, fliparr);
    
        for (int_t i = 0; i < len; i++) {
            X[i] = (2 * Y[i] - X[i] * Y[i] % mod * Y[i] % mod + mod) % mod;
        }
        transform<-1>(X, size2, fliparr);
        const int_t inv = power(len, -1);
        for (int_t i = 0; i < n; i++)
            result[i] = X[i] * inv % mod;
    #ifdef DEBUG
        cout << "poly inv at " << n << endl;
        cout << "result = ";
        for (int_t i = 0; i < next; i++)
            cout << result[i] << " ";
        cout << endl;
    
    #endif
    }
    int_t poly_divat(const int_t* F, const int_t* G, int_t n, int_t k) {
        /*
            n次多项式F和G
            计算F(x)/G(x)的k次项前系数
            考虑F(x)*G(-x)/G(x)*G(-x),分母只有偶数次项,写为C(x^2);分子写成xA(x^2)+B(x^2),如果k是奇数,那么递归(A,C,n,k/2),如果k是偶数,那么递归(B,C,n,k/2)
            到k<=n时直接计算
        */
        int_t size2 = upper2n(2 * n + 1);
        int_t len = 1 << size2;
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        static int_t T1[LARGE], T2[LARGE], T3[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = F[i], T2[i] = G[i];
            else
                T1[i] = T2[i] = T3[i] = 0;
        }
        const int_t inv = power(len, -1);
        while (k >= n) {
    #ifdef DEBUG
            cout << "curr at k = " << k << endl;
    #endif
            for (int_t i = 0; i < len; i++) {
                if (i <= n) {
                    T3[i] = T2[i] * (i % 2 ? (mod - 1) : 1);
                } else
                    T3[i] = 0;
            }
            transform(T1, size2, fliparr);
            transform(T2, size2, fliparr);
            transform(T3, size2, fliparr);
            for (int_t i = 0; i < len; i++) {
                T1[i] = T1[i] * T3[i] % mod;
                T2[i] = T2[i] * T3[i] % mod;
            }
            transform<-1>(T1, size2, fliparr);
            transform<-1>(T2, size2, fliparr);
    #ifdef DEBUG
            cout << "prod T1 = ";
            for (int_t i = 0; i < len; i++)
                cout << T1[i] * inv % mod << " ";
            cout << endl;
            cout << "prod T2 = ";
            for (int_t i = 0; i < len; i++)
                cout << T2[i] * inv % mod << " ";
            cout << endl;
    
    #endif
            for (int_t i = 0; i < len; i++) {
                if (i * 2 < len) {
                    T2[i] = T2[i * 2] * inv % mod;
                } else
                    T2[i] = 0;
            }
            int_t b = k % 2;
            for (int_t i = 0; i < len; i++) {
                if (i % 2 == b) {
                    T1[i / 2] = T1[i] * inv % mod;
    #ifdef DEBUG
                    cout << "at " << i << " assgin " << T1[i] * inv << " to "
                         << i / 2 << endl;
    #endif
                }
    #ifdef DEBUG
                cout << "assign 0 to " << i << endl;
    #endif
                if (i > 0)
                    T1[i] = 0;
            }
    #ifdef DEBUG
            cout << "finished T1 = ";
            for (int_t i = 0; i < len; i++)
                cout << T1[i] % mod << " ";
            cout << endl;
            cout << "finished T2 = ";
            for (int_t i = 0; i < len; i++)
                cout << T2[i] % mod << " ";
            cout << endl;
    
    #endif
            k >>= 1;
        }
    
        poly_inv(T2, k + 1, T3);
        // #ifdef DEBUG
    
        //     cout << "finished k = " << k << endl;
        //     cout << "T1 = ";
        //     for (int_t i = 0; i < len; i++)
        //         cout << T1[i] % mod << " ";
        //     cout << endl;
        //     cout << "T2 = ";
        //     for (int_t i = 0; i < len; i++)
        //         cout << T2[i] % mod << " ";
        //     cout << endl;
        //     cout << "T2 inv = ";
        //     for (int_t i = 0; i < len; i++)
        //         cout << T3[i] % mod << " ";
        //     cout << endl;
    
        // #endif
        int_t result = 0;
        //计算结果的k次项
        for (int_t i = 0; i <= k; i++) {
            result = (result + T1[i] * T3[k - i] % mod) % mod;
        }
        return result;
    }
    void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) {
        /*
           计算n次多项式A与m次多项式B的乘积
        */
        int_t size2 = upper2n(n + m + 1);
        int_t len = 1 << size2;
        static int_t T1[LARGE], T2[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = A[i];
            else
                T1[i] = 0;
            if (i <= m)
                T2[i] = B[i];
            else
                T2[i] = 0;
        }
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform(T1, size2, fliparr);
        transform(T2, size2, fliparr);
        for (int_t i = 0; i < len; i++)
            T1[i] = T1[i] * T2[i] % mod;
        transform<-1>(T1, size2, fliparr);
        int_t inv = power(len, -1);
        for (int_t i = 0; i <= n + m; i++)
            C[i] = T1[i] * inv % mod;
    }
    int_t poly_linear_rec(const int_t* A0, const int_t* F0, int_t n, int_t k) {
        /*
            计算线性递推
            F[1],F[2]...F[k] 递推系数
            A[0],A[1]...A[k-1] 首项
    
            A[m]=A[m-1]*F[1]+A[m-2]*F[2]+...+A[m-k]*F[k]
        */
    
        static int_t T1[LARGE], T2[LARGE];
        static int_t Ax[LARGE], Fx[LARGE];
        Fx[0] = 0;
        Ax[k] = 0;
        for (int i = 1; i <= k; i++) {
            Fx[i] = F0[i];
            Ax[i - 1] = A0[i - 1];
        }
        poly_mul(Ax, k, Fx, k, T1);
        T1[0] = Ax[0];
        for (int i = 1; i <= k - 1; i++) {
            T1[i] = (Ax[i] - T1[i] + mod) % mod;
        }
        for (int i = k; i <= 2 * k; i++)
            T1[i] = 0;
        T2[0] = 1;
        for (int i = 1; i <= k; i++)
            T2[i] = (mod - Fx[i]) % mod;
    
        // #ifdef DEBUG
        //     cout << "T1 = ";
        //     for (int_t i = 0; i <= k; i++) {
        //         cout << T1[i] << " ";
        //     }
        //     cout << endl;
        //     cout << "T2 = ";
        //     for (int_t i = 0; i <= k; i++) {
        //         cout << T2[i] << " ";
        //     }
        //     cout << endl;
    
        // #endif
        return poly_divat(T1, T2, k, n);
    }
    
    void poly_log(const int_t* A, int_t n, int_t* out) {
        /*
            计算log(A(x)), A(x)为n次多项式
            DlogF(x) =DF(x)/F(x)
        */
        static int_t Ad[LARGE];
        static int_t Finv[LARGE];
        static int_t R[LARGE];
        const int_t m = n - 1;
        for (int_t i = 0; i <= m; i++) {
            Ad[i] = A[i + 1] * (i + 1) % mod;
        }
        Ad[n] = 0;
        poly_inv(A, n + 1, Finv);
        poly_mul(Ad, m, Finv, n, R);
        for (int_t i = 1; i <= n; i++) {
            out[i] = R[i - 1] * power(i, -1) % mod;
        }
    }
    void poly_exp(const int_t* A, int_t n, int_t* out) {
        /*
            计算exp(A(x)), A(x)为n次多项式
            H(x)=G(x)(1-logG(x)+A(x))
            H(x)为当次递推项,G(x)为上一次递推项
        */
        if (n == 1) {
            out[0] = 1;
            return;
        }
        int_t r = n / 2 + n % 2;
        poly_exp(A, r, out);
        static int_t G[LARGE];
        static int_t G2[LARGE];
        static int_t logG[LARGE];
        static int_t R[LARGE];
        for (int_t i = 0; i < r; i++)
            G[i] = out[i];
        for (int_t i = r; i < n; i++)
            G[i] = 0;
        poly_log(G, n - 1, logG);
        // for (int_t i = r; i < n; i++)
        //     logG[i] = 0;
        for (int_t i = 0; i < n; i++) {
            G2[i] = (mod - logG[i] + A[i]) % mod;
        }
        G2[0] = (G2[0] + 1) % mod;
        poly_mul(G, n - 1, G2, n - 1, R);
        for (int_t i = 0; i < n; i++)
            out[i] = R[i];
    #ifdef DEBUG
        cout << "at " << n << endl;
        cout << "A = ";
        for (int_t i = 0; i < n; i++)
            cout << A[i] << " ";
        cout << endl;
        cout << "G = ";
        for (int_t i = 0; i < n; i++)
            cout << G[i] << " ";
        cout << endl;
        cout << "logG = ";
        for (int_t i = 0; i < n; i++)
            cout << logG[i] << " ";
        cout << endl;
        cout << "G2 = ";
        for (int_t i = 0; i < n; i++)
            cout << G2[i] << " ";
        cout << endl;
        cout << "R = ";
        for (int_t i = 0; i < n; i++)
            cout << R[i] << " ";
        cout << endl;
        cout << endl;
    #endif
    }
    int main() {
        std::ios::sync_with_stdio(false);
        cin.tie(0), cout.tie(0);
    
        int_t n;
        cin >> n;
        static int_t A[LARGE], B[LARGE];
        static int_t C[LARGE];
        for (int_t i = 0; i < n; i++)
            cin >> A[i];
        poly_exp(A, n, B);
        for (int_t i = 0; i < n; i++)
            cout << B[i] << " ";
    
        return 0;
    }

     

  • 常系数线性递推的新做法 – 计算[x^k]P(x)/Q(x)

    $$ a_n=\sum_{1\le i\le k}{f_ia_{n-i}} \\ \left\{ a_0,a_1….a_{k-1} \right\} \text{已知} \\ \\ \text{设}F\left( x \right) \text{表示从第}k\text{项开始的该数列的生成函数} \\ \sum_{i\ge k}{a_ix^i}=\sum_{i\ge k}{\sum_{1\le j\le k}{f_ja_{i-j}x^i}} \\ =\sum_{1\le j\le k}{f_j\sum_{i\ge k}{a_{i-j}x^i}} \\ =\sum_{1\le j\le k}{f_j\sum_{i\ge k-j}{a_ix^{i+j}}} \\ =\sum_{1\le j\le k}{f_jx^j\sum_{i\ge k-j}{a_ix^i}} \\ =\sum_{1\le j\le k}{f_jx^j\left( F\left( x \right) +\sum_{k-j\le i\le k-1}{a_ix^i} \right)} \\ \\ F\left( x \right) =F\left( x \right) \sum_{1\le j\le k}{f_jx^j}+\sum_{1\le j\le k}{f_jx^j\sum_{k-j\le i\le k-1}{a_ix^i}} \\ F\left( x \right) =\frac{\sum_{1\le j\le k}{f_jx^j\sum_{k-j\le i\le k-1}{a_ix^i}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{1\le j\le k}{f_jx^j\sum_{k-j\le i-j\le k-1}{a_{i-j}x^{i-j}}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{1\le j\le k}{f_j\sum_{k\le i\le k-1+j}{a_{i-j}x^i}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{k\le i\le 2k-1}{\begin{array}{c} x^i\sum_{i-k+1\le j\le k}{a_{i-j}f_j}\\ \end{array}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ F\left( x \right) +\sum_{0\le i\le k-1}{a_ix^i}=\frac{\left( 1-\sum_{1\le j\le k}{f_jx^j} \right) \left( \sum_{0\le i\le k-1}{a_ix^i} \right) +\sum_{1\le j\le k}{f_j\sum_{k\le i\le k-1+j}{a_{i-j}x^i}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ \frac{\sum_{0\le i\le k-1}{a_ix^i}-\sum_{1\le j\le k}{f_j\sum_{0\le i\le k-1}{a_ix^{i+j}}}+\sum_{1\le j\le k}{f_j\sum_{k\le i\le k-1+j}{a_{i-j}x^i}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{0\le i\le k-1}{a_ix^i}-\sum_{1\le j\le k}{f_j\sum_{j\le i\le j+k-1}{a_{i-j}x^i}}+\sum_{1\le j\le k}{f_j\sum_{k\le i\le k-1+j}{a_{i-j}x^i}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ \frac{\sum_{0\le i\le k-1}{a_ix^i}+\sum_{1\le j\le k}{f_j\left( \sum_{k\le i\le k-1+j}{a_{i-j}x^i}-\sum_{j\le i\le j+k-1}{a_{i-j}x^i} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ \frac{\sum_{0\le i\le k-1}{a_ix^i}+\sum_{1\le j\le k}{f_j\left( \sum_{k\le i\le k-1+j}{a_{i-j}x^i}-\sum_{k\le i\le j+k-1}{a_{i-j}x^i}-\sum_{j\le i\le k-1}{a_{i-j}x^i} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ \frac{\sum_{0\le i\le k-1}{a_ix^i}+\sum_{1\le j\le k}{f_j\left( -\sum_{j\le i\le k-1}{a_{i-j}x^i} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{0\le i\le k-1}{a_ix^i}-\sum_{1\le j\le k}{f_j\left( \sum_{j\le i\le k-1}{a_{i-j}x^i} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{0\le i\le k-1}{a_ix^i}-\sum_{1\le j\le k-1}{f_j\left( \sum_{j\le i\le k-1}{a_{i-j}x^i} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{\sum_{0\le i\le k-1}{a_ix^i}-\sum_{1\le i\le k-1}{x^i\sum_{1\le j\le i}{a_{i-j}f_j}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{a_0+\sum_{1\le i\le k-1}{x^ia_i}-\sum_{1\le i\le k-1}{x^i\sum_{1\le j\le i}{a_{i-j}f_j}}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ =\frac{a_0+\sum_{1\le i\le k-1}{x^i\left( a_i-\sum_{1\le j\le i}{a_{i-j}f_j} \right)}}{1-\sum_{1\le j\le k}{f_jx^j}} \\ \text{我们得到了原数列}\left\{ a_i \right\} \text{的生成函数}G\left( x \right) =\frac{P\left( x \right)}{Q\left( x \right)} \\ \text{考虑计算这个有理函数的}k\text{次项} \\ \text{令}G_k\left( x \right) =\frac{P\left( x \right)}{Q\left( x \right)}=\frac{P\left( x \right) P\left( -x \right)}{Q\left( x \right) Q\left( -x \right)}=\frac{xA\left( x^2 \right) +B\left( x^2 \right)}{C\left( x^2 \right)} \\ \text{即分母只有偶数次方项},\text{分子的奇数次方项和偶数次方项拆开} \\ \text{如果}k\text{为奇数},\text{那么}x^k\text{之可能存在于}x\frac{A\left( x^2 \right)}{C\left( x^2 \right)}\text{中}\left( \text{只有这边才有奇数次方} \right) \\ \text{同理},\text{如果}k\text{为偶数,那么}x^k\text{只能出现在}\frac{B\left( x^2 \right)}{C\left( x^2 \right)}\text{中,} \\ \text{然后根据}k\text{的奇偶性,继续递归} \\ \text{如果}k\text{是奇数},\text{那么计算}G_{\frac{k-1}{2}}\left( x \right) =\frac{A\left( x \right)}{C\left( x \right)}\left( \text{项的次数除以}2 \right) ,\text{答案为}\left[ x^{\frac{k-1}{2}} \right] G_{\frac{k-1}{2}}\left( x \right) \\ \text{如果}k\text{是偶数},\text{那么计算}G_{\frac{k}{2}}\left( x \right) =\frac{B\left( x \right)}{C\left( x \right)}\text{,答案为}\left[ x^k \right] G_{\frac{k}{2}}\left( x \right) \\ \text{然后有两种选择}:\text{递归到}k=0\text{时,计算}\frac{P\left( 0 \right)}{G\left( 0 \right)}\mathrm{mod}p\text{,此时不需要多项式求逆运算} \\ \text{但常数较大}\left( \text{递归次数多} \right) \\ \text{递归到}k<n\text{时,直接求逆计算。} \\ \\ \\ \\ \\ $$

    代码1: 递归到$k=0$时执行整数运算,较慢

    #include <cstring>
    #include <iostream>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 2e6;
    
    const int_t mod = 998244353;
    const int_t g = 3;
    int_t power(int_t b, int_t i) {
        int_t r = 1;
        if (i < 0)
            i = (i % (mod - 1) + mod - 1) % (mod - 1);
        while (i) {
            if (i & 1)
                r = r * b % mod;
            b = b * b % mod;
            i >>= 1;
        }
        return r;
    }
    
    void makeflip(int_t* arr, int_t size2) {
        int_t len = (1 << size2);
        arr[0] = 0;
        for (int_t i = 1; i < len; i++) {
            arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1));
        }
    }
    
    int_t upper2n(int_t x) {
        int_t r = 0;
        while ((1 << r) < x)
            r++;
        return r;
    }
    template <int_t arg = 1>
    void transform(int_t* A, int_t size2, int_t* flip) {
        int_t len = (1 << size2);
        for (int_t i = 0; i < len; i++) {
            int_t r = flip[i];
            if (r > i)
                std::swap(A[i], A[r]);
        }
        for (int_t i = 2; i <= len; i *= 2) {
            int_t mr = power(g, arg * (mod - 1) / i);
            for (int_t j = 0; j < len; j += i) {
                int_t curr = 1;
                for (int_t k = 0; k < i / 2; k++) {
                    int_t u = A[j + k], v = curr * A[j + k + i / 2] % mod;
                    A[j + k] = (u + v) % mod;
                    A[j + k + i / 2] = (u - v + mod) % mod;
                    curr = curr * mr % mod;
                }
            }
        }
    }
    void poly_inv(const int_t* A, int_t n, int_t* result) {
        /*
        计算B(x)*A(x) = 1 mod x^n, 其中A(x)已知
        假设已知A(x)*B(x) = 1 mod x^{ceil(n/2)}
        假设C(x)*A(x) = 1 mod x^n
        (A(x)B(x)-1)^2 = A^2(x)B^2(x)-2A(x)B(x)+1= 0
        A(x)B^2(x)-2B(x)+C(x) = 0
        C(x) = 2B(x)-A(x)B^2(x)
        */
        if (n == 1) {
            result[0] = power(A[0], -1);
            return;
        }
        int_t next = n / 2 + n % 2;
        poly_inv(A, next, result);
        //次数不要选错了,应该用n次的A和B去卷
        int_t size2 = upper2n(n + 2 * next + 1);
        static int_t X[LARGE];
        static int_t Y[LARGE];
        int_t len = (1 << size2);
        memset(X + next, 0, sizeof(X[0]) * (len - n));
        memset(Y + next, 0, sizeof(Y[0]) * (len - next));
        memcpy(X, A, sizeof(A[0]) * n);
        memcpy(Y, result, sizeof(result[0]) * next);
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform<1>(X, size2, fliparr);
        transform<1>(Y, size2, fliparr);
    
        for (int_t i = 0; i < len; i++) {
            X[i] = (2 * Y[i] - X[i] * Y[i] % mod * Y[i] % mod + mod) % mod;
        }
        transform<-1>(X, size2, fliparr);
        const int_t inv = power(len, -1);
        for (int_t i = 0; i < n; i++)
            result[i] = X[i] * inv % mod;
    }
    int_t poly_divat(const int_t* F, const int_t* G, int_t n, int_t k) {
        /*
            n次多项式F和G
            计算F(x)/G(x)的k次项前系数
            考虑F(x)*G(-x)/G(x)*G(-x),分母只有偶数次项,写为C(x^2);分子写成xA(x^2)+B(x^2),如果k是奇数,那么递归(A,C,n,k/2),如果k是偶数,那么递归(B,C,n,k/2)
            到k<=n时直接计算
        */
        int_t size2 = upper2n(2 * n + 1);
        int_t len = 1 << size2;
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        static int_t T1[LARGE], T2[LARGE], T3[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = F[i], T2[i] = G[i];
            else
                T1[i] = T2[i] = T3[i] = 0;
        }
        const int_t inv = power(len, -1);
        while (k != 0) {
            for (int_t i = 0; i < len; i++) {
                if (i <= n) {
                    T3[i] = T2[i] * (i % 2 ? (mod - 1) : 1);
                } else
                    T3[i] = 0;
            }
            transform(T1, size2, fliparr);
            transform(T2, size2, fliparr);
            transform(T3, size2, fliparr);
            for (int_t i = 0; i < len; i++) {
                T1[i] = T1[i] * T3[i] % mod;
                T2[i] = T2[i] * T3[i] % mod;
            }
            transform<-1>(T1, size2, fliparr);
            transform<-1>(T2, size2, fliparr);
            for (int_t i = 0; i < len; i++) {
                if (i * 2 < len) {
                    T2[i] = T2[i * 2] * inv % mod;
                } else
                    T2[i] = 0;
            }
            int_t b = k % 2;
            for (int_t i = 0; i < len; i++) {
                if (i % 2 == b) {
                    T1[i / 2] = T1[i] * inv % mod;
                }
    
                if (i > 0)  //防止把T1[0]改为0
                    T1[i] = 0;
            }
            k >>= 1;
        }
    
        return T1[0] * power(T2[0], -1) % mod;
    }
    void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) {
        int_t size2 = upper2n(n + m + 1);
        int_t len = 1 << size2;
        static int_t T1[LARGE], T2[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = A[i];
            else
                T1[i] = 0;
            if (i <= m)
                T2[i] = B[i];
            else
                T2[i] = 0;
        }
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform(T1, size2, fliparr);
        transform(T2, size2, fliparr);
        for (int_t i = 0; i < len; i++)
            T1[i] = T1[i] * T2[i] % mod;
        transform<-1>(T1, size2, fliparr);
        int_t inv = power(len, -1);
        for (int_t i = 0; i <= n + m; i++)
            C[i] = T1[i] * inv % mod;
    }
    int main() {
        std::ios::sync_with_stdio(false);
        static int_t A[LARGE], F[LARGE];
        static int_t T1[LARGE], T2[LARGE];
        int_t n, k;
        cin >> n >> k;
        for (int_t i = 1; i <= k; i++) {
            cin >> F[i];
            F[i] = (F[i] % mod + mod) % mod;
        }
        F[0] = 0;
        for (int_t i = 0; i < k; i++) {
            cin >> A[i];
            A[i] = (A[i] % mod + mod) % mod;
        }
        A[k] = 0;
        poly_mul(A, k, F, k, T1);
        T1[0] = A[0];
        for (int_t i = 1; i <= k - 1; i++) {
            T1[i] = (A[i] - T1[i] + mod) % mod;
        }
        for (int_t i = k; i <= 2 * k; i++)
            T1[i] = 0;
        T2[0] = 1;
        for (int_t i = 1; i <= k; i++)
            T2[i] = (mod - F[i]) % mod;
        int_t r = poly_divat(T1, T2, k, n);
        cout << r << endl;
        return 0;
    }
    /*
    (1+2*x)/(1+x+x^2)
    2 5
    
    1 2 0
    1 1 1
    
    
    ans = -2 = 998244351
    
    */

    代码2: 递归到$k<n$时直接求逆运算,较快,大约是代码1的一半时间

    #include <cstring>
    #include <iostream>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 2e6;
    
    const int_t mod = 998244353;
    const int_t g = 3;
    int_t power(int_t b, int_t i) {
        int_t r = 1;
        if (i < 0)
            i = (i % (mod - 1) + mod - 1) % (mod - 1);
        while (i) {
            if (i & 1)
                r = r * b % mod;
            b = b * b % mod;
            i >>= 1;
        }
        return r;
    }
    
    void makeflip(int_t* arr, int_t size2) {
        int_t len = (1 << size2);
        arr[0] = 0;
        for (int_t i = 1; i < len; i++) {
            arr[i] = (arr[i >> 1] >> 1) | ((i & 1) << (size2 - 1));
        }
    }
    
    int_t upper2n(int_t x) {
        int_t r = 0;
        while ((1 << r) < x)
            r++;
        return r;
    }
    template <int_t arg = 1>
    void transform(int_t* A, int_t size2, int_t* flip) {
        int_t len = (1 << size2);
        for (int_t i = 0; i < len; i++) {
            int_t r = flip[i];
            if (r > i)
                std::swap(A[i], A[r]);
        }
        for (int_t i = 2; i <= len; i *= 2) {
            int_t mr = power(g, arg * (mod - 1) / i);
            for (int_t j = 0; j < len; j += i) {
                int_t curr = 1;
                for (int_t k = 0; k < i / 2; k++) {
                    int_t u = A[j + k], v = curr * A[j + k + i / 2] % mod;
                    A[j + k] = (u + v) % mod;
                    A[j + k + i / 2] = (u - v + mod) % mod;
                    curr = curr * mr % mod;
                }
            }
        }
    }
    void poly_inv(const int_t* A, int_t n, int_t* result) {
        /*
        计算B(x)*A(x) = 1 mod x^n, 其中A(x)已知
        假设已知A(x)*B(x) = 1 mod x^{ceil(n/2)}
        假设C(x)*A(x) = 1 mod x^n
        (A(x)B(x)-1)^2 = A^2(x)B^2(x)-2A(x)B(x)+1= 0
        A(x)B^2(x)-2B(x)+C(x) = 0
        C(x) = 2B(x)-A(x)B^2(x)
        */
        if (n == 1) {
            result[0] = power(A[0], -1);
            return;
        }
        int_t next = n / 2 + n % 2;
        poly_inv(A, next, result);
        //次数不要选错了,应该用n次的A和B去卷
        int_t size2 = upper2n(n + 2 * next + 1);
        static int_t X[LARGE];
        static int_t Y[LARGE];
        int_t len = (1 << size2);
        memset(X + next, 0, sizeof(X[0]) * (len - n));
        memset(Y + next, 0, sizeof(Y[0]) * (len - next));
        memcpy(X, A, sizeof(A[0]) * n);
        memcpy(Y, result, sizeof(result[0]) * next);
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform<1>(X, size2, fliparr);
        transform<1>(Y, size2, fliparr);
    
        for (int_t i = 0; i < len; i++) {
            X[i] = (2 * Y[i] - X[i] * Y[i] % mod * Y[i] % mod + mod) % mod;
        }
        transform<-1>(X, size2, fliparr);
        const int_t inv = power(len, -1);
        for (int_t i = 0; i < n; i++)
            result[i] = X[i] * inv % mod;
    }
    int_t poly_divat(const int_t* F, const int_t* G, int_t n, int_t k) {
        /*
            n次多项式F和G
            计算F(x)/G(x)的k次项前系数
            考虑F(x)*G(-x)/G(x)*G(-x),分母只有偶数次项,写为C(x^2);分子写成xA(x^2)+B(x^2),如果k是奇数,那么递归(A,C,n,k/2),如果k是偶数,那么递归(B,C,n,k/2)
            到k<=n时直接计算
        */
        int_t size2 = upper2n(2 * n + 1);
        int_t len = 1 << size2;
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        static int_t T1[LARGE], T2[LARGE], T3[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = F[i], T2[i] = G[i];
            else
                T1[i] = T2[i] = T3[i] = 0;
        }
        const int_t inv = power(len, -1);
        while (k >= n) {
            for (int_t i = 0; i < len; i++) {
                if (i <= n) {
                    T3[i] = T2[i] * (i % 2 ? (mod - 1) : 1);
                } else
                    T3[i] = 0;
            }
            transform(T1, size2, fliparr);
            transform(T2, size2, fliparr);
            transform(T3, size2, fliparr);
            for (int_t i = 0; i < len; i++) {
                T1[i] = T1[i] * T3[i] % mod;
                T2[i] = T2[i] * T3[i] % mod;
            }
            transform<-1>(T1, size2, fliparr);
            transform<-1>(T2, size2, fliparr);
            for (int_t i = 0; i < len; i++) {
                if (i * 2 < len) {
                    T2[i] = T2[i * 2] * inv % mod;
                } else
                    T2[i] = 0;
            }
            int_t b = k % 2;
            for (int_t i = 0; i < len; i++) {
                if (i % 2 == b) {
                    T1[i / 2] = T1[i] * inv % mod;
                }
    
                if (i > 0)  //防止把T1[0]改为0
                    T1[i] = 0;
            }
            k >>= 1;
        }
    
        poly_inv(T2, k + 1, T3);
        int_t result = 0;
        //计算结果的k次项
        for (int_t i = 0; i <= k; i++) {
            result = (result + T1[i] * T3[k - i] % mod) % mod;
        }
        return result;
    }
    void poly_mul(const int_t* A, int_t n, const int_t* B, int_t m, int_t* C) {
        int_t size2 = upper2n(n + m + 1);
        int_t len = 1 << size2;
        static int_t T1[LARGE], T2[LARGE];
        for (int_t i = 0; i < len; i++) {
            if (i <= n)
                T1[i] = A[i];
            else
                T1[i] = 0;
            if (i <= m)
                T2[i] = B[i];
            else
                T2[i] = 0;
        }
        static int_t fliparr[LARGE];
        makeflip(fliparr, size2);
        transform(T1, size2, fliparr);
        transform(T2, size2, fliparr);
        for (int_t i = 0; i < len; i++)
            T1[i] = T1[i] * T2[i] % mod;
        transform<-1>(T1, size2, fliparr);
        int_t inv = power(len, -1);
        for (int_t i = 0; i <= n + m; i++)
            C[i] = T1[i] * inv % mod;
    }
    int main() {
        std::ios::sync_with_stdio(false);
        static int_t A[LARGE], F[LARGE];
        static int_t T1[LARGE], T2[LARGE];
        int_t n, k;
        cin >> n >> k;
        for (int_t i = 1; i <= k; i++) {
            cin >> F[i];
            F[i] = (F[i] % mod + mod) % mod;
        }
        F[0] = 0;
        for (int_t i = 0; i < k; i++) {
            cin >> A[i];
            A[i] = (A[i] % mod + mod) % mod;
        }
        A[k] = 0;
        poly_mul(A, k, F, k, T1);
        T1[0] = A[0];
        for (int_t i = 1; i <= k - 1; i++) {
            T1[i] = (A[i] - T1[i] + mod) % mod;
        }
        for (int_t i = k; i <= 2 * k; i++)
            T1[i] = 0;
        T2[0] = 1;
        for (int_t i = 1; i <= k; i++)
            T2[i] = (mod - F[i]) % mod;
        int_t r = poly_divat(T1, T2, k, n);
        cout << r << endl;
        return 0;
    }
    /*
    (1+2*x)/(1+x+x^2)
    2 5
    
    1 2 0
    1 1 1
    
    
    ans = -2 = 998244351
    
    */

     

  • Nginx路由匹配模拟器

    https://nginx.viraptor.info/

    真是我的救星,调路由调吐了

  • CFGYM103104I-WHUPC I Sequence

    题意:有个$n\times n,n\leq 5000$的矩阵,挖掉$m,m\leq 10^6$个格子,问不包括这些格子的子矩阵个数。

    做法很容易考虑:枚举一行,这一行从左往右扫,依次计算每个格子为右下角的子矩阵个数。

    扫的时候维护一个单调栈,里面存(元素,以这个元素为高度的格子最长向右延伸了多少格)

    为什么要存第二个项呢?我们维护的单调栈里面,每个东西实际上表示的是一个矩形,含义是:在我们处理完当前这个列后,这些矩形里的元素都可以作为合法的子矩阵左上顶点。

    扫的时候维护一个东西:单调栈里所有矩形的大小和$sum$,即合法的左上顶点个数。

    每次加入一个元素$x$,分以下三种情况讨论:

    • $x$大于单调栈最后的元素:直接加入。长度记为1,sum加上$x$(显然多了这么多合法的格子)
    • $x$等于单调栈最后的元素:单调栈最后的元素的长度加1,sum加上$x$
    • $x$小于单调栈最后的元素:弹出单调栈最后的元素,并把其长度加到倒数第二个元素上,然后重复这三个check

    最后我们每插入一个元素后,所维护的sum就是以这个点为右下角的答案。

    另外有一个坑:输入1625 0,答案输出正确,但是从1626 0开始,答案越来越小,看起来像是溢出了,但是开-fsanitize=undefined没报任何问题,结果最后查出来了,算一行的答案时使用的是std::accumlate,初始值传的是int类型,而后累加变量自动推导为int,导致这个函数里面溢出了,然后由于ubsan并不会对包含的其他文件的代码做检查,于是挂掉了。

     

    #include <assert.h>
    #include <algorithm>
    #include <iostream>
    #include <numeric>
    #include <utility>
    #include <vector>
    using int_t = long long;
    using std::cin;
    using std::cout;
    using std::endl;
    /**
     * 1624 0 -> R
     * 1625 0 -> R
     * 1626 0 -> E 正确1749670208001 本程序 1736785306113
     * 1627 0 -> E
     */
    bool block[5001][5001];
    int_t upmost[5001][5001];
    int_t n, m;
    int main() {
        std::ios::sync_with_stdio(false);
        cin >> n >> m;
        for (int_t i = 1; i <= m; i++) {
            int_t r, c;
            cin >> r >> c;
            block[r][c] = true;
        }
        for (int_t i = 1; i <= n; i++) {  //列
            int_t lastpos = 0;
            for (int_t j = 1; j <= n; j++) {
    #ifdef DEBUG
                // cout << "block " << j << " " << i << " = " << block[j][i] <<
                // endl;
    #endif
                if (block[j][i])
                    lastpos = j;
                upmost[j][i] = lastpos;
    #ifdef DEBUG
                cout << "upmost " << j << " " << i << " = " << upmost[j][i] << endl;
    #endif
            }
        }
        int_t result = 0;
        //枚举底边行号
        for (int_t i = 1; i <= n; i++) {
            static int_t answer[5001];  // j->以(i,j)为右下角的矩形个数
            //以当前元素为右下角的矩形个数
            int_t sum = 0;
            // first为元素,second为出现次数
            std::vector<std::pair<int_t, int_t>> stack;
            stack.emplace_back(0, 0);
            for (int_t j = 1; j <= n; j++) {
                int_t x = i - upmost[i][j];  //向上延伸的空格子数(包括i,j)
                int_t count = 1;
                while (x < stack.back().first) {
                    count += stack.back().second;
                    sum -= stack.back().first * stack.back().second;
                    stack.pop_back();
                }
                assert(sum >= 0);
                if (x == stack.back().first) {
                    stack.back().second += count;
                } else {
                    stack.emplace_back(x, count);
                }
                sum += x * count;
                answer[j] = sum;
                // cout << "answer col " << j << " = " << answer[j] << endl;
    #ifdef DEBUG
                cout << "pushed col " << j << " val " << x << " sum = " << sum
                     << " answer = " << answer[j] << endl;
    #endif
            }
            int_t currrow = std::accumulate(answer + 1, answer + 1 + n, (int_t)0);
            result += currrow;
            // cout << "answer row " << i << " = " << currrow << " " << currrow / i
            //      << " " << currrow % i << endl;
    #ifdef DEBUG
            cout << "answer row " << i << " = " << currrow << endl;
    #endif
        }
        cout << result << endl;
        return 0;
    }

     

  • CF1521C Nastia and a Hidden Permutation

    真是个沙比提,比赛的时候也想出来了,但是因为太迷糊了没写完

    首先我们先考虑如何确定下最大值所在的位置。

    执行询问$max(min(n-1,p_i),min(n,p_{i+1}))$,这个询问实际上等价于$max(min(n-1,p_i),p_{i+1})$。

    • 如果结果是n,那么说明$p_{i+1}=n$(很显然,max里第一项不会超过n-1)。
    • 如果结果比n-1要小,那么说明$p_i$和$p_{i+1}$都不是n。
    • 如果结果是n-1,那么就要讨论下了。这个时候n可能出现在这两个数里,也可能没出现。

    我们再构造另一个询问:$min(max(n-1,p_i),max(n,p_{i+1}))$,显然这个询问等价于$max(n-1,p_i)$,所以我们可以利用这个询问来检查$p_i$的值是不是n。

    在找到最大值的位置,设它为$maxpos$,那么我们对于每一个$i\neq maxpos$,构造询问$min(max(1,p_i),max(2,p_{maxpos}))$,这个询问等价于$p_i$,然后就可以求出来每个位置的值了。

    #include <algorithm>
    #include <iostream>
    using std::cin;
    using std::cout;
    using std::endl;
    using int_t = long long;
    int_t result[int_t(1e5)];
    int main() {
        std::ios::sync_with_stdio(false);
        int_t T;
        cin >> T;
        while (T--) {
            int_t n;
            cin >> n;
            int_t maxpos = -1;
            const auto check = [&](int_t i, int_t j) {
                cout << "? 1 " << i << " " << j << " " << n - 1 << endl;
                int_t x;
                cin >> x;
                if (x == n) {
                    maxpos = j;
                    return;
                }
                if (x < n - 1)
                    return;
                // x==n-1
                for (int_t _ = 1; _ <= 2; _++) {
                    cout << "? 2 " << i << " " << j << " "<< n - 1 << endl;
                    cin >> x;
                    if (x == n) {
                        maxpos = i;
                    }
                    std::swap(i, j);
                }
            };
            for (int_t i = 1; i <= n && i + 1 <= n && maxpos == -1; i += 2) {
                int_t j = i + 1;
    
                check(i, j);
            }
            if (maxpos == -1) {
                maxpos = n;
            }
            result[maxpos] = n;
            for (int_t i = 1; i <= n; i++) {
                if (i == maxpos)
                    continue;
                cout << "? 2 " << i << " " << maxpos << " 1" << endl;
                ;
                int_t x;
                cin >> x;
                result[i] = x;
            }
            cout << "! ";
            for (int_t i = 1; i <= n; i++) {
                cout << result[i] << " ";
            }
            cout << endl;
        }
        return 0;
    }
    

     

  • CFGYM102394L CCPC2019哈尔滨 L题 LRU Algorithm

    首先做法很简单:令缓存大小为n,然后直接把操作模拟一遍,后期如果我们限制了缓存大小为x,那就等价于取我们在模拟时长度为x的前缀。

    然后我们显然可以把前缀哈希算一算,插到`std::unordered_map`里,然后成功TLE。

    然后我们考虑另一个做法:把询问插到字典树里,仍然对操作进行模拟,每模拟一次后,在字典树上把序列走一遍,并标记对应的询问为Yes。

    但是有几个地方要注意

    • 一个点可能对应多个询问,所以用一个vector来挂询问吧。
    • 一组询问在去除后缀0之后可能长度为0,此时询问是被挂在根上的,务必进行处理,否则WA!
    #include <assert.h>
    #include <algorithm>
    #include <cstring>
    #include <iostream>
    #include <unordered_map>
    #include <vector>
    using int_t = int;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    using map_t = std::unordered_map<int_t, struct Node*>;
    const int_t LARGE = 5e3 + 20;
    char inputbuf[(int64_t)1e8];
    char* head = inputbuf;
    void initinput() {
        fread(inputbuf, 1, sizeof(inputbuf), stdin);
    }
    char nextchar() {
        assert(head <= inputbuf + sizeof(inputbuf));
        return *(head++);
    }
    template <class T>
    void read(T& x) {
        x = 0;
        char chr = nextchar();
        while (chr < '0' || chr > '9')
            chr = nextchar();
        while (chr >= '0' && chr <= '9') {
            x = x * 10 + chr - '0';
            chr = nextchar();
        }
    }
    
    struct Node {
        std::vector<bool*> result;
        map_t chd;
        ~Node() {
            for (const auto& kvp : chd)
                delete kvp.second;
        }
    };
    
    int_t arr[LARGE + 1];
    int_t arr1[LARGE + 10], queue[LARGE + 1];
    bool result[LARGE + 1];
    int main() {
        initinput();
        int_t T;
        read(T);
        while (T--) {
            queue[0] = 0;
            int_t n, q;
            read(n), read(q);
            Node* root = new Node;
            for (int_t i = 1; i <= n; i++) {
                read(arr[i]);
            }
            for (int_t i = 1; i <= q; i++) {
                result[i] = false;
                Node* curr = root;
                int_t len;
                read(len);
    #ifdef DEBUG
                cout << "insert len " << len << endl;
    #endif
                for (int_t j = 1; j <= len; j++) {
                    int_t x;
                    read(x);
                    if (x == 0)
                        continue;
    #ifdef DEBUG
                    cout << "insert " << x << endl;
    #endif
                    auto& ref = curr->chd[x];
                    if (ref == nullptr)
                        ref = new Node;
                    curr = ref;
                }
                curr->result.push_back(&result[i]);
    #ifdef DEBUG
                cout << "insert ok, result to " << i << endl;
    #endif
            }
            for (int_t i = 1; i <= n; i++) {
                int_t x = arr[i];
                arr1[0] = 0;
                arr1[++arr1[0]] = x;
                for (int_t j = 1; j <= queue[0]; j++) {
                    if (queue[j] != x)
                        arr1[++arr1[0]] = queue[j];
                }
                // arr1[0] = std::min(arr1[0], n);
                assert(arr1[0] <= n);
                memcpy(queue, arr1, sizeof(arr1[0]) * (n + 1));
                Node* curr = root;
    #ifdef DEBUG
                cout << "mapping seq ";
                for (int_t i = 1; i <= queue[0]; i++)
                    cout << queue[i] << " ";
                cout << endl;
    #endif
                for (int_t i = 1; i <= queue[0]; i++) {
                    if (!curr->chd.count(queue[i]))
                        break;
                    else
                        curr = curr->chd[queue[i]];
    #ifdef DEBUG
                    cout << "walk with " << queue[i] << endl;
    #endif
                    for (auto ptr : curr->result) {
                        *ptr = true;
    #ifdef DEBUG
                        cout << "mark result " << (ptr - result) << " to true"
                             << endl;
    #endif
                    }
                }
            }
            for (auto x : root->result)
                *x = true;
            for (int_t i = 1; i <= q; i++) {
                if (result[i]) {
                    puts("Yes");
                } else {
                    puts("No");
                }
            }
            delete root;
        }
        return 0;
    }

     

  • CFGYM102394E CCPC2019哈尔滨 E题 Exchanging Gifts

    假设我们能维护出最终序列的长度L和最终序列出现最多的数的个数cnt,假设$cnt\leq\frac L 2$ ,那么答案是L,这时候我们把序列升序和降序对起来就构造出n的结果。
    假设$cnt>\frac L 2$,那么答案是$2(L-cnt)$,我们仍然把序列升序和降序对应起来,答案是$L-(cnt-(L-cnt))=2*(L-cnt)$
    比如
    “`
    1 2 3 3 3 3 3
    3 3 3 3 3 2 1
    “`
    中间有3个3是重了的,那么重的部分有cnt-(L-cnt)个,其中L-cnt表示的是`1 2`的长度,所以总答案是L-(cnt-(L-cnt))
    现在的问题在于如何维护出这个众数,并且判定$cnt\leq \frac L 2$是否成立。

     

    考虑一种线性时间求求序列众数(出现次数大于一半)的方法:
    令f(i)表示我们从头开始扫到第i个元素的时候(用$a_i$表示),序列中出现次数最多的数与其他的数出现次数之和的差值。同时我们需要维护这个数是多少(用x来表示)。
    每次扫到$a_i$的时候:
    – 如果$a_i=x$,那么f(i)=f(i-1)+1,x不变
    – 如果$a_i\neq x$,那么f(i)=f(i-1)-1,然后如果$f(i)<0$,那么x变为$a_i$,同时$f(i)$取反。

     

    对于这个题,如果我们要使用这种求众数的方法,核心在于如何考虑操作2(合并两个序列时)如何处理。
    我们维护每个序列的f和x,合并两个序列的时候:
    – 如果他们的f相同,那么新序列的f不变,x相加。
    – 如果他们的f不同,那么考虑下两个序列的哪一个的x比较大。如果他们相同,那么新序列的f就写为0(这时候可能存在两个众数),x从原来的x里随便选一个。如果他们不同,那x选择为f较大的那一个,同时f设置为较大值减掉较小值。

     

    所以对于这个题,我们先用这种求众数的方法算出来最终序列的众数。
    但此时求出来的众数,仅仅是在保证`存在一个出现次数超过序列长度一半时候`的众数,如果不存在这样子的众数,也会得出来一个结果,但是并不具有意义。
    所以我们再跑一遍递推,求出来我们上一步求的众数的出现次数,然后我们就可以照着前文所述来算答案了。
    *此外,这个题卡常*
    我跑了半天性能分析后大概知道了几个卡常的点:
    1. 输入量非常巨大,请考虑一次性读入输入数据后自行用缓冲区处理输入。
    2. `std::vector`的构造函数非常慢(性能分析显示,$10^6$次调用大约花了80ms),所以不要使用vector来存储变长的序列,考虑自行分配-回收内存。
    3. `State`结构体的构造函数占了大概40ms的时间,考虑强制内联。
    另外,读入函数千万不要写错了!不要写成`chr>=’9’`!
    #pragma GCC optimize("O2")
    #include <assert.h>
    #include <inttypes.h>
    #include <algorithm>
    #include <iostream>
    #include <vector>
    
    
    using int_t = int;
    using std::cin;
    using std::cout;
    using std::endl;
    const int_t LARGE = 1e6;
    char inputbuf[(int(2e8) / 1024) * 1024];
    char* head = inputbuf;
    inline char nextchar() {
        return *(head++);
    }
    void initinput() {
        fread(inputbuf, 1024, sizeof(inputbuf) / 1024, stdin);
    }
    struct State {
        int64_t len;
        int_t mostval;
        int64_t mostcount;
        inline State(int64_t len = 0, int_t mostval = 0, int64_t mostcount = 0)
            : len(len), mostval(mostval), mostcount(mostcount) {}
        State operator+(const State& rhs) const {
            State result(len + rhs.len, 0, 0);
            if (mostval == rhs.mostval)
                result.mostval = mostval,
                result.mostcount = mostcount + rhs.mostcount;
            else {
                if (mostcount > rhs.mostcount) {
                    result.mostcount = mostcount - rhs.mostcount;
                    result.mostval = mostval;
                } else {
                    result.mostcount = rhs.mostcount - mostcount;
                    result.mostval = rhs.mostval;
                }
            }
            return result;
        }
    };
    std::ostream& operator<<(std::ostream& os, const State& state) {
        os << "State{len=" << state.len << ",mostval=" << state.mostval
           << ",mostcount=" << state.mostcount << "}";
        return os;
    }
    struct Opt {
        int_t type;
        int_t* data;
        int_t datalen;
        int_t x1, x2;
        int64_t mostcount = 0;
    } opts[LARGE + 1];
    State dp[LARGE + 1];
    int_t n;
    template <class T>
    void read(T& x) {
        x = 0;
        char chr = nextchar();
        while (chr < '0' || chr > '9')
            chr = nextchar();
        while (chr >= '0' && chr <= '9') {
            x = x * 10 + chr - '0';
            chr = nextchar();
        }
        assert(x >= 0);
    }
    template <class T>
    void write(T x) {
        assert(x >= 0);
        if (x > 9)
            write(x / 10);
        putchar('0' + x % 10);
    }
    int main() {
        // freopen("input.txt", "r", stdin);
        initinput();
        int_t T;
        read(T);
        while (T--) {
            read(n);
            for (int_t i = 1; i <= n; i++) {
                auto& ref = opts[i];
                if (ref.data) {
                    delete[] ref.data;
                    ref.data = nullptr;
                }
                dp[i] = State();
                read(ref.type);
                if (ref.type == 1) {
                    int_t k;
                    read(k);
                    int_t sum = 0, val = 0;
                    ref.data = new int_t[k + 1];
                    ref.datalen = k;
                    for (int_t i = 1; i <= k; i++) {
                        int_t x;
                        read(x);
                        ref.data[i] = x;
                        if (x == val)
                            sum++;
                        else
                            sum--;
                        if (sum < 0) {
                            sum *= -1, val = x;
                        }
                    }
                    // ref.seq.shrink_to_fit();
                    dp[i] = State(k, val, sum);
                } else {
                    read(ref.x1);
                    read(ref.x2);
                }
            }
            for (int_t i = 1; i <= n; i++) {
                const auto& ref1 = opts[i];
                if (ref1.type == 2) {
                    dp[i] = dp[ref1.x1] + dp[ref1.x2];
                }
            }
    #ifdef DEBUG
            for (int_t i = 1; i <= n; i++) {
                cout << "dp " << i << " = " << dp[i] << endl;
            }
    #endif
            int_t mostval = dp[n].mostval;
            for (int_t i = 1; i <= n; i++) {
                auto& ref = opts[i];
                if (ref.type == 1)
                    ref.mostcount = std::count(ref.data + 1,
                                               ref.data + 1 + ref.datalen, mostval);
                else
                    ref.mostcount = opts[ref.x1].mostcount + opts[ref.x2].mostcount;
            }
            int64_t mostcount = opts[n].mostcount;
    #ifdef DEBUG
            cout << "final len " << dp[n].len << endl;
            cout << "mostcount " << mostcount << endl;
            cout << "mostval " << mostval << endl;
    #endif
            if (mostcount * 2 <= dp[n].len) {
                write(dp[n].len);
            } else {
                write(2 * (dp[n].len - mostcount));
            }
            puts("");
        }
        return 0;
    }

     

  • CFGYM102394I CCPC2019哈尔滨 I题 Interesting Permutations

    签到题我都不会做,我爬了。

    首先检测一些明显非法的情况:

    1. $h_i\geq n$,明显不可能,$h_i$上限是$n-1$

    2. $h_i<h_{i-1}$(对于$i\geq 2$),最大值不会减小,最小值不会增大,所以$h_i$一定是单调不减的。

    3. $h_1\neq 0$,这很显然。

    4. $h_n\neq n-1$这也很显然。

    然后我们考虑从第$2$个元素开始枚举,我们维护一个`gap`变量,用以存储”在当前这个位置,一共有多少个位置可以填数,并且保证填了数之后不影响当前位置前缀最大值和前缀最小值的取值”。当前枚举到第$i$个元素时:

    • 如果$h_i=h_{i-1}$,那就说明当前这个位置的值没有改变前缀最值的分布,那么总答案就可以乘上`gap`(因为有这么多的方案数让我们来填,并且填了后不影响前缀最值),并且 `gap`要减掉1,因为我们填了个数后占了一个空位。
    • 如果$h_i>h_{i-1}$,那就说明当前这个位置的值改变了前缀最大值或者前缀最小值二者之一,那么总答案乘上2。同时`gap`要加上$h_i-h_{i-1}-1$,因为我们新创造了这么多的空位。
    #include <algorithm>
    #include <cstdio>
    #include <iostream>
    
    using int_t = long long;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t mod = 1e9 + 7;
    const int_t LARGE = 1e5 + 10;
    int_t arr[LARGE + 1];
    int_t n;
    int main() {
        std::ios::sync_with_stdio(false);
        int_t T;
        cin >> T;
        while (T--) {
            cin >> n;
            bool fail = false;
            for (int_t i = 1; i <= n; i++) {
                cin >> arr[i];
                if (arr[i] < arr[i - 1])
                    fail = true;
                if (i == 1 && arr[i] != 0)
                    fail = true;
                if (i == n && arr[i] != n - 1)
                    fail = true;
                if (arr[i] >= n)
                    fail = true;
            }
            if (fail) {
                cout << 0 << endl;
                continue;
            }
            int_t prod = 1;
            int_t sec = 0;
            for (int_t i = 2; i <= n; i++) {
                if (arr[i] == arr[i - 1]) {
                    prod = prod * sec % mod;
                    sec--;  //消耗了一个中间空位
                }
                if (arr[i] > arr[i - 1]) {
                    prod = prod * 2 % mod;
                    sec = (sec + arr[i] - arr[i - 1] - 1 + mod) % mod;
                }
            }
            cout << prod << endl;
        }
        return 0;
    }