또르르's 개발 Story

[34-4] Hourglass Network using PyTorch 본문

부스트캠프 AI 테크 U stage/실습

[34-4] Hourglass Network using PyTorch

또르르21 2021. 3. 12. 03:26

Hourglass Network는Landmark Localization의 대표적인 Network입니다.

이 코드는 basic code이며, 전체 코드는 링크에서 볼 수 있습니다.

 

[Newell et al., ECCV 2016]

 

1️⃣ 설정

 

필요한 모듈을 import 합니다.

# Seed

import torch

import numpy as np

import random


torch.manual_seed(0)

torch.cuda.manual_seed(0)

np.random.seed(0)

random.seed(0)


# Ignore warnings

import warnings

warnings.filterwarnings('ignore')

 

2️⃣ Hourglass 모듈

 

Hourglass Network는 Hourglass 모듈의 stack으로 구성되어 있고, Hourglass 모듈을 residual block을 기본 convolution block으로 사용합니다.

 

1) Resiual block

import torch

import torch.nn as nn

import torch.nn.functional as F


class ResidualBlock(nn.Module):

  def __init__(self, num_channels=256):
  
    super(ResidualBlock, self).__init__()
    

    self.bn1 = nn.BatchNorm2d(num_channels)
    
    self.conv1 = nn.Conv2d(num_channels, num_channels//2, kernel_size=1, bias=True)
    

    self.bn2 = nn.BatchNorm2d(num_channels//2)
    
    self.conv2 = nn.Conv2d(num_channels//2, num_channels//2, kernel_size=3, stride=1, padding=1, bias=True)
    

    self.bn3 = nn.BatchNorm2d(num_channels//2)
    
    self.conv3 = nn.Conv2d(num_channels//2, num_channels, kernel_size=1, bias=True)
    

    self.relu = nn.ReLU(inplace=True)
    

  def forward(self, x):
  
    residual = x
    

    out = self.bn1(x)
    
    out = self.relu(out)
    
    out = self.conv1(out)
    

    out = self.bn2(out)
    
    out = self.relu(out)
    
    out = self.conv2(out)
    

    out = self.bn3(out)
    
    out = self.relu(out)
    
    out = self.conv3(out)
    

    out += residual
    

    return out

 

2) Hourglass module

class Hourglass(nn.Module):

  def __init__(self, block, num_channels=256):
  
    super(Hourglass, self).__init__()
    

    self.downconv_1 = block(num_channels)
    
    self.pool_1 = nn.MaxPool2d(kernel_size=2)
    
    self.downconv_2 = block(num_channels)
    
    self.pool_2 = nn.MaxPool2d(kernel_size=2)
    
    self.downconv_3 = block(num_channels)
    
    self.pool_3 = nn.MaxPool2d(kernel_size=2)
    
    self.downconv_4 = block(num_channels)
    
    self.pool_4 = nn.MaxPool2d(kernel_size=2)
    

    self.midconv_1 = block(num_channels)
    
    self.midconv_2 = block(num_channels)
    
    self.midconv_3 = block(num_channels)
    
    
    self.skipconv_1 = block(num_channels)
    
    self.skipconv_2 = block(num_channels)
    
    self.skipconv_3 = block(num_channels)
    
    self.skipconv_4 = block(num_channels)
    

    self.upconv_1 = block(num_channels)
    
    self.upconv_2 = block(num_channels)
    
    self.upconv_3 = block(num_channels)
    
    self.upconv_4 = block(num_channels)


  def forward(self, x):
  
    x1 = self.downconv_1(x)
    
    x  = self.pool_1(x1)
    
    
    x2 = self.downconv_2(x)
    
    x  = self.pool_2(x2)
    
    
    x3 = self.downconv_3(x)
    
    x  = self.pool_3(x3)
    
    
    x4 = self.downconv_4(x)
    
    x  = self.pool_4(x4)
    

    x = self.midconv_1(x)
    
    x = self.midconv_2(x)
    
    x = self.midconv_3(x)


    x4 = self.skipconv_1(x4)
    
    x = F.upsample(x, scale_factor=2)
    
    x = x + x4 # element wise addition
    
    x = self.upconv_1(x)
    

    x3 = self.skipconv_2(x3)
    
    x = F.upsample(x, scale_factor=2)
    
    x = x + x3 # element wise addition
    
    x = self.upconv_2(x)
    

    x2 = self.skipconv_3(x2)
    
    x = F.upsample(x, scale_factor=2)
    
    x = x + x2 # element wise addition
    
    x = self.upconv_3(x)
    

    x1 = self.skipconv_4(x1)
    
    x = F.upsample(x, scale_factor=2)
    
    x = x + x1 # element wise addition
    
    x = self.upconv_4(x)

    return x

torchsummary(https://github.com/sksq96/pytorch-summary)는 PyTorch로 구현한 네트워크를 직관적으로 확인할 수 있는 라이브러리입니다. 해당 라이브러리를 이용하여 각 feature map의 dimension과 각각의 layer가 몇개의 parameter 수를 가지고 있는지 확인할 수 있습니다.

hg = Hourglass(ResidualBlock)


from torchsummary import summary

>>> summary(hg, input_size=(256,64,64), device='cpu')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
       BatchNorm2d-1          [-1, 256, 64, 64]             512
              ReLU-2          [-1, 256, 64, 64]               0
            Conv2d-3          [-1, 128, 64, 64]          32,896
       BatchNorm2d-4          [-1, 128, 64, 64]             256
              ReLU-5          [-1, 128, 64, 64]               0
            Conv2d-6          [-1, 128, 64, 64]         147,584
       BatchNorm2d-7          [-1, 128, 64, 64]             256
              ReLU-8          [-1, 128, 64, 64]               0
            Conv2d-9          [-1, 256, 64, 64]          33,024

            .........................................
     
          Conv2d-147          [-1, 128, 64, 64]          32,896
     BatchNorm2d-148          [-1, 128, 64, 64]             256
            ReLU-149          [-1, 128, 64, 64]               0
          Conv2d-150          [-1, 128, 64, 64]         147,584
     BatchNorm2d-151          [-1, 128, 64, 64]             256
            ReLU-152          [-1, 128, 64, 64]               0
          Conv2d-153          [-1, 256, 64, 64]          33,024
   ResidualBlock-154          [-1, 256, 64, 64]               0
================================================================
Total params: 3,217,920
Trainable params: 3,217,920
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.00
Forward/backward pass size (MB): 226.44
Params size (MB): 12.28
Estimated Total Size (MB): 242.71
----------------------------------------------------------------

 

3️⃣ Stacked hourglass network

 

[Newell et al., ECCV 2016]

 

아래 링크는 stacked hourglass network의 전체 코드입니다.

 

github.com/bearpaw/pytorch-pose

 

bearpaw/pytorch-pose

A PyTorch toolkit for 2D Human Pose Estimation. Contribute to bearpaw/pytorch-pose development by creating an account on GitHub.

github.com

 

Hourglass 모듈을 구현할 때 일일이 layer를 쌓는 것 대신에 for loop와 nn.ModuleList를 이용하여 더욱 직관적이고 명료한 코드 작성이 가능합니다.

'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글

[37-2] PyTorch profiler  (0) 2021.03.17
[36-1] Model Conversion  (0) 2021.03.16
[34-3] CNN Visualization using VGG11  (0) 2021.03.12
[33-2] Pytorch Autograd  (0) 2021.03.11
[32-2] Segmentation using PyTorch  (0) 2021.03.10
Comments