또르르's 개발 Story

[19-1] Byte Pair Encoding with Python 본문

부스트캠프 AI 테크 U stage/실습

[19-1] Byte Pair Encoding with Python

또르르21 2021. 2. 19. 03:17

1️⃣ Byte Pair Encoding

 

일반적으로 하나의 단어에 대해 하나의 embedding을 생성할 경우 out-of-vocabulary(OOV)라는 치명적인 문제를 갖게 됩니다. 학습 데이터에서 등장하지 않은 단어가 나오는 경우 Unknown token으로 처리해주어 모델의 입력으로 넣게 되면서 전체적으로 모델의 성능이 저하될 수 있습니다.

 

반면 모든 단어의 embedding을 만들기에는 필요한 embedding parameter의 수가 지나치게 많습니다. 

 

이러한 문제를 해결하기 위해 컴퓨터가 이해하는 단어를 표현하는 데에 데이터 압축 알고리즘 중 하나인 byte pair encoding 기법을 적용한 sub-word tokenizaiton이라는 개념이 있습니다.

 

2️⃣ BPE 코드

1) 전체 코드

from typing import List, Dict, Set

from itertools import chain

import re

from collections import defaultdict, Counter



def get_stats(dictionary):

    # 유니그램의 pair들의 빈도수 count
    
    pairs = defaultdict(int)
    
    for word, freq in dictionary.items():
    
        symbols = word.split(' ')
        
        for i in range(len(symbols)-1):
        
            pairs[symbols[i], symbols[i+1]] += freq     # i와 i+1 symbol를 묶어서 freq만큼 빈도수를 올려줌
            
    # print('frequency of Current pairs : ', dict(pairs))
    
    return pairs
    


def merge_char(pair, v_in):

    v_out = {}
    
    # ' '.join(pair) => e s 
    
    bigram = re.escape(' '.join(pair))    # 문자열을 입력받으면 특수문자들을 이스케이프 처리 "e\ s" (여기서 띄어쓰기를 escape 시킴)
    
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')   # 특수/문자 + "e\ s" + 특수/문자 일 경우 
    
    for word in v_in:
    
        w_out = p.sub(''.join(pair), word)    # 패턴에 일치되는 문자열은 다른 문자열로 바꿔주는 것
        
                                              # 여기서는 일치하는 e s가 있을 경우 es 이렇게 붙여서 다음 pair에서 붙여서 나오게 함
                                              
        v_out[w_out] = v_in[word]             # char_cnt의 빈도수를 v_out에 넣어줌
        
    return v_out
    


def build_bpe(

        corpus: List[str],
        
        max_vocab_size: int
        
) -> List[int]:


    # Special tokens
    
    PAD = BytePairEncoding.PAD_token  # Index of <PAD> must be 0
    
    UNK = BytePairEncoding.UNK_token  # Index of <UNK> must be 1
    
    CLS = BytePairEncoding.CLS_token  # Index of <CLS> must be 2
    
    SEP = BytePairEncoding.SEP_token  # Index of <SEP> must be 3
    
    MSK = BytePairEncoding.MSK_token  # Index of <MSK> must be 4
    
    SPECIAL = [PAD, UNK, CLS, SEP, MSK]
    

    WORD_END = BytePairEncoding.WORD_END  # Use this token as the end of a word
    

    idx2word = []


    char_cnt = dict(Counter([' '.join([char for char in word]+[WORD_END]) for word in corpus]))
    
    idx2word.extend(set([char for word in corpus for char in word]+[WORD_END]))
    
    while len(idx2word) < max_vocab_size-len(SPECIAL):    # SPECIAL는 마지막에 앞에 붙여주기 때문에 max_vocab_size-len(SPECIAL)까지 순회
    
      try:
      
          pairs = get_stats(char_cnt)         # 문자 두개씩 쌍으로 묶은 빈도수를 구함
          
          best = max(pairs, key=pairs.get)    # pairs에서 빈도수가 가장 많은 pair를 구함
          
          char_cnt = merge_char(best, char_cnt)
          
          idx2word.append(''.join(best))
          
      except:       # 만약 max_vocab_size가 크게 들어오면 max에서 empty가 발생하고 except문으로 들어옴
      
          break
          

    return SPECIAL+sorted(idx2word, key=len, reverse=True)

2) 알고리즘 설명

우선 딕셔너리의 모든 단어들을 글자(chracter) 단위로 분리합니다.

# dictionary

l o w : 5,  l o w e r : 2,  n e w e s t : 6,  w i d e s t : 3

여기서는 WORD_END를 '_'를 사용했습니다.

# dictionary

l o w _ : 5,  l o w e r _ : 2,  n e w e s t _ : 6,  w i d e s t _ : 3

Dictionary의 초기 단어 집합은 "글자 단위로 분리"된 상태입니다.

# vocabulary

l, o, w, e, r, n, w, s, t, i, d

 


  • STEP 1 : pairs dictionary를 생성합니다. pairs dictionary는 Vocabulary에서 나온 순서 (앞, 뒤)로 묶은 pair를 key로 갖고 빈도수를 value로 갖습니다. (get_stats 함수)
frequency of Current pairs :  {('l', 'o'): 7, ('o', 'w'): 7, ('w', '_'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '_'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '_'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}
  • STEP 2 : pairs dictionary에서 빈도수가 9로 가장 높은 (e, s)의 쌍을 es로 통합합니다. (merge_char 함수)
# dictionary update

l o w : 5,

l o w e r : 2,

n e w es t : 6,

w i d es t : 3
# vocabulary update

l, o, w, e, r, n, w, s, t, i, d, es

 

  • STEP 3 : 다시 새로운 pairs dictionary에서 (w, es), (es, t)의 쌍이 생성된 것을 알 수 있습니다.
frequency of Current pairs :  {('l', 'o'): 7, ('o', 'w'): 7, ('w', '_'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '_'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '_'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}
  • STEP 4 : pairs dictionary에서 빈도수가 9로 가장 높은 (es, t)의 쌍을 est로 통합합니다.
# dictionary update

l o w : 5,

l o w e r : 2,

n e w est : 6,

w i d est : 3
# vocabulary update

l, o, w, e, r, n, w, s, t, i, d, es, est

 

위와 같은 방법으로 계속 수행하게되면 아래와 같은 결과를 얻을 수 있습니다.

# dictionary update

low : 5,

low e r : 2,

newest : 6,

widest : 3
# vocabulary update

l, o, w, e, r, n, w, s, t, i, d, es, est, lo, low, ne, new, newest, wi, wid, widest
Comments