본문 바로가기

부스트캠프 AI Tech/Pytorch

[10] Dataloader 의 기본요소

Dataloader는 데이터셋을 미니 배치 단위로 제공해주는 역할


DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

.

Datalaoder에는 Dataset 인스턴스가 들어감

a = iter(DataLoader(dataset_iris)) # 이렇게 정의해 두고
next(a) # 를 계속 실행하면 다음 batch 가 나옴. 출력 잘 나오는지 파악할 때 확인 가능

# [tensor([[5.1000, 3.5000, 1.4000, 0.2000]], dtype=torch.float64), tensor([0])]

batch_size

next(iter(DataLoader(dataset_iris, batch_size=4)))
'''
[tensor([[5.1000, 3.5000, 1.4000, 0.2000],
         [4.9000, 3.0000, 1.4000, 0.2000],
         [4.7000, 3.2000, 1.3000, 0.2000],
         [4.6000, 3.1000, 1.5000, 0.2000]], dtype=torch.float64),
tensor([0, 0, 0, 0])]
 '''

shuffle

  • 데이터를 섞음

sampler와 batch_sampler

  • sampler는 index를 컨트롤하는 방법
  • index를 컨트롤 하기 때문에 shuffle = False (기본값) 이어야 함
  • map-style에서 컨트롤 하기 위해 사용. leniter 구현
  • 제공되는 Sampler 예시
    • equentialSampler : 항상 같은 순서
    • RandomSampler : 랜덤, replacemetn 여부 선택 가능, 개수 선택 가능
    • SubsetRandomSampler : 랜덤 리스트, 위와 두 조건 불가능
    • WeigthRandomSampler : 가중치에 따른 확률
    • BatchSampler : batch단위로 sampling 가능
    • DistributedSampler : 분산처리 (torch.nn.parallel.DistributedDataParallel과 함께 사용)

num_workers

  • 데이터를 불러올때 사용하는 서브 프로세스 개수.
  • cpq core 개수 0.5 // 또는 gpu 개수 당 4 ??
  • DataLoader num_workers에 대한 고찰 (https://jybaek.tistory.com/799) 참고

collate_fn

  • map-style 데이터셋에서 sample list를 batch 단위로 바꾸기 위해 필요한 기능
  • zero-padding이나 Variable Size 데이터 등 데이터 사이즈를 맞추기 위해 많이 사용
  • 참고 (https://deepbaksuvision.github.io/Modu_ObjectDetection/posts/03_01_dataloader.html)
  • ((피처1, 라벨1) (피처2, 라벨2))와 같은 배치 단위 데이터가 ((피처1, 피처2), (라벨1, 라벨2))와 같이 수정 가능

collate_fn 실습

class ExampleDataset(Dataset):
    def __init__(self, num):
        self.num = num

    def __len__(self):
        return self.num

    def __getitem__(self, idx):
        return {"X":torch.tensor([idx] * (idx+1), dtype=torch.float32), 
                "y": torch.tensor(idx, dtype=torch.float32)}

dataset_example = ExampleDataset(4)

dataloader_example = torch.utils.data.DataLoader(dataset_example)
for d in dataloader_example:
    print(d['X'], d['y'])

#tensor([[0.]]) tensor([0.])
#tensor([[1., 1.]]) tensor([1.])
#tensor([[2., 2., 2.]]) tensor([2.])
#tensor([[3., 3., 3., 3.]]) tensor([3.])

위 Dataloader 에서 batch_size를 2로 설정하면 에러가 발생. 에러 발생안하게 collate_fn 작성

def my_collate_fn(batch):
    collate_X = []
    collate_y = []
    l = len(batch) - 1

    for sample in batch:
      collate_X.append(sample['X'])
      collate_y.append(sample['y'])

    for i in range(len(batch)):
      zero_tensor = torch.tensor([0] * l)
      collate_X[i] = torch.cat([collate_X[i], zero_tensor])
      l -= 1

    return {'X': torch.stack(collate_X),
             'y': torch.stack(collate_y)}

dataloader_example = torch.utils.data.DataLoader(dataset_example, 
                                                 batch_size=2,
                                                 collate_fn=my_collate_fn)
for d in dataloader_example:
    print(d['X'], d['y'])

#tensor([[0., 0.],
#        [1., 1.]]) tensor([0., 1.])
#tensor([[2., 2., 2., 0.],
#        [3., 3., 3., 3.]]) tensor([2., 3.])

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[12] 전이학습 tansfer learning  (0) 2022.01.24
[11] image transform  (1) 2022.01.23
[09] Dataset  (0) 2022.01.21
[08] backward 과정 이해하기  (0) 2022.01.20
[07] apply  (0) 2022.01.20