Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
Tags
- 부스트캠프 AI테크
- Numpy
- 정규분포 MLE
- Comparisons
- Python 유래
- pivot table
- subplot
- unstack
- 가능도
- 딥러닝
- groupby
- dtype
- ndarray
- python 문법
- seaborn
- 최대가능도 추정법
- boolean & fancy index
- type hints
- 카테고리분포 MLE
- scatter
- Array operations
- Numpy data I/O
- Operation function
- linalg
- BOXPLOT
- Python
- 표집분포
- VSCode
- Python 특징
- namedtuple
Archives
- Today
- Total
또르르's 개발 Story
[19-1] Byte Pair Encoding with Python 본문
1️⃣ Byte Pair Encoding
일반적으로 하나의 단어에 대해 하나의 embedding을 생성할 경우 out-of-vocabulary(OOV)라는 치명적인 문제를 갖게 됩니다. 학습 데이터에서 등장하지 않은 단어가 나오는 경우 Unknown token으로 처리해주어 모델의 입력으로 넣게 되면서 전체적으로 모델의 성능이 저하될 수 있습니다.
반면 모든 단어의 embedding을 만들기에는 필요한 embedding parameter의 수가 지나치게 많습니다.
이러한 문제를 해결하기 위해 컴퓨터가 이해하는 단어를 표현하는 데에 데이터 압축 알고리즘 중 하나인 byte pair encoding 기법을 적용한 sub-word tokenizaiton이라는 개념이 있습니다.
2️⃣ BPE 코드
1) 전체 코드
from typing import List, Dict, Set
from itertools import chain
import re
from collections import defaultdict, Counter
def get_stats(dictionary):
# 유니그램의 pair들의 빈도수 count
pairs = defaultdict(int)
for word, freq in dictionary.items():
symbols = word.split(' ')
for i in range(len(symbols)-1):
pairs[symbols[i], symbols[i+1]] += freq # i와 i+1 symbol를 묶어서 freq만큼 빈도수를 올려줌
# print('frequency of Current pairs : ', dict(pairs))
return pairs
def merge_char(pair, v_in):
v_out = {}
# ' '.join(pair) => e s
bigram = re.escape(' '.join(pair)) # 문자열을 입력받으면 특수문자들을 이스케이프 처리 "e\ s" (여기서 띄어쓰기를 escape 시킴)
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') # 특수/문자 + "e\ s" + 특수/문자 일 경우
for word in v_in:
w_out = p.sub(''.join(pair), word) # 패턴에 일치되는 문자열은 다른 문자열로 바꿔주는 것
# 여기서는 일치하는 e s가 있을 경우 es 이렇게 붙여서 다음 pair에서 붙여서 나오게 함
v_out[w_out] = v_in[word] # char_cnt의 빈도수를 v_out에 넣어줌
return v_out
def build_bpe(
corpus: List[str],
max_vocab_size: int
) -> List[int]:
# Special tokens
PAD = BytePairEncoding.PAD_token # Index of <PAD> must be 0
UNK = BytePairEncoding.UNK_token # Index of <UNK> must be 1
CLS = BytePairEncoding.CLS_token # Index of <CLS> must be 2
SEP = BytePairEncoding.SEP_token # Index of <SEP> must be 3
MSK = BytePairEncoding.MSK_token # Index of <MSK> must be 4
SPECIAL = [PAD, UNK, CLS, SEP, MSK]
WORD_END = BytePairEncoding.WORD_END # Use this token as the end of a word
idx2word = []
char_cnt = dict(Counter([' '.join([char for char in word]+[WORD_END]) for word in corpus]))
idx2word.extend(set([char for word in corpus for char in word]+[WORD_END]))
while len(idx2word) < max_vocab_size-len(SPECIAL): # SPECIAL는 마지막에 앞에 붙여주기 때문에 max_vocab_size-len(SPECIAL)까지 순회
try:
pairs = get_stats(char_cnt) # 문자 두개씩 쌍으로 묶은 빈도수를 구함
best = max(pairs, key=pairs.get) # pairs에서 빈도수가 가장 많은 pair를 구함
char_cnt = merge_char(best, char_cnt)
idx2word.append(''.join(best))
except: # 만약 max_vocab_size가 크게 들어오면 max에서 empty가 발생하고 except문으로 들어옴
break
return SPECIAL+sorted(idx2word, key=len, reverse=True)
2) 알고리즘 설명
우선 딕셔너리의 모든 단어들을 글자(chracter) 단위로 분리합니다.
# dictionary
l o w : 5, l o w e r : 2, n e w e s t : 6, w i d e s t : 3
여기서는 WORD_END를 '_'를 사용했습니다.
# dictionary
l o w _ : 5, l o w e r _ : 2, n e w e s t _ : 6, w i d e s t _ : 3
Dictionary의 초기 단어 집합은 "글자 단위로 분리"된 상태입니다.
# vocabulary
l, o, w, e, r, n, w, s, t, i, d
- STEP 1 : pairs dictionary를 생성합니다. pairs dictionary는 Vocabulary에서 나온 순서 (앞, 뒤)로 묶은 pair를 key로 갖고 빈도수를 value로 갖습니다. (get_stats 함수)
frequency of Current pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '_'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '_'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '_'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}
- STEP 2 : pairs dictionary에서 빈도수가 9로 가장 높은 (e, s)의 쌍을 es로 통합합니다. (merge_char 함수)
# dictionary update
l o w : 5,
l o w e r : 2,
n e w es t : 6,
w i d es t : 3
# vocabulary update
l, o, w, e, r, n, w, s, t, i, d, es
- STEP 3 : 다시 새로운 pairs dictionary에서 (w, es), (es, t)의 쌍이 생성된 것을 알 수 있습니다.
frequency of Current pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '_'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '_'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '_'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}
- STEP 4 : pairs dictionary에서 빈도수가 9로 가장 높은 (es, t)의 쌍을 est로 통합합니다.
# dictionary update
l o w : 5,
l o w e r : 2,
n e w est : 6,
w i d est : 3
# vocabulary update
l, o, w, e, r, n, w, s, t, i, d, es, est
위와 같은 방법으로 계속 수행하게되면 아래와 같은 결과를 얻을 수 있습니다.
# dictionary update
low : 5,
low e r : 2,
newest : 6,
widest : 3
# vocabulary update
l, o, w, e, r, n, w, s, t, i, d, es, est, lo, low, ne, new, newest, wi, wid, widest
'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글
[19-3] Masked Multi-head Attention Using PyTorch (0) | 2021.02.19 |
---|---|
[19-2] Multi-head Attention Using PyTorch (0) | 2021.02.19 |
[18-3] fairseq 사용하기 (0) | 2021.02.18 |
[18-2] Seq2Seq with Attention using PyTorch (0) | 2021.02.18 |
[17-2] 번역 모델 전처리 using PyTorch (0) | 2021.02.17 |