본문 바로가기

부스트캠프 AI Tech/Pytorch

[06] Pytorch의 hook

hook이 뭘까요 ??


from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()

def module_hook(grad):
    pass

model_obj = Model()
model_obj.__dict__

# {'training': True,
#  '_parameters': OrderedDict(),
#  '_buffers': OrderedDict(),
#  '_non_persistent_buffers_set': set(),
#  '_backward_hooks': OrderedDict(),
#  '_is_full_backward_hook': None,
#  '_forward_hooks': OrderedDict(),
#  '_forward_pre_hooks': OrderedDict(),
#  '_state_dict_hooks': OrderedDict(),
#  '_load_state_dict_pre_hooks': OrderedDict(),
#  '_modules': OrderedDict()}

nn.Module 상속 하는 모델을 만들면

~

hooks 가 자동으로 있는 것을 볼 수 있습니다.
hook은 예를들어 forward()나 backward() 시에 사용자가 커스텀한 함수나 레이어 들을 실행 시킬수 있도록
nn.Module을 만들 때 정의해 둔 것입니다.
이름 처럼 함수 실행 전이나 후에 '훅' 들어와 훅함수를 실행시키는 것이죠.
훅함수를 등록하는 방법은 아래와 같습니다.

 

 

  • model.register_forward_pre_hook(module_hook)
  • model.register_forward_hook(module_hook)
  • model.register_full_backward_hook(module_hook)
  • state_dict의 경우도 hook이 있는데 저희가 사용하는게 아니라
    "load_state_dict" 함수가 내부적으로 사용한다고 하네요!

직접 해보기

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)

        return output

위와 같은 모델이 있을 때 answer이란 list에 forward 전에 x1, x2 값, forward 후에 output 값을 넣어봅시다.

add = Add()
answer = []

def pre_hook(module, input):
    answer.extend(input)
    pass

def hook(module, input, output):
    answer.append(output)
    pass

add.register_forward_pre_hook(pre_hook)
add.register_forward_hook(hook)

x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2) # [tensor([0.7252]), tensor([0.6348]), tensor([1.3600])
print(answer) # [tensor([0.7252]), tensor([0.6348]), tensor([1.3600])]

위에서는 값을 가져오는 hook을 넣었지만 전파되는 값을 수정하는 것도 가능 !

full backward hook


  • forward hook은 module에만 적용할 수 있지만 backward hook은 tensor, module 2가지 적용 가능 !
  • The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature (pytorch 공식문서)
  • register_backward_hook 말고 register_full_backward_hook 사용. (옛날 버전이라고 생각하면 될 듯)

모듈 단위 backward hook

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W
        return output

model = Model()
answer = []

def module_hook(module, grad_input, grad_output):
    answer.extend(grad_input)
    answer.append(grad_output[0])
    pass

model.register_full_backward_hook(module_hook)

x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.retain_grad()
output.backward()

# 결과 output [tensor([1.1247]), tensor([3.2380]), tensor([1.])]
#        answer [tensor([1.1247]), tensor([3.2380]), tensor([1.])]

탠서 단위 backward hook

module 단위의 backward hook은, module 기준으로 input, output gradient 값만 가져와서 내부의 tensor의 gradient값은 알아낼 수 없다.

이 때는, tensor 단위의 hook를 사용한다. model.W.register_hook(tensor_hook)

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

model = Model()
answer = []

def tensor_hook(grad):
    answer.append(grad)
    pass

model.W.register_hook(tensor_hook)

x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

# output, answer : tensor([0.3604])

backward hook 으로 gradient 다루기

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()

# hook를 이용해서 module의 gradient 출력의 합이 1이 되게 만들어보자
#        ex) (1.5, 0.5) -> (0.75, 0.25)
def module_hook(module, grad_input, grad_output):
    print(grad_input)
    total = 0
    for grad in grad_input:
      total+=grad

    grad_input = torch.divide(grad_input[0],total), torch.divide(grad_input[1],total)
    print(grad_input)
    return grad_input

model.register_full_backward_hook(module_hook)    
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[08] backward 과정 이해하기  (0) 2022.01.20
[07] apply  (0) 2022.01.20
[05] named_children, named_modules.  (0) 2022.01.19
[04] nn.parameter  (0) 2022.01.19
[03] nn.Module 이해하기  (0) 2022.01.18