반응형
https://www.acmicpc.net/problem/11385
26/02/03
풀이 방향은 중국인의 나머지 정리를 알고있는 전제 하에 굉장히 쉽다.
문제 접근 방식:
여기서는 FFT대신 NTT를 사용한다.(물론 FFT를 사용해도 되는 방식이 있고, 실제로 그렇게 풀리기도 하나 일단 나는 NTT를 사용한 풀이로 해결하였다.)
NTT를 사용하면 알다시피, 항상 NTT-friendly한 소수가 있고 그 소수의 모듈로 안에서 연산을 하게 된다.
즉, 우리가 구하고자 하는 다항식의 실제 계수가 모듈로 안에서 연산한 계수가 다를 수 있다.
따라서, 다양한 NTT-friendly한 소수를 사용하여(대략 3개 정도), 나온 다항식의 계수들에 대한 결과를 각 차수 별로 중국인의 나머지 정리를 적용하여 원래 다항식의 계수를 복원하면 된다.
여기서 중국인의 나머지 정리에 대한 내용을 설명하기에는 한 글의 볼륨이 너무 커지므로, 이후 기회가 된다면 따로 중국인의 나머지 정리에 대한 글을 작성하여 첨부하도록 하겠다.
해당 문제에서 가장 중요한 것은, 중국인의 나머지 정리를 얼마나 잘 구현했냐의 여부이다. overflow가 나지 않도록 구현하는 것이 가장 주요한 관건이라고 할 수 있을 것이다.
아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.
더보기
// 11385번 씽크스몰
// NTT, CRT
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'
/*
Notes:
MOD = 998'244'353 = 119*2^23 + 1
MOD = 1'004'535'809 = 479*2^21 + 1 / G = 3 / Works for lengths up to 2^21. / Use with CRT and 998'244'353
MOD = 469'762'049 = 7*2^26 + 1 / G = 3 / Works for lengths up to 2^26.
MOD = 167'772'161 = 5*2^25 + 1 / G = 3 / Works for lengths up to 2^25.
MOD = 1'224'736'769 = 73*2^24 + 1 / G = 3 / Works for lengths up to 2^24.
Combine above 3-NTT primes with CRT -> 75bits fast product ok.
*/
// NOTE THAT FOR FAST WORKING, CHANGE i128 -> ll, ll -> int
template<long long MOD, long long G>
struct NTT {
using i128 = __int128_t;
using ll = long long;
private :
// HELPER FUNCTION
inline ll mul_mod(ll a, ll b) const {return (i128)a * b % MOD;}
inline ll pow_mod(ll a, ll e) const {
ll r = 1;
while(e > 0){
if(e & 1) r = mul_mod(r, a);
a = mul_mod(a, a);
e >>= 1;
}
return r;
}
inline ll mod_inv(ll a) const {return pow_mod(a, MOD-2);}
public :
void iterative_NTT(vector<ll>& a, int invert){
int n = (int)a.size();
// 1) bit-reversal permutation
for (int i = 1, j = 0; i < n; i++){
int bit = n >> 1;
while (j & bit){
j ^= bit;
bit >>= 1;
}
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
// 2) butterflies by length = 2, 4, 8, ...
for (int len = 2; len <= n; len <<= 1){
// primitive len-th root of unity
ll w_len = pow_mod(G, (MOD - 1) / len);
if (invert) w_len = mod_inv(w_len);
for (int i = 0; i < n; i += len){
ll w = 1;
for (int j = 0; j < len / 2; j++){
ll u = a[i+j];
ll v = mul_mod(w, a[i+j+len/2]);
ll x = u+v; if (x >= MOD) x -= MOD;
ll y = u-v; if (y < 0) y += MOD;
a[i+j] = x;
a[i+j+len/2] = y;
w = mul_mod(w, w_len);
}
}
}
// 3) Divide by n (multiply by inverse) for inverse transform
if (invert){
ll inv_n = mod_inv(n);
for (int i = 0; i < n; ++i){
a[i] = mul_mod(a[i], inv_n);
}
}
}
// Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
// Output : Convolution of two coefficient vector
vector<ll> convolution(vector<ll> a, vector<ll> b) {
if (a.empty() || b.empty()) return {};
int S_a = (int)a.size(), S_b = (int)b.size();
// Make vector size easy to dnc(2^n).
int n = 1;
while (n < S_a + S_b - 1){
n <<= 1;
}
a.resize(n); b.resize(n);
// Normalize to [0, MOD)
for (int i = 0; i < n; ++i){
a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
b[i] %= MOD; if (b[i] < 0) b[i] += MOD;
}
// NTT
iterative_NTT(a, 0); iterative_NTT(b, 0);
// Pointwise product
for (int i = 0; i < n; i++) {
a[i] = mul_mod(a[i], b[i]);
}
// INTT
iterative_NTT(a, 1);
a.resize(S_a + S_b - 1);
return a;
}
};
// CRT utilities
// - merge two congruences
// - merge many congruences (sequential)
// - Garner: compute result modulo some MOD without constructing huge integer
struct CRT{
using ll = long long;
using i128 = __int128_t;
private:
static inline ll mul_mod(ll a, ll b, ll m){return (ll)((i128)a * b % m);}
// Normalize to [0, m) (work for long long)
static inline ll norm(ll a, ll m){
a %= m;
if(a < 0) a += m;
return a;
}
static inline i128 norm(i128 a, i128 m){
a %= m;
if(a < 0) a += m;
return a;
}
static inline ll gcd(ll a, ll b){return b == 0 ? (a >= 0 ? a : -a) : gcd(b, a % b);}
// Extended gcd : returns gcd(a, b) and find x, y s.t. ax+by = gcd(a,b)
static inline ll extended_gcd(ll a, ll b, ll& x, ll& y){
if (b == 0){x = 1; y = 0; return a >= 0 ? a : -a;}
ll x1, y1;
ll g = extended_gcd(b, a%b, x1, y1);
x = y1; y = x1 - (a / b)*y1;
return g;
}
// Modular inverse of a mod m (Assumes gcd(a, m) = 1)
static inline ll mod_inv(ll a, ll m){
ll x, y;
ll g = extended_gcd(a, m, x, y);
return norm(x, m);
}
// Special merge for mergeAll
// First is x = r (mod M) (r, M can be i128), second is x = a2 (mod m2) (m2 fits ll)
// Returns {nr, nM} where nM = lcm(M, m2), nr in [0, nM)
static inline pair<i128, i128> mergeTwoi128(i128 r, i128 M, ll a2, ll m2) {
a2 = norm(a2, m2);
ll r_mod_m2 = (ll)norm(r, (i128)m2); // r mod m2 (safe because m2 fits ll)
ll M_mod_m2 = (ll)norm(M, (i128)m2);
ll g = gcd(M_mod_m2, m2); // g = gcd(M, m2) = gcd(M % m2, m2)
ll diff = a2 - r_mod_m2;
if (diff % g != 0) return {-1, -1};
// Reduce : M / g = M', m2 / g = m2'
// (M' mod m2') = ( (M mod m2) / g )
ll m2_g = m2 / g;
ll M_mod_m2_g = (ll)((i128)M_mod_m2 / (i128)g);
ll diff_g = diff / g; // may be negative
ll M_mod_m2_g_inv = mod_inv(norm(M_mod_m2_g, m2_g), m2_g);
ll t = mul_mod(norm(diff_g, m2_g), M_mod_m2_g_inv, m2_g);
// nr = r + M * t, nM = lcm(M, m2) = (M/g) * m2
i128 nr = r + M * (i128)t;
i128 nM = M * (i128)m2_g;
nr = norm(nr, nM);
return {nr, nM};
}
public:
// Merge x ≡ a1 (mod m1) and x ≡ a2 (mod m2)
// Note that this code even works when m1 and m2 are not coprime.
// When solvable, returns {r, M} where M = lcm(m1,m2), r in [0,M).
// When not solvable(not coprime and inconsistent), returns {-1, -1}.
static inline pair<i128, i128> mergeTwo(ll a1, ll m1, ll a2, ll m2){
a1 = norm(a1, m1); a2 = norm(a2, m2); // Normalize
// General Merge works even if not coprime.
ll x, y;
ll g = extended_gcd(m1, m2, x, y); // x*m1 + y*m2 = g
ll diff = a2 - a1;
// Check solvablity
if (diff % g != 0) return {-1, -1};
// Reduce : m1/g * t = diff / g (mod m2/g)
ll m1_g = m1/g, m2_g = m2/g, diff_g = diff/g;
ll m1_g_inv = mod_inv(m1_g, m2_g);
ll t = mul_mod(m1_g_inv, norm(diff_g, m2_g), m2_g);
// Recover : x = a1 + m1*t
i128 r = (i128)a1 + (i128)m1 * t;
i128 M = (i128)m1_g * (i128)m2; // lcm(m1, m2)
r = norm(r, M);
return {r, M};
}
// Merge many congruences sequentially.
// Note that final result x can be overflow when M is too large. (Then, you have to use GarnersMod)
// return {r, M} if ok. If not, return {-1, -1};
static inline pair<i128, i128> mergeAll(const vector<ll>& a, const vector<ll>& m){
if (a.size() != m.size()) return {-1, -1};
i128 r = 0, M = 1;
for (size_t i = 0; i < a.size(); ++i){
auto [nr, nM] = mergeTwoi128(r, M, a[i], m[i]);
if (nr == -1) return {-1, -1};
r = nr; M = nM;
}
return {r, M};
}
// Garner Algorithm
// Given congruences(all mod are pairwise coprime), compute x mod MOD without building huge x. (all mod are coprime with MOD)
// Assume that x = c0 + c1*m0 + c2*m0*m1 + ...
//
// Step 0:
// x ≡ c0 (mod m0)
// -> c0 ≡ a0 (mod m0)
//
// Step 1:
// x ≡ c0 + c1*m0 (mod m1)
// -> c1*m0 ≡ a1 - c0 (mod m1)
// -> c1 ≡ (a1 - c0) * m0^{-1} (mod m1)
//
// Step 2:
// x ≡ c0 + c1*m0 + c2*m0*m1 (mod m2)
// -> c2*(m0*m1) ≡ a2 - (c0 + c1*m0) (mod m2)
// -> c2 ≡ (a2 - current_x) * (m0*m1)^{-1} (mod m2)
// ...
// In general, at step k:
// Let Pk = m0*m1*...*m(k-1)
// x ≡ current_x + ck*Pk (mod mk)
// -> ck ≡ (ak - current_x) * Pk^{-1} (mod mk)
//
// Each ck is determined so that previously satisfied congruences
// (mod m0, ..., mod m(k-1)) are not broken.
static inline ll Garner(const vector<ll>& aIn, const vector<ll>& mIn, ll MOD){
int n = (int)mIn.size();
vector<ll> a(n+1), m(n+1);
for (int i = 0; i < n; ++i){
m[i] = mIn[i];
a[i] = norm(aIn[i], m[i]); // Normalize [0, m[i])
}
m[n] = MOD; a[n] = 0;
vector<ll> Pk(n+1, 1); // Pk[i] = (m0*m1*...*m(k-1)) mod m[i]
vector<ll> current_x(n+1, 0); // current_x[i] = (current x) mod m[i]
for (int k = 0; k < n; ++k){
ll mk = m[k];
ll leftTerm = norm((a[k] - current_x[k]), mk);
ll Pkinv = mod_inv(Pk[k] % mk, mk);
ll ck = mul_mod(leftTerm, Pkinv, mk);
for (int i = k+1; i <= n; ++i){
current_x[i] = norm(current_x[i] + mul_mod(Pk[i], ck, m[i]), m[i]); // Add Pk[i]*ck
Pk[i] = mul_mod(Pk[i], mk % m[i], m[i]); // Pk[i] *= (mk % m[i])
}
}
return current_x[n] % MOD;
}
// Convert i128 to string
static inline string i128toStr(i128 v){
if(v == 0) return "0";
int neg = (v < 0 ? 1 : 0);
if(neg) v = -v;
string s;
while(v > 0){
int digit = (int)(v % 10);
s.push_back(char('0' + digit));
v /= 10;
}
if(neg) s.push_back('-');
reverse(s.begin(), s.end());
return s;
}
};
int main() {
fastio
using ll = long long;
using i128 = __int128_t;
NTT<998'244'353,3> ntt1;
NTT<1'004'535'809,3> ntt2;
NTT<469'762'049,3> ntt3;
int N, M; cin >> N >> M;
vector<long long> f(N+1), g(M+1);
for (int i = 0; i < N+1; ++i) cin >> f[i];
for (int i = 0; i < M+1; ++i) cin >> g[i];
auto C1 = ntt1.convolution(f, g);
auto C2 = ntt2.convolution(f, g);
auto C3 = ntt3.convolution(f, g);
vector<i128> total;
vector<ll> modulars = {998'244'353, 1'004'535'809, 469'762'049};
i128 ans = 0;
for (int i = 0; i < C1.size(); ++i){
vector<ll> a = {C1[i], C2[i], C3[i]};
pair<i128, i128> p = CRT::mergeAll(a, modulars);
ans = ans^p.first;
}
cout << CRT::i128toStr(ans);
return 0;
}
반응형
'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글
| [C++] 14756번 Telescope (0) | 2026.02.26 |
|---|---|
| [C++] 14958번 Rock Paper Scissors (0) | 2026.02.26 |
| [C++] 13575번 보석 가게 (0) | 2026.02.25 |
| [C++] 15576번 큰 수 곱셈 (2) (0) | 2026.02.24 |
| [Python] 25214번 크림 파스타 (0) | 2026.02.24 |