본문 바로가기

알고리즘/SWEA

[C++] 3234번 준환이의 양팔저울

728x90

https://swexpertacademy.com/main/code/problem/problemDetail.do?contestProbId=AWAe7XSKfUUDFAUw


 

25/02/06

 

 

가지치기를 "잘"해야 하는 백트래킹 문제다. 나이브하게 무지성으로 탐색하면 시간 초과의 늪에 빠지기 쉽다.

 

그래서 최적화를 두 번 해야하는데, 내가 풀었던 흐름대로 해설하고자 한다.


 

문제 접근 방식:

 

 

문제를 확인하면 특정 조건을 만족하는 순열의 수를 구하는 것이 목적이다. 또한 문제의 제한도 $N = 9$까지로 매우 작아서 백트래킹이 잘 동작할 것이라는 어떤 믿음이 존재할 것이다.

 

백트래킹 문제를 좀 풀어봤다면 나이브한 백트래킹 코드는 금방 짤 수 있을 것이라고 확신한다.

 

나이브한 코드는 다음과 같다.

// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
 
*/
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'
 
void dfs(int depth, const vector<int>& arr, int left_sum, int right_sum, vector<bool>& vis, int& ans){
    if (depth == 0){
        ans++;
        return;
    }
    for (int i = 0; i < arr.size(); i++){
        if (!vis[i]){
            vis[i] = true;
            dfs(depth-1, arr, left_sum+arr[i], right_sum, vis, ans);
            if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, left_sum, right_sum+arr[i], vis, ans);
            vis[i] = false;
        }
    }
    return;
}
 
void solve(int test_case_num){
    int N; cin >> N;
    vector<int> arr(N);
    for (int i = 0; i < N; i++) cin >> arr[i];
    vector<bool> vis(N, 0);
    int ans = 0;
    dfs(N, arr, 0, 0, vis, ans);
    cout << '#' << test_case_num << ' ' << ans << endl;
    return;
}
 
int main(void){
    fastio;
    int T; cin >> T;
    for (int i = 1; i <= T; i++){
        solve(i);
    }
    return 0;
}

 

주어진 그대로, (오른쪽 저울에 올려진 무게 추의 합) + (현재 무게 추)를 해도 (왼쪽 저울에 올려진 무게 추의 합)인 경우에만 오른쪽 저울에 무게추를 올려보는 형식으로 탐색을 진행한다.

 

하지만 이 코드는 시간 초과를 받는다. 어디선가 느리게 동작하는 부분이 있어서 그렇다는 뜻이다.

 

첫번째로 최적화 할 부분이 여기서 생긴다.

 

엄청 중요한 아이디어인데, 탐색을 하다보면 어느 순간에 왼쪽 저울에 무게추가 개 많이 올라가있어서 이후에 "어떤" 무게추를 "어디에" 올려도 상관 없는 순간이 올 것이다.

 

근데 위의 코드에서는 그런 순간이 와도 묵묵히 탐색을 진행한다. 그리고 하나씩 더한다. 아오~ 이러면 시간초과를 받기 쉽다.

 

편의 상 남은 무게 추가 $N$개가 남았고, 위의 상황에 마주하게 되었다고 해보자.

 

문제에서 주어진 대로 따져주면 남은 $N$개의 무게 추를 저울에 올리는 경우의 수는 $2^N \cdot N!$이다.

 

따라서 위의 상황에 처하게 되었다면 ans에 저 숫자를 더해주고 탐색을 끊어주면 될 것이다.

 

위의 상황에 처하게 되는 판단은 어떻게 할 것인가? 가 문제인데, 지금까지 (왼쪽 저울에 올라가 있던 추들의 무게 합) $\geq$ (남은 무게추들의 무게 합) + (오른쪽 저울에 올라가 있던 추들의 무게 합)을 만족한다면 될 것이다.

 

이를 적용한 코드는 아래와 같다.

 

// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
프루닝이 필요하다
*/
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'
 
int factorial(int N){
    if (N == 1) return 1;
    return N*factorial(N-1);
}
 
void dfs(int depth, const vector<int>& arr, const int& total_sum, int left_sum, int right_sum, vector<bool>& vis, int& ans){
    if (depth == 0){
        ans++;
        return;
    }
    if (left_sum - right_sum >= (total_sum - right_sum - left_sum)){
        ans += factorial(depth)*(1 << depth);
        return;
    }
    for (int i = 0; i < arr.size(); i++){
        if (vis[i]) continue;
        vis[i] = true;
        dfs(depth-1, arr, total_sum, left_sum+arr[i], right_sum, vis, ans);
        if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, total_sum, left_sum, right_sum+arr[i], vis, ans);
        vis[i] = false;
    }
    return;
}
 
void solve(int test_case_num){
    int N; cin >> N;
    int S = 0;
    vector<int> arr(N);
    for (int i = 0; i < N; i++){
        cin >> arr[i];
        S += arr[i];
    }
    vector<bool> vis(N, 0);
    int ans = 0;
    dfs(N, arr, S, 0, 0, vis, ans);
    cout << '#' << test_case_num << ' ' << ans << endl;
    return;
}
 
int main(void){
    fastio;
    int T; cin >> T;
    for (int i = 1; i <= T; i++){
        solve(i);
    }
    return 0;
}

 

하지만 이 코드도 시간 초과를 받는다.

 

여기서 두번째 최적화가 들어간다.

 

위에서 방문배열을 선언하고 처리할 때, vector<bool>은 뭔가 거창하다. 어차피 $N = 9$까지가 최대라 좀 느리게 동작할 것이라는 느낌이 든다.

 

$N$이 많이 작다. 그래서 비트마스킹을 사용해볼 수 있을 것이라는 느낌이 든다. 비트를 9개 만들면 $2^9$이여서 int로도 쉽게 비트마스킹을 진행할 수 있다.

 

방문 배열을 비트마스킹으로 처리하여 최적화하면 맞았습니다를 받을 수 있다.


아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
프루닝이 필요하다
*/
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'

int factorial[] = {0,1,2,6,24,120,720,5040,40320,362880};
int exp[] = {1,2,4,8,16,32,64,128,256,512};

void dfs(int depth, const vector<int>& arr, const int& total_sum, int left_sum, int right_sum, int& vis, int& ans){
    if (depth == 0){
        ans++;
        return;
    }
    if (left_sum - right_sum >= (total_sum - right_sum - left_sum)){
        ans += factorial[depth]*exp[depth];
        return;
    }
    for (int i = 0; i < arr.size(); i++){
        if (vis & (1 << i)) continue;
        vis ^= (1 << i);
        dfs(depth-1, arr, total_sum, left_sum+arr[i], right_sum, vis, ans);
        if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, total_sum, left_sum, right_sum+arr[i], vis, ans);
        vis ^= (1 << i);
    }
    return;
}

void solve(int test_case_num){
    int N; cin >> N;
    int S = 0;
    vector<int> arr(N);
    for (int i = 0; i < N; i++){
        cin >> arr[i];
        S += arr[i];
    }
    int vis = 0;
    int ans = 0;
    dfs(N, arr, S, 0, 0, vis, ans);
    cout << '#' << test_case_num << ' ' << ans << endl;
    return;
}

int main(void){
    fastio;
    int T; cin >> T;
    for (int i = 1; i <= T; i++){
        solve(i);
    }
    return 0;
}