본문 바로가기

부스트캠프 AI Tech/Pytorch

[03] nn.Module 이해하기

pytorch documentation


https://pytorch.org/docs/stable/index.html

 

PyTorch documentation — PyTorch 1.10.1 documentation

Shortcuts

pytorch.org

index select


>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

문제

``` python
[[1  2]
 [3  4]] 2차원 텐서에서 [1  3]이라는 값을 가져오기.
# 방법 1 - index_select 사용
output = torch.index_select(A,1,indices).view(-1,2)

# 방법2 - 리스트인덱싱 사용
output = torch.tensor(np.array(A)[:,0])

방법2 에서.. 리스트 사용하다가 안되서.. 이유 찾다가 결국 알아냄

import numpy as np
ar = np.array([[1,2],[3,4]])
ar[0,1] # array 는 이런 접근이 가능하구나..
li = [[1,2],[3,4]]
li[0,1] # 오류 발생

nn.Module 모델 제작


add Model 제작

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x1, x2):
        return (x1 + x2)

x1 = torch.tensor([1])
x2 = torch.tensor([2])

add = Add()
output = add(x1, x2) # -> forward 자동 실행 리턴
print(output) ## 출력 3

torch.nn.Sequential

  • 모듈들을 하나로 묶어 순차적으로 실행시키고 싶을 때
  • 내가 이해한 과정에 따르면 클래스 init 초기화 -> forward -> return값으로 init 초기화라고 이해했음
class Add(nn.Module):
    def __init__(self, value):
        super().__init__()
        self.value = value

    def forward(self, x):
        return x + self.value

calculator = nn.Sequential(
    Add(3),
    Add(2),
    Add(5),
)


x = torch.tensor([1])
output = calculator(x)
print(output) # tensor([11])

''' 
내가 이해하려고 만든 과정
add = Add(torch.tensor([1]))
add(3) # tensor([4])
add = Add(add(3))
add(2) # tensor([4])
add = Add(add(2))
add(5) # tensor([4])
'''

뭔가 너무 복잡해서 코드를 하나하나 돌려봤는데 "Add(3)(1)"이 되는 것을 발견했습니다.
왜 이 코드가 작동하는지 잘 몰라서 stack overflow에 질문했습니다. 허접한 영어 실력.. 동원

https://stackoverflow.com/questions/70754176/add35-nn-sequential-how-it-works

 

Add(3)(5) nn.Sequential. How it works?

class Add(nn.Module): def __init__(self, value): super().__init__() self.value = value def forward(self, x): return x + self.value calculator = nn.Sequential( ...

stackoverflow.com

 

한 줄 답변을 기대했으나 매우 친절한 외쿡 형님들.. 바로 이해해버렸습니다.

add() - callable object

Add() - Module class

헷갈리지 않기 위해 add_obj = Add(1) 라고 하는 것도 좋은 방법
callable object는 snake_case / Class는 CamelCase 관습 !

Add(3)(x) 는 self.value가 3인 add_obj에 add_obj(x)와 같다.

add_obj = Add(3)
add_obj(x)

쉬운 것인데 Add()() 이런 형태는 처음봐서 좀 당황한게 큰 것 같다.

이걸 이해하니 ModuleList, ModuleDict는 바로바로 이해하고 넘어갔다.

nn.ModuleList

class Calculator(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_list = nn.ModuleList([Add(2), Add(3), Add(5)])

    def forward(self, x):
        # y = ((x + 3) + 2) + 5 의 연산을 구현하세요!
        x = self.add_list[1](x)
        x = self.add_list[0](x)
        x = self.add_list[2](x)

        return x

nn.ModuleDict

class Calculator(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_dict = nn.ModuleDict({'add2': Add(2),
                                       'add3': Add(3),
                                       'add5': Add(5)})

    def forward(self, x):
        # y = ((x + 3) + 2) + 5 의 연산을 구현하세요!
        x = self.add_dict['add3'](x)
        x = self.add_dict['add2'](x)
        x = self.add_dict['add5'](x)

        return x

function 과 layer 쌓기


 import torch
from torch import nn


# Function
class Function_A(nn.Module):
    def __init__(self):
        super().__init__()
        print(f"        Function A Initialized")

    def forward(self, x):
        print(f"        Function A started")
        print(f"        Function A done")

class Function_B(nn.Module):
    def __init__(self):
        super().__init__()
        print(f"        Function B Initialized")

    def forward(self, x):
        print(f"        Function B started")
        print(f"        Function B done")

class Function_C(nn.Module):
    def __init__(self):
        super().__init__()
        print(f"        Function C Initialized")

    def forward(self, x):
        print(f"        Function C started")
        print(f"        Function C done")

class Function_D(nn.Module):
    def __init__(self):
        super().__init__()
        print(f"        Function D Initialized")

    def forward(self, x):
        print(f"        Function D started")
        print(f"        Function D done")


# Layer
class Layer_AB(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = Function_A()
        self.b = Function_B()

        print(f"    Layer AB Initialized")

    def forward(self, x):
        print(f"    Layer AB started")
        self.a(x)
        self.b(x)
        print(f"    Layer AB done")

class Layer_CD(nn.Module):
    def __init__(self):
        super().__init__()

        self.c = Function_C()
        self.d = Function_D()

        print(f"    Layer CD Initialized")

    def forward(self, x):
        print(f"    Layer CD started")
        self.c(x)
        self.d(x)
        print(f"    Layer CD done")


# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.ab = Layer_AB()
        self.cd = Layer_CD()

        print(f"Model ABCD Initialized\n")

    def forward(self, x):
        print(f"Model ABCD started")
        self.ab(x)
        self.cd(x)
        print(f"Model ABCD done\n")


x = torch.tensor([7])

model = Model()
model(x)


#         Function A Initialized
#         Function B Initialized
#     Layer AB Initialized
#         Function C Initialized
#         Function D Initialized
#     Layer CD Initialized
# Model ABCD Initialized

# Model ABCD started
#     Layer AB started
#         Function A started
#         Function A done
#         Function B started
#         Function B done
#     Layer AB done
#     Layer CD started
#         Function C started
#         Function C done
#         Function D started
#         Function D done
#     Layer CD done
# Model ABCD done

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[06] Pytorch의 hook  (0) 2022.01.20
[05] named_children, named_modules.  (0) 2022.01.19
[04] nn.parameter  (0) 2022.01.19
[02] Pytorch Basics  (0) 2022.01.18
[01] Introduction (Pytorch vs Tensorflow)  (0) 2022.01.17