일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
- groupby
- 카테고리분포 MLE
- VSCode
- Array operations
- Numpy
- Operation function
- 정규분포 MLE
- Python 특징
- 부스트캠프 AI테크
- subplot
- 최대가능도 추정법
- boolean & fancy index
- Numpy data I/O
- Python 유래
- Python
- 딥러닝
- ndarray
- namedtuple
- scatter
- seaborn
- linalg
- 가능도
- unstack
- 표집분포
- Comparisons
- type hints
- BOXPLOT
- python 문법
- pivot table
- dtype
- Today
- Total
또르르's 개발 Story
[34-4] Hourglass Network using PyTorch 본문
Hourglass Network는Landmark Localization의 대표적인 Network입니다.
이 코드는 basic code이며, 전체 코드는 링크에서 볼 수 있습니다.
1️⃣ 설정
필요한 모듈을 import 합니다.
# Seed
import torch
import numpy as np
import random
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
# Ignore warnings
import warnings
warnings.filterwarnings('ignore')
2️⃣ Hourglass 모듈
Hourglass Network는 Hourglass 모듈의 stack으로 구성되어 있고, Hourglass 모듈을 residual block을 기본 convolution block으로 사용합니다.
1) Resiual block
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, num_channels=256):
super(ResidualBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(num_channels)
self.conv1 = nn.Conv2d(num_channels, num_channels//2, kernel_size=1, bias=True)
self.bn2 = nn.BatchNorm2d(num_channels//2)
self.conv2 = nn.Conv2d(num_channels//2, num_channels//2, kernel_size=3, stride=1, padding=1, bias=True)
self.bn3 = nn.BatchNorm2d(num_channels//2)
self.conv3 = nn.Conv2d(num_channels//2, num_channels, kernel_size=1, bias=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out
2) Hourglass module
class Hourglass(nn.Module):
def __init__(self, block, num_channels=256):
super(Hourglass, self).__init__()
self.downconv_1 = block(num_channels)
self.pool_1 = nn.MaxPool2d(kernel_size=2)
self.downconv_2 = block(num_channels)
self.pool_2 = nn.MaxPool2d(kernel_size=2)
self.downconv_3 = block(num_channels)
self.pool_3 = nn.MaxPool2d(kernel_size=2)
self.downconv_4 = block(num_channels)
self.pool_4 = nn.MaxPool2d(kernel_size=2)
self.midconv_1 = block(num_channels)
self.midconv_2 = block(num_channels)
self.midconv_3 = block(num_channels)
self.skipconv_1 = block(num_channels)
self.skipconv_2 = block(num_channels)
self.skipconv_3 = block(num_channels)
self.skipconv_4 = block(num_channels)
self.upconv_1 = block(num_channels)
self.upconv_2 = block(num_channels)
self.upconv_3 = block(num_channels)
self.upconv_4 = block(num_channels)
def forward(self, x):
x1 = self.downconv_1(x)
x = self.pool_1(x1)
x2 = self.downconv_2(x)
x = self.pool_2(x2)
x3 = self.downconv_3(x)
x = self.pool_3(x3)
x4 = self.downconv_4(x)
x = self.pool_4(x4)
x = self.midconv_1(x)
x = self.midconv_2(x)
x = self.midconv_3(x)
x4 = self.skipconv_1(x4)
x = F.upsample(x, scale_factor=2)
x = x + x4 # element wise addition
x = self.upconv_1(x)
x3 = self.skipconv_2(x3)
x = F.upsample(x, scale_factor=2)
x = x + x3 # element wise addition
x = self.upconv_2(x)
x2 = self.skipconv_3(x2)
x = F.upsample(x, scale_factor=2)
x = x + x2 # element wise addition
x = self.upconv_3(x)
x1 = self.skipconv_4(x1)
x = F.upsample(x, scale_factor=2)
x = x + x1 # element wise addition
x = self.upconv_4(x)
return x
torchsummary(https://github.com/sksq96/pytorch-summary)는 PyTorch로 구현한 네트워크를 직관적으로 확인할 수 있는 라이브러리입니다. 해당 라이브러리를 이용하여 각 feature map의 dimension과 각각의 layer가 몇개의 parameter 수를 가지고 있는지 확인할 수 있습니다.
hg = Hourglass(ResidualBlock)
from torchsummary import summary
>>> summary(hg, input_size=(256,64,64), device='cpu')
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
BatchNorm2d-1 [-1, 256, 64, 64] 512
ReLU-2 [-1, 256, 64, 64] 0
Conv2d-3 [-1, 128, 64, 64] 32,896
BatchNorm2d-4 [-1, 128, 64, 64] 256
ReLU-5 [-1, 128, 64, 64] 0
Conv2d-6 [-1, 128, 64, 64] 147,584
BatchNorm2d-7 [-1, 128, 64, 64] 256
ReLU-8 [-1, 128, 64, 64] 0
Conv2d-9 [-1, 256, 64, 64] 33,024
.........................................
Conv2d-147 [-1, 128, 64, 64] 32,896
BatchNorm2d-148 [-1, 128, 64, 64] 256
ReLU-149 [-1, 128, 64, 64] 0
Conv2d-150 [-1, 128, 64, 64] 147,584
BatchNorm2d-151 [-1, 128, 64, 64] 256
ReLU-152 [-1, 128, 64, 64] 0
Conv2d-153 [-1, 256, 64, 64] 33,024
ResidualBlock-154 [-1, 256, 64, 64] 0
================================================================
Total params: 3,217,920
Trainable params: 3,217,920
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.00
Forward/backward pass size (MB): 226.44
Params size (MB): 12.28
Estimated Total Size (MB): 242.71
----------------------------------------------------------------
3️⃣ Stacked hourglass network
아래 링크는 stacked hourglass network의 전체 코드입니다.
github.com/bearpaw/pytorch-pose
bearpaw/pytorch-pose
A PyTorch toolkit for 2D Human Pose Estimation. Contribute to bearpaw/pytorch-pose development by creating an account on GitHub.
github.com
Hourglass 모듈을 구현할 때 일일이 layer를 쌓는 것 대신에 for loop와 nn.ModuleList를 이용하여 더욱 직관적이고 명료한 코드 작성이 가능합니다.
'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글
[37-2] PyTorch profiler (0) | 2021.03.17 |
---|---|
[36-1] Model Conversion (0) | 2021.03.16 |
[34-3] CNN Visualization using VGG11 (0) | 2021.03.12 |
[33-2] Pytorch Autograd (0) | 2021.03.11 |
[32-2] Segmentation using PyTorch (0) | 2021.03.10 |