본문 바로가기

부스트캠프 AI Tech/Pytorch

[15] Multi-GPU 학습

Multi-GPU


  • GPU vs Node : Node를 system이라고 부르며 이 경우 1대의 컴퓨터로 보면 된다.
  • Single Node Single GPU : 1대의 컴퓨터에 1개의 GPU
  • Single Node Multi GPU : 1개의 컴퓨터에 여러 개의 GPU
  • Multi Node Multi GPU : 서버실에 달려있는 GPU

Model Parallel


  • 다중 GPU에 학습을 분산하는 두 가지 방법
    • 모델 나누기 / 데이터 나누기
  • 모델을 나누는 것은 생각보다 예전부터 사용(alexnet)
  • 모델의 병목, 파이프라인의 어려움 등으로 인해 모델 병렬화는 고난이도 과제

Model parallel(e.g. AlexNet)

모델 병렬화는 다음과 같이 있는데, 첫번째 케이스는 하나의 GPU가 작업이 끝나야 다른 GPU가 작업을 하므로, 병렬적으로 처리하는 의미가 없다. 이는 파이프라인이 안만들어졌기에 병렬화 효과가 없다.
즉, 아래와 케이스와 같이, 파이프라인을 통해 두 GPU가 각기 다른 처리를 동시에 진행할 수 있도록 해야 한다.

coda:0에서 처리한 것을 cuda:1으로 보내서 또 다른 처리를 하고, 마지막에 fc layer를 cude:1에서 처리한 것을 볼 수 있다. 단, 이렇게만 구현하면 병렬화가 제대로 안되어 병렬현상이 발생하므로, 앞에서와 같은 파이프라인 구축이 필요하다.

Data Parallel

  • 데이터를 나눠 GPU에 할당 후 결과의 평균을 취하는 방법
  • mini batch 수식과 유사한데 한번에 여러 GPU에서 수행
    1. mini-batch inputs을 여러 GPU에 나눔
    2. 모델을 각 GPU에 복사
    3. 각 GPU 순전파 과정 진행
    4. 연산 결과를 한 GPU에 모음 (한 곳에 모아줄 경우 각각의 loss를 한 번에 구할 수 있음)
    5. 한 gpu에서 구한 4개의 로스를 기본으로 gradient 계산
    6. 이후 각 loss에 gradient를 각 GPU에 보냄
    7. backward
    8. 최종적으로 출력된 gardient를 첫 GPU에 모은 다음 평균을 내서 gradient를 update

  • PyTorch에서는 아래 두 가지 방식을 제공
    • DataParallel, DistributedDataParallel
  • DataParallel - 단순히 데이터를 분배한 후 평균을 취함
    • GPU 사용 불균형 문제 발생, Batch 사이즈 감소 (한 GPU가 병목), GIL(Global interpreter Lock)
    • 각 GPU 성능이 동일할 떄 한 GPU가 많은 업무를 할당받으면, 같이 batch size를 지정해도 한 GPU만 처리를 늦어짐(병목현상 발생) 이 경우 처리가 늦어지는 GPU 때문에 전체 GPU의 batch size를 줄여야 함.
  • DistributedDataParallel - 각 CPU마다 process 생성하여 개별 GPU에 할당
    • 기본적으로 DataParallel로 하나 개별적으로 연산의 평균을 냄
    • DataParallel과 달리, loss를 구하기 위해 각 gpu의 output을 모으는게 아니라, 각 GPU에서 forward, backward 둘 다 수행 후 최종 gradient만 하나로 모아서 평균을 취함
    • 각 GPU에 CPU도 할당하여 각각이 Coordinator 역할을 수행하면 됨

  • 반면, Distributed Data Parallell은 몇가지 과정이 필요하다.
    1. sampler 사용, shuffle=False, pin_memory = True
    2. Dataloader에 이 조건들을 적용, num_workers는 GPU x 4개 정도가 일반적임

'부스트캠프 AI Tech > Pytorch' 카테고리의 다른 글

[17] Troubleshooting  (0) 2022.01.24
[16] Hyperparameter Tuning  (0) 2022.01.24
[14] monitoring tool - wandb  (0) 2022.01.24
[13] Monitoring tool - Tensorboard  (0) 2022.01.24
[12] 전이학습 tansfer learning  (0) 2022.01.24