Trained model을 어떻게 사용할까? Torch 버전
훈련한 모델을 저장하고 불러오는 여러 가지 방법들에 대해서 소개한다. 마지막에는 상황에 따른 사용 예시 코드도 첨부하였다.
import torch
import torch.nn as nn
torch.save(arg, PATH) #tensor, dictionary.. as parameter for saving
torch.load(PATH)
model.load_state_dict(arg)
save function
import torch
import torch.nn as nn
#### COMPLETE MODEL ####
torch.save(model, PATH)
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
#### another method ####
#When you want to save trained model and use it for inference
torch.save(model.state_dict(), PATH)
#model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
예제1: Lazy method
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features = 6)
#train your model ...
FILE = "model.pth" #pth(pytorch) file: non human readable file
torch.save(model, FILE)
#AFTER SAVING...
model = torch.load(FILE)
model.eval()
for param in model.parameters():
print(param)
예제2: Prefered method (model.state_dict())
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features = 6)
#train your model ...
for param in model.parameters():
print(param) # ----(1)
FILE = "model.pth" #pth(pytorch) file: non human readable file
torch.save(model.state_dict(), FILE)
#AFTER SAVING...
loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()
for param in loaded_model.parameters():
print(param) # ----(2)
(1), (2)의 출력은 같다. 즉, 동일한 파라미터가 저장, 로드되었다.
weight를 따로 pth파일로 저장하고 다시 불러오는 방법이 예제2이고,
A 모델의 weight를 B 모델의 weight로 바로 복제해 오려면 아래의 방법을 사용해도 된다.
B.load_state_dict(A.state_dict())
[차이점]
예제1: 아예 모델 전체를 파일로 저장 // 예제2: 모델의 parameter만 저장
ImageNet pretrained model 사용 예시
기본 사용 예시, class 수가 다른 경우 수정하는 방법, key error가 생길 때 무시하는 방법에 대한 예시를 추가했다.
key error의 경우, strict=False라는 옵션을 주어서 key가 다른 경우는 넘어가고, 일치하는 weight,bias만 받도록 하면 된다.
import torch.nn as nn
import torchvision.models as models
#Case 1. 기본
model.load_state_dict(models.resnet50(pretrained=True).state_dict())
#Case 2. class 수가 다른 경우(imagenet: 1000, target: 10)
pretrain = models.resnet50(pretrained=True)
fc_in = model.fc.in_features
pretrain.fc = nn.Linear(fc_in, 10)
model.load_state_dict(pretrain.state_dict())
#Case 3. key error 무시하고 싶은 경우
model.load_state_dict(models.resnet50(pretrained=True).state_dict(), strict=False)
생성한 모델은 to(device)로 GPU에 저장할 수 있다.
출처
https://www.youtube.com/watch?v=9L9jEOwRrCg
https://tutorials.pytorch.kr/beginner/saving_loading_models.html
모델 저장하기 & 불러오기 — PyTorch Tutorials 1.10.2+cu102 documentation
Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich 번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다
tutorials.pytorch.kr
'python' 카테고리의 다른 글
VScode python interpreter 버전 변경 (0) | 2022.01.14 |
---|---|
[python] Iterator slicing (0) | 2021.12.24 |
[tf, keras] trained model 저장 및 사용 - (1)TensorFlow, Keras (0) | 2021.12.21 |
[python] datetime, replace (0) | 2021.10.27 |
map 함수 (0) | 2021.10.26 |