torch.nn.Module 클래스
모든 뉴럴 네트워크 모델의 베이스 클래스이다.
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
eval()
모듈을 evaluation mode로 변경한다. 이 기능은 특정 모듈에만 영향을 미친다. (Dropout, BatchNorm 등..)
Dropout의 경우 train 경우에만 적용되고 inference 시에는 모든 레이어와 노드들을 사용하기 때문에, Dropout 기능은 꺼줘야 하는데 이를 eval()이 수행해준다.
특정 모듈에 영향을 미친 후, 모듈을 return한다. model.train(False)와 같은 뜻이다.
train(mode=True)
모듈을 training mode로 변경한다. 이 또한 eval과 같게 특정 모듈(dropout, batchnorm.. train과 inference 경우에 다르게 동작하는 모듈들)에만 영향을 미친다.
따라서, model.eval() 또는 train()은 모듈의 training 상태를 변경해주는 flag 역할이다.
torch.no_grad
Gradient 계산 기능을 꺼서 inference 시 유용하게 사용할 수 있다. 당연하게도 no_grad는 Tensor.backward()를 사용하지 않을 때 호출해야 한다. 이는 계산에 필요한 memory consumption을 절약시킬 수 있다. tensor의 requires_grad=True를 설정했더라도, torch.no_grad() 후에는 requires_grad=False로 변경된다.
사용 시에는 no_grad할 프로세스를 with torch.no_grad(): 로 감싸주면 된다.
>>> x = torch.tensor([1.], requires_grad=True) #no_grad 전
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False #no_grad 후
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
BUT! tensor.backward() (loss.backward())로 backpropagation만 안하면 no_grad를 사용할 필요가 없지 않나? 어차피 gradient 계산 안할거니까. 라고 생각할 수 있다. 앞서 말했듯 memory consumption과 속도 측면에서 이득을 볼 수 있기 때문에 사용한다고 한다.. 그래서 inference/evalutaion 시에는 with torch.no_grad()로 감싸는 것이 일반적임!
pytorch docs에서는 gradient 없이 parameter를 업데이트 하고 싶은 경우, 또는 이전의 gradient로 업데이트된 parameter를 사용해서 forward pass를 진행하고 싶을 때(~inference) 사용하라고 말한다.
참고
Module — PyTorch 1.12 documentation
Shortcuts
pytorch.org
https://pytorch.org/docs/stable/generated/torch.no_grad.html
no_grad — PyTorch 1.12 documentation
Shortcuts
pytorch.org
'python' 카테고리의 다른 글
[런타임에러] RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same (0) | 2022.07.28 |
---|---|
[numpy] array에서 특정 조건 인덱스 return, numpy.where (0) | 2022.07.27 |
torch.tensor mul_ (0) | 2022.07.21 |
[pycharm] 단축키 정리 (0) | 2022.06.22 |
[CUDA] cuda 정리하기 (0) | 2022.05.27 |