본문 바로가기

알고리즘/백준 문제 풀이

[Python] 31532번 선형 회귀는 너무 쉬워 3

728x90

https://www.acmicpc.net/problem/31532

 

31532번: 선형 회귀는 너무 쉬워 3

첫 번째 줄에 $f_3(a)$의 값이 $0$에 가장 가깝게 하는 $a$, 즉, $a_3$을 출력한다. 답이 여러 가지라면 그중 아무거나 하나 출력한다. 가능한 정답 중 최소 하나 이상과의 절대오차 또는 상대오차가 $10

www.acmicpc.net


 

24/03/12

 

 

지난 겨울 MatKor Cup에 나왔던 문제로, 단순한 수치해석 문제입니다.

 

선형 회귀는 너무 쉬워 시리즈의 3번째 문제에 해당됩니다.

 

서브태스크 별로 문제 해설을 해보겠습니다.


 

문제 접근 방식:

 

 

  • 서브태스크 1 ($b = 0;y_i = 0$)

모든 데이터(점)들의 $y$값 들이 $0$이고, 직선의 $y$절편이 $0$으로 주어져 있습니다.

 

이를 설명하는 가장 적합한 직선의 기울기 $a_3$은 당연히 $0$일 것이므로, $0$을 출력하면 $10$점을 얻을 수 있습니다.

 

  • 서브태스크 2 ($b = 0; -10 \leq y_i \leq 10$)

직선의 $y$절편이 $0$으로 주어진 상황입니다.

 

점들의 $y$값이 최소 $-10$이고, 최대 $10$이므로, 이 점을 '잘' 지나도록 하는 직선의 기울기 $a_3$은 최소 $-10$에서 최대 $10$임을 알 수 있습니다.($\because x_i \geq 1$)

 

허용 가능한 상대 오차는 $10^{-7}$이므로, $-10$부터 시작하여 $10^{-7}$씩 올려가며 완전탐색을 해볼 수 있습니다.

 

즉, 총 $2\cdot10^9$번의 탐색을 진행합니다. 또 다른 접근으로는 이 $a_3$값을 이분 탐색을 통해 찾을 수도 있습니다.

 

이분 탐색이 통할 것이라는 확신을 가지고 증명까지 했다면 사실 문제를 다 푼 것이나 다름이 없긴 하지만... 어쨌든 서브태스크 2는 그런 식으로 해결할 수 있습니다.

 

참고로 검수 당시에 서브태스크 2만 해결하는 코드는 짜지 못했습니다.

 

  • 서브태스크 3 ($b=0; y_i = x_i \text{ or } y_i=0$)

서브태스크 2와 마찬가지 방법으로 접근할 수 있습니다. 이 경우 직선의 기울기 $a_3$의 범위는 $0$부터 $1$까지 이므로 더 쉬운 접근을 할 수 있습니다.

 

  • 서브태스크 4 ($n \leq 2$)

점이 $1$개일 경우, 그 점을 지나도록 $a_3$을 조정해 주면 됩니다.

 

점이 $2$개일 경우, 식을 잘 정리해서 풀 수 있습니다.(근데 이 식 정리를 하는 아이디어를 떠올릴 정도면, 정해는 금방 떠올릴 수 있을 것 같습니다.)

 

  • 서브태스크 5 (정해)

사실 저 식은 $0$이 되도록 하는 $a_3$의 값을 무조건 찾을 수 있습니다.

 

그리고 이는 하나로 유일합니다.

 

따라서 이분 탐색을 사용하여 문제를 풀 수 있습니다.

 

그 이유는, 위의 식을 미분해 봄으로써 확인할 수 있습니다.

 

위의 식은 다음과 같습니다.

 

$$\begin{align}f_{3}(a) &= \sum_{i=1}^{n} (y_{i}-ax_{i}-b)^{3} \\&= a^{3}\sum_{i=1}^{n} x^{3} + 3a^{2}\sum_{i=1}^{n}x^{2}(b-y) + 3a\sum_{i=1}^{n}x(b-y)^{2}+\sum_{i=1}^{n}(b-y)^{3}\end{align}$$

 

여기서, $\sum_{i=1}^{n} x^{3}$는 어떤 특정한 상수이므로, 이를 모두 나누어줍시다.

 

$$\begin{align}f_{3}(a) &= a^{3} + 3a^{2}\underbrace{ \sum_{i=1}^{n}\left( \frac{b-y}{x} \right) }_{ p } + 3a\underbrace{ \sum_{i=1}^{n}\left( \frac{b-y}{x} \right)^{2} }_{ q }+\underbrace{ \sum_{i=1}^{n}\left( \frac{b-y}{x} \right)^{3} }_{ r } \\&=a^{3}+3pa^{2}+3qa+r \\f_{3}'(a) &=3a^{2}+6pa+3q \rightarrow a^{2}+2pa+q \\\frac{D}{4} &= p^{2}-q < 0 (\because p^{2}\text{ is square of sum}, q \text{ is sum of square})\end{align}$$

 

미분하고 판별식을 써보면 $0$보다 작음을 쉽게 확인할 수 있습니다. $0$보다 작은 이유는 합의 제곱보다 제곱의 합이 더 크기 때문입니다.

 

따라서, $0$이 되는 경우는 오직 한 경우 밖에 없음을 쉽게 확인할 수 있습니다.

 

이분 탐색을 할 때에는 정밀도를 잘 따져주어서 구현해야 함에 주의합시다. 파이썬에서는 decimal 모듈을 사용하여 정밀하게 구현할 수 있습니다. 저의 경우, 정밀도를 50자리까지 하여 맞을 수 있었습니다. 


아래는 내가 위의 접근 방식과 같이 작성한 파이썬 코드이다. 더 보기를 누르면 확인할 수 있다.

더보기
# 선형 회귀는 너무 쉬워 3
import sys
input = sys.stdin.readline
from decimal import *
getcontext().prec = 50

N, b = map(int, input().split())
p, q, r, s = 0, 0, 0, 0
for _ in range(N):
    x, y = map(int, input().split())
    p += x*x*x
    q += x*x*(b-y)
    r += x*(b-y)*(b-y)
    s += (b-y)*(b-y)*(b-y)
def f(a):
    return a*a*a*p + 3*a*a*q + 3*a*r + s
start, end = Decimal(-1_000_000_000), Decimal(1_000_000_000)
for _ in range(1000):
    mid = (start + end)/Decimal(2)
    if f(mid) > 0:
        end = mid
    else:
        start = mid
print(mid)