본문 바로가기

부스트캠프 AI Tech/Pytorch

[04] nn.parameter

nn.Module 에서 parameter 정의하기


  • linear transformation인 Y = XW + b 에 대해서 W,b 를 어떻게 만들까??
  • nn.Module안에 미리 만들어진 tensor들을 보관 가능 -> Parameter
  • tensor를 안쓰고 Parameter 사용하는 이유 : 아래에 나와용
  • 보통은 torch.nn에 구현된 layer들을 가져다 쓰기 때문에 Parameter를 직접 다루는 경우는
    직접 layer를 작성하지 않는 이상 사용할 일이 거의 없음
import torch
from torch import nn
from torch.nn.parameter import Parameter

class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.W = Parameter(torch.ones(out_features, in_features))
        self.b = Parameter(torch.ones(out_features))

    def forward(self, x):
        output = torch.addmm(self.b, x, self.W.T)

        return output

x = torch.Tensor([[1, 2],
                  [3, 4]])

linear = Linear(2, 3)
output = linear(x)

# torch.Tensor([[4, 4, 4],
#              [8, 8, 8]]):

Linear의 forward에서 왜 transpose를 취해서 곱해줄까요 ??

transpose을 안하고 W에 파라미터를 줄때
(in_features, out_features) 로 주면 안 헷갈리고 좋을 것 같은데요.

구글링 해보니 Transpose를 취하면 backward() 를 할 때 효율이 좋다고 합니다.
중요한 파트라고 생각 안해서 자세히는 안보고 넘어가겠습니다.

state_dict()

* dict 형태로 parameter 를 저장되어 있음
  linear_parameter.state_dict()

  #OrderedDict([('W', tensor([[1., 1.],
  #                    [1., 1.],
  #                    [1., 1.]])), ('b', tensor([1., 1., 1.]))])

buffer


  • "Tensor"
    • ❌ gradient 계산
    • ❌ 값 업데이트
    • ❌ 모델 저장시 값 저장
  • "Parameter"
    • ✅ gradient 계산
    • ✅ 값 업데이트
    • ✅ 모델 저장시 값 저장
  • "Buffer"
    • ❌ gradient 계산
    • ❌ 값 업데이트
    • ✅ 모델 저장시 값 저장
import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.parameter = Parameter(torch.Tensor([7]))
        self.tensor = torch.Tensor([7])
        self.register_buffer('buffer',torch.Tensor([7]))

model = Model()
model.state_dict() # OrderedDict([('parameter', tensor([7.])), ('buffer', tensor([7.]))])

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[06] Pytorch의 hook  (0) 2022.01.20
[05] named_children, named_modules.  (0) 2022.01.19
[03] nn.Module 이해하기  (0) 2022.01.18
[02] Pytorch Basics  (0) 2022.01.18
[01] Introduction (Pytorch vs Tensorflow)  (0) 2022.01.17