python

[pytorch] torch.nn.Module - train, eval, no_grad

jiheek 2022. 7. 27. 14:41

torch.nn.Module 클래스

모든 뉴럴 네트워크 모델의 베이스 클래스이다.

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

 

 

eval()

모듈을 evaluation mode로 변경한다. 이 기능은 특정 모듈에만 영향을 미친다. (Dropout, BatchNorm 등..)

Dropout의 경우 train 경우에만 적용되고 inference 시에는 모든 레이어와 노드들을 사용하기 때문에, Dropout 기능은 꺼줘야 하는데 이를 eval()이 수행해준다.

특정 모듈에 영향을 미친 후, 모듈을 return한다. model.train(False)와 같은 뜻이다.

 

 

train(mode=True)

모듈을 training mode로 변경한다. 이 또한 eval과 같게 특정 모듈(dropout, batchnorm.. train과 inference 경우에 다르게 동작하는 모듈들)에만 영향을 미친다.

 

 

https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch

 

 

따라서, model.eval() 또는 train()은 모듈의 training 상태를 변경해주는 flag 역할이다.

https://pytorch.org/docs/stable/generated/torch.nn.Module.html

 

 

torch.no_grad

Gradient 계산 기능을 꺼서 inference 시 유용하게 사용할 수 있다. 당연하게도 no_grad는 Tensor.backward()를 사용하지 않을 때 호출해야 한다. 이는 계산에 필요한 memory consumption을 절약시킬 수 있다. tensor의 requires_grad=True를 설정했더라도, torch.no_grad() 후에는 requires_grad=False로 변경된다.

사용 시에는 no_grad할 프로세스를 with torch.no_grad(): 로 감싸주면 된다.

>>> x = torch.tensor([1.], requires_grad=True) #no_grad 전
>>> with torch.no_grad(): 
...   y = x * 2
>>> y.requires_grad
False 					#no_grad 후
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False

 

BUT! tensor.backward() (loss.backward())로 backpropagation만 안하면 no_grad를 사용할 필요가 없지 않나? 어차피 gradient 계산 안할거니까. 라고 생각할 수 있다. 앞서 말했듯 memory consumption과 속도 측면에서 이득을 볼 수 있기 때문에 사용한다고 한다.. 그래서 inference/evalutaion 시에는 with torch.no_grad()로 감싸는 것이 일반적임! 

https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc

pytorch docs에서는 gradient 없이 parameter를 업데이트 하고 싶은 경우, 또는 이전의 gradient로 업데이트된 parameter를 사용해서 forward pass를 진행하고 싶을 때(~inference) 사용하라고 말한다.

 

 

참고

 

Module — PyTorch 1.12 documentation

Shortcuts

pytorch.org

https://pytorch.org/docs/stable/generated/torch.no_grad.html

 

no_grad — PyTorch 1.12 documentation

Shortcuts

pytorch.org