https://www.acmicpc.net/problem/5051
26/02/02
FFT를 활용할 수 있는 전형적인 유형의 문제이다.
문제 접근 방식:
결국 이 문제도 수의 목록에서 $2$개의 수를 뽑아 더했을 때, 특정 수가 되도록 하는 경우의 수를 세는 문제로 변환이 된다.
문제의 조건은 $2$가지가 있다.
1. $a \leq b$
2. $a^{2}+b^{2}=c^{2}(\text{mod }n)$
1번 조건은 순서쌍 $(a, b)$와 $(b, a)$를 구분하지 않겠다는 뜻이다. 따라서 우리가 FFT로 구한 경우의 수에서 중복되는 경우의 수를 제거한 경우의 수를 계산해야 한다.
먼저, 이런 두 수의 합을 만드는 경우의 수를 구하는 문제에서 자주 했던 것처럼, 특정 수가 존재하지 않으면 $0$, 존재하면 $1$씩 더해주는 배열 $A$를 하나 만들어주자.
$1$부터 시작해서 $N-1$까지 $A[i^{2} \text{ mod }N]$에 $+1$씩 더해준다.(이 뜻은, 만약 $A[0] = 3$이라면 $a^{2} (\text{mod }N) = 0$을 만족시키는 $a$의 값이 $3$개 있다는 뜻이다.)
이제 이 배열 $A$를 서로 곱한 배열 $\text{res} = A^{2}$를 만들자.
이제 이 배열 $\text{res}$를 $\text{mod }N$에 대하여 접은 배열 $C$를 만들자.(즉, $C[i] = \text{res}[i] + \text{res}[i+N]$)
이제 이 배열 $C[i]$는 $(a^{2}+b^{2})\text{ mod }N = i$를 만족하는 경우의 수를 의미한다.
물론, $a \leq b$와 같은 제약은 걸려있지 않은 상태이다.
제약조건 $a < b$가 걸린 순서쌍을 구하기 위해서는 $C[i]$에서 $a=b$인 경우를 뺀 후에, $2$로 나누면 된다.
우리가 구하고자 하는 제약조건 $a \leq b$가 걸린 순서쌍이라면, 다시 그 경우의 수에서 $a=b$인 경우의 수를 더하면 된다.
즉, 결과적으로 $a=b$인 경우의 수를 담은 배열을 $B$라고 한다면, $(C[i]+B[i]) / 2$가 $(a^{2} + b^{2})\text{ mod }N = i$와 $a \leq b$를 모두 만족하는 경우의 수이다.
따라서, $2a^{2} = r (\text{mod }N)$을 만족시키는 경우의 수를 구하기 위해 길이 $N$짜리 배열 $B$를 선언한다.
이후 $0$부터 $N-1$까지 모든 $i$에 대해 $B[2i \text{ mod }N]$에 $A[i]$를 더해주면 $B$를 구할 수 있다.
이제, $a^{2} + b^{2} = c^{2} (\text{mod }N)$을 만족하는 경우의 수는, $(a^{2}+b^{2}) \text{ mod }N$을 만드는 경우의 수와 $c^{2} \text{ mod }N$을 만드는 경우의 수(배열 $A$)를 곱하면 구할 수 있다.
아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.
// 5051번 피타고라스의 정리
// FFT
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'
/*
Notes:
MOD = 998'244'353 = 119*2^23 + 1
MOD = 1'004'535'809 = 479*2^21 + 1 / G = 3 / Works for lengths up to 2^21. / Use with CRT and 998'244'353
MOD = 469'762'049 = 7*2^26 + 1 / G = 3 / Works for lengths up to 2^26.
MOD = 167'772'161 = 5*2^25 + 1 / G = 3 / Works for lengths up to 2^25.
MOD = 1'224'736'769 = 73*2^24 + 1 / G = 3 / Works for lengths up to 2^24.
Combine above 3-NTT primes with CRT -> 75bits fast product ok.
*/
template<int MOD, int G>
struct NTT {
using ll = long long;
private :
// HELPER FUNCTION
inline int mul_mod(int a, int b) const {return (ll)a * b % MOD;}
inline int pow_mod(int a, int e) const {
int r = 1;
while(e > 0){
if(e & 1) r = mul_mod(r, a);
a = mul_mod(a, a);
e >>= 1;
}
return r;
}
inline int mod_inv(int a) const {return pow_mod(a, MOD-2);}
public :
void iterative_NTT(vector<int>& a, int invert){
int n = (int)a.size();
// 1) bit-reversal permutation
for (int i = 1, j = 0; i < n; i++){
int bit = n >> 1;
while (j & bit){
j ^= bit;
bit >>= 1;
}
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
// 2) butterflies by length = 2, 4, 8, ...
for (int len = 2; len <= n; len <<= 1){
// primitive len-th root of unity
int w_len = pow_mod(G, (MOD - 1) / len);
if (invert) w_len = mod_inv(w_len);
for (int i = 0; i < n; i += len){
int w = 1;
for (int j = 0; j < len / 2; j++){
int u = a[i+j];
int v = mul_mod(w, a[i+j+len/2]);
int x = u+v; if (x >= MOD) x -= MOD;
int y = u-v; if (y < 0) y += MOD;
a[i+j] = x;
a[i+j+len/2] = y;
w = mul_mod(w, w_len);
}
}
}
// 3) Divide by n (multiply by inverse) for inverse transform
if (invert){
int inv_n = mod_inv(n);
for (int i = 0; i < n; ++i){
a[i] = mul_mod(a[i], inv_n);
}
}
}
// Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
// Output : Convolution of two coefficient vector
vector<int> convolution(vector<int> a, vector<int> b) {
if (a.empty() || b.empty()) return {};
int S_a = (int)a.size(), S_b = (int)b.size();
// Make vector size easy to dnc(2^n).
int n = 1;
while (n < S_a + S_b - 1){
n <<= 1;
}
a.resize(n); b.resize(n);
// Normalize to [0, MOD)
for (int i = 0; i < n; ++i){
a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
b[i] %= MOD; if (b[i] < 0) b[i] += MOD;
}
// NTT
iterative_NTT(a, 0); iterative_NTT(b, 0);
// Pointwise product
for (int i = 0; i < n; i++) {
a[i] = mul_mod(a[i], b[i]);
}
// INTT
iterative_NTT(a, 1);
a.resize(S_a + S_b - 1);
return a;
}
};
int main(void){
fastio
int N; cin >> N;
vector<int> A(N, 0);
for (int i = 1; i < N; ++i){
A[(int)((long long)i*i % N)] += 1;
}
NTT<998'244'353, 3> ntt;
auto res = ntt.convolution(A, A);
vector<int> B(N, 0), C(N, 0);
for (int i = 0; i < N; ++i){
B[2*i % N] += A[i];
C[i] = res[i] + res[i+N];
}
long long ans = 0;
for (int i = 0; i < N; ++i){
if (C[i] != 0 && A[i] != 0){
ans += ((long long)(C[i] + B[i])*A[i] / 2);
}
}
cout << ans;
return 0;
}
'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글
| [C++] 21624번 Fence (0) | 2026.02.23 |
|---|---|
| [C++] 2350번 대운하 (0) | 2026.02.22 |
| [C++] 31435번 기숙사 비밀번호 구하기 (0) | 2026.02.20 |
| [C++] 7881번 YAPTCHA (0) | 2026.02.19 |
| [C++] 17104번 골드바흐 파티션 2 (0) | 2026.02.18 |