본문 바로가기

알고리즘/백준 문제 풀이

[C++] 29021번 IPvK

반응형

https://www.acmicpc.net/problem/29021


 

26/02/17

 

 

자꾸 시간초과가 나서 애를 먹었던 문제이다.


 

문제 접근 방식:

 

 

우리가 원하는 요구 사항을 수학적으로 기술하자면 다음과 같다.

- $a_{1} + a_{2}+\dots+a_{k} = N$
- $0 \leq a_{i} \leq 255$
- $\displaystyle \sum_{\begin{align}&a_{1}+\dots+a_{k}=N\\&0\leq a_{i} \leq 255\end{align}}\prod_{i=1}^{k} a_{i}$

결국, 이걸 어떻게 구하냐가 문제다.

우리는 지수의 합에서 $a_{1}+a_{2}+\dots+a_{k}=N$조건을 만족시킬 수 있음을 이전의 수 많은 FFT 문제들을 통해 확인해 왔었다.

결국 곱이 문제인데, 이 곱 또한 사실 문제에서 요구하는 $a_{i}$ 그대로이기 때문에, 다음과 같은 다항식을 $K$번 거듭제곱하면 됨을 알 수 있다.
$$\sum_{i=0}^{n}ix^{i}$$
즉, $0 \leq a_{i} \leq 255$이므로 $0+1x^{1}+2x^{2}+\dots+255x^{255}$를 $K$번 거듭제곱해서 $x^{N}$ 항의 계수를 찾으면 된다.

이제 여기서 여러 번의 난관이 있었다.

나는 지금까지 같은 다항식을 거듭제곱할 때, 분할 정복을 이용한 거듭제곱 방식을 사용하여 구현해왔었다.

즉, Convolution함수 자체를 $\log$번 호출하는 형태로 진행했는데, 내가 구현했던 NTT 및 FFT구현체의 Convolution함수에는 FFT변환이 정변환 $2$번, 역변환 $1$번으로 총 $3$번의 변환이 필요했다.

따라서, $K=2000$과 같은 데이터가 많이 있다면 최악의 경우 한번의 테스트케이스 당 $\lceil \log K \rceil\cdot 3 = 24$번의 FFT호출이 필요하고, 테스트케이스는 최대 $50$개가 주어지니 최악의 경우 $1200$번의 FFT함수 호출이 이뤄짐을 확인할 수 있었다.

이는 $8$초라는 시간 제한에 비해 매우 오래 걸리는 동작이다.

이것을 해결하기 위해 나는 Convolution함수를 $\log$번 호출하는 형태가 아니라 다항식을 $K$번 거듭제곱하는 power함수를 따로 만들었다.

FFT의 기본 동작 원리를 생각해보자.

다항식을 곱하기 위해 $1 \text{<<}\lceil \log(S_{a}+S_{b}-1) \rceil$만큼 길이를 미리 늘려놓고, 해당 길이에 대해서 변환한 다음에, 점별 곱을 진행한다.

하지만 우리는 같은 다항식을 여러 번 곱하기 때문에, 점별 곱을 여러 번 하면 충분하다.

물론 이때 다항식의 길이는 $K$번 거듭제곱 했을 때의 길이로 미리 늘려놔야 한다.

이런 식으로 다항식을 거듭제곱한다면, 두번의 FFT변환(정변환, 역변환)으로도 충분히 $K$제곱을 구할 수 있다. 즉, 테스트케스 1개 당 $2$번의 변환만 필요한 셈이다.

더불어, 기존 NTT코드를 손봤다. 기존 NTT코드는 NTT변환 후 진행하는 bit-reversal permutation이나 버터플라이 곱이나 모두 그때 그때 근이나 인덱스를 구하는 형식을 사용했었다.

해당 방식을 미리 인덱스와 근을 구하는 precompute함수와 이를 저장하는 캐시들을 만들어 개선하였다. NTT코드가 실행되기 전에 캐시에 근이나 인덱스가 미리 저장되어있으면 그걸 써먹도록 하였다.

이러한 최적화는 같은 크기의 입력이 여러 번 주어질 때 매우 효율적인 최적화라고 생각했고, 실제로도 약간의 도움이 되었다.


아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
// 29021번 IPvK
// 수학, 고속 푸리에 변환, 생성함수
#include <iostream>
#include <vector>

using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl '\n'

template<int MOD, int G>
struct NTT {
private :
    // HELPER FUNCTION
    inline int mul_mod(int a, int b) const {return (long long)a * b % MOD;}
    inline int pow_mod(int a, int e) const {
        int r = 1;
        while(e > 0){
            if(e & 1) r = mul_mod(r, a);
            a = mul_mod(a, a);
            e >>= 1;
        }
        return r;
    }
    inline int mod_inv(int a) const {return pow_mod(a, MOD-2);}
    // ---- precompute cache (per template instantiation) ----
    static inline vector<int> rev;
    static inline vector<int> roots_fwd, roots_inv;
    static inline int prepared_n = 0; // current prepared size (power of two)
    void precompute(int n) {
        if (prepared_n == n) return;
        prepared_n = n;
        rev.assign(n, 0);
        int lg = 0;
        while ((1 << lg) < n) ++lg;
        for (int i = 0; i < n; i++) {
            rev[i] = 0;
            for (int b = 0; b < lg; b++){
                if (i & (1 << b)) rev[i] |= 1 << (lg - 1 - b);
            }
        }
        roots_fwd.assign(n, 1); roots_inv.assign(n, 1);
        for (int len = 1; len < n; len <<= 1) {
            int wlen_fwd = pow_mod(G, (MOD - 1) / (2*len));
            int wlen_inv = mod_inv(wlen_fwd);
            int w_f = 1; int w_i = 1;
            for (int j = 0; j < len; j++) {
                roots_fwd[len + j] = w_f;
                roots_inv[len + j] = w_i;
                w_f = mul_mod(w_f, wlen_fwd);
                w_i = mul_mod(w_i, wlen_inv);
            }
        }
    }
public : 
    void iterative_NTT(vector<int>& a, int invert){
        int n = (int)a.size();
        precompute(n);
        // 1) bit-reversal permutation (use rev[])
        for (int i = 1; i < n; i++){
            int j = rev[i];
            if (i < j) swap(a[i], a[j]);
        }
        // 2) butterflies by length = 2, 4, 8, ...
        // use roots_fwd/roots_inv precomputed
        const auto& roots = invert ? roots_inv : roots_fwd;
        for (int len = 2; len <= n; len <<= 1){
            int half = len >> 1;
            for (int i = 0; i < n; i += len){
                for (int j = 0; j < len / 2; j++){
                    int u = a[i+j];
                    int v = mul_mod(roots[half+j], a[i+j+half]);
                    int x = u+v; if (x >= MOD) x -= MOD;
                    int y = u-v; if (y < 0) y += MOD;
                    a[i+j] = x;
                    a[i+j+half] = y;
                }
            }
        }
        // 3) Divide by n (multiply by inverse) for inverse transform
        if (invert){
            int inv_n = mod_inv(n);
            for (int i = 0; i < n; ++i) a[i] = mul_mod(a[i], inv_n);
        }
    }
    // Input : Coefficient vector {a_0, a_1, ...}, {b_0, b_1, ...}
    // Output : Convolution of two coefficient vector
    vector<int> convolution(vector<int> a, vector<int> b) {
        if (a.empty() || b.empty()) return {};
        int S_a = (int)a.size(), S_b = (int)b.size();
        // Make vector size easy to dnc(2^n).
        int n = 1;
        while (n < S_a + S_b - 1){
            n <<= 1;
        }
        a.resize(n); b.resize(n);
        // Normalize to [0, MOD)
        for (int i = 0; i < n; ++i){
            a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
            b[i] %= MOD; if (b[i] < 0) b[i] += MOD;
        }
        // NTT
        iterative_NTT(a, 0); iterative_NTT(b, 0);
        // Pointwise product
        for (int i = 0; i < n; i++) {
            a[i] = mul_mod(a[i], b[i]);
        }
        // INTT
        iterative_NTT(a, 1);
        a.resize(S_a + S_b - 1);
        return a;
    }
    // Input : Coefficient vector A = {a_0, a_1, ...} and Power
    // Output : Coefficient vector of A^P
    vector<int> power(vector<int> a, int e){
        if (a.empty()) return {};
        if (e == 0) return {1};
        int S_a = (int)a.size();
        // Make vector size easy to dnc(2^n).
        int n = 1;
        while (n < (S_a - 1)*e + 1){
            n <<= 1;
        }
        a.resize(n);
        // Normalize to [0, MOD)
        for (int i = 0; i < n; ++i){
            a[i] %= MOD; if (a[i] < 0) a[i] += MOD;
        }
        // NTT
        iterative_NTT(a, 0);
        // Pointwise product
        for (int i = 0; i < n; i++) {
            a[i] = pow_mod(a[i], e);
        }
        // INTT
        iterative_NTT(a, 1);
        a.resize((S_a - 1)*e + 1);
        return a;
    }
};

int main(void) {
    fastio

    int K, N;
    NTT<998'244'353, 3> ntt;
    vector<int> poly(256, 0);
    for (int i = 0; i < 256; ++i) poly[i] = i;
    while (1){
        cin >> K >> N;
        if (K == 0 && N == 0) break;
        auto ret_poly = ntt.power(poly, K);
        if (ret_poly.size() < N+1) cout << 0 << endl;
        else cout << ret_poly[N] << endl;
    }
    return 0;
}
반응형