본문 바로가기

알고리즘/백준 문제 풀이

[Python] 2733번 Brainf*ck

728x90

 

 

https://www.acmicpc.net/problem/2733


 

24/11/18

 

 

입력 커맨드를 제외한 brainfuck 인터프리터를 구현하는 문제이다.

 

문제의 핵심 부분은 괄호 쌍을 찾아서 넘어가는 부분이다.


 

문제 접근 방식:

 

 

구현의 용이성을 위해, 먼저 파싱부분과 인터프리터 부분을 나눴다.

 

 

기본적으로 파싱부분은 하나의 테스트 케이스를 입력받아서 프로그램에 해당하는 하나의 문자열을 내뱉는 함수로 구현했다.

 

하나의 테스트 케이스는 'end'만 적혀 있는 문자열로 구분된다.

 

따라서 'end'라는 문자열이 입력되기 전 까지 계속해서 입력을 받는다.

 

파이썬은 줄바꿈을 기준으로 문자열을 입력받기 때문에, while True를 사용하여 'end'를 입력 받을 때까지 계속 입력받도록 했다.

 

이후, 주석의 처리, 공백의 제거, 유효하지 않은 커맨드(즉, 주어진 명령어 외의 문자열)등의 제거를 위해 입력받은 문자열을 하나하나 살펴가며 한 줄의 공백없는 프로그램 문자열을 생성했다.

 

그 문자열을 전체 프로그램 문자열에 반복하며 추가함으로써 전체 프로그램 문자열을 구성했다.

 

 

파싱 부분을 이렇게 간단하게 구성하고 인터프리터 부분을 구현하려고 보니, 괄호 문자열이 서로 매칭되지 않아서 컴파일 에러가 나오는 경우를 처리해야 한다는 사실을 알게 되었다.

 

어차피 인터프리터 부분에서 이걸 처리하는 것보다, 파싱부분에서 문자열을 읽을 때 바로 처리하는 것이 나을 것이라고 생각이 들었다.

 

괄호 문자열이 모두 매칭이 되려면, 문자열을 읽는 도중에 닫는 괄호의 개수가 더 많아지면 안되고, 문자열을 다 읽었을 때 열린 괄호의 개수와 닫힌 괄호의 개수가 같아야한다.

 

그래서 위의 기본 구조에서 매칭되지 않은 열린 괄호의 개수를 나타내는 변수 value를 더 추가했다.

 

여는 괄호가 나올 때 value의 값을 1증가시키고, 닫힌 괄호가 나올 때 value의 값을 1감소시키며, 중간에 value의 값이 음수가 되거나 모든 문자열을 다 읽었을 때 value가 0이 되지 않는다면 ValueError를 raise하도록 구성했다.

 

 

이후 인터프리터 부분은 주어진 문제 그대로 구현했다.

 

크기가 32768짜리인 바이트배열 0으로 초기화하여 선언했고, 그 바이트 배열을 가리키는 포인터 bp(바이트 포인터)또한 0으로 초기화하여 선언했다.

 

추가적으로 바이트 배열의 값은 255로 제한되어있으므로, 상수 ascii_mod = 255로 선언하였다.

 

파싱부분을 통해 하나의 문자열로 나온 프로그램을 하나하나씩 읽으며 명령어를 처리하도록 했다.

 

어차피 중간에 괄호를 통해 명령어를 읽는 부분을 옮기는 경우도 있으므로, 명령어를 가리키는 포인터 pp(프로그램 포인터)를 0으로 초기화하여 선언해주었다.

 

구현하다보니, 괄호 문자열이 나올 때 매칭이 되는 괄호문자열로 pp를 옮겨야했다.

 

 

좋은 구현방법을 고민하다가 어차피 파싱부분에서 이를 처리해서, 프로그램을 나타내는 문자열 외에도 괄호문자열을 어디로 옮겨야할 지 나타내는 딕셔너리도 반환하도록 하면 좋을 것 같다는 생각을 했다.

 

매칭이 되는 괄호 문자열들은 스택을 통해 찾을 수 있다.

 

따라서, 여는 괄호가 나온다면 스택에 그 여는 괄호가 프로그램의 몇번째 문자인지에 대한 숫자(즉, pp)를 push해준다.

 

닫힌 괄호가 나온다면 가장 최근에 push된 여는 괄호와 매칭되므로, 스택에서 pop을 하여 pp랑 매칭하여 튜플로 나타낸다.

 

이 튜플을 matching이라는 리스트에 모두 저장했다.

 

matching에 있는 모든 튜플을 순회하면서 여는 괄호를 닫힌 괄호로 바로 보낼 수 있는 mat_s_to_e와 닫힌 괄호를 여는 괄호로 보낼 수 있는 mat_e_to_s를 만들었고, 이 두 딕셔너리를 추가적으로 반환하도록 파싱부분을 수정했다.


아래는 내가 위의 접근 방식과 같이 작성한 파이썬 코드이다. 더보기를 누르면 확인할 수 있다.

더보기
# 2733번 Brainf*ck
# 구현, 파싱, 스택
import sys
input = sys.stdin.readline

def parsing():
    OK = set('><+-.[]')
    program = ''
    value = 0
    stack, matching = [], []
    pp = 0
    while True:
        line = input().strip()
        if line == 'end': break
        l = ''
        for i in line:
            if i == '%': break
            if i not in OK: continue
            if i == '[':
                value += 1
                stack.append(pp)
            if i == ']':
                value -= 1
                if value < 0: raise ValueError
                matching.append((stack.pop(), pp))
            l += i
            pp += 1
        program += l
    if value != 0: raise ValueError
    mat_s_to_e, mat_e_to_s = dict(), dict()
    for s, e in matching:
        mat_s_to_e[s] = e
        mat_e_to_s[e] = s
    return program, mat_s_to_e, mat_e_to_s
def solve(test_case_num):
    try:
        pointer_mod = 32768
        ascii_mod = 255
        program, mat_s_to_e, mat_e_to_s = parsing()
        byte_array = [0 for _ in range(pointer_mod)]
        bp, pp = 0, 0
        program_output = ''
        while pp < len(program):
            inst = program[pp]
            if inst == '>':
                bp += 1
                bp %= pointer_mod
            elif inst == '<':
                bp -= 1
                bp %= pointer_mod
            elif inst == '+':
                byte_array[bp] += 1
                byte_array[bp] %= ascii_mod
            elif inst == '-':
                byte_array[bp] -= 1
                byte_array[bp] %= ascii_mod
            elif inst == '.':
                program_output += chr(byte_array[bp])
            elif inst == '[':
                if byte_array[bp] == 0:
                    pp = mat_s_to_e[pp]
            else:
                if byte_array[bp] != 0:
                    pp = mat_e_to_s[pp]
            pp += 1
        print(f'PROGRAM #{test_case_num}:')
        print(program_output)
        return
    except ValueError:
        print(f'PROGRAM #{test_case_num}:')
        print('COMPILE ERROR')
        return
def main():
    N = int(input())
    for test_case_num in range(1, N+1):
        solve(test_case_num)
    return
main()