본문 바로가기

알고리즘/백준 문제 풀이

[C++] 10531번 Golf Bot

728x90

https://www.acmicpc.net/problem/10531


 

26/01/29

 

 

FFT의 기본 활용 방법을 배울 수 있는 문제이다. 해당 문제를 잘 익히면 이후의 FFT 활용 문제들도 해당 아이디어를 사용하여 해결할 수 있다.


 

문제 접근 방식:

 

 

골프 봇이 칠 수 있는 거리들 $k_{i}$가 크기 $N$짜리 리스트로 주어지고, 각 코스의 거리들 $d_{j}$가 크기 $M$짜리 리스트로 주어진다.

문제에서 요구하는 것은 골프 봇이 최대 $2$번까지 칠 수 있을 때, 몇 개의 코스를 골프 봇이 정복할 수 있냐를 물어보고 있다.

골프 봇이 $1$번만 칠 때 각 코스를 정복할 수 있는지의 여부는 금방 판별할 수 있다.

문제는 골프 봇이 $2$번을 칠 때 각 코스를 정복할 수 있는지의 여부이다.

수학적으로 문제를 다시 환원하자면, 어떤 수들의 리스트에서 중복을 포함하여 $2$개의 수를 뽑았을 때, 두 수의 합이 $M$개의 수들 중에서 존재하는지 찾아야 하는 문제로 바뀐다.

처음 보면 이게 어떻게 고속 푸리에 변환과 연결될 수 있는지 의문이 들 수 있다.

하지만 우리는 같은 문자를 여러번 "곱하면", 지수 관점에서는 "더하는 것"과 동일하다는 사실을 생각할 수 있다. 즉, $x^{i}\times x^{j} = x^{i+j}$를 생각하자.

따라서, 계수 배열 $A(x)$를 만들 때 $p$라는 수가 $k_{i}$들 중 존재한다면 $a_{p} = 1$, 그렇지 않다면 $0$으로 두자.

이제 배열 $A(x)$끼리 곱하면 $a_{p} = 1, a_{q} = 1$인 경우에만 $x^{p+q}$의 계수 $a_{p+q}$가 살아남을 것이다.

그리고 해당 계수 $a_{p+q}$는 $k_{i}$들 중에서 중복을 포함하여 $2$개를 뽑아 더했을 때 $p+q$를 만들 수 있는 경우의 수와 같다. 

따라서, FFT로 $A(x)$끼리 곱해 $A^{2}(x)$를 만든 후, $A^{2}(x)$의 계수들을 쭉 보고 $d_{j}$와 같은 것이 있는지 확인하면 된다.

이 문제의 경우 NTT를 사용해도 되고, FFT를 사용해도 되나 나는 생각하기 귀찮아서 FFT를 사용하여 구현했다.


아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
// 10531번 Golf Bot
// FFT, 생성함수
#include <iostream>
#include <vector>
#include <cmath>

using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'

struct FFT {
    using ld = long double;
    using ll = long long;
    static constexpr ld PI = acosl(-1.0L);
private :
    // HELPER struct
    struct cd {
        ld re, im;
        cd(ld r=0, ld i=0): re(r), im(i) {}
        cd operator + (const cd& o) const { return cd(re+o.re, im+o.im); }
        cd operator - (const cd& o) const { return cd(re-o.re, im-o.im); }
        cd operator * (const cd& o) const { return cd(re*o.re - im*o.im, re*o.im + im*o.re); }
        cd& operator += (const cd& o) { re+=o.re; im+=o.im; return *this; }
        cd& operator -= (const cd& o) { re-=o.re; im-=o.im; return *this; }
        cd& operator *= (const cd& o) { *this = (*this)*o; return *this; }
    };
public : 
    void iterative_FTT(vector<cd>& a, int invert){
        int n = (int)a.size();
        // 1) bit-reversal permutation
        for (int i = 1, j = 0; i < n; i++){
            int bit = n >> 1;
            while (j & bit){
                j ^= bit;
                bit >>= 1;
            }
            j ^= bit;
            if (i < j) swap(a[i], a[j]);
        }
        // 2) butterflies by length = 2, 4, 8, ...
        for (int len = 2; len <= n; len <<= 1){
            ld ang = 2 * PI / len * (invert ? -1 : 1);
            cd w_len(cosl(ang), sinl(ang)); // primitive len-th root of unity on unit circle
            for (int i = 0; i < n; i += len){
                cd w(1, 0);
                for (int j = 0; j < len / 2; j++){
                    cd u = a[i+j];
                    cd v = w * a[i+j+len/2];
                    a[i+j] = u+v;
                    a[i+j+len/2] = u-v;
                    w *= w_len;
                }
            }
        }
        // 3) Divide by n (multiply by inverse) for inverse transform
        if (invert){
            ld inv_n = 1.0L / n;
            for (int i = 0; i < n; ++i){
                a[i].re *= inv_n;
                a[i].im *= inv_n;
            }
        }
    }
    // Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
    // Output : Convolution of two coefficient vector
    vector<ll> convolution(const vector<ll>& a, const vector<ll>& b) {
        if (a.empty() || b.empty()) return {};
        int S_a = (int)a.size(), S_b = (int)b.size();
        // Make vector size easy to dnc(2^n).
        int n = 1;
        while (n < S_a + S_b - 1){
            n <<= 1;
        }
        vector<cd> fa(n), fb(n);
        for (int i = 0; i < S_a; i++) fa[i] = cd((ld)a[i], 0);
        for (int i = 0; i < S_b; i++) fb[i] = cd((ld)b[i], 0);
        // FTT
        iterative_FTT(fa, 0); iterative_FTT(fb, 0);
        // Pointwise product
        for (int i = 0; i < n; i++) fa[i] *= fb[i];
        // IFTT
        iterative_FTT(fa, 1);
        vector<ll> res(S_a + S_b - 1);
        for (int i = 0; i < res.size(); i++) {
            res[i] = (ll) llround(fa[i].re); // rounding to nearest integer
        }
        return res;
    }
};

int main(void){
    fastio

    int N, M;
    vector<long long> A(200'001, 0);
    cin >> N;
    long long k;
    for (int i = 0; i < N; ++i){
        cin >> k;
        ++A[k];
    }
    vector<long long> check(200'001, 0);
    cin >> M;
    long long ans = 0;
    long long d;
    for (int i = 0; i < M; ++i){
        cin >> d;
        if (A[d]) ++ans;
        else ++check[d];
    }
    FFT fft;
    auto C = fft.convolution(A, A);
    for (int i = 0; i < 200'001; ++i){
        if (C[i]) ans += check[i];
    }
    cout << ans;
    return 0;
}
728x90

'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글

[C++] 17134번 르모앙의 추측  (0) 2026.02.17
[C++] 20176번 Needle  (0) 2026.02.16
[C++] 25456번 궁금한 시프트  (0) 2026.02.14
[C++] 1067번 이동  (0) 2026.02.13
[Python] 1165번 단어퍼즐  (0) 2026.02.07