본문 바로가기

알고리즘/백준 문제 풀이

[Python] 17367번 공교육 도박

728x90

https://www.acmicpc.net/problem/17367


 

24/07/05

 

 

살짝 변형된 확률 DP문제이다. 기댓값의 의미를 되살리며 문제를 해결해보자.


 

문제 접근 방식:

 

 

모든 게임의 상태는 마지막 주사위 눈금 3개에 의해 결정되고, 주사위를 몇 번 던질 수 있는 지에 의해 결정이 된다.

 

따라서 4차원 확률 DP로 문제를 해결할 수 있다.

 

먼저 마지막 주사위 눈금 3개를 입력받으면 이에 대한 상금을 반환하는 prize(d1, d2, d3)를 구현하자.

 

수찬이가 받을 수 있는 상금의 기댓값이 최대가 되도록 게임을 한다는 말이 무슨 말일까?

 

이 말은, 현재 상태에서 수찬이가 주사위를 굴려서 상금의 기댓값이 올라간다면 주사위를 굴리고, 그렇지 않다면 현재 상태에서 얻을 수 있는 상금이 최댓값이 된다는 뜻이다.

 

따라서, 마지막 주사위 눈금 3개와 주사위를 몇 번 더 던질 수 있는지의 여부를 모두 저장하여 다음과 같은 DP식을 정의하자.

 

$$DP[d_1][d_2][d_3][i] = \text{최근 주사위의 눈금이 }d_1, d_2, d_3\text{이고 주사위를 } i\text{번 더 굴릴 수 있을 때 상금의 최댓값}$$

 

점화식은 다음과 같다.

 

$$DP[d_1][d_2][d_3][i] = \max (\text{멈췄을 때의 상금}, \text{한 번 더 던졌을 때의 상금})$$

 

구체적으로, 다음과 같이 표현할 수 있다.

 

$$\text{멈췄을 때의 상금} = \text{prize}(d_1, d_2, d_3)$$

 

한 번 더 던졌을 경우, $i$번 던질 수 있는 기회는 하나 줄어든다. 또한 $1/6$이라는 "균등한"확률로 갈 수 있으므로 한 번 더 던졌을 때의 상금의 "기댓값"은 다음과 같다.

 

$$\text{한 번 더 던졌을 때의 상금의 기댓값} = \sum (DP[new\_ roll][d_1][d_2][i-1]/6) $$

 

이를 그대로 구현하면 된다.

 

최종 답을 구할 때는 $\sum (DP[d_1][d_2][d_3][N-3]/216)$을 구하면 된다.

 

그 이유는 $3$번을 던지기 전까지는 DP식 자체를 논할 수 없고 따라서 처음부터 $N$번이 남은 것이 아니라 $N-3$번이 남았다고 생각해야 옳기 때문이다.

 

그리고 그 각각의 모든 경우(216가지)에 대한 기댓값이므로 균등하게 $216$으로 나누어야 하기 때문이다.

 

추가로, 위의 DP식이 전이되는 모습을 그래프로도 그릴 수 있는데, 이는 다음과 같다.

 

 

위의 그림은 $d_1, d_2, d_3 = 1, 1, 4$인 상황에서 주사위 $1$을 굴려 기댓값이 $d_1, d_2, d_3 = 1, 1, 1$로 전이되는 모습을 표현한 것이다.(나머지 간선들은 모두 생략)


아래는 내가 위의 접근 방식과 같이 작성한 파이썬 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
# 17367번 공교육 도박
# 확률론, DP
'''
접근 방식:
확률 DP
게임의 상태는 마지막 주사위 3개에 의해 결정지어짐을 생각하자.
마지막 주사위 3개에 의해 결정되는 상금을 내뱉는 함수 Score를 생각하자.
다음과 같은 DP식을 생각해보자.
DP[d1][d2][d3][i]
이는 주사위를 i번 더 던질 수 있고,
주사위의 눈이 d1, d2, d3가 나왔을 때 얻을 수 있는 상금의 기댓값이다.
최선의 전략으로 주사위를 던진다는 말은 무슨 말일까?
멈추는게 더 높은 상금을 얻는다면 멈추고, 그렇지 않다면 한번 더 던진다는 뜻
따라서 점화식은 다음과 같다.
DP[d1][d2][d3][i] =
max(멈췄을 때의 상금, 한번 더 던졌을 때의 상금 = i-1번 남았을 때의 상금)
'''
import sys
input = sys.stdin.readline

N = int(input())

def prize(d1, d2, d3):
    if d1 == d2 and d2 == d3:
        return 10000 + 1000*d1
    if d1 == d2:
        return 1000 + 100*d1
    if d2 == d3:
        return 1000 + 100*d2
    if d1 == d3:
        return 1000 + 100*d3
    return 100*max(d1, d2, d3)

DP = [[[[0 for _ in range(N)]
        for _ in range(7)]
       for _ in range(7)]
      for _ in range(7)]

# 초기항
for d1 in range(1, 7):
    for d2 in range(1, 7):
        for d3 in range(1, 7):
            DP[d1][d2][d3][0] = prize(d1, d2, d3)

for n in range(1, N-2):  # 3번은 이미 던졌다고 생각하자.
    for d1 in range(1, 7):
        for d2 in range(1, 7):
            for d3 in range(1, 7):
                S = 0
                for new_roll in range(1, 7):
                    S += DP[new_roll][d1][d2][n-1] / 6
                DP[d1][d2][d3][n] = max(S, prize(d1, d2, d3))

ans = 0
for d1 in range(1, 7):
    for d2 in range(1, 7):
        for d3 in range(1, 7):
            ans += DP[d1][d2][d3][N-3] / 216
print(ans)