본문 바로가기

부스트캠프 AI Tech/Pytorch

[07] apply

PyTorch를 사용하면서 Pretrained된 모델을 많이 사용하게 될텐데 모델 자체에 버그가 있거나, 혹은 수정해서 써야만 하는 등의 사항들이 발생할 수 있다. 이런 경우 전 포스팅 했던 hook이나 apply 등을 이용하여 모델을 수정 할 수 있다.

apply

입력으로 받는 모든 module을 순차적으로 처리한다.

(1) apply를 활용해 parameter(W)를 1로 초기화하는 함수를 구현해보자

model = Model()

# pply를 이용해 모든 Parameter 값을 1로 초기화
def weight_initialization(module):
    module_name = module.__class__.__name__

    for param in module.parameters():
      # param data를 update
      param.data = torch.ones_like(param.data)

# apply는 apply가 적용된 module을 return 해준다
returned_module = model.apply(weight_initialization)

(2) apply를 활용해 repr을 수정해보자

model = Model()

# apply를 이용해서 repr 출력을 수정
from functools import partial

def function_repr(self):
    # print(self.name)
    return f'name={self.name}'

def add_repr(module):
    module_name = module.__class__.__name__
    try:
      print(function_repr(module))
      extra_repr = lambda repr:repr
      module.extra_repr = partial(extra_repr, function_repr(module))
    except:
      pass

# apply 적용된 module을 return
returned_module = model.apply(add_repr)

model_repr = repr(model)

print("모델 출력 결과")
print("-" * 30)
print(model_repr)
print("-" * 30)

(3) apply를 활용해 function을 linear transformation처럼 동작하도록 수정하자

  • Function_A : x+W
  • Function_B : x-W
  • Function_C : x+W
  • Function_D : x/W

 x @ W + b

model = Model()

from functools import partial

# Parameter b 추가
def add_bias(module):
    module_name = module.__class__.__name__
    if module_name.split('_')[0] == "Function":
      module.b = Parameter(torch.rand(2,1))

# 1로 초기화
def weight_initialization(module):
    module_name = module.__class__.__name__
    add_bias(module)
    if module_name.split('_')[0] == "Function":
        module.W.data.fill_(1.0)
        module.b.data.fill_(1.0)


# apply를 이용해 모든 Function을 linear transformation으로 바꾸자 (X @ W + b)
def hook(module, input, output):
    module_name = module.__class__.__name__  
    output = input[0] @ module.W.T
    # output = torch.mul(input[0],module.W.T)
    output = torch.add(output, module.b)
    return output


def linear_transformation(module):
    module_name = module.__class__.__name__
    print(module_name)
    if module_name.split('_')[0] == "Function":
        module.register_forward_hook(hook)

returned_module = model.apply(add_bias)
returned_module = model.apply(weight_initialization)
returned_module = model.apply(linear_transformation)


# FriendLinearModel : nn.linear
class FriendLinearModel(nn.Module):
    def __init__(self):
        super().__init__() 
        self.linear = nn.Sequential(nn.Linear(2, 2),
                                    nn.Linear(2, 2),
                                    nn.Linear(2, 2),
                                    nn.Linear(2, 2))

    def forward(self, x):
        return self.linear(x)

def friends_init_weights(m):
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(1.0)

friend_model = FriendLinearModel()
friend_model.apply(friends_init_weights)


# nn.Linear 모델과 비교
grads = tester(model, friend_model)

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[09] Dataset  (0) 2022.01.21
[08] backward 과정 이해하기  (0) 2022.01.20
[06] Pytorch의 hook  (0) 2022.01.20
[05] named_children, named_modules.  (0) 2022.01.19
[04] nn.parameter  (0) 2022.01.19