hook이 뭘까요 ??
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def module_hook(grad):
pass
model_obj = Model()
model_obj.__dict__
# {'training': True,
# '_parameters': OrderedDict(),
# '_buffers': OrderedDict(),
# '_non_persistent_buffers_set': set(),
# '_backward_hooks': OrderedDict(),
# '_is_full_backward_hook': None,
# '_forward_hooks': OrderedDict(),
# '_forward_pre_hooks': OrderedDict(),
# '_state_dict_hooks': OrderedDict(),
# '_load_state_dict_pre_hooks': OrderedDict(),
# '_modules': OrderedDict()}
nn.Module 상속 하는 모델을 만들면
~
hooks 가 자동으로 있는 것을 볼 수 있습니다.
hook은 예를들어 forward()나 backward() 시에 사용자가 커스텀한 함수나 레이어 들을 실행 시킬수 있도록
nn.Module을 만들 때 정의해 둔 것입니다.
이름 처럼 함수 실행 전이나 후에 '훅' 들어와 훅함수를 실행시키는 것이죠.
훅함수를 등록하는 방법은 아래와 같습니다.
- model.register_forward_pre_hook(module_hook)
- model.register_forward_hook(module_hook)
- model.register_full_backward_hook(module_hook)
- state_dict의 경우도 hook이 있는데 저희가 사용하는게 아니라
"load_state_dict" 함수가 내부적으로 사용한다고 하네요!
직접 해보기
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
위와 같은 모델이 있을 때 answer이란 list에 forward 전에 x1, x2 값, forward 후에 output 값을 넣어봅시다.
add = Add()
answer = []
def pre_hook(module, input):
answer.extend(input)
pass
def hook(module, input, output):
answer.append(output)
pass
add.register_forward_pre_hook(pre_hook)
add.register_forward_hook(hook)
x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2) # [tensor([0.7252]), tensor([0.6348]), tensor([1.3600])
print(answer) # [tensor([0.7252]), tensor([0.6348]), tensor([1.3600])]
위에서는 값을 가져오는 hook을 넣었지만 전파되는 값을 수정하는 것도 가능 !
full backward hook
- forward hook은 module에만 적용할 수 있지만 backward hook은 tensor, module 2가지 적용 가능 !
- The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature (pytorch 공식문서)
- register_backward_hook 말고 register_full_backward_hook 사용. (옛날 버전이라고 생각하면 될 듯)
모듈 단위 backward hook
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
answer = []
def module_hook(module, grad_input, grad_output):
answer.extend(grad_input)
answer.append(grad_output[0])
pass
model.register_full_backward_hook(module_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.retain_grad()
output.backward()
# 결과 output [tensor([1.1247]), tensor([3.2380]), tensor([1.])]
# answer [tensor([1.1247]), tensor([3.2380]), tensor([1.])]
탠서 단위 backward hook
module 단위의 backward hook은, module 기준으로 input, output gradient 값만 가져와서 내부의 tensor의 gradient값은 알아낼 수 없다.
이 때는, tensor 단위의 hook를 사용한다. model.W.register_hook(tensor_hook)
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
answer = []
def tensor_hook(grad):
answer.append(grad)
pass
model.W.register_hook(tensor_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.backward()
# output, answer : tensor([0.3604])
backward hook 으로 gradient 다루기
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
# 모델 생성
model = Model()
# hook를 이용해서 module의 gradient 출력의 합이 1이 되게 만들어보자
# ex) (1.5, 0.5) -> (0.75, 0.25)
def module_hook(module, grad_input, grad_output):
print(grad_input)
total = 0
for grad in grad_input:
total+=grad
grad_input = torch.divide(grad_input[0],total), torch.divide(grad_input[1],total)
print(grad_input)
return grad_input
model.register_full_backward_hook(module_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.backward()
'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글
[08] backward 과정 이해하기 (0) | 2022.01.20 |
---|---|
[07] apply (0) | 2022.01.20 |
[05] named_children, named_modules. (0) | 2022.01.19 |
[04] nn.parameter (0) | 2022.01.19 |
[03] nn.Module 이해하기 (0) | 2022.01.18 |