标签: 快速幂

  • SDOI2011 计算器

    快速幂+exgcd+BSGS的模板.

    关于BSGS:

    $$\text{已知}a,b,p\ \text{求}x\text{使得}a^x\equiv b\left( mod\ p \right) \\\text{设}sqr=\sqrt{p}\\x=i\times sqr+j\\a^x=a^{i\times sqr}\times a^j=\left( a^{sqr} \right) ^i\times a^j=b\\\left( a^{sqr} \right) ^i=b\times a^{-j}\\\text{枚举}j\text{从0到}sqr,\text{把}b\times a^{-j}\text{存在哈希表里}\\\text{然后枚举}i,\text{查询是否存在}b\times a^{-j}\text{使得}\left( a^{sqr} \right) ^i=b\times a^{-j}\text{成立即可}\\$$

    #include <iostream>
    #include <algorithm>
    #include <map>
    #include <unordered_map>
    #include <utility>
    #include <cmath>
    #include <assert.h>
    using int_t = long long int;
    using pair_t = std::pair<int_t, int_t>;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const char *ERROR = "Orz, I cannot find x!";
    
    int_t power(int_t base, int_t index, int_t mod)
    {
        int_t result = 1;
        if (index < 0)
        {
            index *= -1;
            base = power(base, mod - 2, mod);
        }
        while (index)
        {
            if (index & 1)
            {
                result = (result * base) % mod;
            }
            index >>= 1;
            base = (base * base) % mod;
        }
        return result;
    }
    pair_t exgcd(int_t a, int_t b)
    {
        if (b == 0)
        {
            return pair_t(1, 0);
        }
        auto temp = exgcd(b, a % b);
        return pair_t(temp.second, temp.first - a / b * temp.second);
    }
    int_t gcd(int_t a, int_t b)
    {
        if (b == 0)
            return a;
        else
            return gcd(b, a % b);
    }
    //寻找x,满足a^x=b mod p
    int_t log_mod(int_t a, int_t b, int_t p)
    {
        a %= p;
        b %= p;
        if (a == 0)
            return -1;
        int_t sqr = sqrt(p - 1) + 3;
        assert(sqr * sqr >= p);
        std::unordered_map<int_t, int_t> memory;
        for (int_t i = 0; i <= sqr; i++)
        {
            int_t curr = b * power(a, -i, p) % p;
            if (memory.count(curr))
            {
                memory[curr] = std::min(memory[curr], i);
            }
            else
            {
                memory[curr] = i;
            }
        }
        // curr = 1;
        for (int_t i = 0; i <= sqr; i++)
        {
            int_t curr = power(a, sqr * i, p);
            if (memory.count(curr))
            {
                return memory[curr] + sqr * i;
            }
        }
        return -1;
    }
    int_t solve(int_t y, int_t z, int_t p)
    {
        /*
        Sy+Tp=z
        */
        int_t A = y % p;
        int_t B = p;
        int_t C = z % p;
        int_t gcd = ::gcd(A, B);
        if (z % gcd != 0)
            return -1;
        int_t x0 = exgcd(A, B).first;
        x0 = (x0 % p + p) % p;
        return x0 * C % p;
    }
    int main()
    {
        int_t T, k;
        cin >> T >> k;
        for (int_t i = 1; i <= T; i++)
        {
            int_t y, z, p;
            cin >> y >> z >> p;
            if (k == 1)
            {
                cout << power(y, z, p) << endl;
            }
            else if (k == 2)
            {
                int_t result = solve(y, z, p);
                if (result == -1)
                    cout << ERROR << endl;
                else
                    cout << result << endl;
            }
            else if (k == 3)
            {
                int_t result = log_mod(y, z, p);
                if (result == -1)
                    cout << ERROR << endl;
                else
                    cout << result << endl;
            }
        }
        return 0;
    }