또르르's 개발 Story

[19-3] Masked Multi-head Attention Using PyTorch 본문

부스트캠프 AI 테크 U stage/실습

[19-3] Masked Multi-head Attention Using PyTorch

또르르21 2021. 2. 19. 03:22

1️⃣ 설정

 

필요한 모듈을 import 합니다.

from torch import nn

from torch.nn import functional as F

from tqdm import tqdm


import torch

import math

 

2️⃣ 데이터 전처리

 

데이터를 생성합니다.

pad_id = 0

vocab_size = 100


data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

padding을 수행합니다.

def padding(data):

  max_len = len(max(data, key=len))
  
  print(f"Maximum sequence length: {max_len}")
  

  for i, seq in enumerate(tqdm(data)):
  
    if len(seq) < max_len:
    
      data[i] = seq + [pad_id] * (max_len - len(seq))
      

  return data, max_len
data, max_len = padding(data)
>>> data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

 

 

3️⃣ Hyperparameter 세팅 및 embedding

 

Model의 hidden state size 과 head의 개수를 정해줍니다.

d_model = 8  # model의 hidden size

num_heads = 2  # head의 개수

inf = 1e12

vocab_size로 입력을 받아 d_model size로 출력하는 embedding을 만들어줍니다.

embedding = nn.Embedding(vocab_size, d_model)


# B: batch size, L: maximum sequence length

batch = torch.LongTensor(data)  # (B, L)

batch_emb = embedding(batch)  # (B, L, d_model)
>>> print(batch_emb.shape)

torch.Size([5, 10, 8])

 

4️⃣ Mask 구축

 

Mask는 아직 예측하지 않은 Word에 대해 차단을 하는 방법입니다.

아래 그림과 같이 Self-Attention을 수행하면서 빗금 칠해진 곳을 masking해주어야 합니다.

 

https://www.edwith.org/bcaitech1

 

만약, 4,5번째 word가 <PAD>라면 4,5도 같이 masking 합니다.

 

https://www.edwith.org/bcaitech1

 

코드로 작성하면 True attention이 적용될 부분, Falsemasking될 자리입니다.

>>> print(batch)

tensor([[62, 13, 47, 39, 78, 33, 56, 13,  0,  0],
        [60, 96, 51, 32, 90,  0,  0,  0,  0,  0],
        [35, 45, 48, 65, 91, 99, 92, 10,  3, 21],
        [66, 88, 98, 47,  0,  0,  0,  0,  0,  0],
        [77, 65, 51, 77, 19, 15, 35, 19, 23,  0]])


>>> print(pad_id)

0

pad_id=0이기 때문에 batch에서 값이 있는 부분은 True, 0인 부분은 False가 됩니다.

unsqueeze를 통해 가운데에 1을 넣어 input size를 맞춥니다.

padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L)

# batch에서 0이 아닌 부분만 True, 0인 부분은 False

# True는 attention을 허가한 부분, False는 masking 부분


>>> print(padding_mask)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
        
        
>>> print(padding_mask.shape)

torch.Size([5, 1, 10])

torch.tril을 사용하면 삼각형 반쪽 아래를 True로 만들어줍니다.

nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)

nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L)

# triangle low (tril)를 사용하면 삼각형 반쪽 아래를 True로 만들어줌


>>> print(nopeak_mask)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
         
         
>>> print(nopeak_mask.shape)

torch.Size([1, 10, 10])

따라서 torch.tril을 사용한 nopeak_mask와 padding_maxk와 &를 사용해 둘이 True인 부분만 True로 만듭니다.

mask = padding_mask & nopeak_mask  # (B, 1, L) (B, L, L)

# 둘이 함께 True인 부분만 True

# padding mask가 row로 element-wise 되고, nopeak_mask로 &연산을 수행함


>>> print(mask)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
         
         
>>> print(mask.shape)

torch.Size([5, 10, 10])

 

 

5️⃣ Linear transformation & 여러 head로 나누기

 

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

w_q = nn.Linear(d_model, d_model)

w_k = nn.Linear(d_model, d_model)

w_v = nn.Linear(d_model, d_model)


w_0 = nn.Linear(d_model, d_model)

Q, k, v를 num_head개의 차원 분할된 여러 vector로 만듭니다.

q = w_q(batch_emb)  # (B, L, d_model)

k = w_k(batch_emb)  # (B, L, d_model)

v = w_v(batch_emb)  # (B, L, d_model)


batch_size = q.shape[0]

d_k = d_model // num_heads


q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)


q = q.transpose(1, 2)  # (B, num_heads, L, d_k)

k = k.transpose(1, 2)  # (B, num_heads, L, d_k)

v = v.transpose(1, 2)  # (B, num_heads, L, d_k)


>>> print(q.shape)

torch.Size([5, 2, 10, 4])


>>> print(k.shape)

torch.Size([5, 2, 10, 4])


>>> print(v.shape)

torch.Size([5, 2, 10, 4])

 

 

6️⃣ Masking이 적용된 self-attention 구현

 

각 head에서 실행되는 self-attetion 과정입니다.

attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

각 scores 값이 나온 후 아직 예측되지 않은 word에 대한 scores 값을 -inf로 만들어야합니다(-inf로 보내 softmax에서 확실하게 0에 수렴할 수 있게). masked_fill_ 함수를 사용해서 False인 자리는 -inf로 만들어 버립니다.

masks = mask.unsqueeze(1)  # (B, 1, L, L)

# (B, num_heads, L, L) 와 (B, 1, L, L)에서 num_heads는 element-wise


# masked_fill_ 함수를 사용해서 masks == false인 자리는 -무한대로 처리

# 무한대로 하는 이유는 softmax를 통과하면서 매우 작은 값이기 때문에 거의 0으로 수렴

masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # (B, num_heads, L, L)


>>>print(masked_attn_scores)

tensor([[[[ 1.3143e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 8.9403e-01,  8.6426e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 5.1270e-01,  3.3004e-01, -1.2025e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.5980e-01,  1.5570e-01,  3.0258e-02,  1.6243e-02, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.4218e-01, -2.0450e-01, -2.8070e-01, -1.8888e-01,  7.9033e-02,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.4986e-01, -1.7202e-01, -2.0935e-01, -3.9877e-01,  9.3217e-02,
           -1.0008e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-8.9391e-01, -6.9801e-01, -3.8013e-01, -3.2609e-01,  2.0198e-01,
            1.5171e-01,  4.4246e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 8.9403e-01,  8.6426e-01,  1.5397e-01,  2.4023e-01, -2.0241e-01,
           -3.3018e-01, -4.1745e-01,  8.6426e-01, -1.0000e+12, -1.0000e+12],
          [-6.1210e-01, -4.9643e-01, -1.7917e-01, -8.4149e-02,  1.2341e-01,
            1.8224e-01,  1.4375e-01, -4.9643e-01, -1.0000e+12, -1.0000e+12],
          [-6.1210e-01, -4.9643e-01, -1.7917e-01, -8.4149e-02,  1.2341e-01,
            1.8224e-01,  1.4375e-01, -4.9643e-01, -1.0000e+12, -1.0000e+12]],

         ...


        [[[-3.4352e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-7.4922e-02, -1.1854e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-4.5488e-01, -2.8107e-01, -7.6803e-02, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.4352e-01, -2.5569e-01, -5.3718e-02, -3.4352e-01, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.6794e-01,  1.2540e-01,  9.9537e-03,  1.6794e-01, -3.4771e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-8.6918e-02, -2.4528e-02,  1.9679e-03, -8.6918e-02, -1.1026e-01,
           -9.3543e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.9392e-02,  4.2141e-02,  1.5314e-02,  2.9392e-02, -1.6987e-01,
           -8.8894e-02, -1.3360e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.6794e-01,  1.2540e-01,  9.9537e-03,  1.6794e-01, -3.4771e-01,
           -2.3102e-01, -1.7695e-02, -3.4771e-01, -1.0000e+12, -1.0000e+12],
          [ 1.3523e-01,  8.7783e-02,  7.0304e-02,  1.3523e-01,  2.0379e-03,
            2.2378e-01, -2.3160e-02,  2.0379e-03,  5.1986e-01, -1.0000e+12],
          [-3.3574e-01, -2.3610e-01, -7.8838e-02, -3.3574e-01,  2.7018e-01,
           -4.6749e-02, -2.8217e-01,  2.7018e-01, -4.6545e-01, -1.0000e+12]],

         [[-3.6017e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-5.0642e-01,  2.1329e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 5.9064e-02,  9.1698e-02,  2.9015e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.6017e-02, -1.5189e-01, -1.7008e-01, -3.6017e-02, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 3.0850e-01, -5.6023e-01, -5.6363e-01,  3.0850e-01, -8.2863e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.3936e-01, -2.2218e-01, -7.3533e-02,  1.3936e-01, -9.0720e-02,
           -5.8827e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 2.6230e-02, -5.8277e-01, -3.1651e-01,  2.6230e-02, -1.2308e+00,
           -5.0647e-01, -2.7437e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 3.0850e-01, -5.6023e-01, -5.6363e-01,  3.0850e-01, -8.2863e-01,
           -6.3746e-01, -5.0875e-01, -8.2863e-01, -1.0000e+12, -1.0000e+12],
          [-3.0797e-01,  3.7385e-02, -4.1402e-01, -3.0797e-01, -1.2205e+00,
           -7.0539e-01, -4.6010e-01, -1.2205e+00, -2.5901e-01, -1.0000e+12],
          [-2.8654e-01,  3.7312e-01,  1.9621e-01, -2.8654e-01,  1.3229e-01,
            1.3978e-01,  2.8992e-01,  1.3229e-01, -1.6794e-01, -1.0000e+12]]]],
       grad_fn=<MaskedFillBackward0>)


>>> print(masked_attn_scores.shape)

torch.Size([5, 2, 10, 10])

-1* inf로 masking된 부분은 softmax 후 0이 됩니다.

attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)


>>> print(attn_dists)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5074, 0.4926, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.4230, 0.3524, 0.2246, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2673, 0.2662, 0.2348, 0.2316, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1695, 0.1945, 0.1802, 0.1976, 0.2583, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1526, 0.1649, 0.1589, 0.1315, 0.2150, 0.1772, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0654, 0.0795, 0.1093, 0.1154, 0.1956, 0.1860, 0.2488, 0.0000,
           0.0000, 0.0000],
          [0.2068, 0.2007, 0.0987, 0.1075, 0.0691, 0.0608, 0.0557, 0.2007,
           0.0000, 0.0000],
          [0.0775, 0.0869, 0.1194, 0.1313, 0.1616, 0.1714, 0.1649, 0.0869,
           0.0000, 0.0000],
          [0.0775, 0.0869, 0.1194, 0.1313, 0.1616, 0.1714, 0.1649, 0.0869,
           0.0000, 0.0000]],

         ...


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5109, 0.4891, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2740, 0.3260, 0.3999, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2258, 0.2466, 0.3018, 0.2258, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2268, 0.2173, 0.1936, 0.2268, 0.1354, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1632, 0.1737, 0.1784, 0.1632, 0.1594, 0.1621, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1525, 0.1545, 0.1504, 0.1525, 0.1250, 0.1355, 0.1296, 0.0000,
           0.0000, 0.0000],
          [0.1537, 0.1473, 0.1312, 0.1537, 0.0918, 0.1031, 0.1276, 0.0918,
           0.0000, 0.0000],
          [0.1104, 0.1053, 0.1035, 0.1104, 0.0966, 0.1206, 0.0942, 0.0966,
           0.1622, 0.0000],
          [0.0882, 0.0975, 0.1141, 0.0882, 0.1617, 0.1178, 0.0931, 0.1617,
           0.0775, 0.0000]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3275, 0.6725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3037, 0.3137, 0.3826, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2656, 0.2365, 0.2323, 0.2656, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3166, 0.1328, 0.1324, 0.3166, 0.1016, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1953, 0.1361, 0.1579, 0.1953, 0.1552, 0.1602, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2055, 0.1118, 0.1459, 0.2055, 0.0585, 0.1206, 0.1522, 0.0000,
           0.0000, 0.0000],
          [0.2321, 0.0974, 0.0970, 0.2321, 0.0744, 0.0901, 0.1025, 0.0744,
           0.0000, 0.0000],
          [0.1299, 0.1835, 0.1169, 0.1299, 0.0522, 0.0873, 0.1116, 0.0522,
           0.1365, 0.0000],
          [0.0767, 0.1484, 0.1243, 0.0767, 0.1166, 0.1175, 0.1366, 0.1166,
           0.0864, 0.0000]]]], grad_fn=<SoftmaxBackward>)



>>> print(attn_dists.shape)

torch.Size([5, 2, 10, 10])

value vector를 곱해 self-attention을 완성합니다.

attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)


>>> print(attn_values.shape)

torch.Size([5, 2, 10, 4])

 

7️⃣ 전체코드

 

위의 과정을 모두 합쳐 하나의 Masked Multi-head attention 모듈을 구현하겠습니다.

class MultiheadAttention(nn.Module):

  def __init__(self):
  
    super(MultiheadAttention, self).__init__()
    

    # Q, K, V learnable matrices
    
    self.w_q = nn.Linear(d_model, d_model)
    
    self.w_k = nn.Linear(d_model, d_model)
    
    self.w_v = nn.Linear(d_model, d_model)
    

    # Linear transformation for concatenated outputs
    
    self.w_0 = nn.Linear(d_model, d_model)
    

  def forward(self, q, k, v, mask=None):
  
    batch_size = q.shape[0]


    q = self.w_q(q)  # (B, L, d_model)
    
    k = self.w_k(k)  # (B, L, d_model)
    
    v = self.w_v(v)  # (B, L, d_model)
    

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)
    

    attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, L, d_k)
    
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)
    

    return self.w_0(attn_values)
    

  def self_attention(self, q, k, v, mask=None):
  
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)


    # 만약 mask를 받게 된다면 (추가된 부분)
    
    if mask is not None:
    
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
      
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)
      

    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)
    

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)
    

    return attn_values
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, L, d_model)

 

 

8️⃣ (추가) Encoder-Decoder attention 

 

Encoder와 Decoder는 Q, K, V를 주고받습니다. 정확히 말하면 K, V를 Encoder에서 가지고와서 Q의 Decoder와 합쳐 Self-Attention을 수행합니다. 즉, Query, key, value만 달라질 뿐 구현은 동일합니다.
Decoder에 들어갈 batch만 별도 구현하겠습니다.

# src_data는 위의 데이터 사용

trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)
# S_L: source maximum sequence length, T_L: target maximum sequence length

src_batch = batch  # (B, S_L)

trg_batch = torch.LongTensor(trg_data)  # (B, T_L)


>>> print(src_batch.shape)

torch.Size([5, 10])


>>> print(trg_batch.shape)

torch.Size([5, 12])
src_emb = embedding(src_batch)  # (B, S_L, d_w)

trg_emb = embedding(trg_batch)  # (B, T_L, d_w)


>>> print(src_emb.shape)

torch.Size([5, 10, 8])


>>> print(trg_emb.shape)

torch.Size([5, 12, 8])

src_emb를 encoder에서 나온 결과, 그리고 trg_emb를 masked multi-head attention 후 결과로 가정합니다.

q = w_q(trg_emb)  # (B, T_L, d_model)   # query만 target embedding vector를 넣음

k = w_k(src_emb)  # (B, S_L, d_model)

v = w_v(src_emb)  # (B, S_L, d_model)


batch_size = q.shape[0]

d_k = d_model // num_heads


q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)

k = k.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)

v = v.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)


q = q.transpose(1, 2)  # (B, num_heads, T_L, d_k)

k = k.transpose(1, 2)  # (B, num_heads, S_L, d_k)

v = v.transpose(1, 2)  # (B, num_heads, S_L, d_k)


>>> print(q.shape)

torch.Size([5, 2, 12, 4])


>>> print(k.shape)

torch.Size([5, 2, 10, 4])


>>> print(v.shape)

torch.Size([5, 2, 10, 4])

attn_scores의 shape은 결국 target length size를 따라갑니다.

attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, T_L, S_L)

attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T_L, S_L)


>>> print(attn_dists.shape)

torch.Size([5, 2, 12, 10])
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, T_L, d_k)


>>> print(attn_values.shape)

torch.Size([5, 2, 12, 4])

Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.

Comments