$$ \text{计算}\left( \begin{array}{c} n\\ m\\ \end{array} \right) \ mod\ p,\text{满足}p\le 10^6,n,m\le 10^{18},\text{不保证}p\text{是质数} \\ \text{首先可以考虑将}p\text{进行分解,得到}p_1^{k_1}p_2^{k_2}p_{3}^{k_3}…. \\ \text{然后考虑分别计算出}\left( \begin{array}{c} n\\ m\\ \end{array} \right) \ mod\ p_i^{k_i},\text{然后用}CRT\text{进行合并} \\ \text{现在的问题在于如何计算}\frac{n!}{m!\left( n-m \right) !}\ mod\ p_i^{k_i} \\ \text{因为}p_i^{k_i}\text{不是质数,所以不能用卢卡斯定理,但是我们考虑到}p_i^{k_i}\text{中只有}p_i\text{一个质因子,} \\ \text{所以我们可以把}n!\text{中所有的}p_i\text{提出来} \\ \text{例如现在计算11!对}3^2\text{取模} \\ \text{11!}=1\times 2\times 3\times 4\times 5\times 6\times 7\times 8\times 9\times 10\times \text{11\ } \\ =\left( 1\times 2 \right) \times \left( 3 \right) \times \left( 4\times 5 \right) \times \left( 2\times 3 \right) \times \left( 7\times 8 \right) \times \left( 3\times 3 \right) \times \left( 10\times 11 \right) \\ =3^3\times \left( 1\times 2\times 3 \right) \times \left( 1\times 2\times 4\times 5\times 7\times 8\times 10\times 11 \right) \\ =3^3\times \left( 1\times 2\times 3 \right) \times \left( 1\times 2\times 4\times 5\times 7\times 8\times 1\times 2 \right) \\ \text{对于}3^3,\text{额外记录下来,对于}\left( 1\times 2\times 3 \right) ,\text{递归计算,对于}\left( 1\times 2\times 4\times 5\times 7\times 8 \right) \times \left( 1\times 2 \right) ,\text{求出单个循环节和不足一个循环节的部分} \\ \text{每个循环节的长度是}p_i^{k_i}-p_i^{k_i-1},\text{一共循环了}\lfloor \frac{n-\lfloor \frac{n}{p_i} \rfloor}{p_i^{k_i}-p_i^{k_i-1}} \rfloor \text{次,不足一个循环节部分的长度是}\left( n-\lfloor \frac{n}{p_i} \rfloor \right) \ mod\left( p_i^{k_i}-p_i^{k_i-1} \right) \ \\ \text{对于不足一个循环节的部分暴力计算} \\ \text{最后可以得到11!中的3的幂次和除掉3之后的数在模}3^2\text{下的值} \\ \text{复杂度}O\left( p_i^{k_i}\log n \right) \\ \text{这样对于}\left( \begin{array}{c} n\\ m\\ \end{array} \right) =\frac{n!}{m!\left( n-m \right) !},\text{我们可以分别计算出}n!,m!,\left( n-m \right) !\text{中}p_i\text{的幂次和除掉}p_i\text{的幂次的部分,} \\ \text{然后就可以计算了} \\ \\ $$
#include <iostream>
#include <algorithm>
#include <utility>
#include <inttypes.h>
using int_t = int64_t;
using pair_t = std::pair<int_t, int_t>;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 1e6;
int_t power(int_t base, int_t index, int_t mod);
int_t gcd(int_t a, int_t b)
{
if (b == 0)
return a;
return gcd(b, a % b);
}
pair_t exgcd(int_t a, int_t b)
{
if (b == 0)
return pair_t(1, 0);
auto ret = exgcd(b, a % b);
return pair_t(ret.second, ret.first - a / b * ret.second);
}
int_t excrt(int64_t *coef, int64_t *mod, int64_t n)
{
int_t prex = coef[1], M = mod[1];
for (int_t i = 2; i <= n; i++)
{
int_t A = M, B = mod[i];
int_t C = coef[i] - prex;
C = (C % B + B) % B;
int_t gcd = ::gcd(A, B);
auto x = exgcd(A, B).first;
x = (x % (B / gcd) + (B / gcd)) % (B / gcd) * (C / gcd);
int_t preM = M;
M = M / gcd * mod[i];
prex = ((prex + x * preM) % M + M) % M;
}
return prex;
}
//计算n!中,p的次数和除掉p的幂之后的数对p^k取模的值
void fact(int_t n, int_t p, int_t k, int_t *idx, int_t *rem, int_t *facts)
{
if (n < p)
{
*idx = 0;
*rem = facts[n];
return;
}
fact(n / p, p, k, idx, rem, facts);
*idx += n / p;
int_t prod = 1, remaining = 1;
int_t count = 0;
const int_t mod = power(p, k, 998244353);
int_t length = mod - mod / p;
for (int_t i = 1; i < mod; i++)
{
if (i % p != 0)
{
prod = prod * i % mod;
if (count < (n - n / p) % length)
{
remaining = remaining * i % mod;
count++;
}
}
}
(*rem) = (*rem) * power(prod, (n - n / p) / (length), mod) % mod * remaining % mod;
}
//计算C(n,m)对p^k取模的结果
int_t C(int_t n, int_t m, int_t p, int_t k)
{
static int_t fact[LARGE + 1];
int_t mod = power(p, k, 998244353);
const int_t phi = mod - (mod / p);
fact[0] = 1;
for (int_t i = 1; i < p; i++)
fact[i] = fact[i - 1] * i % mod;
int_t idx1, idx2, idx3, rem1, rem2, rem3;
::fact(n, p, k, &idx1, &rem1, fact);
::fact(m, p, k, &idx2, &rem2, fact);
::fact(n - m, p, k, &idx3, &rem3, fact);
int_t index = idx1 - (idx2 + idx3);
int_t remaining = rem1 * power(rem2 * rem3 % mod, phi - 1, mod) % mod;
return power(p, index, mod) * remaining % mod;
}
int main()
{
int_t n, m, p;
cin >> n >> m >> p;
static int_t index[LARGE + 1];
static int_t prime[LARGE + 1];
static int_t mod[LARGE + 1];
static int_t coef[LARGE + 1];
int_t used = 0;
for (int_t i = 2; i <= p && p != 1; i++)
{
if (p % i == 0)
{
used++;
prime[used] = i;
while (p % i == 0)
{
index[used]++;
p /= i;
}
}
}
for (int_t i = 1; i <= used; i++)
{
mod[i] = power(prime[i], index[i], 998244353);
coef[i] = C(n, m, prime[i], index[i]);
}
cout << excrt(coef, mod, used) << endl;
return 0;
}
int_t power(int_t base, int_t index, int_t mod)
{
int_t result = 1;
while (index)
{
if (index & 1)
result = result * base % mod;
base = base * base % mod;
index >>= 1;
}
return result;
}
发表回复