2019-3-14更新:换成了拉格朗日插值
$$ \text{注意到}F_k\left( n \right) \text{是积性函数} \\ \text{对于质数幂}F_k\left( p^c \right) =\sum_{0\le i\le c}{F}_{k-1}\left( p^i \right) \\ \text{可以注意到}F_k\left( p^c \right) \text{的取值与}p\text{无关} \\ \text{设}F_k\left( p^n \right) \text{的生成函数为}G_k\left( x \right) =\sum_{n\ge 0}{F}_k\left( p^n \right) x^n \\ \text{由上述可得}G_k\left( x \right) =\frac{G_{k-1}\left( x \right)}{1-x} \\ \text{其中}G_0\left( x \right) =\frac{1}{1-x} \\ \text{故}G_k\left( x \right) =\frac{1}{\left( 1-x \right) ^{k+1}} \\ \text{由广义二项式定理可得} \\ F_k\left( p^n \right) =\left( \begin{array}{c} n+k\\ n\\ \end{array} \right) \\ \text{同时又注意到,}F_k\left( p_{1}^{c_1}p_{2}^{c_2}p_{3}^{c_3}….p_{t}^{c_t} \right) \text{的取值只与}\left\{ c_1,c_2….c_t \right\} \text{有关。} \\ \text{写个}DFS\text{看一下可行的}\left\{ c_1,c_2..c_t \right\} \text{的取值只有244种} \\ \text{又因为}\left( \begin{array}{c} c+x\\ c\\ \end{array} \right) =\frac{\left( c+x \right) !}{c!x!}\text{是一个关于}x\text{的}c\text{次多项式} \\ \text{所以,}\prod_i{\left( \begin{array}{c} c_i+x\\ c_i\\ \end{array} \right)}\text{是一个关于}x\text{的}\sum_i{c_i}\text{次多项式。} \\ \text{又因为}\sum_i{c_i}\text{的取值上界为}O\left( \log n \right) \\ \text{所以直接拉格朗日插值求出这个多项式,就可以}O\left( \log N \right) \text{的时间计算出}F_k\left( n \right) \\ $$
#pragma GCC optimize("unroll-loops")
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <numeric>
#include <set>
#include <vector>
using int_t = int;
using std::cin;
using std::cout;
using std::endl;
const int_t mod = 998244353;
const int_t LIMIT = 5e5;
const int_t LARGE = 5e5;
int fact[LIMIT + 1];
int factInv[LIMIT + 1];
bool isPrime[LIMIT + 1];
int_t factor[LIMIT + 1];
bool operator<(const std::vector<int_t>& A, const std::vector<int_t>& B) {
if (A.size() != B.size())
return A.size() < B.size();
else {
for (int_t i = 0; i < A.size(); i++) {
if (A[i] != B[i]) return A[i] < B[i];
}
return false;
}
}
bool operator==(const std::vector<int_t>& A, const std::vector<int_t>& B) {
if (A.size() != B.size()) return false;
return std::equal(A.begin(), A.end(), B.begin());
}
template <class T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
for (auto x : vec) os << x << " ";
return os;
}
std::vector<std::vector<int>> states;
std::set<std::vector<int>> set;
int polys[244][25];
std::vector<int> state[LARGE + 1];
std::map<std::vector<int>, int> hash;
int prefix[LARGE + 1][244];
int C(int n, int m) {
return 1ll * fact[n] * factInv[m] % mod * factInv[n - m] % mod;
}
int_t power(int_t base, int_t index) {
int_t result = 1;
base = (base % mod + mod) % mod;
while (index) {
if (index & 1) result = 1ll * result * base % mod;
base = 1ll * base * base % mod;
index >>= 1;
}
return result;
}
int maxn, maxk, q;
//将k代入状态
int subState(const std::vector<int>& state, int_t k) {
int result = 1;
for (int x : state) result = 1ll * result * C(x + k, x) % mod;
return result;
}
int __attribute__((hot)) subPoly(int* poly, int x) {
int result = 0;
for (int i = poly[0]; i >= 1; i--) {
result = (1ll * result * x + poly[i]) % mod;
}
return result;
};
void poly_mul(int* A, int n, const int* B, int m) {
static int C[50];
memset(C, 0, sizeof(C));
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= m; j++) {
C[i + j] = (C[i + j] + 1ll * A[i] * B[j] % mod) % mod;
}
}
std::copy(C, C + n + m + 1, A);
}
//计算n次多项式A整除m次多项式B,忽略余数
void poly_div(const int* A, int n, const int* B, int m, int* C) {
static int R[50];
std::copy(A, A + n + 1, R);
for (int i = n - m; i >= 0; i--) {
int_t x = (int_t)R[i + m] * power(B[m], mod - 2) % mod;
C[i] = x;
for (int_t j = m; j >= 0; j--) {
R[i + j] = (R[i + j] - 1ll * B[j] * x + 2ll * mod) % mod;
}
}
}
// state是一组{c_1,c_2...c_n}的取值
//求出对应的多项式
void interpolate(const std::vector<int>& state, int* output) {
static int prod[50];
static int result[50];
memset(prod, 0, sizeof(prod));
memset(result, 0, sizeof(result));
prod[0] = 1;
// k是要带进去的点的数量
// k-1是结果多项式的次数
int k = std::accumulate(state.begin(), state.end(), 0) + 1;
for (int_t i = 0; i < k; i++) {
static int A[20];
A[0] = (mod - i) % mod;
A[1] = 1;
poly_mul(prod, k, A, 1);
}
#ifdef DEBUG
cout << "prod = ";
for (int_t i = 0; i <= k; i++) cout << prod[i] << " ";
cout << endl;
#endif
for (int i = 0; i < k; i++) {
int y = subState(state, i);
#ifdef DEBUG
cout << "y of " << i << " = " << y << endl;
#endif
static int A[50];
static int B[50];
B[0] = (mod - i) % mod;
B[1] = 1;
poly_div(prod, k, B, 1, A);
#ifdef DEBUG
cout << "divide x-" << i << " = ";
for (int_t i = 0; i <= k - 1; i++) cout << A[i] << " ";
cout << endl;
#endif
int x0 = 1;
for (int_t j = 0; j < k; j++) {
if (j != i) x0 = 1ll * x0 * (i - j + mod) % mod;
}
#ifdef DEBUG
cout << "x0 of " << i << " = " << x0 << endl;
#endif
y = 1ll * y * power(x0, mod - 2) % mod;
for (int_t j = 0; j < k; j++)
result[j] = (1ll * result[j] + 1ll * y * A[j] % mod + mod) % mod;
}
output[0] = k;
std::copy(result, result + k, output + 1);
}
int main() {
// freopen("qwq.txt", "r", stdin);
fact[0] = fact[1] = factInv[0] = factInv[1] = 1;
for (int_t i = 2; i <= LIMIT; i++) {
fact[i] = 1ll * fact[i - 1] * i % mod;
factInv[i] = 1ll * (mod - mod / i) * factInv[mod % i] % mod;
}
for (int_t i = 2; i <= LIMIT; i++) {
factInv[i] = 1ll * factInv[i - 1] * factInv[i] % mod;
}
memset(isPrime, true, sizeof(isPrime));
isPrime[1] = 0;
for (int_t i = 2; i <= LIMIT; i++) {
if (isPrime[i]) {
for (long long j = 1ll * i * i; j <= LIMIT; j += i) {
isPrime[j] = false;
factor[j] = i;
}
factor[i] = i;
}
}
scanf("%d%d%d", &maxn, &maxk, &q);
for (int_t i = 2; i <= maxn; i++) {
int x = i;
std::vector<int> index;
while (x != 1) {
int times = 0;
int p = factor[x];
while (x % p == 0) {
times++;
x /= p;
}
index.push_back(times);
}
std::sort(index.begin(), index.end());
state[i] = index;
set.insert(index);
}
states.resize(set.size());
std::copy(set.begin(), set.end(), states.begin());
for (const auto& vec : states) {
hash[vec] = hash.size();
interpolate(vec, polys[hash[vec]]);
#ifdef DEBUG
cout << "poly of ";
for (int_t x : vec) cout << x << " ";
cout << "is ";
for (int_t x : polys[hash[vec]]) cout << x << " ";
cout << endl;
#endif
}
for (int i = 2; i <= maxn; i++) {
memcpy(prefix[i], prefix[i - 1], sizeof(prefix[i]));
auto hash0 = hash[state[i]];
prefix[i][hash0] = (prefix[i][hash0] + 1) % mod;
#ifdef DEBUG
cout << "count 1 at " << i << " for state " << state[i] << "("
<< hash[state[i]] << ")" << endl;
#endif
}
for (int_t i = 1; i <= q; i++) {
int k, left, right;
scanf("%d%d%d", &left, &right, &k);
int result = 0;
for (int j = 0; j < states.size(); j++) {
int count = prefix[right][j] - prefix[left - 1][j];
if (count > 0) {
result =
(1ll * result + 1ll * count * subPoly(polys[j], k) % mod) %
mod;
}
}
if (left == 1) result = (result + 1) % mod;
printf("%d\n", result);
}
return 0;
}
发表回复