python

[torch] model parameter 개수/값 확인

jiheek 2022. 3. 16. 11:17
  1.  모델 파라미터 개수 확인

python에서 모델의 파라미터 개수를 세기 위해서 다음 함수를 추가하면 된다.

print(count_parameters(model)) #사용

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

 

  • 코드 해석

model.parameters()의 파라미터 텐서들에 대해서 requires_grad = True라면 해당 파라미터 텐서의 numel값들을 더한다.

 

  • torch.numel(input) 

input: 입력 tensor. 텐서의 모든 원소의 수를 반환한다.

 

  • requires_grad

Tensor의 옵션. True로 되어있어야 해당 tensor의 gradient를 구할 수 있고, backpropagation이 가능하다.

 

 

   2. 모델 파라미터 확인

#방법 1
for param in model.parameters():
	print(param)
    
#방법 2
for key, value in model.state_dict().items():
	print(value)

#방법 1 결과
#방법 2 결과

똑같이 tensor 형식으로 출력해준다.

방법 2에서 key도 출력하면 각 레이어의 이름까지 출력 가능하다.

'python' 카테고리의 다른 글

[python] os.mkdirs 에러  (0) 2022.03.31
[python] numpy array, torch tensor 크기 확인하기  (0) 2022.03.22
[argparse] 옵션 설명  (0) 2022.02.21
VScode python interpreter 버전 변경  (0) 2022.01.14
[python] Iterator slicing  (0) 2021.12.24