import torch
from torch import nn
from torch.nn.parameter import Parameter
# Function
class Function_A(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
def forward(self, x):
x = x * 2
return x
class Function_B(nn.Module):
def __init__(self):
super().__init__()
self.W1 = Parameter(torch.Tensor([10]))
self.W2 = Parameter(torch.Tensor([2]))
def forward(self, x):
x = x / self.W1
x = x / self.W2
return x
class Function_C(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('duck', torch.Tensor([7]), persistent=True)
def forward(self, x):
x = x * self.duck
return x
class Function_D(nn.Module):
def __init__(self):
super().__init__()
self.W1 = Parameter(torch.Tensor([3]))
self.W2 = Parameter(torch.Tensor([5]))
self.c = Function_C()
def forward(self, x):
x = x + self.W1
x = self.c(x)
x = x / self.W2
return x
# Layer
class Layer_AB(nn.Module):
def __init__(self):
super().__init__()
self.a = Function_A('duck')
self.b = Function_B()
def forward(self, x):
x = self.a(x) / 5
x = self.b(x)
return x
class Layer_CD(nn.Module):
def __init__(self):
super().__init__()
self.c = Function_C()
self.d = Function_D()
def forward(self, x):
x = self.c(x)
x = self.d(x) + 1
return x
# Model
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ab = Layer_AB()
self.cd = Layer_CD()
def forward(self, x):
x = self.ab(x)
x = self.cd(x)
return x
x = torch.tensor([7])
model = Model()
model(x)
model.named_children()
for name, module in model.named_modules():
print(f"[ Name ] : {name}\n[ Module ]\n{module}")
print("-" * 30)
# [ Name ] :
# [ Module ]
# Model(
# (ab): Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# (cd): Layer_CD(
# (c): Function_C()
# (d): Function_D(
# (c): Function_C()
# )
# )
# )
# ------------------------------
# [ Name ] : ab
# [ Module ]
# Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# ------------------------------
# [ Name ] : ab.a
# [ Module ]
# Function_A()
# ------------------------------
# [ Name ] : ab.b
# [ Module ]
# Function_B()
# ------------------------------
# [ Name ] : cd
# [ Module ]
# Layer_CD(
# (c): Function_C()
# (d): Function_D(
# (c): Function_C()
# )
# )
# ------------------------------
# [ Name ] : cd.c
# [ Module ]
# Function_C()
# ------------------------------
# [ Name ] : cd.d
# [ Module ]
# Function_D(
# (c): Function_C()
# )
# ------------------------------
# [ Name ] : cd.d.c
# [ Module ]
# Function_C()
# ------------------------------
model.named_children()
for name, child in model.named_children():
print(f"[ Name ] : {name}\n[ Children ]\n{child}")
print("-" * 30)
# [ Name ] : ab
# [ Children ]
# Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# ------------------------------
# [ Name ] : cd
# [ Children ]
# Layer_CD(
# (c): Function_C()
# (d): Function_D(
# (c): Function_C()
# )
# )
# ------------------------------
model.named_parameters
for name, parameter in model.named_parameters():
print(f"[ Name ] : {name}\n[ Parameter ]\n{parameter}")
print("-" * 30)
# [ Name ] : ab.b.W1
# [ Parameter ]
# Parameter containing:
# tensor([10.], requires_grad=True)
# ------------------------------
# [ Name ] : ab.b.W2
# [ Parameter ]
# Parameter containing:
# tensor([2.], requires_grad=True)
# ------------------------------
# [ Name ] : cd.d.W1
# [ Parameter ]
# Parameter containing:
# tensor([3.], requires_grad=True)
# ------------------------------
# [ Name ] : cd.d.W2
# [ Parameter ]
# Parameter containing:
# tensor([5.], requires_grad=True)
# ------------------------------
model.named_buffers()
for name, buffer in model.named_buffers():
print(f"[ Name ] : {name}\n[ Buffer ] : {buffer}")
print("-" * 30)
# [ Name ] : cd.c.duck
# [ Buffer ] : tensor([7.])
# ------------------------------
# [ Name ] : cd.d.c.duck
# [ Buffer ] : tensor([7.])
# ------------------------------
- "children"은 한 단계 아래의 submodule까지만 표시
- "modules"는 자신에게 속하는 모든 submodule들을 표시
- Function_A() 접근 -> model.ab.a
- 파라미터 W1 접근 -> model.ab.b.W1
- buffer duck 접근 -> model.cd.c.duck
'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글
[07] apply (0) | 2022.01.20 |
---|---|
[06] Pytorch의 hook (0) | 2022.01.20 |
[04] nn.parameter (0) | 2022.01.19 |
[03] nn.Module 이해하기 (0) | 2022.01.18 |
[02] Pytorch Basics (0) | 2022.01.18 |