본문 바로가기

알고리즘/백준 문제 풀이

[C++] 11385번 씽크스몰

반응형

https://www.acmicpc.net/problem/11385


 

26/02/03

 

 

풀이 방향은 중국인의 나머지 정리를 알고있는 전제 하에 굉장히 쉽다.


 

문제 접근 방식:

 

 

여기서는 FFT대신 NTT를 사용한다.(물론 FFT를 사용해도 되는 방식이 있고, 실제로 그렇게 풀리기도 하나 일단 나는 NTT를 사용한 풀이로 해결하였다.)

NTT를 사용하면 알다시피, 항상 NTT-friendly한 소수가 있고 그 소수의 모듈로 안에서 연산을 하게 된다.

즉, 우리가 구하고자 하는 다항식의 실제 계수가 모듈로 안에서 연산한 계수가 다를 수 있다.

따라서, 다양한 NTT-friendly한 소수를 사용하여(대략 3개 정도), 나온 다항식의 계수들에 대한 결과를 각 차수 별로 중국인의 나머지 정리를 적용하여 원래 다항식의 계수를 복원하면 된다.

여기서 중국인의 나머지 정리에 대한 내용을 설명하기에는 한 글의 볼륨이 너무 커지므로, 이후 기회가 된다면 따로 중국인의 나머지 정리에 대한 글을 작성하여 첨부하도록 하겠다.

해당 문제에서 가장 중요한 것은, 중국인의 나머지 정리를 얼마나 잘 구현했냐의 여부이다. overflow가 나지 않도록 구현하는 것이 가장 주요한 관건이라고 할 수 있을 것이다.


아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
// 11385번 씽크스몰
// NTT, CRT
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>

using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'

/*
Notes:
MOD = 998'244'353 = 119*2^23 + 1
MOD = 1'004'535'809 = 479*2^21 + 1 / G = 3 / Works for lengths up to 2^21. / Use with CRT and 998'244'353
MOD = 469'762'049 = 7*2^26 + 1 / G = 3 / Works for lengths up to 2^26.
MOD = 167'772'161 = 5*2^25 + 1 / G = 3 / Works for lengths up to 2^25.
MOD = 1'224'736'769 = 73*2^24 + 1 / G = 3 / Works for lengths up to 2^24.
Combine above 3-NTT primes with CRT -> 75bits fast product ok.
*/
// NOTE THAT FOR FAST WORKING, CHANGE i128 -> ll, ll -> int
template<long long MOD, long long G>
struct NTT {
    using i128 = __int128_t;
    using ll = long long;
private :
    // HELPER FUNCTION
    inline ll mul_mod(ll a, ll b) const {return (i128)a * b % MOD;}
    inline ll pow_mod(ll a, ll e) const {
        ll r = 1;
        while(e > 0){
            if(e & 1) r = mul_mod(r, a);
            a = mul_mod(a, a);
            e >>= 1;
        }
        return r;
    }
    inline ll mod_inv(ll a) const {return pow_mod(a, MOD-2);}
public : 
    void iterative_NTT(vector<ll>& a, int invert){
        int n = (int)a.size();
        // 1) bit-reversal permutation
        for (int i = 1, j = 0; i < n; i++){
            int bit = n >> 1;
            while (j & bit){
                j ^= bit;
                bit >>= 1;
            }
            j ^= bit;
            if (i < j) swap(a[i], a[j]);
        }
        // 2) butterflies by length = 2, 4, 8, ...
        for (int len = 2; len <= n; len <<= 1){
            // primitive len-th root of unity
            ll w_len = pow_mod(G, (MOD - 1) / len);
            if (invert) w_len = mod_inv(w_len);
            for (int i = 0; i < n; i += len){
                ll w = 1;
                for (int j = 0; j < len / 2; j++){
                    ll u = a[i+j];
                    ll v = mul_mod(w, a[i+j+len/2]);
                    ll x = u+v; if (x >= MOD) x -= MOD;
                    ll y = u-v; if (y < 0) y += MOD;
                    a[i+j] = x;
                    a[i+j+len/2] = y;
                    w = mul_mod(w, w_len);
                }
            }
        }
        // 3) Divide by n (multiply by inverse) for inverse transform
        if (invert){
            ll inv_n = mod_inv(n);
            for (int i = 0; i < n; ++i){
                a[i] = mul_mod(a[i], inv_n);
            }
        }
    }
    // Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
    // Output : Convolution of two coefficient vector
    vector<ll> convolution(vector<ll> a, vector<ll> b) {
        if (a.empty() || b.empty()) return {};
        int S_a = (int)a.size(), S_b = (int)b.size();
        // Make vector size easy to dnc(2^n).
        int n = 1;
        while (n < S_a + S_b - 1){
            n <<= 1;
        }
        a.resize(n); b.resize(n);
        // Normalize to [0, MOD)
        for (int i = 0; i < n; ++i){
            a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
            b[i] %= MOD; if (b[i] < 0) b[i] += MOD;
        }
        // NTT
        iterative_NTT(a, 0); iterative_NTT(b, 0);
        // Pointwise product
        for (int i = 0; i < n; i++) {
            a[i] = mul_mod(a[i], b[i]);
        }
        // INTT
        iterative_NTT(a, 1);
        a.resize(S_a + S_b - 1);
        return a;
    }
};

// CRT utilities
// - merge two congruences
// - merge many congruences (sequential)
// - Garner: compute result modulo some MOD without constructing huge integer
struct CRT{
    using ll = long long;
    using i128 = __int128_t;
private:
    static inline ll mul_mod(ll a, ll b, ll m){return (ll)((i128)a * b % m);}
    // Normalize to [0, m) (work for long long)
    static inline ll norm(ll a, ll m){
        a %= m;
        if(a < 0) a += m;
        return a;
    }
    static inline i128 norm(i128 a, i128 m){
        a %= m;
        if(a < 0) a += m;
        return a;
    }
    static inline ll gcd(ll a, ll b){return b == 0 ? (a >= 0 ? a : -a) : gcd(b, a % b);}
    // Extended gcd : returns gcd(a, b) and find x, y s.t. ax+by = gcd(a,b)
    static inline ll extended_gcd(ll a, ll b, ll& x, ll& y){
        if (b == 0){x = 1; y = 0; return a >= 0 ? a : -a;}
        ll x1, y1;
        ll g = extended_gcd(b, a%b, x1, y1);
        x = y1; y = x1 - (a / b)*y1;
        return g;
    }
    // Modular inverse of a mod m (Assumes gcd(a, m) = 1)
    static inline ll mod_inv(ll a, ll m){
        ll x, y;
        ll g = extended_gcd(a, m, x, y);
        return norm(x, m);
    }
    // Special merge for mergeAll
    // First is x = r (mod M) (r, M can be i128), second is x = a2 (mod m2) (m2 fits ll)
    // Returns {nr, nM} where nM = lcm(M, m2), nr in [0, nM)
    static inline pair<i128, i128> mergeTwoi128(i128 r, i128 M, ll a2, ll m2) {
        a2 = norm(a2, m2);
        ll r_mod_m2 = (ll)norm(r, (i128)m2); // r mod m2 (safe because m2 fits ll)
        ll M_mod_m2 = (ll)norm(M, (i128)m2);
        ll g = gcd(M_mod_m2, m2); // g = gcd(M, m2) = gcd(M % m2, m2)
        ll diff = a2 - r_mod_m2;
        if (diff % g != 0) return {-1, -1};
        // Reduce : M / g = M', m2 / g = m2'
        // (M' mod m2') = ( (M mod m2) / g )
        ll m2_g = m2 / g;
        ll M_mod_m2_g = (ll)((i128)M_mod_m2 / (i128)g);
        ll diff_g = diff / g; // may be negative

        ll M_mod_m2_g_inv = mod_inv(norm(M_mod_m2_g, m2_g), m2_g);
        ll t = mul_mod(norm(diff_g, m2_g), M_mod_m2_g_inv, m2_g);
        // nr = r + M * t, nM = lcm(M, m2) = (M/g) * m2
        i128 nr = r + M * (i128)t;
        i128 nM = M * (i128)m2_g;
        nr = norm(nr, nM);
        return {nr, nM};
    }
public:
    // Merge x ≡ a1 (mod m1) and x ≡ a2 (mod m2)
    // Note that this code even works when m1 and m2 are not coprime.
    // When solvable, returns {r, M} where M = lcm(m1,m2), r in [0,M).
    // When not solvable(not coprime and inconsistent), returns {-1, -1}.
    static inline pair<i128, i128> mergeTwo(ll a1, ll m1, ll a2, ll m2){
        a1 = norm(a1, m1); a2 = norm(a2, m2); // Normalize
        // General Merge works even if not coprime.
        ll x, y;
        ll g = extended_gcd(m1, m2, x, y); // x*m1 + y*m2 = g
        ll diff = a2 - a1;
        // Check solvablity
        if (diff % g != 0) return {-1, -1};
        // Reduce : m1/g * t = diff / g (mod m2/g)
        ll m1_g = m1/g, m2_g = m2/g, diff_g = diff/g;
        ll m1_g_inv = mod_inv(m1_g, m2_g);
        ll t = mul_mod(m1_g_inv, norm(diff_g, m2_g), m2_g);
        // Recover : x = a1 + m1*t
        i128 r = (i128)a1 + (i128)m1 * t;
        i128 M = (i128)m1_g * (i128)m2; // lcm(m1, m2)
        r = norm(r, M);
        return {r, M};
    }
    // Merge many congruences sequentially.
    // Note that final result x can be overflow when M is too large. (Then, you have to use GarnersMod)
    // return {r, M} if ok. If not, return {-1, -1};
    static inline pair<i128, i128> mergeAll(const vector<ll>& a, const vector<ll>& m){
        if (a.size() != m.size()) return {-1, -1};
        i128 r = 0, M = 1;
        for (size_t i = 0; i < a.size(); ++i){
            auto [nr, nM] = mergeTwoi128(r, M, a[i], m[i]);
            if (nr == -1) return {-1, -1};
            r = nr; M = nM;
        }
        return {r, M};
    }
    // Garner Algorithm
    // Given congruences(all mod are pairwise coprime), compute x mod MOD without building huge x. (all mod are coprime with MOD)
    // Assume that x = c0 + c1*m0 + c2*m0*m1 + ...
    //
    // Step 0:
    // x ≡ c0 (mod m0)
    // -> c0 ≡ a0 (mod m0)
    //
    // Step 1:
    // x ≡ c0 + c1*m0 (mod m1)
    // -> c1*m0 ≡ a1 - c0 (mod m1)
    // -> c1 ≡ (a1 - c0) * m0^{-1} (mod m1)
    //
    // Step 2:
    // x ≡ c0 + c1*m0 + c2*m0*m1 (mod m2)
    // -> c2*(m0*m1) ≡ a2 - (c0 + c1*m0) (mod m2)
    // -> c2 ≡ (a2 - current_x) * (m0*m1)^{-1} (mod m2)
    // ...
    // In general, at step k:
    // Let Pk = m0*m1*...*m(k-1)
    // x ≡ current_x + ck*Pk (mod mk)
    // -> ck ≡ (ak - current_x) * Pk^{-1} (mod mk)
    //
    // Each ck is determined so that previously satisfied congruences
    // (mod m0, ..., mod m(k-1)) are not broken.
    static inline ll Garner(const vector<ll>& aIn, const vector<ll>& mIn, ll MOD){
        int n = (int)mIn.size();
        vector<ll> a(n+1), m(n+1);
        for (int i = 0; i < n; ++i){
            m[i] = mIn[i];
            a[i] = norm(aIn[i], m[i]); // Normalize [0, m[i])
        }
        m[n] = MOD; a[n] = 0;
        vector<ll> Pk(n+1, 1);        // Pk[i] = (m0*m1*...*m(k-1)) mod m[i]
        vector<ll> current_x(n+1, 0); // current_x[i] = (current x) mod m[i]
        for (int k = 0; k < n; ++k){
            ll mk = m[k];
            ll leftTerm = norm((a[k] - current_x[k]), mk);
            ll Pkinv = mod_inv(Pk[k] % mk, mk);
            ll ck = mul_mod(leftTerm, Pkinv, mk);
            for (int i = k+1; i <= n; ++i){
                current_x[i] = norm(current_x[i] + mul_mod(Pk[i], ck, m[i]), m[i]); // Add Pk[i]*ck
                Pk[i] = mul_mod(Pk[i], mk % m[i], m[i]); // Pk[i] *= (mk % m[i])
            }
        }
        return current_x[n] % MOD;
    }
    // Convert i128 to string
    static inline string i128toStr(i128 v){
        if(v == 0) return "0";
        int neg = (v < 0 ? 1 : 0);
        if(neg) v = -v;
        string s;
        while(v > 0){
            int digit = (int)(v % 10);
            s.push_back(char('0' + digit));
            v /= 10;
        }
        if(neg) s.push_back('-');
        reverse(s.begin(), s.end());
        return s;
    }
};

int main() {
    fastio
    using ll = long long;
    using i128 = __int128_t;
    
    NTT<998'244'353,3> ntt1;
    NTT<1'004'535'809,3> ntt2;
    NTT<469'762'049,3> ntt3;
    int N, M; cin >> N >> M;
    vector<long long> f(N+1), g(M+1);
    for (int i = 0; i < N+1; ++i) cin >> f[i];
    for (int i = 0; i < M+1; ++i) cin >> g[i];
    auto C1 = ntt1.convolution(f, g);
    auto C2 = ntt2.convolution(f, g);
    auto C3 = ntt3.convolution(f, g);
    vector<i128> total;
    vector<ll> modulars = {998'244'353, 1'004'535'809, 469'762'049};
    i128 ans = 0;
    for (int i = 0; i < C1.size(); ++i){
        vector<ll> a = {C1[i], C2[i], C3[i]};
        pair<i128, i128> p = CRT::mergeAll(a, modulars);
        ans = ans^p.first;
    }
    cout << CRT::i128toStr(ans);
    return 0;
}
반응형

'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글

[C++] 14756번 Telescope  (0) 2026.02.26
[C++] 14958번 Rock Paper Scissors  (0) 2026.02.26
[C++] 13575번 보석 가게  (0) 2026.02.25
[C++] 15576번 큰 수 곱셈 (2)  (0) 2026.02.24
[Python] 25214번 크림 파스타  (0) 2026.02.24