Dataloader는 데이터셋을 미니 배치 단위로 제공해주는 역할
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
.
Datalaoder에는 Dataset 인스턴스가 들어감
a = iter(DataLoader(dataset_iris)) # 이렇게 정의해 두고
next(a) # 를 계속 실행하면 다음 batch 가 나옴. 출력 잘 나오는지 파악할 때 확인 가능
# [tensor([[5.1000, 3.5000, 1.4000, 0.2000]], dtype=torch.float64), tensor([0])]
batch_size
next(iter(DataLoader(dataset_iris, batch_size=4)))
'''
[tensor([[5.1000, 3.5000, 1.4000, 0.2000],
[4.9000, 3.0000, 1.4000, 0.2000],
[4.7000, 3.2000, 1.3000, 0.2000],
[4.6000, 3.1000, 1.5000, 0.2000]], dtype=torch.float64),
tensor([0, 0, 0, 0])]
'''
shuffle
- 데이터를 섞음
sampler와 batch_sampler
- sampler는 index를 컨트롤하는 방법
- index를 컨트롤 하기 때문에 shuffle = False (기본값) 이어야 함
- map-style에서 컨트롤 하기 위해 사용. len 과 iter 구현
- 제공되는 Sampler 예시
- equentialSampler : 항상 같은 순서
- RandomSampler : 랜덤, replacemetn 여부 선택 가능, 개수 선택 가능
- SubsetRandomSampler : 랜덤 리스트, 위와 두 조건 불가능
- WeigthRandomSampler : 가중치에 따른 확률
- BatchSampler : batch단위로 sampling 가능
- DistributedSampler : 분산처리 (torch.nn.parallel.DistributedDataParallel과 함께 사용)
num_workers
- 데이터를 불러올때 사용하는 서브 프로세스 개수.
- cpq core 개수 0.5 // 또는 gpu 개수 당 4 ??
- DataLoader num_workers에 대한 고찰 (https://jybaek.tistory.com/799) 참고
collate_fn
- map-style 데이터셋에서 sample list를 batch 단위로 바꾸기 위해 필요한 기능
- zero-padding이나 Variable Size 데이터 등 데이터 사이즈를 맞추기 위해 많이 사용
- 참고 (https://deepbaksuvision.github.io/Modu_ObjectDetection/posts/03_01_dataloader.html)
- ((피처1, 라벨1) (피처2, 라벨2))와 같은 배치 단위 데이터가 ((피처1, 피처2), (라벨1, 라벨2))와 같이 수정 가능
collate_fn 실습
class ExampleDataset(Dataset):
def __init__(self, num):
self.num = num
def __len__(self):
return self.num
def __getitem__(self, idx):
return {"X":torch.tensor([idx] * (idx+1), dtype=torch.float32),
"y": torch.tensor(idx, dtype=torch.float32)}
dataset_example = ExampleDataset(4)
dataloader_example = torch.utils.data.DataLoader(dataset_example)
for d in dataloader_example:
print(d['X'], d['y'])
#tensor([[0.]]) tensor([0.])
#tensor([[1., 1.]]) tensor([1.])
#tensor([[2., 2., 2.]]) tensor([2.])
#tensor([[3., 3., 3., 3.]]) tensor([3.])
위 Dataloader 에서 batch_size를 2로 설정하면 에러가 발생. 에러 발생안하게 collate_fn 작성
def my_collate_fn(batch):
collate_X = []
collate_y = []
l = len(batch) - 1
for sample in batch:
collate_X.append(sample['X'])
collate_y.append(sample['y'])
for i in range(len(batch)):
zero_tensor = torch.tensor([0] * l)
collate_X[i] = torch.cat([collate_X[i], zero_tensor])
l -= 1
return {'X': torch.stack(collate_X),
'y': torch.stack(collate_y)}
dataloader_example = torch.utils.data.DataLoader(dataset_example,
batch_size=2,
collate_fn=my_collate_fn)
for d in dataloader_example:
print(d['X'], d['y'])
#tensor([[0., 0.],
# [1., 1.]]) tensor([0., 1.])
#tensor([[2., 2., 2., 0.],
# [3., 3., 3., 3.]]) tensor([2., 3.])
'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글
[12] 전이학습 tansfer learning (0) | 2022.01.24 |
---|---|
[11] image transform (1) | 2022.01.23 |
[09] Dataset (0) | 2022.01.21 |
[08] backward 과정 이해하기 (0) | 2022.01.20 |
[07] apply (0) | 2022.01.20 |