标签: 生成函数

  • CF438E 多项式开根 多项式牛顿迭代

    多项式真好玩qwq.jpg

    $$\text{显然满足权值为}S\text{的二叉树的个数是有限的}\\\text{设}T\left( z \right) =\sum_{s\in \left\{ c_i \right\}}{z^s}\text{,即如果集合}\left\{ c_i \right\} \text{中存在元素}x,\text{那么}T\left( z \right) \text{中就存在项}z^x\\\text{设}G\left( z \right) \text{表示一棵有根二叉树的方案数的生成函数,其中}G^{\left[ i \right]}\left( z \right) \text{为权值为}i\text{的二叉树的数目}\\\text{则有递归}\\G\left( z \right) =1+T\left( z \right) G^2\left( z \right) \\\text{其中1表示这是棵空树,非空情况下,}T\left( z \right) \text{表示树根的取值,而}G^2\left( z \right) \text{表示两棵子树的取值}\\\text{解方程}\\T\left( z \right) G^2\left( z \right) -G\left( z \right) +1=0\\G\left( z \right) =\frac{1\pm \sqrt{1-4T\left( z \right)}}{2T\left( z \right)}\\\text{因为}G\left( 0 \right) =\text{1,所以}G\left( z \right) =\frac{1+\sqrt{1-4T\left( z \right)}}{2T\left( z \right)}\\\text{因为}T\left( z \right) \text{的常数项为0,不能求逆,所以我们变一下形}\\G\left( z \right) =\frac{\left( 1+\sqrt{1-4T\left( z \right)} \right) \left( 1-\sqrt{1-4T\left( z \right)} \right)}{2T\left( z \right) \left( 1-\sqrt{1-4T\left( z \right)} \right)}=\frac{2}{1-\sqrt{1-4T\left( z \right)}}\\\text{我们需要做的只剩下多项式开根和求逆了}\\\\\text{多项式开根与牛顿迭代:}\\\text{设}F\left( X\left( z \right) \right) =X^2\left( z \right) -A\left( z \right) \\\text{要计算}\sqrt{A\left( z \right)}\text{,等价于解方程}F\left( X\left( z \right) \right) \equiv \text{0\ }mod\ z^n\\n=\text{1时,}X\left( z \right) \equiv \sqrt{A^{\left[ 0 \right]}\left( z \right)}\\n\ge \text{2时,考虑倍增}\\\text{假设已经计算出了}X_0\left( z \right) ,\text{使得}G\left( X_0\left( z \right) \right) \equiv \text{0\ }mod\ z^{\lceil \left. \frac{n}{2} \rceil \right.}\\\text{现在要求}G\left( X_1\left( z \right) \right) \equiv \text{0\ }mod\ z^n\\\text{将}G\left( X\left( z \right) \right) \text{在}X\left( z \right) =X_0\left( z \right) \text{处泰勒展开}\\G\left( X_1\left( z \right) \right) \equiv \sum_{i\ge 0}{\frac{\left( X_1\left( z \right) -X_0\left( z \right) \right) ^i}{i!}G^{\left( i \right)}\left( X_0\left( z \right) \right)}\ \\\equiv G\left( X_0\left( z \right) \right) +\left( X_1\left( z \right) -X_0\left( z \right) \right) G`\left( X_0\left( z \right) \right) +\sum_{i\ge 2}{\frac{\left( X_1\left( z \right) -X_0\left( z \right) \right) ^i}{i!}G^{\left( i \right)}\left( X_0\left( z \right) \right)}\\\text{因为}X_1\left( z \right) \text{与}X_0\left( z \right) \text{的低}\lceil \frac{n}{2} \rceil \text{次项相同,所以}X_1\left( z \right) -X_0\left( z \right) \text{的不为0的次数最低的项不会小于}\lceil \frac{n}{2} \rceil \\\therefore \left( X_1\left( z \right) -X_0\left( z \right) \right) ^i\left( i\ge 2 \right) \text{的次数最低的项不会小于}n\\\therefore \left( X_1\left( z \right) -X_0\left( z \right) \right) ^i\equiv \text{0\ }\left( mod\ z^n \right) \left( i\ge 2 \right) \\\therefore G\left( X_1\left( z \right) \right) \equiv G\left( X_0\left( z \right) \right) +\left( X_1\left( z \right) -X_0\left( z \right) \right) G`\left( X_0\left( z \right) \right) \equiv \text{0\ }\left( mod\ z^n \right) \\\text{整理一下}G\left( X_0\left( z \right) \right) +X_1\left( z \right) G`\left( X_0\left( z \right) \right) -X_0\left( z \right) G`\left( X_0\left( z \right) \right) \equiv \text{0\ }\left( mod\ z^n \right) \\\text{整理出}X_1\left( z \right) \\X_1\left( z \right) \equiv X_0\left( z \right) -\frac{G\left( X_0\left( z \right) \right)}{G`\left( X_0\left( z \right) \right)}\ \left( mod\ z^n \right) \\\text{在计算}\sqrt{A\left( z \right)}\text{的情况下,}G\left( X_0\left( z \right) \right) =X_{0}^{2}\left( z \right) -A\left( z \right) \\G`\left( X\left( z \right) \right) \equiv 2X\left( z \right) \\X_1\left( z \right) \equiv X_0\left( z \right) -\frac{X_{0}^{2}\left( z \right) -A\left( z \right)}{2X_0\left( z \right)}\ \left( mod\ z^n \right) \\$$

    #include <iostream>
    #include <algorithm>
    #include <cmath>
    #include <assert.h>
    #include <sstream>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t mod = 998244353;
    const int_t g = 3;
    int_t power(int_t base, int_t index);
    int_t upper2n(int_t x);
    int_t bitReverse(int_t bits, int_t size2);
    template <int_t arg = 1>
    void transform(int_t *A, int_t size);
    void poly_inv(int_t *A, int_t *inv, int_t n);
    void poly_sqrt(int_t *A, int_t *result, int_t n);
    const int_t LARGE = 1 << 21;
    std::string tostr(int_t *A, int_t size)
    {
        std::ostringstream oss;
        for (int_t i = 0; i < size; i++)
            oss << A[i] << " ";
        return oss.str();
    }
    int main()
    {
        static int_t T[LARGE];
        static int_t Tsqrt[LARGE];
        static int_t Tinv[LARGE];
        static int_t result[LARGE];
        int_t n, m;
        scanf("%lld%lld", &n, &m);
        int_t size = upper2n(2 * m + 1);
        {
            for (int_t i = 1; i <= n; i++)
            {
                int_t x;
                scanf("%lld", &x);
                T[x] = 1;
            }
            // cout << "T=" << tostr(T, size) << endl;
        }
    
        {
            for (int_t i = 0; i <= m; i++)
            {
                T[i] = (-4 * T[i] + mod) % mod;
            }
            T[0] = 1;
            poly_sqrt(T, Tsqrt, m + 1);
            Tsqrt[0] = (Tsqrt[0] + 1);
            poly_inv(Tsqrt, result, m + 1);
        }
        for (int_t i = 1; i <= m; i++)
        {
            printf("%d\n", (int)(result[i] * 2 % mod));
        }
        return 0;
    }
    void poly_sqrt(int_t *A, int_t *result, int_t n)
    {
        static int_t Ax[LARGE];
        if (n == 1)
        {
            assert(A[0] == 1 || A[0] == 0);
            result[0] = A[0];
            return;
        }
        poly_sqrt(A, result, n / 2 + n % 2);
        int_t size = upper2n(3 * n - 1);
        for (int_t i = 0; i < size; i++)
        {
            Ax[i] = A[i];
            if (i >= n)
                Ax[i] = 0;
        }
        static int_t inv[LARGE];
        std::fill(inv, inv + size, 0);
        poly_inv(result, inv, n);
        for (int_t i = 0; i < size; i++)
        {
            inv[i] = (inv[i] * power(2, -1)) % mod;
        }
        transform(result, size);
        transform(Ax, size);
        transform(inv, size);
        for (int_t i = 0; i < size; i++)
        {
            result[i] = (result[i] - (result[i] * result[i] % mod - Ax[i] + mod) % mod * inv[i] % mod + mod) % mod;
        }
        transform<-1>(result, size);
        for (int_t i = 0; i < size; i++)
        {
            if (i < n)
            {
                result[i] = (result[i] * power(size, -1)) % mod;
            }
            else
            {
                result[i] = 0;
            }
        }
        // cout << "sqrt mod x^" << n << " = ";
        // for (int_t i = 0; i < size; i++)
        // cout << result[i] << " ";
        // cout << endl;
    }
    void poly_inv(int_t *A, int_t *inv, int_t n)
    {
        static int_t Ax[LARGE];
        if (n == 1)
        {
            inv[0] = power(A[0], -1);
            return;
        }
        poly_inv(A, inv, n / 2 + n % 2);
        int_t size = upper2n(n * 3 - 1);
        //一个坑,注意每次弄完后要把Ax中次数大等于n的项清零
        for (int_t i = 0; i < size; i++)
        {
    
            if (i < n)
            {
                Ax[i] = A[i];
            }
            else
            {
                Ax[i] = 0;
            }
        }
        //C(x)<-2B(x)-B^2(x)A(x)
    
        transform(inv, size);
        transform(Ax, size);
        for (int_t i = 0; i < size; i++)
        {
            inv[i] = (2 * inv[i] - inv[i] * inv[i] % mod * Ax[i] % mod + mod) % mod;
        }
        transform<-1>(inv, size);
        for (int_t i = 0; i < n; i++)
        {
            inv[i] = (inv[i] * power(size, -1)) % mod;
        }
        for (int_t i = n; i < size; i++)
        {
            inv[i] = 0;
        }
    }
    
    template <int_t arg = 1>
    void transform(int_t *A, int_t size)
    {
        int_t size2 = log2(size);
        for (int_t i = 0; i < size; i++)
        {
            int_t x = bitReverse(i, size2);
            if (x > i)
            {
                std::swap(A[i], A[x]);
            }
        }
        for (int_t i = 2; i <= size; i *= 2)
        {
            int_t mr = power(g, arg * (mod - 1) / i);
            for (int_t j = 0; j < size; j += i)
            {
                int_t curr = 1;
                for (int_t k = 0; k < i / 2; k++)
                {
                    int_t u = A[j + k];
                    int_t t = A[j + k + i / 2] * curr % mod;
                    A[j + k] = (u + t) % mod;
                    A[j + k + i / 2] = (u - t + mod) % mod;
                    curr = (curr * mr) % mod;
                }
            }
        }
    }
    
    int_t bitReverse(int_t bits, int_t size2)
    {
        int_t result = 0;
        for (int_t i = 1; i < size2; i++)
        {
            result |= (bits & 1);
            result <<= 1;
            bits >>= 1;
        }
        return result | (bits & 1);
    }
    int_t upper2n(int_t x)
    {
        int_t result = 1;
        while (result < x)
            result *= 2;
        return result;
    }
    int_t power(int_t base, int_t index)
    {
        int_t result = 1;
        if (index < 0)
        {
            index *= -1;
            base = power(base, mod - 2);
        }
        while (index)
        {
            if (index & 1)
            {
                result = (result * base) % mod;
            }
            index >>= 1;
            base = (base * base) % mod;
        }
        return result;
    }