https://www.acmicpc.net/problem/15681
24/03/29
트리 DP를 연습할 수 있는 아주 좋은 문제다. 이전에 해결했던 문제를 복습하는 겸, 글로 작성해본다.
문제 접근 방식:
나이브하게 각 쿼리마다 위의 값을 일일히 구해서 내보내려고 한다면, 최악의 경우 $\mathcal{O}(QN)$이 소요된다.
각 쿼리마다 DFS를 한 번 진행할 것이고, 트리에서의 DFS는 $\mathcal{O}(N)$의 시간 복잡도가 소요되기 때문이다.
따라서, 한 번의 DFS를 진행하여, 그 결과를 저장해야한다.
즉, DP를 진행하는데, DFS를 진행하며 DP를 진행하면 된다.
탑다운 DP의 구현 방식과 비슷한데, 트리 DP는 서브 트리 별로 문제를 분할하는 것이 핵심이다.
DP 점화식을 다음과 같이 정의해보자.
$$DP[i] = i\text{번 노드가 서브트리의 루트일 때 서브트리에 속한 정점의 수}$$
dfs코드는 다음과 같이 짤 수 있다. 대부분의 트리 DP 문제들은 이러한 형식의 코드를 따르므로 잘 익혀둬보자.
def dfs(i):
visited[i] = 1
for below_node in tree[i]:
if visited[below_node] == 0:
dfs(below_node)
DP[i] += DP[below_node]
DP[i] += 1
이를 예제 입력을 통해 확인해보자.
예제 입력은 다음과 같은 트리이다.
여기서 위의 코드를 실행한 방문 순서 결과는 다음과 같다. 가독성을 위해 디버깅용 코드의 결과를 출력해보았다.
한 줄 한 줄 읽으며 따라가보자.
5번 노드 방문!
이 노드의 자식 노드는 4 6 입니다.
---- 4번 노드 탐색 시작
---- 4번 노드 방문!
---- 이 노드의 자식 노드는 3 입니다.
-------- 3번 노드 탐색 시작
-------- 3번 노드 방문!
-------- 이 노드의 자식 노드는 1 2 입니다.
------------ 1번 노드 탐색 시작
------------ 1번 노드 방문!
------------ 이 노드는 리프 노드입니다.
------------ 1번 노드의 값은 1입니다.
-------- 3번 노드의 값이 1만큼 상승
------------ 2번 노드 탐색 시작
------------ 2번 노드 방문!
------------ 이 노드는 리프 노드입니다.
------------ 2번 노드의 값은 1입니다.
-------- 3번 노드의 값이 1만큼 상승
-------- 3번 노드의 값은 3입니다.
---- 4번 노드의 값이 3만큼 상승
---- 4번 노드의 값은 4입니다.
5번 노드의 값이 4만큼 상승
---- 6번 노드 탐색 시작
---- 6번 노드 방문!
---- 이 노드의 자식 노드는 7 9 8 입니다.
-------- 7번 노드 탐색 시작
-------- 7번 노드 방문!
-------- 이 노드는 리프 노드입니다.
-------- 7번 노드의 값은 1입니다.
---- 6번 노드의 값이 1만큼 상승
-------- 9번 노드 탐색 시작
-------- 9번 노드 방문!
-------- 이 노드는 리프 노드입니다.
-------- 9번 노드의 값은 1입니다.
---- 6번 노드의 값이 1만큼 상승
-------- 8번 노드 탐색 시작
-------- 8번 노드 방문!
-------- 이 노드는 리프 노드입니다.
-------- 8번 노드의 값은 1입니다.
---- 6번 노드의 값이 1만큼 상승
---- 6번 노드의 값은 4입니다.
5번 노드의 값이 4만큼 상승
5번 노드의 값은 9입니다.
다음과 같이 한 번의 DFS를 통해 모든 DP배열의 값을 채울 수 있다.
이 결과를 각 쿼리마다 출력하면 충분히 빠른 시간 안에 구할 수 있다.
시간 복잡도는 $\mathcal{O}(2N-1 + Q)$가 소요된다.
아래는 내가 위의 접근 방식과 같이 작성한 파이썬 코드이다. 더보기를 누르면 확인할 수 있다.
# 15681번 트리와 쿼리
# tree DP
import sys
input = sys.stdin.readline
sys.setrecursionlimit(500_000)
N, R, Q = map(int, input().split())
tree = [[] for _ in range(N+1)]
for _ in range(N-1):
U, V = map(int, input().split())
tree[U].append(V)
tree[V].append(U)
DP = [0 for _ in range(N+1)]
visited = [0 for _ in range(N+1)]
def dfs(i):
visited[i] = 1
for below_node in tree[i]:
if visited[below_node] == 0:
dfs(below_node)
DP[i] += DP[below_node]
DP[i] += 1
dfs(R)
ans = []
for _ in range(Q):
ans.append(DP[int(input())])
print(*ans, sep='\n')
'알고리즘 > 백준 문제 풀이' 카테고리의 다른 글
[Python] 30959번 앙상블할래? (0) | 2024.05.24 |
---|---|
[Python] 31462번 삼각 초콜릿 포장 (Sweet) (0) | 2024.05.23 |
[Python] 28422번 XOR 카드 게임 (0) | 2024.05.21 |
[Python] 28464번 Potato (0) | 2024.05.20 |
[Python] 12850번 본대 산책2 (0) | 2024.05.19 |