0. 내가 참가했던 Mask classification Competiton의 data는 매우 imbalance 했다.
3가지의 기준 나이 (young, middle, old), 성별(male, female), 마스크 착용 여부(정상 착용, 오착용, 미착용)로 총 18가지의 class가 존재했는데, 이 데이터를 살펴본 결과 나이의 class는 young : middel : old = 6 : 6 : 1으로 불균형했고, 성별은 거의 비슷했고, 마스크 착용 여부를 나타내는 데이터는 정상 착용 : 오착용 : 미착용 = 5 : 1 : 1 정도로 불균형했다.
우리 팀에서 각 기준에 대해 얼마나 정확도가 나오는지를 뽑아보았는데, 성별 여부와 마스크 착용 여부의 validation acc는 굉장히 높았으나 나이에 해당하는 validation acc는 상대적으로 낮은 편이었다. 그래서, dataloader의 sampler를 통해 class imbalance를 해결하려 하였다.
1. sampler는 무엇일까?
데이터 로더에서는 학습을 위해 (또는 validation을 위해) 모집단에서 배치 사이즈만큼의 데이터를 뽑아온다.
그런데 앞서 설명했던 Mask 데이터와 같이 data의 수가 imbalance 한 경우, 배치 사이즈 안에 속해있는 데이터들의 class도 불균형하게 불러와지기 때문에, 학습이 잘 되지 않을 가능성이 크다.
예를 들어 사과 데이터가 900개 있고, 바나나 데이터가 100개 있을 때 10개라는 배치사이즈 안에서 1개의 데이터 set을 불러왔다고 가정할 때 사과 데이터가 불러와질 확률은 0.9이고, 바나나 데이터가 불러와질 확률은 0.1이다. 즉, 사과 데이터는 훨씬 학습이 잘 되겠지만 바나나 데이터는 학습이 잘 안 된다. 이럴 때 우리는 sampler를 설정해 주는 방법을 사용해 바나나 데이터를 5개, 사과 데이터를 5개 불러오도록 하여 class imbalance를 해결할 수 있다.
2. 어떤 sampler를 쓸까?
보통, imbalanced datasets을 처리하는 sampler 종류로는 크게 Oversampling, Undersampling 두 가지가 있다.
첫 번째는 Oversampling이다.
데이터 수가 적은 class에 해당하는 더 많은 예제를 추가하여 오버 샘플링하는 것은 data imbalance 문제를 해결할 수 있다. 하지만 이때 소수 클래스의 예제를 복제하여 학습시킬 경우 overfitting 즉, 과적합을 일으킬 수 있어 문제가 된다.
두 번째는 Undersampling이다.
데이터수가 많을 경우 이 방법을 채택할 것인데, 언더샘플링의 경우 데이터 수가 많은 class에서 무작위 레코드를 제거하여 class 간의 balance를 맞출 수 있다. 하지만 이때 무작위 데이터 셋에서 몇 개만 뽑아 오는 과정이기 때문에 정보 손실을 유발하게 된다.
03_1. Oversampling도 싫고 Undersampling도 싫어_ImbalancedDatasetSampler
이 두 가지 방법을 어느 정도 합친 것이 내가 사용했던 것은 ImbalancedDatasetSampler인데, data augmentation을 적용하여 과적합을 완화하고 이를 통해 소수의 dataset보다 더 많은 데이터셋을 적용하므로 많은 양의 정보손실이 일어나지 않는다고 한다.
데이터 수가 적은 class에 해당하는 더 많은 예제를 추가하여 오버 샘플링하는 것은 data imbalance 문제를 해결할 수 있다. 하지만 이때 소수 클래스의 예제를 복제하여 학습시킬 경우 overfitting 즉, 과적합을 일으킬 수 있어 문제가 된다.
https://github.com/ufoym/imbalanced-dataset-sampler
GitHub - ufoym/imbalanced-dataset-sampler: A (PyTorch) imbalanced dataset sampler for oversampling low frequent classes and unde
A (PyTorch) imbalanced dataset sampler for oversampling low frequent classes and undersampling high frequent ones. - GitHub - ufoym/imbalanced-dataset-sampler: A (PyTorch) imbalanced dataset sample...
github.com
이 sampler를 사용하려고 할 때, 팀원들에게 이 sampler에 대해 소개를 했는데 github Readme에는 소수의 class에 있는 데이터들을 증강시켜 즉, oversampling 해주었다고 설명을 써놓아서 이 깃헙에 있는 코드를 살펴보았다. 그러나 augmentation에 대한 코드는 못 찾았으며 그냥 data의 크기에 따라 가중치를 다르게 적용한 코드만 확인할 수 있었다. 뭘까.
(이에 대해 정확하게 아시는 분은 댓글을 남겨 알려주시면 정말 감사하겠습니다.)
03_2. 그럴거면 Weighted_Random_Sampler를 쓰면 되잖아
맞다. 그래서 나도 둘 다 써봤는데 솔직히 무슨 차이가 있는지 이해가 가지 않아서 성능 차이라도 보자 라는 마음에 대회가 다 끝난 이후에 실험해 보았다. 가중치를 두는 것은 가장 imbalance가 심하고 학습이 잘 안 된다고 느꼈던 age에만 적용하였다.
결론적으로는 두 sampler 모두 실험 모델 1에서는 성능이 있었고 적용하기로 했는데, 제출용으로 사용하던 모델 2에서는 sampler를 적용할 때 성능이 더 낮아지는 것을 확인했다. 하지만 public과 private의 리더보드 상에서 차이가 있을 것이라 예상하고 마지막에 제출하긴 하였는데 그다지 좋은 효과는 보지 못한 것 같다.
'AI&ML > 부스트캠프 기록' 카테고리의 다른 글
[Linux] csv에서 특정 열을 count 하기 (0) | 2023.04.25 |
---|---|
[Linux] Linux에서 "!!!!" 쓰는 방법 (0) | 2023.04.24 |
[대회 기록] Mask classification Competiton 회고록 (0) | 2023.04.20 |
[대회 기록] Mask classification Competiton_loss 종류 (0) | 2023.04.19 |
[Data viz] fig.add_subplot(111)의 의미 (0) | 2023.03.26 |
댓글