본문 바로가기

알고리즘/백준 문제 풀이

[Python] 9027번 Stadium

728x90

 

 

https://www.acmicpc.net/problem/9027


 

24/07/22

 

 

지난 주 주간 연습 문제 중 하나로 뽑았던 문제인데, 개인적으로 누적 합 알고리즘을 적용하기에 좋은 문제인 것 같아서 쉽지만 해설을 작성해보고자 한다.


 

문제 접근 방식:

 

 

먼저 문제에서 요구하는 바를 보자.

 

$N$개의 마을 중 한 곳을 선택하여 돔 구장을 짓고자 하고, 각 마을마다 팬이 몇 명 살고있는지, 각 마을의 위치가 어떠한지 또한 주어진다. 각 마을의 위치는 오름차순으로 주어진다.

이때 문제에서 요구하는 것은 모든 팬들이 돔 구장으로 모일 때의 이동 거리의 합이 최소가 되도록 돔 구장을 지을 때, 돔 구장의 위치를 출력하는 것이다.

 

문제 설명의 편의 상 $i$번째 마을의 위치를 $v_i$, 돔 구장의 위치를 $d$, $i$번째 마을에 몇 명의 팬이 있는지를 $f_i$라고 표현하겠다.

 

이를 사용하여 문제의 요구 사항을 다음과 같이 재작성 할 수 있다.

 

$$\text{Minimize }total_d \text{ where }d \text{ is stadium position}$$

$$total_d = \sum _{i=1} ^{n} |v_i - d|\cdot f_i$$

 

가능한 $d$의 위치는 마을 위치와 동일하므로 $N$개가 가능하고, 각 마을마다 돔 구장을 지었을 때의 $total_d$를 각각 계산한다면 얼만큼의 시간이 걸릴까?

 

파이썬의 경우 초 당 약 2000만번 정도의 계산이 가능하다고 알려져 있다.

이 문제에서는 $N$의 제한이 최대 $100,000$이다.

 

위의 식을 통해 각 마을마다 다른 $total_d$를 한 번 계산할 때마다 $N$번 계산하게 되므로, 시간 복잡도는 $\mathcal{O}(N^2)$, 즉, 약 100억번 정도의 연산을 해야한다.

 

당연히 시간을 줄이는 방법을 써야한다. 이럴 때 쓸 수 있는 테크닉이 누적 합이다.

 

위의 식을 조금 만 더 정리해보자.

 

우리는 돔 구장을 기준으로 하여 $v_i - d$의 값이 음/양으로 부호가 바뀜을 알고 있다.

 

또한, 돔 구장의 위치 $d$는 마을의 위치 중 하나이므로, $k = \text{index of }d$라고 해보자.

 

이를 통해 식을 정리하면 다음과 같다.

 

$$\sum_{i=1}^{k} (df_i - v_i f_i) + \sum_{i=k+1}^{n}(v_i f_i - df_i)$$

$$total_d = -\sum_{i=1}^{k} v_if_i + \sum_{i=1}^{k} df_i - \sum_{i=k+1}^{n}df_i + \sum_{i=k+1}^{n} v_if_i$$

 

$\sum v_i f_i$와 $\sum f_i$가 중복되어 보임을 확인할 수 있다.

 

따라서, $p=1$부터 $n$까지 모든 $\sum_{i=1}^{p} v_if_i$의 값들을 미리 구해 놓아 저장한 다음, $\sum_{i=k+1}^{n}$의 값이 필요하다면, $\sum_{i=1}^{n} - \sum_{i=1}^{k}$를 주는 것으로 중복된 계산들을 피할 수 있다. (이것이 누적 합 알고리즘이다.)

 

이렇게 구현한다면 시간 복잡도는 $\mathcal{O}(N^2)$에서 $\mathcal{O}(N)$으로 줄어들게 된다.

 

즉, $\mathcal{O}(N)$만큼의 메모리를 희생하는 대신, $\mathcal{O}(N)$만큼의 시간을 더욱 확보할 수 있는 것이다.


아래는 내가 위의 접근 방식과 같이 작성한 파이썬 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
# 9027번 Stadium
# 누적합
import sys
input = sys.stdin.readline

T = int(input())

def solve():
    N = int(input())
    village = list(map(int, input().split()))
    fan = list(map(int, input().split()))
    villagefan = [village[i]*fan[i] for i in range(N)]
    accum_villagefan = [0]
    accum_fan = [0]
    for i in range(N):
        accum_villagefan.append(villagefan[i] + accum_villagefan[-1])
        accum_fan.append(fan[i] + accum_fan[-1])
    min_val, min_idx = 1_000_000_000_000_000_000, 0
    for k in range(N):
        d = village[k]
        p = accum_villagefan[k+1]
        q = accum_villagefan[-1] - p
        r = d*accum_fan[k+1]
        s = d*accum_fan[-1] - r
        total_d = -p + r - s + q
        if total_d < min_val:
            min_val = total_d
            min_idx = k
    print(village[min_idx])

for _ in range(T):
    solve()