python

[torch] trained model 저장 및 사용 - (2)PyTorch

jiheek 2021. 12. 21. 22:28

Trained model을 어떻게 사용할까? Torch 버전

훈련한 모델을 저장하고 불러오는 여러 가지 방법들에 대해서 소개한다. 마지막에는 상황에 따른 사용 예시 코드도 첨부하였다.

import torch
import torch.nn as nn

torch.save(arg, PATH) #tensor, dictionary.. as parameter for saving

torch.load(PATH)

model.load_state_dict(arg)

 

 

save function

import torch
import torch.nn as nn

#### COMPLETE MODEL ####
torch.save(model, PATH)

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()


#### another method #### 
#When you want to save trained model and use it for inference
torch.save(model.state_dict(), PATH)

#model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

 

 

예제1: Lazy method

import torch
import torch.nn as nn

class Model(nn.Module):
	def __init__(self, n_input_features):
    	super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

	def forward(self, x):
    	y_pred = torch.sigmoid(self.linear(x))
        return y_pred
        
model = Model(n_input_features = 6)
#train your model ...

FILE = "model.pth" #pth(pytorch) file: non human readable file
torch.save(model, FILE)


#AFTER SAVING...
model = torch.load(FILE)
model.eval()

for param in model.parameters():
	print(param)

 

 

예제2: Prefered method (model.state_dict())

import torch
import torch.nn as nn

class Model(nn.Module):
	def __init__(self, n_input_features):
    	super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

	def forward(self, x):
    	y_pred = torch.sigmoid(self.linear(x))
        return y_pred
        
model = Model(n_input_features = 6)
#train your model ...

for param in model.parameters():
	print(param) # ----(1)

FILE = "model.pth" #pth(pytorch) file: non human readable file
torch.save(model.state_dict(), FILE)


#AFTER SAVING...
loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()

for param in loaded_model.parameters():
	print(param) # ----(2)

(1), (2)의 출력은 같다. 즉, 동일한 파라미터가 저장, 로드되었다.

 

weight를 따로 pth파일로 저장하고 다시 불러오는 방법이 예제2이고,

A 모델의 weight를 B 모델의 weight로 바로 복제해 오려면 아래의 방법을 사용해도 된다.

B.load_state_dict(A.state_dict())

 

[차이점]

예제1: 아예 모델 전체를 파일로 저장 // 예제2: 모델의 parameter만 저장

 

 

ImageNet pretrained model 사용 예시

기본 사용 예시, class 수가 다른 경우 수정하는 방법, key error가 생길 때 무시하는 방법에 대한 예시를 추가했다.

key error의 경우, strict=False라는 옵션을 주어서 key가 다른 경우는 넘어가고, 일치하는 weight,bias만 받도록 하면 된다.

import torch.nn as nn
import torchvision.models as models

#Case 1. 기본
model.load_state_dict(models.resnet50(pretrained=True).state_dict())

#Case 2. class 수가 다른 경우(imagenet: 1000, target: 10)
pretrain = models.resnet50(pretrained=True)
fc_in = model.fc.in_features
pretrain.fc = nn.Linear(fc_in, 10)
model.load_state_dict(pretrain.state_dict())

#Case 3. key error 무시하고 싶은 경우
model.load_state_dict(models.resnet50(pretrained=True).state_dict(), strict=False)

 

 


생성한 모델은 to(device)로 GPU에 저장할 수 있다.

 


출처

https://www.youtube.com/watch?v=9L9jEOwRrCg 

https://tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기 — PyTorch Tutorials 1.10.2+cu102 documentation

Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich 번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다

tutorials.pytorch.kr

 

'python' 카테고리의 다른 글

VScode python interpreter 버전 변경  (0) 2022.01.14
[python] Iterator slicing  (0) 2021.12.24
[tf, keras] trained model 저장 및 사용 - (1)TensorFlow, Keras  (0) 2021.12.21
[python] datetime, replace  (0) 2021.10.27
map 함수  (0) 2021.10.26