快速幂+exgcd+BSGS的模板.
关于BSGS:
$$\text{已知}a,b,p\ \text{求}x\text{使得}a^x\equiv b\left( mod\ p \right) \\\text{设}sqr=\sqrt{p}\\x=i\times sqr+j\\a^x=a^{i\times sqr}\times a^j=\left( a^{sqr} \right) ^i\times a^j=b\\\left( a^{sqr} \right) ^i=b\times a^{-j}\\\text{枚举}j\text{从0到}sqr,\text{把}b\times a^{-j}\text{存在哈希表里}\\\text{然后枚举}i,\text{查询是否存在}b\times a^{-j}\text{使得}\left( a^{sqr} \right) ^i=b\times a^{-j}\text{成立即可}\\$$
#include <iostream>
#include <algorithm>
#include <map>
#include <unordered_map>
#include <utility>
#include <cmath>
#include <assert.h>
using int_t = long long int;
using pair_t = std::pair<int_t, int_t>;
using std::cin;
using std::cout;
using std::endl;
const char *ERROR = "Orz, I cannot find x!";
int_t power(int_t base, int_t index, int_t mod)
{
int_t result = 1;
if (index < 0)
{
index *= -1;
base = power(base, mod - 2, mod);
}
while (index)
{
if (index & 1)
{
result = (result * base) % mod;
}
index >>= 1;
base = (base * base) % mod;
}
return result;
}
pair_t exgcd(int_t a, int_t b)
{
if (b == 0)
{
return pair_t(1, 0);
}
auto temp = exgcd(b, a % b);
return pair_t(temp.second, temp.first - a / b * temp.second);
}
int_t gcd(int_t a, int_t b)
{
if (b == 0)
return a;
else
return gcd(b, a % b);
}
//寻找x,满足a^x=b mod p
int_t log_mod(int_t a, int_t b, int_t p)
{
a %= p;
b %= p;
if (a == 0)
return -1;
int_t sqr = sqrt(p - 1) + 3;
assert(sqr * sqr >= p);
std::unordered_map<int_t, int_t> memory;
for (int_t i = 0; i <= sqr; i++)
{
int_t curr = b * power(a, -i, p) % p;
if (memory.count(curr))
{
memory[curr] = std::min(memory[curr], i);
}
else
{
memory[curr] = i;
}
}
// curr = 1;
for (int_t i = 0; i <= sqr; i++)
{
int_t curr = power(a, sqr * i, p);
if (memory.count(curr))
{
return memory[curr] + sqr * i;
}
}
return -1;
}
int_t solve(int_t y, int_t z, int_t p)
{
/*
Sy+Tp=z
*/
int_t A = y % p;
int_t B = p;
int_t C = z % p;
int_t gcd = ::gcd(A, B);
if (z % gcd != 0)
return -1;
int_t x0 = exgcd(A, B).first;
x0 = (x0 % p + p) % p;
return x0 * C % p;
}
int main()
{
int_t T, k;
cin >> T >> k;
for (int_t i = 1; i <= T; i++)
{
int_t y, z, p;
cin >> y >> z >> p;
if (k == 1)
{
cout << power(y, z, p) << endl;
}
else if (k == 2)
{
int_t result = solve(y, z, p);
if (result == -1)
cout << ERROR << endl;
else
cout << result << endl;
}
else if (k == 3)
{
int_t result = log_mod(y, z, p);
if (result == -1)
cout << ERROR << endl;
else
cout << result << endl;
}
}
return 0;
}