일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
- Numpy data I/O
- ndarray
- Python 특징
- boolean & fancy index
- Operation function
- scatter
- python 문법
- namedtuple
- type hints
- 표집분포
- 딥러닝
- Python 유래
- unstack
- Python
- Numpy
- 카테고리분포 MLE
- Array operations
- 정규분포 MLE
- groupby
- VSCode
- Comparisons
- linalg
- BOXPLOT
- 부스트캠프 AI테크
- subplot
- 가능도
- 최대가능도 추정법
- dtype
- seaborn
- pivot table
- Today
- Total
또르르's 개발 Story
[19-3] Masked Multi-head Attention Using PyTorch 본문
1️⃣ 설정
필요한 모듈을 import 합니다.
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import torch
import math
2️⃣ 데이터 전처리
데이터를 생성합니다.
pad_id = 0
vocab_size = 100
data = [
[62, 13, 47, 39, 78, 33, 56, 13],
[60, 96, 51, 32, 90],
[35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
[66, 88, 98, 47],
[77, 65, 51, 77, 19, 15, 35, 19, 23]
]
padding을 수행합니다.
def padding(data):
max_len = len(max(data, key=len))
print(f"Maximum sequence length: {max_len}")
for i, seq in enumerate(tqdm(data)):
if len(seq) < max_len:
data[i] = seq + [pad_id] * (max_len - len(seq))
return data, max_len
data, max_len = padding(data)
>>> data
[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
[60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
[35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
[66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
[77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]
3️⃣ Hyperparameter 세팅 및 embedding
Model의 hidden state size 과 head의 개수를 정해줍니다.
d_model = 8 # model의 hidden size
num_heads = 2 # head의 개수
inf = 1e12
vocab_size로 입력을 받아 d_model size로 출력하는 embedding을 만들어줍니다.
embedding = nn.Embedding(vocab_size, d_model)
# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data) # (B, L)
batch_emb = embedding(batch) # (B, L, d_model)
>>> print(batch_emb.shape)
torch.Size([5, 10, 8])
4️⃣ Mask 구축
Mask는 아직 예측하지 않은 Word에 대해 차단을 하는 방법입니다.
아래 그림과 같이 Self-Attention을 수행하면서 빗금 칠해진 곳을 masking해주어야 합니다.
만약, 4,5번째 word가 <PAD>라면 4,5도 같이 masking 합니다.
코드로 작성하면 True는 attention이 적용될 부분, False는 masking될 자리입니다.
>>> print(batch)
tensor([[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
[60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
[35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
[66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
[77, 65, 51, 77, 19, 15, 35, 19, 23, 0]])
>>> print(pad_id)
0
pad_id=0이기 때문에 batch에서 값이 있는 부분은 True, 0인 부분은 False가 됩니다.
unsqueeze를 통해 가운데에 1을 넣어 input size를 맞춥니다.
padding_mask = (batch != pad_id).unsqueeze(1) # (B, 1, L)
# batch에서 0이 아닌 부분만 True, 0인 부분은 False
# True는 attention을 허가한 부분, False는 masking 부분
>>> print(padding_mask)
tensor([[[ True, True, True, True, True, True, True, True, False, False]],
[[ True, True, True, True, True, False, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, True]],
[[ True, True, True, True, False, False, False, False, False, False]],
[[ True, True, True, True, True, True, True, True, True, False]]])
>>> print(padding_mask.shape)
torch.Size([5, 1, 10])
torch.tril을 사용하면 삼각형 반쪽 아래를 True로 만들어줍니다.
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool) # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask) # (1, L, L)
# triangle low (tril)를 사용하면 삼각형 반쪽 아래를 True로 만들어줌
>>> print(nopeak_mask)
tensor([[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]]])
>>> print(nopeak_mask.shape)
torch.Size([1, 10, 10])
따라서 torch.tril을 사용한 nopeak_mask와 padding_maxk와 &를 사용해 둘이 True인 부분만 True로 만듭니다.
mask = padding_mask & nopeak_mask # (B, 1, L) (B, L, L)
# 둘이 함께 True인 부분만 True
# padding mask가 row로 element-wise 되고, nopeak_mask로 &연산을 수행함
>>> print(mask)
tensor([[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, False, False]],
[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False]],
[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]],
[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False]],
[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False]]])
>>> print(mask.shape)
torch.Size([5, 10, 10])
5️⃣ Linear transformation & 여러 head로 나누기
Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)
w_0 = nn.Linear(d_model, d_model)
Q, k, v를 num_head개의 차원 분할된 여러 vector로 만듭니다.
q = w_q(batch_emb) # (B, L, d_model)
k = w_k(batch_emb) # (B, L, d_model)
v = w_v(batch_emb) # (B, L, d_model)
batch_size = q.shape[0]
d_k = d_model // num_heads
q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
q = q.transpose(1, 2) # (B, num_heads, L, d_k)
k = k.transpose(1, 2) # (B, num_heads, L, d_k)
v = v.transpose(1, 2) # (B, num_heads, L, d_k)
>>> print(q.shape)
torch.Size([5, 2, 10, 4])
>>> print(k.shape)
torch.Size([5, 2, 10, 4])
>>> print(v.shape)
torch.Size([5, 2, 10, 4])
6️⃣ Masking이 적용된 self-attention 구현
각 head에서 실행되는 self-attetion 과정입니다.
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
각 scores 값이 나온 후 아직 예측되지 않은 word에 대한 scores 값을 -inf로 만들어야합니다(-inf로 보내 softmax에서 확실하게 0에 수렴할 수 있게). masked_fill_ 함수를 사용해서 False인 자리는 -inf로 만들어 버립니다.
masks = mask.unsqueeze(1) # (B, 1, L, L)
# (B, num_heads, L, L) 와 (B, 1, L, L)에서 num_heads는 element-wise
# masked_fill_ 함수를 사용해서 masks == false인 자리는 -무한대로 처리
# 무한대로 하는 이유는 softmax를 통과하면서 매우 작은 값이기 때문에 거의 0으로 수렴
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf) # (B, num_heads, L, L)
>>>print(masked_attn_scores)
tensor([[[[ 1.3143e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 8.9403e-01, 8.6426e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 5.1270e-01, 3.3004e-01, -1.2025e-01, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 1.5980e-01, 1.5570e-01, 3.0258e-02, 1.6243e-02, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-3.4218e-01, -2.0450e-01, -2.8070e-01, -1.8888e-01, 7.9033e-02,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-2.4986e-01, -1.7202e-01, -2.0935e-01, -3.9877e-01, 9.3217e-02,
-1.0008e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-8.9391e-01, -6.9801e-01, -3.8013e-01, -3.2609e-01, 2.0198e-01,
1.5171e-01, 4.4246e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 8.9403e-01, 8.6426e-01, 1.5397e-01, 2.4023e-01, -2.0241e-01,
-3.3018e-01, -4.1745e-01, 8.6426e-01, -1.0000e+12, -1.0000e+12],
[-6.1210e-01, -4.9643e-01, -1.7917e-01, -8.4149e-02, 1.2341e-01,
1.8224e-01, 1.4375e-01, -4.9643e-01, -1.0000e+12, -1.0000e+12],
[-6.1210e-01, -4.9643e-01, -1.7917e-01, -8.4149e-02, 1.2341e-01,
1.8224e-01, 1.4375e-01, -4.9643e-01, -1.0000e+12, -1.0000e+12]],
...
[[[-3.4352e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-7.4922e-02, -1.1854e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-4.5488e-01, -2.8107e-01, -7.6803e-02, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-3.4352e-01, -2.5569e-01, -5.3718e-02, -3.4352e-01, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 1.6794e-01, 1.2540e-01, 9.9537e-03, 1.6794e-01, -3.4771e-01,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-8.6918e-02, -2.4528e-02, 1.9679e-03, -8.6918e-02, -1.1026e-01,
-9.3543e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 2.9392e-02, 4.2141e-02, 1.5314e-02, 2.9392e-02, -1.6987e-01,
-8.8894e-02, -1.3360e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 1.6794e-01, 1.2540e-01, 9.9537e-03, 1.6794e-01, -3.4771e-01,
-2.3102e-01, -1.7695e-02, -3.4771e-01, -1.0000e+12, -1.0000e+12],
[ 1.3523e-01, 8.7783e-02, 7.0304e-02, 1.3523e-01, 2.0379e-03,
2.2378e-01, -2.3160e-02, 2.0379e-03, 5.1986e-01, -1.0000e+12],
[-3.3574e-01, -2.3610e-01, -7.8838e-02, -3.3574e-01, 2.7018e-01,
-4.6749e-02, -2.8217e-01, 2.7018e-01, -4.6545e-01, -1.0000e+12]],
[[-3.6017e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-5.0642e-01, 2.1329e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 5.9064e-02, 9.1698e-02, 2.9015e-01, -1.0000e+12, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[-3.6017e-02, -1.5189e-01, -1.7008e-01, -3.6017e-02, -1.0000e+12,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 3.0850e-01, -5.6023e-01, -5.6363e-01, 3.0850e-01, -8.2863e-01,
-1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 1.3936e-01, -2.2218e-01, -7.3533e-02, 1.3936e-01, -9.0720e-02,
-5.8827e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 2.6230e-02, -5.8277e-01, -3.1651e-01, 2.6230e-02, -1.2308e+00,
-5.0647e-01, -2.7437e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
[ 3.0850e-01, -5.6023e-01, -5.6363e-01, 3.0850e-01, -8.2863e-01,
-6.3746e-01, -5.0875e-01, -8.2863e-01, -1.0000e+12, -1.0000e+12],
[-3.0797e-01, 3.7385e-02, -4.1402e-01, -3.0797e-01, -1.2205e+00,
-7.0539e-01, -4.6010e-01, -1.2205e+00, -2.5901e-01, -1.0000e+12],
[-2.8654e-01, 3.7312e-01, 1.9621e-01, -2.8654e-01, 1.3229e-01,
1.3978e-01, 2.8992e-01, 1.3229e-01, -1.6794e-01, -1.0000e+12]]]],
grad_fn=<MaskedFillBackward0>)
>>> print(masked_attn_scores.shape)
torch.Size([5, 2, 10, 10])
-1* inf로 masking된 부분은 softmax 후 0이 됩니다.
attn_dists = F.softmax(masked_attn_scores, dim=-1) # (B, num_heads, L, L)
>>> print(attn_dists)
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.5074, 0.4926, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.4230, 0.3524, 0.2246, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2673, 0.2662, 0.2348, 0.2316, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.1695, 0.1945, 0.1802, 0.1976, 0.2583, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.1526, 0.1649, 0.1589, 0.1315, 0.2150, 0.1772, 0.0000, 0.0000,
0.0000, 0.0000],
[0.0654, 0.0795, 0.1093, 0.1154, 0.1956, 0.1860, 0.2488, 0.0000,
0.0000, 0.0000],
[0.2068, 0.2007, 0.0987, 0.1075, 0.0691, 0.0608, 0.0557, 0.2007,
0.0000, 0.0000],
[0.0775, 0.0869, 0.1194, 0.1313, 0.1616, 0.1714, 0.1649, 0.0869,
0.0000, 0.0000],
[0.0775, 0.0869, 0.1194, 0.1313, 0.1616, 0.1714, 0.1649, 0.0869,
0.0000, 0.0000]],
...
[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.5109, 0.4891, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2740, 0.3260, 0.3999, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2258, 0.2466, 0.3018, 0.2258, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2268, 0.2173, 0.1936, 0.2268, 0.1354, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.1632, 0.1737, 0.1784, 0.1632, 0.1594, 0.1621, 0.0000, 0.0000,
0.0000, 0.0000],
[0.1525, 0.1545, 0.1504, 0.1525, 0.1250, 0.1355, 0.1296, 0.0000,
0.0000, 0.0000],
[0.1537, 0.1473, 0.1312, 0.1537, 0.0918, 0.1031, 0.1276, 0.0918,
0.0000, 0.0000],
[0.1104, 0.1053, 0.1035, 0.1104, 0.0966, 0.1206, 0.0942, 0.0966,
0.1622, 0.0000],
[0.0882, 0.0975, 0.1141, 0.0882, 0.1617, 0.1178, 0.0931, 0.1617,
0.0775, 0.0000]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.3275, 0.6725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.3037, 0.3137, 0.3826, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2656, 0.2365, 0.2323, 0.2656, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.3166, 0.1328, 0.1324, 0.3166, 0.1016, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.1953, 0.1361, 0.1579, 0.1953, 0.1552, 0.1602, 0.0000, 0.0000,
0.0000, 0.0000],
[0.2055, 0.1118, 0.1459, 0.2055, 0.0585, 0.1206, 0.1522, 0.0000,
0.0000, 0.0000],
[0.2321, 0.0974, 0.0970, 0.2321, 0.0744, 0.0901, 0.1025, 0.0744,
0.0000, 0.0000],
[0.1299, 0.1835, 0.1169, 0.1299, 0.0522, 0.0873, 0.1116, 0.0522,
0.1365, 0.0000],
[0.0767, 0.1484, 0.1243, 0.0767, 0.1166, 0.1175, 0.1366, 0.1166,
0.0864, 0.0000]]]], grad_fn=<SoftmaxBackward>)
>>> print(attn_dists.shape)
torch.Size([5, 2, 10, 10])
value vector를 곱해 self-attention을 완성합니다.
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)
>>> print(attn_values.shape)
torch.Size([5, 2, 10, 4])
7️⃣ 전체코드
위의 과정을 모두 합쳐 하나의 Masked Multi-head attention 모듈을 구현하겠습니다.
class MultiheadAttention(nn.Module):
def __init__(self):
super(MultiheadAttention, self).__init__()
# Q, K, V learnable matrices
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# Linear transformation for concatenated outputs
self.w_0 = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.shape[0]
q = self.w_q(q) # (B, L, d_model)
k = self.w_k(k) # (B, L, d_model)
v = self.w_v(v) # (B, L, d_model)
q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
q = q.transpose(1, 2) # (B, num_heads, L, d_k)
k = k.transpose(1, 2) # (B, num_heads, L, d_k)
v = v.transpose(1, 2) # (B, num_heads, L, d_k)
attn_values = self.self_attention(q, k, v, mask=mask) # (B, num_heads, L, d_k)
attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model) # (B, L, num_heads, d_k) => (B, L, d_model)
return self.w_0(attn_values)
def self_attention(self, q, k, v, mask=None):
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
# 만약 mask를 받게 된다면 (추가된 부분)
if mask is not None:
mask = mask.unsqueeze(1) # (B, 1, L, L) or (B, 1, 1, L)
attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)
attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)
return attn_values
multihead_attn = MultiheadAttention()
outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask) # (B, L, d_model)
8️⃣ (추가) Encoder-Decoder attention
Encoder와 Decoder는 Q, K, V를 주고받습니다. 정확히 말하면 K, V를 Encoder에서 가지고와서 Q의 Decoder와 합쳐 Self-Attention을 수행합니다. 즉, Query, key, value만 달라질 뿐 구현은 동일합니다.
Decoder에 들어갈 batch만 별도 구현하겠습니다.
# src_data는 위의 데이터 사용
trg_data = [
[33, 11, 49, 10],
[88, 34, 5, 29, 99, 45, 11, 25],
[67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
[16, 58, 91, 47, 12, 5, 8],
[71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]
trg_data, trg_max_len = padding(trg_data)
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch # (B, S_L)
trg_batch = torch.LongTensor(trg_data) # (B, T_L)
>>> print(src_batch.shape)
torch.Size([5, 10])
>>> print(trg_batch.shape)
torch.Size([5, 12])
src_emb = embedding(src_batch) # (B, S_L, d_w)
trg_emb = embedding(trg_batch) # (B, T_L, d_w)
>>> print(src_emb.shape)
torch.Size([5, 10, 8])
>>> print(trg_emb.shape)
torch.Size([5, 12, 8])
src_emb를 encoder에서 나온 결과, 그리고 trg_emb를 masked multi-head attention 후 결과로 가정합니다.
q = w_q(trg_emb) # (B, T_L, d_model) # query만 target embedding vector를 넣음
k = w_k(src_emb) # (B, S_L, d_model)
v = w_v(src_emb) # (B, S_L, d_model)
batch_size = q.shape[0]
d_k = d_model // num_heads
q = q.view(batch_size, -1, num_heads, d_k) # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)
q = q.transpose(1, 2) # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2) # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2) # (B, num_heads, S_L, d_k)
>>> print(q.shape)
torch.Size([5, 2, 12, 4])
>>> print(k.shape)
torch.Size([5, 2, 10, 4])
>>> print(v.shape)
torch.Size([5, 2, 10, 4])
attn_scores의 shape은 결국 target length size를 따라갑니다.
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, T_L, S_L)
>>> print(attn_dists.shape)
torch.Size([5, 2, 12, 10])
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, T_L, d_k)
>>> print(attn_values.shape)
torch.Size([5, 2, 12, 4])
Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.
'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글
[20-3] KoELECTRA (수정) (0) | 2021.02.20 |
---|---|
[20-1] HuggingFace's Transformers - BERT (0) | 2021.02.20 |
[19-2] Multi-head Attention Using PyTorch (0) | 2021.02.19 |
[19-1] Byte Pair Encoding with Python (0) | 2021.02.19 |
[18-3] fairseq 사용하기 (0) | 2021.02.18 |