组合数对合数取模

$$ \text{计算}\left( \begin{array}{c} n\\ m\\ \end{array} \right) \ mod\ p,\text{满足}p\le 10^6,n,m\le 10^{18},\text{不保证}p\text{是质数} \\ \text{首先可以考虑将}p\text{进行分解,得到}p_1^{k_1}p_2^{k_2}p_{3}^{k_3}…. \\ \text{然后考虑分别计算出}\left( \begin{array}{c} n\\ m\\ \end{array} \right) \ mod\ p_i^{k_i},\text{然后用}CRT\text{进行合并} \\ \text{现在的问题在于如何计算}\frac{n!}{m!\left( n-m \right) !}\ mod\ p_i^{k_i} \\ \text{因为}p_i^{k_i}\text{不是质数,所以不能用卢卡斯定理,但是我们考虑到}p_i^{k_i}\text{中只有}p_i\text{一个质因子,} \\ \text{所以我们可以把}n!\text{中所有的}p_i\text{提出来} \\ \text{例如现在计算11!对}3^2\text{取模} \\ \text{11!}=1\times 2\times 3\times 4\times 5\times 6\times 7\times 8\times 9\times 10\times \text{11\ } \\ =\left( 1\times 2 \right) \times \left( 3 \right) \times \left( 4\times 5 \right) \times \left( 2\times 3 \right) \times \left( 7\times 8 \right) \times \left( 3\times 3 \right) \times \left( 10\times 11 \right) \\ =3^3\times \left( 1\times 2\times 3 \right) \times \left( 1\times 2\times 4\times 5\times 7\times 8\times 10\times 11 \right) \\ =3^3\times \left( 1\times 2\times 3 \right) \times \left( 1\times 2\times 4\times 5\times 7\times 8\times 1\times 2 \right) \\ \text{对于}3^3,\text{额外记录下来,对于}\left( 1\times 2\times 3 \right) ,\text{递归计算,对于}\left( 1\times 2\times 4\times 5\times 7\times 8 \right) \times \left( 1\times 2 \right) ,\text{求出单个循环节和不足一个循环节的部分} \\ \text{每个循环节的长度是}p_i^{k_i}-p_i^{k_i-1},\text{一共循环了}\lfloor \frac{n-\lfloor \frac{n}{p_i} \rfloor}{p_i^{k_i}-p_i^{k_i-1}} \rfloor \text{次,不足一个循环节部分的长度是}\left( n-\lfloor \frac{n}{p_i} \rfloor \right) \ mod\left( p_i^{k_i}-p_i^{k_i-1} \right) \ \\ \text{对于不足一个循环节的部分暴力计算} \\ \text{最后可以得到11!中的3的幂次和除掉3之后的数在模}3^2\text{下的值} \\ \text{复杂度}O\left( p_i^{k_i}\log n \right) \\ \text{这样对于}\left( \begin{array}{c} n\\ m\\ \end{array} \right) =\frac{n!}{m!\left( n-m \right) !},\text{我们可以分别计算出}n!,m!,\left( n-m \right) !\text{中}p_i\text{的幂次和除掉}p_i\text{的幂次的部分,} \\ \text{然后就可以计算了} \\ \\ $$

#include <iostream>
#include <algorithm>
#include <utility>
#include <inttypes.h>
using int_t = int64_t;
using pair_t = std::pair<int_t, int_t>;

using std::cin;
using std::cout;
using std::endl;

const int_t LARGE = 1e6;
int_t power(int_t base, int_t index, int_t mod);
int_t gcd(int_t a, int_t b)
{
    if (b == 0)
        return a;
    return gcd(b, a % b);
}

pair_t exgcd(int_t a, int_t b)
{
    if (b == 0)
        return pair_t(1, 0);
    auto ret = exgcd(b, a % b);
    return pair_t(ret.second, ret.first - a / b * ret.second);
}

int_t excrt(int64_t *coef, int64_t *mod, int64_t n)
{
    int_t prex = coef[1], M = mod[1];
    for (int_t i = 2; i <= n; i++)
    {
        int_t A = M, B = mod[i];
        int_t C = coef[i] - prex;
        C = (C % B + B) % B;
        int_t gcd = ::gcd(A, B);
        auto x = exgcd(A, B).first;
        x = (x % (B / gcd) + (B / gcd)) % (B / gcd) * (C / gcd);
        int_t preM = M;
        M = M / gcd * mod[i];
        prex = ((prex + x * preM) % M + M) % M;
    }
    return prex;
}
//计算n!中,p的次数和除掉p的幂之后的数对p^k取模的值
void fact(int_t n, int_t p, int_t k, int_t *idx, int_t *rem, int_t *facts)
{
    if (n < p)
    {
        *idx = 0;
        *rem = facts[n];
        return;
    }
    fact(n / p, p, k, idx, rem, facts);
    *idx += n / p;
    int_t prod = 1, remaining = 1;
    int_t count = 0;
    const int_t mod = power(p, k, 998244353);
    int_t length = mod - mod / p;
    for (int_t i = 1; i < mod; i++)
    {
        if (i % p != 0)
        {
            prod = prod * i % mod;
            if (count < (n - n / p) % length)
            {
                remaining = remaining * i % mod;
                count++;
            }
        }
    }
    (*rem) = (*rem) * power(prod, (n - n / p) / (length), mod) % mod * remaining % mod;
}

//计算C(n,m)对p^k取模的结果
int_t C(int_t n, int_t m, int_t p, int_t k)
{
    static int_t fact[LARGE + 1];
    int_t mod = power(p, k, 998244353);
    const int_t phi = mod - (mod / p);
    fact[0] = 1;
    for (int_t i = 1; i < p; i++)
        fact[i] = fact[i - 1] * i % mod;
    int_t idx1, idx2, idx3, rem1, rem2, rem3;
    ::fact(n, p, k, &idx1, &rem1, fact);
    ::fact(m, p, k, &idx2, &rem2, fact);
    ::fact(n - m, p, k, &idx3, &rem3, fact);
    int_t index = idx1 - (idx2 + idx3);
    int_t remaining = rem1 * power(rem2 * rem3 % mod, phi - 1, mod) % mod;
    return power(p, index, mod) * remaining % mod;
}

int main()
{
    int_t n, m, p;
    cin >> n >> m >> p;
    static int_t index[LARGE + 1];
    static int_t prime[LARGE + 1];
    static int_t mod[LARGE + 1];
    static int_t coef[LARGE + 1];
    int_t used = 0;
    for (int_t i = 2; i <= p && p != 1; i++)
    {
        if (p % i == 0)
        {
            used++;
            prime[used] = i;
            while (p % i == 0)
            {
                index[used]++;
                p /= i;
            }
        }
    }
    for (int_t i = 1; i <= used; i++)
    {
        mod[i] = power(prime[i], index[i], 998244353);
        coef[i] = C(n, m, prime[i], index[i]);
    }
    cout << excrt(coef, mod, used) << endl;
    return 0;
}
int_t power(int_t base, int_t index, int_t mod)
{
    int_t result = 1;
    while (index)
    {
        if (index & 1)
            result = result * base % mod;
        base = base * base % mod;
        index >>= 1;
    }
    return result;
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理