https://www.acmicpc.net/problem/1067
26/01/27
FFT/NTT 기초문제이다. 이 문제의 아이디어를 잘 익혀두면 관련된 문제들도 쉽게 해결할 수 있을 것이다.
문제 접근 방식:
고속 푸리에 변환의 정의를 잘 생각해보자.
PS에서 사용되는 고속 푸리에 변환은 엄밀하게 따지면 이산 푸리에 변환(DFT)이다.
이는 어떤 두 다항식 $A(x), B(x)$가 주어져 있을 때, 두 다항식의 곱 $A(x)B(x)$를 $\mathcal{O}(N^{2})$이 아닌 $\mathcal{O}(N\log N)$의 시간 복잡도로 빠르게 구할 수 있는 알고리즘으로, 분할 정복의 아이디어에 기반을 두고 있다.
두 다항식을 곱한 다항식 $C(x)$의 $x^{k}$의 계수 $\displaystyle c_{k} = \sum_{i=0}^{k}a_{i}b_{k-i}$로 계산되는데, 이러한 연산을 Convolution이라고 이야기 한다.
잡소리는 길었으나, 결국 이 문제에서 요구하는 핵심 관찰은 주어진 식의 형태가 Convolution의 형태임을 빠르게 파악하는 것이다.
문제의 요구 사항은 두 배열 $X, Y$가 주어져 있을 때, $S = X[0]Y[0] + \dots + X[N-1]Y[N-1]$의 값을 최대화하도록 리스트를 Rotate했을 경우 $S$의 값을 구하는 것이다.
주어진 $S$의 형태가 Convolution과 매우 유사함을 확인할 수 있고, 실제로 배열 $Y$를 뒤집어서 계산할 경우 $S$의 형태와 동일해짐을 확인할 수 있다.
당연히, 배열을 한번 Rotate하고 FFT로 Convolution계산하고, 한번 Rotate하고 Convolution계산하고, 이를 반복하는 행위는 굉장히 비효율적인 풀이임을 알아야 한다.
여기서 하나의 아이디어를 떠올리면 문제는 해결된다. 배열 $X$를 한번 더 덧붙여서, $X = X+X$로 만들어준다.
이후 FFT로 Convolution을 진행하면 한번의 FFT로 모든 $S$값들을 구할 수 있다.
예제 입력 1을 그림으로 표현하면 다음과 같다.

예시의 경우 FFT로 $X$랑 $Y$를 Convolution한 곱을 계산하고, 그 곱에서 $x^{3}$의 계수부터 $x^{6}$의 계수까지 $4$개의 항에 대한 결과들 중 최댓값을 찾으면 된다.
문제는 길이가 $N$이므로, 나온 결과 중 $x^{N-1}$부터 $x^{2N-1}$까지 $N$개의 항의 계수를 보면 된다.
나는 실수 계산을 피하고 싶었기 때문에 FFT대신 NTT로 구현하였고, $N \leq 60 \ 000$, $X[i], Y[i] \leq 100$을 만족시키기 때문에 결과물로 나오는 계수 $c_{k} \leq 600\ 000\ 000$이다.
따라서, 모듈로 $998\ 244\ 353$을 취해도 변하지 않으므로 그렇게 나온 NTT값을 그대로 사용하였다.
아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.
// 1067번 이동
// NTT
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'
/*
Notes:
MOD = 998'244'353 = 119*2^23 + 1
MOD = 1'004'535'809 = 479*2^21 + 1 / G = 3 / Works for lengths up to 2^21. / Use with CRT and 998'244'353
MOD = 469'762'049 = 7*2^26 + 1 / G = 3 / Works for lengths up to 2^26.
MOD = 167'772'161 = 5*2^25 + 1 / G = 3 / Works for lengths up to 2^25.
MOD = 1'224'736'769 = 73*2^24 + 1 / G = 3 / Works for lengths up to 2^24.
Combine above 3-NTT primes with CRT -> 75bits fast product ok.
*/
template<long long MOD, long long G>
struct NTT {
using ll = long long;
private :
// HELPER FUNCTION
inline int mul_mod(int a, int b) const {return (ll)a * b % MOD;}
inline int pow_mod(int a, int e) const {
int r = 1;
while(e > 0){
if(e & 1) r = mul_mod(r, a);
a = mul_mod(a, a);
e >>= 1;
}
return r;
}
inline int mod_inv(int a) const {return pow_mod(a, MOD-2);}
public :
void iterative_NTT(vector<int>& a, int invert){
int n = (int)a.size();
// 1) bit-reversal permutation
for (int i = 1, j = 0; i < n; i++){
int bit = n >> 1;
while (j & bit){
j ^= bit;
bit >>= 1;
}
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
// 2) butterflies by length = 2, 4, 8, ...
for (int len = 2; len <= n; len <<= 1){
// primitive len-th root of unity
int w_len = pow_mod(G, (MOD - 1) / len);
if (invert) w_len = mod_inv(w_len);
for (int i = 0; i < n; i += len){
int w = 1;
for (int j = 0; j < len / 2; j++){
int u = a[i+j];
int v = mul_mod(w, a[i+j+len/2]);
int x = u+v; if (x >= MOD) x -= MOD;
int y = u-v; if (y < 0) y += MOD;
a[i+j] = x;
a[i+j+len/2] = y;
w = mul_mod(w, w_len);
}
}
}
// 3) Divide by n (multiply by inverse) for inverse transform
if (invert){
int inv_n = mod_inv(n);
for (int i = 0; i < n; ++i){
a[i] = mul_mod(a[i], inv_n);
}
}
}
// Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
// Output : Convolution of two coefficient vector
vector<int> convolution(vector<int> a, vector<int> b) {
if (a.empty() || b.empty()) return {};
int S_a = (int)a.size(), S_b = (int)b.size();
// Make vector size easy to dnc(2^n).
int n = 1;
while (n < S_a + S_b - 1){
n <<= 1;
}
a.resize(n); b.resize(n);
// Normalize to [0, MOD)
for (int i = 0; i < n; ++i){
a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
b[i] %= MOD; if (b[i] < 0) b[i] += MOD;
}
// NTT
iterative_NTT(a, 0); iterative_NTT(b, 0);
// Pointwise product
for (int i = 0; i < n; i++) {
a[i] = mul_mod(a[i], b[i]);
}
// INTT
iterative_NTT(a, 1);
a.resize(S_a + S_b - 1);
return a;
}
};
int main(void){
fastio
int N; cin >> N;
vector<int> a(2*N), b(N);
int v;
for (int i = 0; i < N; ++i){
cin >> v;
a[i] = v; a[i+N] = v;
}
for (int i = 0; i < N; ++i){
cin >> b[N-i-1];
}
NTT<998244353, 3> ntt;
vector<int> c = ntt.convolution(a, b);
int ans = 0;
for (int i = N-1; i < 2*N; ++i){
ans = max(ans, c[i]);
}
cout << ans;
return 0;
}
'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글
| [C++] 10531번 Golf Bot (0) | 2026.02.15 |
|---|---|
| [C++] 25456번 궁금한 시프트 (0) | 2026.02.14 |
| [Python] 1165번 단어퍼즐 (0) | 2026.02.07 |
| [C++] 11670번 초등 수학 / 30887번 Basic Math (0) | 2026.02.06 |
| [C++] 25597번 푸앙이와 러닝머신 (0) | 2026.02.05 |