또르르's 개발 Story

[18-2] Seq2Seq with Attention using PyTorch 본문

부스트캠프 AI 테크 U stage/실습

[18-2] Seq2Seq with Attention using PyTorch

또르르21 2021. 2. 18. 02:42

1️⃣ 설정

 

필요한 모듈을 import 합니다.

from tqdm import tqdm

from torch import nn

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


import torch

import random

 

2️⃣ 데이터 전처리

 

src_data trg_data로 바꾸는 task를 수행하기 위한 sample data입니다.
전체 단어 수는 100개이고 다음과 같이 pad token, start token, end token의 id도 정의합니다.

src_data trg_data는 서로 대응되게 만들어져 있습니다.

vocab_size = 100

pad_id = 0

sos_id = 1    # start token

eos_id = 2    # end token


src_data = [
  [3, 77, 56, 26, 3, 55, 12, 36, 31],
  [58, 20, 65, 46, 26, 10, 76, 44],
  [58, 17, 8],
  [59],
  [29, 3, 52, 74, 73, 51, 39, 75, 19],
  [41, 55, 77, 21, 52, 92, 97, 69, 54, 14, 93],
  [39, 47, 96, 68, 55, 16, 90, 45, 89, 84, 19, 22, 32, 99, 5],
  [75, 34, 17, 3, 86, 88],
  [63, 39, 5, 35, 67, 56, 68, 89, 55, 66],
  [12, 40, 69, 39, 49]
]

trg_data = [
  [75, 13, 22, 77, 89, 21, 13, 86, 95],
  [79, 14, 91, 41, 32, 79, 88, 34, 8, 68, 32, 77, 58, 7, 9, 87],
  [85, 8, 50, 30],
  [47, 30],
  [8, 85, 87, 77, 47, 21, 23, 98, 83, 4, 47, 97, 40, 43, 70, 8, 65, 71, 69, 88],
  [32, 37, 31, 77, 38, 93, 45, 74, 47, 54, 31, 18],
  [37, 14, 49, 24, 93, 37, 54, 51, 39, 84],
  [16, 98, 68, 57, 55, 46, 66, 85, 18],
  [20, 70, 14, 6, 58, 90, 30, 17, 91, 18, 90],
  [37, 93, 98, 13, 45, 28, 89, 72, 70]
]

각각의 데이터를 전처리합니다.

trg_data = [[sos_id]+seq+[eos_id] for seq in tqdm(trg_data)]

Padding 처리를 해주면서 padding 전 길이도 저장합니다.

def padding(data, is_src=True):   # padding 구하기

  max_len = len(max(data, key=len))
  
  print(f"Maximum sequence length: {max_len}")
  

  valid_lens = []
  
  for i, seq in enumerate(tqdm(data)):
  
    valid_lens.append(len(seq))
    
    if len(seq) < max_len:
    
      data[i] = seq + [pad_id] * (max_len - len(seq))
      

  return data, valid_lens, max_len
>>> src_data, src_lens, src_max_len = padding(src_data)

Maximum sequence length: 15


>>> trg_data, trg_lens, trg_max_len = padding(trg_data)

Maximum sequence length: 22

B: batch size, S_L: source maximum sequence length, T_L: target maximum sequence length를 나타냅니다.

src_batch = torch.LongTensor(src_data)  # (B, S_L)

src_batch_lens = torch.LongTensor(src_lens)  # (B)

trg_batch = torch.LongTensor(trg_data)  # (B, T_L)

trg_batch_lens = torch.LongTensor(trg_lens)  # (B)


>>> print(src_batch.shape)

torch.Size([10, 15])


>>> print(src_batch_lens.shape)

torch.Size([10])


>>> print(trg_batch.shape)

torch.Size([10, 22])


>>> print(trg_batch_lens.shape)

torch.Size([10])

PackedSquence를 사용을 위해 source data를 기준으로 정렬합니다.

src_batch_lens, sorted_idx = src_batch_lens.sort(descending=True)

src_batch = src_batch[sorted_idx]

trg_batch = trg_batch[sorted_idx]

trg_batch_lens = trg_batch_lens[sorted_idx]


print(src_batch)

print(src_batch_lens)

print(trg_batch)

print(trg_batch_lens)
tensor([[39, 47, 96, 68, 55, 16, 90, 45, 89, 84, 19, 22, 32, 99,  5],
        [41, 55, 77, 21, 52, 92, 97, 69, 54, 14, 93,  0,  0,  0,  0],
        [63, 39,  5, 35, 67, 56, 68, 89, 55, 66,  0,  0,  0,  0,  0],
        [ 3, 77, 56, 26,  3, 55, 12, 36, 31,  0,  0,  0,  0,  0,  0],
        [29,  3, 52, 74, 73, 51, 39, 75, 19,  0,  0,  0,  0,  0,  0],
        [58, 20, 65, 46, 26, 10, 76, 44,  0,  0,  0,  0,  0,  0,  0],
        [75, 34, 17,  3, 86, 88,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [12, 40, 69, 39, 49,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [58, 17,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [59,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
        
tensor([15, 11, 10,  9,  9,  8,  6,  5,  3,  1])

tensor([[ 1, 37, 14, 49, 24, 93, 37, 54, 51, 39, 84,  2,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 32, 37, 31, 77, 38, 93, 45, 74, 47, 54, 31, 18,  2,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 20, 70, 14,  6, 58, 90, 30, 17, 91, 18, 90,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 75, 13, 22, 77, 89, 21, 13, 86, 95,  2,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1,  8, 85, 87, 77, 47, 21, 23, 98, 83,  4, 47, 97, 40, 43, 70,  8, 65,
         71, 69, 88,  2],
        [ 1, 79, 14, 91, 41, 32, 79, 88, 34,  8, 68, 32, 77, 58,  7,  9, 87,  2,
          0,  0,  0,  0],
        [ 1, 16, 98, 68, 57, 55, 46, 66, 85, 18,  2,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 37, 93, 98, 13, 45, 28, 89, 72, 70,  2,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 85,  8, 50, 30,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 1, 47, 30,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0]])
          
tensor([12, 14, 13, 11, 22, 18, 11, 11,  6,  4])

 

 

3️⃣ Encoder 구현

 

Bidirectional GRU를 이용한 Encoder입니다.

  • self.embedding: word embedding layer.
  • self.gru: encoder 역할을 하는 Bi-GRU.
  • self.linear: 양/단방향 concat된 hidden state를 decoder의 hidden size에 맞게 linear transformation.
                   (양방향 hidden state와 단방향 decoder의 크기를 맞춰주기 위해)

encoder는 어짜피 입력데이터들이기 때문에 모든 word를 알고 있습니다. 따라서 for-loop를 사용해서 output을 뽑아내도 됩니다.

embedding_size = 256

hidden_size = 512

num_layers = 2

num_dirs = 2

dropout = 0.1
class Encoder(nn.Module):

  def __init__(self):
  
    super(Encoder, self).__init__()
    

    self.embedding = nn.Embedding(vocab_size, embedding_size)
    
    self.gru = nn.GRU(
    
        input_size=embedding_size, 
        
        hidden_size=hidden_size,
        
        num_layers=num_layers,
        
        bidirectional=True if num_dirs > 1 else False,
        
        dropout=dropout
        
    )
    
    self.linear = nn.Linear(num_dirs * hidden_size, hidden_size)    
    

  def forward(self, batch, batch_lens):  # batch: (B, S_L), batch_lens: (B)
  
    # d_w: word embedding size
    
    batch_emb = self.embedding(batch)  # (B, S_L, d_w)
    
    batch_emb = batch_emb.transpose(0, 1)  # (S_L, B, d_w)
    

    packed_input = pack_padded_sequence(batch_emb, batch_lens)
    

    h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size))  # (num_layers*num_dirs, B, d_h) = (4, B, d_h)
    
    packed_outputs, h_n = self.gru(packed_input, h_0)  # h_n: (4, B, d_h)
    
    # encoder의 output도 attention 모델에서는 사용하기 때문에 size를 줄여줌
    
    outputs = pad_packed_sequence(packed_outputs)[0]  # outputs: (S_L, B, 2d_h)
    
    outputs = torch.tanh(self.linear(outputs))  # (S_L, B, d_h)
    

    forward_hidden = h_n[-2, :, :]
    
    backward_hidden = h_n[-1, :, :]
    
    # forward_hidden과 backward_hidden을 concat시켜(양방향) liner 통과
    
    hidden = self.linear(torch.cat((forward_hidden, backward_hidden), dim=-1)).unsqueeze(0)  # (1, B, d_h)
    

    return outputs, hidden

 

 

4️⃣ Dot-product Attention 구현

 

우선 대표적인 attention 형태 중 하나인 Dot-product Attention은 다음과 같이 구현할 수 있습니다.

Decoder에 add-on module로써 추가를 해줄 것입니다.

 

TIP : 행렬곱을 할 때 3차원 이상의 텐서(AxBxC, AxCxD) 사이의 행렬 곱은 행렬의 차원을 나타내는 마지막 두 개의 차원(BxC,CxD) 사이에 행렬 곱이 가능하고 행렬 앞에 있는 차원(A)이 같으면 행렬곱이 가능합니다.)

 

class DotAttention(nn.Module):

  def __init__(self):
  
    super().__init__()
    

  def forward(self, decoder_hidden, encoder_outputs):  # (1, B, d_h), (S_L, B, d_h)
  
    query = decoder_hidden.squeeze(0)  # (B, d_h)
    
    key = encoder_outputs.transpose(0, 1)  # (B, S_L, d_h)
    

    # encoder_hidden states의 길이만큼 반복적으로 각 차원이 곱해짐
    
    # unsqueeze - 특정 위치에 1인 차원을 추가 
    
    # torch.mul()은 텐서 요소간의 곱(element-wise 방법)
    
    # dim=-1은 마지막 차원을 더하거나 softmax시킴 (즉, 마지막 차원 제거)
    
    energy = torch.sum(torch.mul(key, query.unsqueeze(1)), dim=-1)  # (B, S_L)
  

    
    attn_scores = F.softmax(energy, dim=-1)  # (B, S_L)
    
    # encoder hidden states의 가중치로 사용하기 위해 mul를 다시 해줌
    
    # context vector
    
    attn_values = torch.sum(torch.mul(encoder_outputs.transpose(0, 1), attn_scores.unsqueeze(2)), dim=1)  # (B, d_h)


    return attn_values, attn_scores
dot_attn = DotAttention()

 

5️⃣ Decoder 구현

 

동일한 설정의 Bi-GRU로 만든 Decoder입니다.

  • self.embedding: word embedding layer.
  • self.gru: decoder 역할을 하는 Bi-GRU.
  • self.output_layer: decoder에서 나온 hidden state를 vocab_size로 linear transformation하는 layer.

여기서 전체 word를 돌릴 for-loop가 없는 이유는 아래 코드에서

batch_emb = batch_emb.unsqueeze(0) # (1, B, d_w)

unsqueeze를 통해서 1개의 word를 output 시키고 다시 그 output을 입력으로 받는 방법을 사용할 것이기 때문입니다.

 

class Decoder(nn.Module):

  def __init__(self, attention):
  
    super().__init__()
    

    self.embedding = nn.Embedding(vocab_size, embedding_size)
    
    self.attention = attention      # attention 추가
    
    self.rnn = nn.GRU(
    
        embedding_size,
        
        hidden_size
        
    )
    
    # concat을 했기 때문에 2*hidden_size
    
    self.output_linear = nn.Linear(2*hidden_size, vocab_size)
    

  def forward(self, batch, encoder_outputs, hidden):  # batch: (B), encoder_outputs: (L, B, d_h), hidden: (1, B, d_h)
  
    batch_emb = self.embedding(batch)  # (B, d_w)
    
    batch_emb = batch_emb.unsqueeze(0)  # (1, B, d_w)
    
    # 여기서 1은 token 하나 (길이 1짜리) decoder로 가정
    

    outputs, hidden = self.rnn(batch_emb, hidden)  # (1, B, d_h), (1, B, d_h)
    

    attn_values, attn_scores = self.attention(hidden, encoder_outputs)  # (B, d_h), (B, S_L)
    
    concat_outputs = torch.cat((outputs, attn_values.unsqueeze(0)), dim=-1)  # (1, B, 2d_h)
    

    return self.output_linear(concat_outputs).squeeze(0), hidden  # (B, V), (1, B, d_h)
decoder = Decoder(dot_attn)

 

 

6️⃣ Seq2seq 모델 구축

 

최종적으로 seq2seq 모델을 다음과 같이 구성할 수 있습니다.

class Seq2seq(nn.Module):

  def __init__(self, encoder, decoder):
  
    super(Seq2seq, self).__init__()
    

    self.encoder = encoder
    
    self.decoder = decoder    # attention이 포함된 decoder
    

  def forward(self, src_batch, src_batch_lens, trg_batch, teacher_forcing_prob=0.5):
  
    # src_batch: (B, S_L), src_batch_lens: (B), trg_batch: (B, T_L)
    

    encoder_outputs, hidden = self.encoder(src_batch, src_batch_lens)  # encoder_outputs: (S_L, B, d_h), hidden: (1, B, d_h)
    
    
    # decoder에서 word 하나씩 처리할 것이기 때문에 첫번째 word만 가지고 옴
    
    input_ids = trg_batch[:, 0]  # (B)
    
    batch_size = src_batch.shape[0]
    
    outputs = torch.zeros(trg_max_len, batch_size, vocab_size)  # (T_L, B, V)
    

    for t in range(1, trg_max_len):
    
      decoder_outputs, hidden = self.decoder(input_ids, encoder_outputs, hidden)  # decoder_outputs: (B, V), hidden: (1, B, d_h)
      

      outputs[t] = decoder_outputs
      
      # top_ids는 decoder_outputs에서 가장 확률이 높은 word
      
      _, top_ids = torch.max(decoder_outputs, dim=-1)  # top_ids: (B)
      

      # 다시 input id로 넣어주는데 teachar_forcing 0.5확률로 input을 넣어줄 수도 있고, 아니면 ground_truth t번째 것을 넣어줄 수 있음
      
      input_ids = trg_batch[:, t] if random.random() > teacher_forcing_prob else top_ids
      

    return outputs
seq2seq = Seq2seq(encoder, decoder)

 

7️⃣ 모델 Test

 

# V: vocab size

outputs = seq2seq(src_batch, src_batch_lens, trg_batch)  # (T_L, B, V)


print(outputs)

print(outputs.shape)
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1102, -0.1184,  0.0267,  ..., -0.0186,  0.0936, -0.0589],
         [ 0.0775, -0.0235,  0.0401,  ..., -0.0237,  0.0530, -0.0576],
         [ 0.0919, -0.1055,  0.0085,  ..., -0.0195,  0.0922, -0.0206],
         ...,
         [ 0.0965, -0.0693, -0.0040,  ..., -0.0278,  0.0924, -0.0321],
         [ 0.0926, -0.0921,  0.0423,  ...,  0.0040,  0.0582, -0.0340],
         [ 0.0769, -0.0896,  0.0229,  ..., -0.0034,  0.0633, -0.0475]],

        [[ 0.1775, -0.1280, -0.0299,  ..., -0.0470,  0.0212, -0.0806],
         [-0.0566,  0.0814,  0.0715,  ...,  0.0031, -0.0425,  0.0399],
         [ 0.1631, -0.0304,  0.0621,  ...,  0.1080,  0.1065,  0.0379],
         ...,
         [ 0.1582, -0.0817, -0.0589,  ..., -0.0476,  0.0338, -0.0748],
         [ 0.1595, -0.1026, -0.0254,  ..., -0.0249,  0.0073, -0.0718],
         [ 0.1589, -0.0085,  0.0703,  ...,  0.1145,  0.0778,  0.0194]],

        ...,

        [[ 0.1957,  0.0454, -0.0929,  ...,  0.1487,  0.0234,  0.0428],
         [ 0.0350,  0.2158, -0.0443,  ...,  0.0917, -0.0664,  0.0556],
         [ 0.1861,  0.1524, -0.1489,  ...,  0.0546, -0.0142, -0.0155],
         ...,
         [ 0.1819,  0.1884, -0.1474,  ...,  0.0351,  0.0087,  0.0487],
         [ 0.1099,  0.0982, -0.0731,  ...,  0.0425,  0.0153, -0.0314],
         [ 0.1251,  0.1598, -0.1116,  ...,  0.0967,  0.0296, -0.0261]],

        [[ 0.1983,  0.1059, -0.1248,  ...,  0.1256,  0.0124, -0.0031],
         [ 0.1025,  0.2343, -0.1056,  ...,  0.0965, -0.0357,  0.0140],
         [ 0.1937,  0.1628, -0.1677,  ...,  0.0824,  0.0074, -0.0281],
         ...,
         [ 0.1944,  0.2021, -0.1747,  ...,  0.0611,  0.0162,  0.0056],
         [ 0.1648,  0.1455, -0.1175,  ...,  0.0844,  0.0158, -0.0304],
         [ 0.1591,  0.1798, -0.1398,  ...,  0.0969,  0.0312, -0.0394]],

        [[ 0.0895,  0.1086, -0.0611,  ...,  0.1039, -0.0398, -0.0078],
         [ 0.0287,  0.2000, -0.0605,  ...,  0.0903, -0.0571,  0.0050],
         [ 0.0745,  0.1434, -0.0954,  ...,  0.0921, -0.0166, -0.0132],
         ...,
         [ 0.0783,  0.1716, -0.1127,  ...,  0.0687, -0.0083,  0.0030],
         [ 0.1946,  0.1788, -0.0976,  ...,  0.0123, -0.0285,  0.1161],
         [ 0.0659,  0.1608, -0.0835,  ...,  0.0928, -0.0128, -0.0223]]],
       grad_fn=<CopySlices>)
       
torch.Size([22, 10, 100])
sample_sent = [4, 10, 88, 46, 72, 34, 14, 51]

sample_len = len(sample_sent)


# torch tensor로 변환

sample_batch = torch.LongTensor(sample_sent).unsqueeze(0)  # (1, L)

sample_batch_len = torch.LongTensor([sample_len])  # (1)


encoder_output, hidden = seq2seq.encoder(sample_batch, sample_batch_len)  # hidden: (4, 1, d_h)
input_id = torch.LongTensor([sos_id]) # (1)

output = []


for t in range(1, trg_max_len):

  decoder_output, hidden = seq2seq.decoder(input_id, encoder_output, hidden)  # decoder_output: (1, V), hidden: (4, 1, d_h)
  

  _, top_id = torch.max(decoder_output, dim=-1)  # top_ids: (1)
  

  if top_id == eos_id:
  
    break
    
  else:
  
    output += top_id.tolist()
    
    input_id = top_id
>>> output

[41,
 35,
 35,
 74,
 35,
 74,
 35,
 74,
 91,
 35,
 74,
 91,
 11,
 75,
 20,
 11,
 11,
 75,
 20,
 11,
 71]

 

8️⃣ Loss

 

ground truth와 비교를 해서 loss를 구하는데 그냥 구하는 것이 아닌 shift한 값과 비교합니다.

Language Modeling에 대한 loss 계산을 위해 shift한 target과 비교합니다.

아래 그림과 같이 output 부분에서 부분 전까지 잘라줍니다.

 

https://www.edwith.org/bcaitech1

 

loss_function = nn.CrossEntropyLoss()

# outputs[:-1, :, :]에서 -1을 통해 뒤를 자름

# (그림 상에선 뒤의 한 token을 잘라주는 식으로 표현되어 있고, 개념상으로 이게 맞지만, 실제 구현된 코드에선 앞의 한 칸을 띄고

# 1번째 index부터 값을 채웠으므로 동일하게 앞의 한 token을 제거해주는 것이 맞습니다.)

preds = outputs[1:, :, :].transpose(0, 1)  # (B, T_L-1, V)

# trg_batch[:,1:]에서는 1을 통해 앞을 자름

loss = loss_function(preds.contiguous().view(-1, vocab_size), trg_batch[:,1:].contiguous().view(-1, 1).squeeze(1))


# outputs과 trg_batch 크기를 맞춰서 loss function 안에 넣어줌

print(loss)
tensor(4.5963, grad_fn=<NllLossBackward>)

 

 

9️⃣ Concat Attention 구현

 

Bahdanau Attention이라고도 불리는 Concat Attention을 구현합니다.

Concat Attention은 해당 시점의 decoder hidden vector와 전체 encoder hidden vector들과의 concat을 해서 score 계산합니다.

  • self.w: Concat한 query와 key 벡터를 1차적으로 linear transformation.
  • self.v: Attention logit 값을 계산.

 

1) Concat Attention 구현

class ConcatAttention(nn.Module):

  def __init__(self):
  
    super().__init__()
    

    self.w = nn.Linear(2*hidden_size, hidden_size, bias=False)
    
    self.v = nn.Linear(hidden_size, 1, bias=False)    # scalar 값으로 변경
    

  def forward(self, decoder_hidden, encoder_outputs):  # (1, B, d_h), (S_L, B, d_h)
  
    src_max_len = encoder_outputs.shape[0]
    

    # src_max_len만큼 반복을 해서 decoder_hidden과 encoder_output의 dimension을 맞춰줌
    
    decoder_hidden = decoder_hidden.transpose(0, 1).repeat(1, src_max_len, 1)  # (B, S_L, d_h)
    
    encoder_outputs = encoder_outputs.transpose(0, 1)  # (B, S_L, d_h)
    

    concat_hiddens = torch.cat((decoder_hidden, encoder_outputs), dim=2)  # (B, S_L, 2d_h)
    
    energy = torch.tanh(self.w(concat_hiddens))  # (B, S_L, d_h)
    

    # energy (B, S_L, d_h)
    
    # self.v(energy) (B, S_L, 1)
    
    # S_L에 대해 softmax를 취함(dim=1)
    
    attn_scores = F.softmax(self.v(energy), dim=1)  # (B, S_L, 1)
    
    # 1을 남겨준 이유는 가중화를 시켜주기 위하여
    
    # B x S_L x d_h, B x S_L x 1에서 1을 d_h로 element-wise 시켜줌
    
    # sum은 dim=1 즉, S_L을 sum해주므로 B,d_h로 변경
    
    attn_values = torch.sum(torch.mul(encoder_outputs, attn_scores), dim=1)  # (B, d_h)
    

    return attn_values, attn_scores
concat_attn = ConcatAttention()

 

2) Decoder 구현

class Decoder(nn.Module):

  def __init__(self, attention):
  
    super().__init__()
    

    self.embedding = nn.Embedding(vocab_size, embedding_size)
    
    self.attention = attention
    
    self.rnn = nn.GRU(
    
        embedding_size + hidden_size,   # 여기서 concat을 해줌
        
        hidden_size
        
    )
    self.output_linear = nn.Linear(hidden_size, vocab_size)
    

  def forward(self, batch, encoder_outputs, hidden):  # batch: (B), encoder_outputs: (S_L, B, d_h), hidden: (1, B, d_h) 
  
    batch_emb = self.embedding(batch)  # (B, d_w)
    
    batch_emb = batch_emb.unsqueeze(0)  # (1, B, d_w)
    

    attn_values, attn_scores = self.attention(hidden, encoder_outputs)  # (B, d_h), (B, S_L)
    

    # concat을 해줌
    
    concat_emb = torch.cat((batch_emb, attn_values.unsqueeze(0)), dim=-1)  # (1, B, d_w+d_h)
    

    outputs, hidden = self.rnn(concat_emb, hidden)  # (1, B, d_h), (1, B, d_h)
    

    return self.output_linear(outputs).squeeze(0), hidden  # (B, V), (1, B, d_h)
decoder = Decoder(concat_attn)

 

3) Seq2Seq 실행

seq2seq = Seq2seq(encoder, decoder)
outputs = seq2seq(src_batch, src_batch_lens, trg_batch)


>>> print(outputs)

>>> print(outputs.shape)
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.1489, -0.0496,  0.2811,  ..., -0.0903, -0.0500, -0.1643],
         [-0.1143, -0.0601,  0.2792,  ..., -0.1171, -0.0728, -0.1188],
         [-0.0769, -0.0391,  0.2122,  ..., -0.0673, -0.0473, -0.2142],
         ...,
         [-0.0999, -0.0731,  0.2399,  ..., -0.0651, -0.0747, -0.1747],
         [-0.1272, -0.0547,  0.2687,  ..., -0.1007, -0.0367, -0.1502],
         [-0.0899, -0.0811,  0.2492,  ..., -0.0773, -0.0597, -0.1484]],

        [[-0.1367, -0.1476,  0.1893,  ..., -0.0668, -0.0470, -0.0891],
         [-0.1294, -0.1665,  0.1874,  ..., -0.0639, -0.0714, -0.0625],
         [-0.0061,  0.1042,  0.2044,  ..., -0.1484, -0.1408, -0.1322],
         ...,
         [-0.0067,  0.0810,  0.2233,  ..., -0.1372, -0.1570, -0.1057],
         [-0.1265, -0.1640,  0.1746,  ..., -0.0511, -0.0538, -0.0954],
         [-0.1098, -0.1806,  0.1690,  ..., -0.0398, -0.0647, -0.0834]],

        ...,

        [[ 0.3878,  0.3853,  0.0732,  ..., -0.5018,  0.1130, -0.1214],
         [ 0.4493,  0.3520,  0.0675,  ..., -0.5227,  0.0758, -0.1446],
         [ 0.4425,  0.3792,  0.0747,  ..., -0.5376,  0.0669, -0.1572],
         ...,
         [ 0.4723,  0.3732,  0.0618,  ..., -0.5280,  0.0673, -0.1661],
         [ 0.4282,  0.3567,  0.0753,  ..., -0.4823,  0.1064, -0.1653],
         [ 0.4427,  0.3596,  0.0549,  ..., -0.4841,  0.0385, -0.1404]],

        [[ 0.1055,  0.2541,  0.0760,  ..., -0.1712,  0.3030, -0.0540],
         [ 0.1310,  0.2289,  0.0770,  ..., -0.1728,  0.2781, -0.0737],
         [ 0.1196,  0.2439,  0.0813,  ..., -0.1875,  0.2701, -0.0878],
         ...,
         [ 0.1520,  0.2354,  0.0696,  ..., -0.1754,  0.2685, -0.0893],
         [ 0.1343,  0.2288,  0.0753,  ..., -0.1453,  0.2877, -0.0858],
         [ 0.4400,  0.3245, -0.0218,  ..., -0.1835, -0.0818, -0.0065]],

        [[ 0.2581,  0.3152,  0.0921,  ..., -0.3583,  0.1990, -0.1037],
         [ 0.2785,  0.2875,  0.0875,  ..., -0.3471,  0.1793, -0.1231],
         [ 0.2638,  0.2985,  0.0935,  ..., -0.3590,  0.1689, -0.1299],
         ...,
         [ 0.2957,  0.2872,  0.0851,  ..., -0.3443,  0.1687, -0.1322],
         [ 0.2862,  0.2885,  0.0878,  ..., -0.3302,  0.1819, -0.1344],
         [ 0.4452,  0.3365,  0.0423,  ..., -0.3502, -0.0304, -0.0856]]],
       grad_fn=<CopySlices>)
       
torch.Size([22, 10, 100])
Comments