SDOI2011 计算器

快速幂+exgcd+BSGS的模板.

关于BSGS:

$$\text{已知}a,b,p\ \text{求}x\text{使得}a^x\equiv b\left( mod\ p \right) \\\text{设}sqr=\sqrt{p}\\x=i\times sqr+j\\a^x=a^{i\times sqr}\times a^j=\left( a^{sqr} \right) ^i\times a^j=b\\\left( a^{sqr} \right) ^i=b\times a^{-j}\\\text{枚举}j\text{从0到}sqr,\text{把}b\times a^{-j}\text{存在哈希表里}\\\text{然后枚举}i,\text{查询是否存在}b\times a^{-j}\text{使得}\left( a^{sqr} \right) ^i=b\times a^{-j}\text{成立即可}\\$$

#include <iostream>
#include <algorithm>
#include <map>
#include <unordered_map>
#include <utility>
#include <cmath>
#include <assert.h>
using int_t = long long int;
using pair_t = std::pair<int_t, int_t>;
using std::cin;
using std::cout;
using std::endl;

const char *ERROR = "Orz, I cannot find x!";

int_t power(int_t base, int_t index, int_t mod)
{
    int_t result = 1;
    if (index < 0)
    {
        index *= -1;
        base = power(base, mod - 2, mod);
    }
    while (index)
    {
        if (index & 1)
        {
            result = (result * base) % mod;
        }
        index >>= 1;
        base = (base * base) % mod;
    }
    return result;
}
pair_t exgcd(int_t a, int_t b)
{
    if (b == 0)
    {
        return pair_t(1, 0);
    }
    auto temp = exgcd(b, a % b);
    return pair_t(temp.second, temp.first - a / b * temp.second);
}
int_t gcd(int_t a, int_t b)
{
    if (b == 0)
        return a;
    else
        return gcd(b, a % b);
}
//寻找x,满足a^x=b mod p
int_t log_mod(int_t a, int_t b, int_t p)
{
    a %= p;
    b %= p;
    if (a == 0)
        return -1;
    int_t sqr = sqrt(p - 1) + 3;
    assert(sqr * sqr >= p);
    std::unordered_map<int_t, int_t> memory;
    for (int_t i = 0; i <= sqr; i++)
    {
        int_t curr = b * power(a, -i, p) % p;
        if (memory.count(curr))
        {
            memory[curr] = std::min(memory[curr], i);
        }
        else
        {
            memory[curr] = i;
        }
    }
    // curr = 1;
    for (int_t i = 0; i <= sqr; i++)
    {
        int_t curr = power(a, sqr * i, p);
        if (memory.count(curr))
        {
            return memory[curr] + sqr * i;
        }
    }
    return -1;
}
int_t solve(int_t y, int_t z, int_t p)
{
    /*
    Sy+Tp=z
    */
    int_t A = y % p;
    int_t B = p;
    int_t C = z % p;
    int_t gcd = ::gcd(A, B);
    if (z % gcd != 0)
        return -1;
    int_t x0 = exgcd(A, B).first;
    x0 = (x0 % p + p) % p;
    return x0 * C % p;
}
int main()
{
    int_t T, k;
    cin >> T >> k;
    for (int_t i = 1; i <= T; i++)
    {
        int_t y, z, p;
        cin >> y >> z >> p;
        if (k == 1)
        {
            cout << power(y, z, p) << endl;
        }
        else if (k == 2)
        {
            int_t result = solve(y, z, p);
            if (result == -1)
                cout << ERROR << endl;
            else
                cout << result << endl;
        }
        else if (k == 3)
        {
            int_t result = log_mod(y, z, p);
            if (result == -1)
                cout << ERROR << endl;
            else
                cout << result << endl;
        }
    }
    return 0;
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理