Pytorch의 nn 모듈은 neural networks를 위한 다양한 구성 요소 클래스를 제공합니다. 특히, 여러 개의 구성 요소를 하나의 리스트로 담는 nn.ModuleList 객체 또한 많이 사용되는데요, 겉보기에는 일반 파이썬 list와 큰 차이가 없어 보입니다. 다음과 같이 간단한 네트워크를 구성해 보겠습니다.
import torch
import torch.nn as nn
class MyNN(nn.Module):
def __init__(self, fc_input_size, fc_hidden_sizes, num_classes):
super(MyNN, self).__init__()
fcs = [nn.Sequential(
nn.Linear(fc_input_size, fc_hidden_size),
nn.ReLU(),
nn.Linear(fc_hidden_size, num_classes)
) for fc_hidden_size in fc_hidden_sizes]
self.network1 = nn.ModuleList(fcs)
self.network2 = fcs
self.network1은 nn.ModuleList로 구성된 네트워크이고 self.network2는 일반 파이썬 list로 구성된 네트워크입니다. nn.ModuleList는 일반 파이썬 list의 속성들을 그대로 가지고 있어 이후 forward 메소드도 정의하는 방법도 크게 차이가 없습니다. 그렇다면 어디에서 차이가 발생할까요?
이를 알아보기 위해 MyNN 네트워크를 생성하고 모델의 named_paramters() 함수를 통해 네트워크의 구성 모듈을 살펴보겠습니다. 보다시피 nn.ModuleList로 구성된 네트워크인 self.network1만 모델 파라미터에 등록되 있고 일반 파이썬 list로 구성된 self.network2는 보이지 않습니다. 따라서 추후 optimizer를 호출할 때 model.parameters()로 모델의 파라미터를 전달할텐데 일반 list로 구성한 파라미터들은 전달되지 않을테니 훈련되지 않을 것이고, state_dict()로 저장하고자 할 때도 일반 파이썬 list로 구성한 모듈들은 저장되지 않겠죠.
if __name__ == "__main__":
nn = MyNN(fc_input_size=4, fc_hidden_sizes=(8,8), num_classes=3)
for n, param in nn.named_parameters():
print(f"Parameter Name: {n}, Shape: {param.shape!r}")
Parameter Name: network1.0.0.weight, Shape: torch.Size([8, 4])
Parameter Name: network1.0.0.bias, Shape: torch.Size([8])
Parameter Name: network1.0.2.weight, Shape: torch.Size([3, 8])
Parameter Name: network1.0.2.bias, Shape: torch.Size([3])
Parameter Name: network1.1.0.weight, Shape: torch.Size([8, 4])
Parameter Name: network1.1.0.bias, Shape: torch.Size([8])
Parameter Name: network1.1.2.weight, Shape: torch.Size([3, 8])
Parameter Name: network1.1.2.bias, Shape: torch.Size([3])
즉, 우리가 리스트 컨테이너 방식으로 neural networks의 구성 요소를 담고자 한다면 nn.Module의 파라미터로 등록이 되어야 하므로 무조건 nn.ModuleList를 사용해야 합니다.
참조
'Machine Learning Models > Pytorch' 카테고리의 다른 글
Pytorch - DistributedDataParallel (1) - 개요 (0) | 2021.10.23 |
---|---|
Pytorch - DataParallel (0) | 2021.10.22 |
Pytorch - backward(retain_graph=True) (2) (1) | 2021.08.01 |
Pytorch LSTM (0) | 2021.06.26 |
Pytorch - embedding (0) | 2021.06.23 |