https://swexpertacademy.com/main/code/problem/problemDetail.do?contestProbId=AWAe7XSKfUUDFAUw
25/02/06
가지치기를 "잘"해야 하는 백트래킹 문제다. 나이브하게 무지성으로 탐색하면 시간 초과의 늪에 빠지기 쉽다.
그래서 최적화를 두 번 해야하는데, 내가 풀었던 흐름대로 해설하고자 한다.
문제 접근 방식:
문제를 확인하면 특정 조건을 만족하는 순열의 수를 구하는 것이 목적이다. 또한 문제의 제한도 $N = 9$까지로 매우 작아서 백트래킹이 잘 동작할 것이라는 어떤 믿음이 존재할 것이다.
백트래킹 문제를 좀 풀어봤다면 나이브한 백트래킹 코드는 금방 짤 수 있을 것이라고 확신한다.
나이브한 코드는 다음과 같다.
// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
*/
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'
void dfs(int depth, const vector<int>& arr, int left_sum, int right_sum, vector<bool>& vis, int& ans){
if (depth == 0){
ans++;
return;
}
for (int i = 0; i < arr.size(); i++){
if (!vis[i]){
vis[i] = true;
dfs(depth-1, arr, left_sum+arr[i], right_sum, vis, ans);
if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, left_sum, right_sum+arr[i], vis, ans);
vis[i] = false;
}
}
return;
}
void solve(int test_case_num){
int N; cin >> N;
vector<int> arr(N);
for (int i = 0; i < N; i++) cin >> arr[i];
vector<bool> vis(N, 0);
int ans = 0;
dfs(N, arr, 0, 0, vis, ans);
cout << '#' << test_case_num << ' ' << ans << endl;
return;
}
int main(void){
fastio;
int T; cin >> T;
for (int i = 1; i <= T; i++){
solve(i);
}
return 0;
}
주어진 그대로, (오른쪽 저울에 올려진 무게 추의 합) + (현재 무게 추)를 해도 (왼쪽 저울에 올려진 무게 추의 합)인 경우에만 오른쪽 저울에 무게추를 올려보는 형식으로 탐색을 진행한다.
하지만 이 코드는 시간 초과를 받는다. 어디선가 느리게 동작하는 부분이 있어서 그렇다는 뜻이다.
첫번째로 최적화 할 부분이 여기서 생긴다.
엄청 중요한 아이디어인데, 탐색을 하다보면 어느 순간에 왼쪽 저울에 무게추가 개 많이 올라가있어서 이후에 "어떤" 무게추를 "어디에" 올려도 상관 없는 순간이 올 것이다.
근데 위의 코드에서는 그런 순간이 와도 묵묵히 탐색을 진행한다. 그리고 하나씩 더한다. 아오~ 이러면 시간초과를 받기 쉽다.
편의 상 남은 무게 추가 $N$개가 남았고, 위의 상황에 마주하게 되었다고 해보자.
문제에서 주어진 대로 따져주면 남은 $N$개의 무게 추를 저울에 올리는 경우의 수는 $2^N \cdot N!$이다.
따라서 위의 상황에 처하게 되었다면 ans에 저 숫자를 더해주고 탐색을 끊어주면 될 것이다.
위의 상황에 처하게 되는 판단은 어떻게 할 것인가? 가 문제인데, 지금까지 (왼쪽 저울에 올라가 있던 추들의 무게 합) $\geq$ (남은 무게추들의 무게 합) + (오른쪽 저울에 올라가 있던 추들의 무게 합)을 만족한다면 될 것이다.
이를 적용한 코드는 아래와 같다.
// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
프루닝이 필요하다
*/
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'
int factorial(int N){
if (N == 1) return 1;
return N*factorial(N-1);
}
void dfs(int depth, const vector<int>& arr, const int& total_sum, int left_sum, int right_sum, vector<bool>& vis, int& ans){
if (depth == 0){
ans++;
return;
}
if (left_sum - right_sum >= (total_sum - right_sum - left_sum)){
ans += factorial(depth)*(1 << depth);
return;
}
for (int i = 0; i < arr.size(); i++){
if (vis[i]) continue;
vis[i] = true;
dfs(depth-1, arr, total_sum, left_sum+arr[i], right_sum, vis, ans);
if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, total_sum, left_sum, right_sum+arr[i], vis, ans);
vis[i] = false;
}
return;
}
void solve(int test_case_num){
int N; cin >> N;
int S = 0;
vector<int> arr(N);
for (int i = 0; i < N; i++){
cin >> arr[i];
S += arr[i];
}
vector<bool> vis(N, 0);
int ans = 0;
dfs(N, arr, S, 0, 0, vis, ans);
cout << '#' << test_case_num << ' ' << ans << endl;
return;
}
int main(void){
fastio;
int T; cin >> T;
for (int i = 1; i <= T; i++){
solve(i);
}
return 0;
}
하지만 이 코드도 시간 초과를 받는다.
여기서 두번째 최적화가 들어간다.
위에서 방문배열을 선언하고 처리할 때, vector<bool>은 뭔가 거창하다. 어차피 $N = 9$까지가 최대라 좀 느리게 동작할 것이라는 느낌이 든다.
$N$이 많이 작다. 그래서 비트마스킹을 사용해볼 수 있을 것이라는 느낌이 든다. 비트를 9개 만들면 $2^9$이여서 int로도 쉽게 비트마스킹을 진행할 수 있다.
방문 배열을 비트마스킹으로 처리하여 최적화하면 맞았습니다를 받을 수 있다.
아래는 내가 위의 접근 방식과 같이 작성한 C++ 코드이다. 더보기를 누르면 확인할 수 있다.
// SWEA3234 준환이의 양팔저울
// 백트래킹
/*
접근 방법:
프루닝이 필요하다
*/
#include <iostream>
#include <vector>
using namespace std;
#define fastio ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
typedef long long ll;
#define endl '\n'
int factorial[] = {0,1,2,6,24,120,720,5040,40320,362880};
int exp[] = {1,2,4,8,16,32,64,128,256,512};
void dfs(int depth, const vector<int>& arr, const int& total_sum, int left_sum, int right_sum, int& vis, int& ans){
if (depth == 0){
ans++;
return;
}
if (left_sum - right_sum >= (total_sum - right_sum - left_sum)){
ans += factorial[depth]*exp[depth];
return;
}
for (int i = 0; i < arr.size(); i++){
if (vis & (1 << i)) continue;
vis ^= (1 << i);
dfs(depth-1, arr, total_sum, left_sum+arr[i], right_sum, vis, ans);
if (left_sum >= right_sum + arr[i]) dfs(depth-1, arr, total_sum, left_sum, right_sum+arr[i], vis, ans);
vis ^= (1 << i);
}
return;
}
void solve(int test_case_num){
int N; cin >> N;
int S = 0;
vector<int> arr(N);
for (int i = 0; i < N; i++){
cin >> arr[i];
S += arr[i];
}
int vis = 0;
int ans = 0;
dfs(N, arr, S, 0, 0, vis, ans);
cout << '#' << test_case_num << ' ' << ans << endl;
return;
}
int main(void){
fastio;
int T; cin >> T;
for (int i = 1; i <= T; i++){
solve(i);
}
return 0;
}
'알고리즘 > SWEA' 카테고리의 다른 글
[C++] 5658번 [모의 SW 역량테스트] 보물상자 비밀번호 (0) | 2025.02.07 |
---|---|
[C++] 1952번 [모의 SW 역량테스트] 수영장 (0) | 2025.02.07 |