또르르's 개발 Story

[39-3] Teacher-Student Network using PyTorch 본문

부스트캠프 AI 테크 U stage/실습

[39-3] Teacher-Student Network using PyTorch

또르르21 2021. 3. 19. 00:50

Knowledge Distillation의 대표격인 Teacher-Student network를 만들어봅니다.

 

ref) github.com/kmsravindra/ML-AI-experiments/blob/master/AI/knowledge_distillation/Knowledge%20distillation.ipynb

 

 

1️⃣ 설정

 

필요한 모듈을 import합니다.

# import 

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F


from torch.optim import lr_scheduler

from torchsummary import summary


import torchvision

import torchvision.models as models

import torchvision.transforms as transforms

 

 

2️⃣ Define Function

1) train function / test function

train function / test function은 일반 train function / test function과 비슷합니다.

# Training

def train(epoch,net):

    print('\nEpoch: %d' % epoch)
    
    net.train()
    
    train_loss = 0
    
    correct = 0
    
    total = 0
    

    for batch_idx, (inputs, targets) in enumerate(trainloader):
    
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss = criterion(outputs, targets)
        
        loss.backward()
        
        optimizer.step()
        

        train_loss += loss.item()
        
        _, predicted = outputs.max(1)
        
        total += targets.size(0)
        
        correct += predicted.eq(targets).sum().item()
        

        if batch_idx %1000 ==0:
        
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
# test

def test(epoch,net):

    global best_acc
    
    net.eval()
    
    test_loss = 0
    
    correct = 0
    
    total = 0
    
    with torch.no_grad():
    
        for batch_idx, (inputs, targets) in enumerate(testloader):
        
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            
            outputs = net(inputs)
            
            loss = criterion(outputs, targets)
            

            test_loss += loss.item()
            
            _, predicted = outputs.max(1)
            
            total += targets.size(0)
            
            correct += predicted.eq(targets).sum().item()
            
    print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

 

2) train_distillation function

Teacher-Student network에서 가장 중요한 function입니다.

Student는 train mode, teacher는 eval mode로 수행해야하며, loss는 아래 loss_fn_kd function으로 수행합니다.

def train_distillation(epoch, student, teacher, TEMPERATURE, ALPHA):

    print('\nEpoch: %d' % epoch)
    
    student.train()       # student만 gradient 수행
    
    teacher.eval()
    
    
    train_loss = 0
    
    correct = 0
    
    total = 0
    

    for batch_idx, (inputs, targets) in enumerate(trainloader):
    
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
        
        optimizer.zero_grad()
        
        outputs = student(inputs) # student
        
        teacher_outputs = teacher(inputs) # teacher 
        

        loss = loss_fn_kd(outputs, targets, teacher_outputs,TEMPERATURE,ALPHA)    # loss가 student, teacher 둘 다 받음
        
        loss.backward()
        
        optimizer.step()
        

        train_loss += loss.item()
        
        _, predicted = outputs.max(1)
        
        total += targets.size(0)
        
        correct += predicted.eq(targets).sum().item()
        

        if batch_idx %1000 ==0:
        
            print('Loss: %.3f | Acc: %.3f%% ' %(train_loss/(batch_idx+1),100.*correct/total))
     

 

 

3) loss_fn_kd function

Teacher-Student network에서는 Loss를 다음과 같이 구합니다.

 

 

[Hinton et al., arXiv 2015]

 

코드로 구현하면 다음과 같습니다.

 

# T * T를 곱하는 이유 https://github.com/peterliht/knowledge-distillation-pytorch/issues/10

 

T*T를 곱하는 이유는 hard target(hard label; 0,1로만 구성)와 soft target(soft label, 확률로 구성)을 같이 사용하기 때문입니다.

T=1일 때는 상관없지만, T>1 이상일 경우 softmax의 역수로 곱해지기 때문에 KL divergence의 영향력이 1/T2으로 줄어들게 됩니다. 이 뜻은 의도치않게 Cross-Entropy의 영향력이 T2으로 커지게 된다는 의미입니다. 따라서 α에다가 T*T를 곱해서 균형을 맞춰줍니다. 

따라서 train을 하는 동안 distillation에 사용되는 T가 변경되더라도, hard target과 soft target의 상대적인 영향력이 거의 변하지 않습니다.

 

# KLDivLoss에서 student output에서 log_softmax를 사용한 이유 github.com/peterliht/knowledge-distillation-pytorch/issues/2

 

PyTorch의 KLDIvLoss에 대한 정의 문서를 참조하면, 이때 입력은 log 확률분포와 확률분포가 들어가야 합니다. 

def loss_fn_kd(outputs, labels, teacher_outputs, T=2, alpha=0.5):

    """
    
    # outputs = result of model
    
    # labels = true label 
    
    # teacher_outputs = teacher_model(data) # teacher model eval()
    
    # params = T & alpha 
    
    """
    
    KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
    
                                F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                                
                F.cross_entropy(outputs, labels) * (1. - alpha)



    return KD_loss

 

 

3️⃣ Data Load

 

데이터는 CIFAR10을 사용합니다.

# Initial Value 

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BEST_ACC = 0 

START_EPOCH = 0 # start from epoch 0 or last checkpoint epoch

BATCH_SIZE = 16
# Data preprocessing

print('==> Preparing data..')


mean_nums = [0.485, 0.456, 0.406]

std_nums = [0.229, 0.224, 0.225]

transform_train = transforms.Compose([transforms.RandomResizedCrop(size=256),

                                      transforms.RandomRotation(degrees=15),
                                      
                                      transforms.RandomHorizontalFlip(),
                                      
                                      transforms.ToTensor(),
                                      
                                      transforms.Normalize(mean_nums, std_nums),
                                      
])



transform_test = transforms.Compose([transforms.Resize(256),

                                     transforms.ToTensor(),
                                     
                                     transforms.Normalize(mean_nums, std_nums),
                                     
])
# DataSet & DataLoader

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

                                        download=True, transform=transform_train)
                                        
testset = torchvision.datasets.CIFAR10(root='./data', train=False,

                                       download=True, transform=transform_test)
                                       

trainloader = torch.utils.data.DataLoader(trainset,

                                          batch_size=BATCH_SIZE,
                                          
                                          shuffle=True,)
                                          

testloader = torch.utils.data.DataLoader(testset, 

                                         batch_size=BATCH_SIZE,
                                         
                                         shuffle=False)

 

 

4️⃣ Teacher Network 학습

 

Teacher network로는 pretrained된 resnet34을 사용합니다.(pretrained with 1000-class imagenet)

# teacher 

teacher = models.__dict__['resnet34'](pretrained=True)

for param in teacher.parameters():

    param.requires_grad = False
    
    
in_features = teacher.fc.in_features

teacher.fc = nn.Linear(in_features,10)


# Fine-tuning

for name, child in teacher.named_children():

    if name in ['layer3', 'layer4','fc']:
    
        print(name + 'has been unfrozen.')
        
        for param in child.parameters():
        
            param.requires_grad = True
            
    else:
    
        for param in child.parameters():
        
            param.requires_grad = False
            

teacher = teacher.to(DEVICE)
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(teacher.parameters(), lr=1e-3,

                      momentum=0.9)
                      
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)		# lr_scheduler
# RUN 

for epoch in range(START_EPOCH,START_EPOCH+20):

    train(epoch,teacher)
    
    test(epoch,teacher)
    
    exp_lr_scheduler.step()
    
    torch.save(teacher.state_dict(), f'./teacher_{epoch}.pth')

 

 

5️⃣ Teacher Network 분포

 

plt를 사용해서 그래프를 그리면 다음과 같습니다.

Softmax에서 T가 커질수록 확률분포가 퍼지는 것을 알 수 있습니다.

import pandas as pd

# Check distribution

teacher_outputs = teacher(torch.unsqueeze(selected_images, 0).to(DEVICE))


# list of Temperature 

T_list = [1,5,10,20,30]

output_dic ={}


for i in T_list:

  output_dic[i] = F.softmax(teacher_outputs/i, dim=1).cpu().numpy().squeeze() # softmax
  

# display

pd.DataFrame(output_dic).plot(title=f'Distribution of {selected_labels}')

plt.legend(title=' Temperature', labelspacing=1, 

           bbox_to_anchor=(1.03,0), loc='lower left')
           
plt.xlabel('Categories')

plt.ylabel('Softmax value')

plt.xticks([0,1,2,3,4,5,6,7,8,9])

plt.show()                                             

 

 

6️⃣ Student Network 학습

 

앞서 학습된 Teacher Network를 가지고 Student Network 에 KD(knowledge distillation)을 하는 단계입니다. 여기서 Student Network 는 'ResNet18'을 사용합니다.

 

ALPHA, TEMPERATURE값에 따라 학습량이 변경됩니다.

# ALPHA & TEMPERATURE 

TEMPERATURE = 2

ALPHA = 0.3 
# Load Teacher dict 

DICT_PATH = 'teacher_{SAVE_NUMBER}.pth'

teacher.load_state_dict(torch.load(DICT_PATH,))

for param in teacher.parameters():

    param.requires_grad = False
    
    
teacher.to(DEVICE)

Pre-trained이 되지않은 ResNet18을 불러옵니다.

student = models.__dict__['resnet18']()

in_features = student.fc.in_features

student.fc = nn.Linear(in_features,10)



student.to(DEVICE)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(student.parameters(), lr=1e-3,

                      momentum=0.9, weight_decay=5e-4)

train_distillation 함수를 사용해서 훈련을 시작합니다.

for epoch in range(0, START_EPOCH+10):

    train_distillation(epoch,student, teacher,TEMPERATURE, ALPHA)
    
    test(epoch,student)

 

 

7️⃣ KD(Knowledge Distilliation) vs Study Alone

 

아래 결과는 Teacher-Student를 통한 KD(Knowledge Distilliation)의 성능과 Student model이 혼자(alone) 훈련했을 때의 성능 차이입니다. 같은 Epoch을 수행했을 때의 차이입니다.

 

  KD(Knowledge Distilliation) 수행했을 때
Study Alone으로 수행했을 때

 

약간의 성능차이가 나는 것을 알 수 있습니다.

'부스트캠프 AI 테크 U stage > 실습' 카테고리의 다른 글

[39-2] Quantization using PyTorch  (0) 2021.03.18
[38-3] Pruning using PyTorch  (0) 2021.03.18
[38-2] Python 병렬 Processing  (0) 2021.03.17
[37-2] PyTorch profiler  (0) 2021.03.17
[36-1] Model Conversion  (0) 2021.03.16
Comments