본문 바로가기

유용 노트

Imbalanced data 다루기


Introduction

대회나 프로젝트를 진행하다보면 Imbalaced data 를 자주 직면하게 됩니다.

특정 class가 다른 데이터보다 현저히 적은 경우 그 class 학습이 잘 안되는 경우가 발생합니다.

얼마전에 object detection 대회에 참여하였는데 10가지 class 중 한 가지 class를 아예 캐치하지 못하는 경우가 있었는데요.

이런 경우 실험 할 수 있는 것들을 정리해보았습니다.


1. 적절한 evaluation metrics 사용

  • Precision/Specificity: how many selected instances are relevant
  • Recall/Sensitivity: how many relevant instances are seleted.
  • F1 score: harmonic mean of precision and recall
  • MCC: correlation coeeficient between the observed and predicted binary classifiations
  • AUC: relation between true-positive rate and false positive rate

2. Under-sampling과 over-sampling

  • Under-sampling
    • abundant class의 사이즈를 줄여서 balanced dataset으로 만든다.
    • rare class의 모든 샘플을 keeping하고, abudant class의 샘플을 랜덤으로 선택해서 크기를 비슷하게 만든다.
  • Over-sampling
    • rare sample의 사이즈를 증가시켜 balanced dataset으로 만든다.
    • repetition, bootstrapping, SMOTE(Synthetic Minority Over-Sampling Technique) 사용가능

3. K-fold Cross-VAlidation 사용

  • over-sampling을 사용할때 적절한 방법
  • cross-validation은 over-sampling을 하기 전에 항상 완료가 되어야 한다.

4. Ensemble different resampled datasets

  • abundant class를 K-fold + rare class 학습. K 개의 model들을 ensemble
  • 예를들어 A가 500, B가 50개면, A를 50개 씩 10번 나누고 A 50개 B 50개 학습. 10개의 모델 앙상블

5. Ensemble with different ratios

  • ensemble 할 경우 rare-class를 더 잘 잡는 모델에 가중치 주기.
  • 예를 들어 두 모델을 soft-voting 할 경우 0.3 / 0.7 weight 를 준다.

6. Cluster the abundant class

  • clustering을 통해 abundant class를 r개의 groups으로 clustering 하고 학습

7. Design own loss function

  • rare class에 편애해서 cost function 설계
  • weighted loss 사용하거나 직접 Custom loss를 만들어서 사용