PyTorch를 사용하면서 Pretrained된 모델을 많이 사용하게 될텐데 모델 자체에 버그가 있거나, 혹은 수정해서 써야만 하는 등의 사항들이 발생할 수 있다. 이런 경우 전 포스팅 했던 hook이나 apply 등을 이용하여 모델을 수정 할 수 있다.
apply
입력으로 받는 모든 module을 순차적으로 처리한다.
(1) apply를 활용해 parameter(W)를 1로 초기화하는 함수를 구현해보자
model = Model()
# pply를 이용해 모든 Parameter 값을 1로 초기화
def weight_initialization(module):
module_name = module.__class__.__name__
for param in module.parameters():
# param data를 update
param.data = torch.ones_like(param.data)
# apply는 apply가 적용된 module을 return 해준다
returned_module = model.apply(weight_initialization)
(2) apply를 활용해 repr을 수정해보자
⬇
model = Model()
# apply를 이용해서 repr 출력을 수정
from functools import partial
def function_repr(self):
# print(self.name)
return f'name={self.name}'
def add_repr(module):
module_name = module.__class__.__name__
try:
print(function_repr(module))
extra_repr = lambda repr:repr
module.extra_repr = partial(extra_repr, function_repr(module))
except:
pass
# apply 적용된 module을 return
returned_module = model.apply(add_repr)
model_repr = repr(model)
print("모델 출력 결과")
print("-" * 30)
print(model_repr)
print("-" * 30)
(3) apply를 활용해 function을 linear transformation처럼 동작하도록 수정하자
- Function_A : x+W
- Function_B : x-W
- Function_C : x+W
- Function_D : x/W
➡ x @ W + b
model = Model()
from functools import partial
# Parameter b 추가
def add_bias(module):
module_name = module.__class__.__name__
if module_name.split('_')[0] == "Function":
module.b = Parameter(torch.rand(2,1))
# 1로 초기화
def weight_initialization(module):
module_name = module.__class__.__name__
add_bias(module)
if module_name.split('_')[0] == "Function":
module.W.data.fill_(1.0)
module.b.data.fill_(1.0)
# apply를 이용해 모든 Function을 linear transformation으로 바꾸자 (X @ W + b)
def hook(module, input, output):
module_name = module.__class__.__name__
output = input[0] @ module.W.T
# output = torch.mul(input[0],module.W.T)
output = torch.add(output, module.b)
return output
def linear_transformation(module):
module_name = module.__class__.__name__
print(module_name)
if module_name.split('_')[0] == "Function":
module.register_forward_hook(hook)
returned_module = model.apply(add_bias)
returned_module = model.apply(weight_initialization)
returned_module = model.apply(linear_transformation)
# FriendLinearModel : nn.linear
class FriendLinearModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(nn.Linear(2, 2),
nn.Linear(2, 2),
nn.Linear(2, 2),
nn.Linear(2, 2))
def forward(self, x):
return self.linear(x)
def friends_init_weights(m):
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
m.bias.data.fill_(1.0)
friend_model = FriendLinearModel()
friend_model.apply(friends_init_weights)
# nn.Linear 모델과 비교
grads = tester(model, friend_model)
'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글
[09] Dataset (0) | 2022.01.21 |
---|---|
[08] backward 과정 이해하기 (0) | 2022.01.20 |
[06] Pytorch의 hook (0) | 2022.01.20 |
[05] named_children, named_modules. (0) | 2022.01.19 |
[04] nn.parameter (0) | 2022.01.19 |