일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
- 표집분포
- Python 유래
- Array operations
- Numpy
- python 문법
- subplot
- dtype
- 딥러닝
- 부스트캠프 AI테크
- boolean & fancy index
- Python
- scatter
- 가능도
- groupby
- Python 특징
- 정규분포 MLE
- pivot table
- 카테고리분포 MLE
- ndarray
- linalg
- VSCode
- 최대가능도 추정법
- type hints
- Comparisons
- BOXPLOT
- Operation function
- unstack
- seaborn
- namedtuple
- Numpy data I/O
- Today
- Total
또르르's 개발 Story
[39-3] Teacher-Student Network using PyTorch 본문
Knowledge Distillation의 대표격인 Teacher-Student network를 만들어봅니다.

1️⃣ 설정
필요한 모듈을 import합니다.
# import
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torchsummary import summary
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
2️⃣ Define Function
1) train function / test function
train function / test function은 일반 train function / test function과 비슷합니다.
# Training
def train(epoch,net):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx %1000 ==0:
print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
# test
def test(epoch,net):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
2) train_distillation function
Teacher-Student network에서 가장 중요한 function입니다.
Student는 train mode, teacher는 eval mode로 수행해야하며, loss는 아래 loss_fn_kd function으로 수행합니다.
def train_distillation(epoch, student, teacher, TEMPERATURE, ALPHA):
print('\nEpoch: %d' % epoch)
student.train() # student만 gradient 수행
teacher.eval()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
optimizer.zero_grad()
outputs = student(inputs) # student
teacher_outputs = teacher(inputs) # teacher
loss = loss_fn_kd(outputs, targets, teacher_outputs,TEMPERATURE,ALPHA) # loss가 student, teacher 둘 다 받음
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx %1000 ==0:
print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
3) loss_fn_kd function
Teacher-Student network에서는 Loss를 다음과 같이 구합니다.


코드로 구현하면 다음과 같습니다.
# T * T를 곱하는 이유 https://github.com/peterliht/knowledge-distillation-pytorch/issues/10
T*T를 곱하는 이유는 hard target(hard label; 0,1로만 구성)와 soft target(soft label, 확률로 구성)을 같이 사용하기 때문입니다.
T=1일 때는 상관없지만, T>1 이상일 경우 softmax의 역수로 곱해지기 때문에 KL divergence의 영향력이 1/T2으로 줄어들게 됩니다. 이 뜻은 의도치않게 Cross-Entropy의 영향력이 T2으로 커지게 된다는 의미입니다. 따라서 α에다가 T*T를 곱해서 균형을 맞춰줍니다.
따라서 train을 하는 동안 distillation에 사용되는 T가 변경되더라도, hard target과 soft target의 상대적인 영향력이 거의 변하지 않습니다.
# KLDivLoss에서 student output에서 log_softmax를 사용한 이유 github.com/peterliht/knowledge-distillation-pytorch/issues/2
PyTorch의 KLDIvLoss에 대한 정의 문서를 참조하면, 이때 입력은 log 확률분포와 확률분포가 들어가야 합니다.
def loss_fn_kd(outputs, labels, teacher_outputs, T=2, alpha=0.5):
"""
# outputs = result of model
# labels = true label
# teacher_outputs = teacher_model(data) # teacher model eval()
# params = T & alpha
"""
KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
F.cross_entropy(outputs, labels) * (1. - alpha)
return KD_loss
3️⃣ Data Load
데이터는 CIFAR10을 사용합니다.
# Initial Value
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BEST_ACC = 0
START_EPOCH = 0 # start from epoch 0 or last checkpoint epoch
BATCH_SIZE = 16
# Data preprocessing
print('==> Preparing data..')
mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]
transform_train = transforms.Compose([transforms.RandomResizedCrop(size=256),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean_nums, std_nums),
])
transform_test = transforms.Compose([transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean_nums, std_nums),
])
# DataSet & DataLoader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=BATCH_SIZE,
shuffle=True,)
testloader = torch.utils.data.DataLoader(testset,
batch_size=BATCH_SIZE,
shuffle=False)
4️⃣ Teacher Network 학습
Teacher network로는 pretrained된 resnet34을 사용합니다.(pretrained with 1000-class imagenet)
# teacher
teacher = models.__dict__['resnet34'](pretrained=True)
for param in teacher.parameters():
param.requires_grad = False
in_features = teacher.fc.in_features
teacher.fc = nn.Linear(in_features,10)
# Fine-tuning
for name, child in teacher.named_children():
if name in ['layer3', 'layer4','fc']:
print(name + 'has been unfrozen.')
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False
teacher = teacher.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=1e-3,
momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # lr_scheduler
# RUN
for epoch in range(START_EPOCH,START_EPOCH+20):
train(epoch,teacher)
test(epoch,teacher)
exp_lr_scheduler.step()
torch.save(teacher.state_dict(), f'./teacher_{epoch}.pth')
5️⃣ Teacher Network 분포
plt를 사용해서 그래프를 그리면 다음과 같습니다.
Softmax에서 T가 커질수록 확률분포가 퍼지는 것을 알 수 있습니다.
import pandas as pd
# Check distribution
teacher_outputs = teacher(torch.unsqueeze(selected_images, 0).to(DEVICE))
# list of Temperature
T_list = [1,5,10,20,30]
output_dic ={}
for i in T_list:
output_dic[i] = F.softmax(teacher_outputs/i, dim=1).cpu().numpy().squeeze() # softmax
# display
pd.DataFrame(output_dic).plot(title=f'Distribution of {selected_labels}')
plt.legend(title=' Temperature', labelspacing=1,
bbox_to_anchor=(1.03,0), loc='lower left')
plt.xlabel('Categories')
plt.ylabel('Softmax value')
plt.xticks([0,1,2,3,4,5,6,7,8,9])
plt.show()

6️⃣ Student Network 학습
앞서 학습된 Teacher Network를 가지고 Student Network 에 KD(knowledge distillation)을 하는 단계입니다. 여기서 Student Network 는 'ResNet18'을 사용합니다.
ALPHA, TEMPERATURE값에 따라 학습량이 변경됩니다.
# ALPHA & TEMPERATURE
TEMPERATURE = 2
ALPHA = 0.3
# Load Teacher dict
DICT_PATH = 'teacher_{SAVE_NUMBER}.pth'
teacher.load_state_dict(torch.load(DICT_PATH,))
for param in teacher.parameters():
param.requires_grad = False
teacher.to(DEVICE)
Pre-trained이 되지않은 ResNet18을 불러옵니다.
student = models.__dict__['resnet18']()
in_features = student.fc.in_features
student.fc = nn.Linear(in_features,10)
student.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=1e-3,
momentum=0.9, weight_decay=5e-4)
train_distillation 함수를 사용해서 훈련을 시작합니다.
for epoch in range(0, START_EPOCH+10):
train_distillation(epoch,student, teacher,TEMPERATURE, ALPHA)
test(epoch,student)
7️⃣ KD(Knowledge Distilliation) vs Study Alone
아래 결과는 Teacher-Student를 통한 KD(Knowledge Distilliation)의 성능과 Student model이 혼자(alone) 훈련했을 때의 성능 차이입니다. 같은 Epoch을 수행했을 때의 차이입니다.


약간의 성능차이가 나는 것을 알 수 있습니다.
'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글
[39-2] Quantization using PyTorch (0) | 2021.03.18 |
---|---|
[38-3] Pruning using PyTorch (0) | 2021.03.18 |
[38-2] Python 병렬 Processing (0) | 2021.03.17 |
[37-2] PyTorch profiler (0) | 2021.03.17 |
[36-1] Model Conversion (0) | 2021.03.16 |