DeepLearning

[논문리뷰] MixMatch: A Holistic Approach to Semi-Supervised Learning

jiheek 2022. 7. 23. 12:43

MixMatch: A Holistic Approach to Semi-Supervised Learning (NIPS 2019), Google Research

https://arxiv.org/abs/1905.02249

 

MixMatch: A Holistic Approach to Semi-Supervised Learning

Semi-supervised learning has proven to be a powerful paradigm for leveraging unlabeled data to mitigate the reliance on large labeled datasets. In this work, we unify the current dominant approaches for semi-supervised learning to produce a new algorithm,

arxiv.org

 

Abstract

Augmented unlabeled 데이터를 사용해서 low-entropy label을 추측하고, labeled와 unlabeled 데이터를 MixUp을 사용해서 섞는다. abstract만 봐서는 어떤 아이디어인지 감이 오지 않는다.

 

Introduction

Label을 모으는 것은 expert knowledge를 필요로 해서 쉽지 않다. Label 데이터는 private한 정보를 가질 수도 있다. 그래서 unlabeled data를 모으는 것이 훨신 쉽다.

Semi Supervised Learning(SSL)의 사용으로 labeled data의 필요성이 감소했다. 최근의 연구들에서는, unlabeled data에 대한 loss term이 다음과 같은 세 종류 중 하나였다고 한다.

 

1. Entropy minimization: 모델이 unlabeled 데이터에 대해서 confident한 output prediction을 생성하도록 유도한다.
2. Consistency regularization: 왜곡된 input에 대해서도 모델이 같은 출력 분포를 생성하도록 한다.
3. Generic regularization: 모델이 잘 generalizee되고 training data에 overfit되지 않게 한다.

 

MixMatch에서는 위 주요 세 접근법들을 하나로 합치는 하나의 unlabeled data loss를 소개한다.

 

Consistency Regularization

FixMatch에서도 많이 설명해서 넘어간다.

 

Entropy Minimization

많은 SSL 방법들의 일반적인 가정은, classifier의 decision boundary가 marginal data distribution에서 높은 밀도의 영역을 통과하지 않는다는 것이다.

marginal distribution? 독립 변수, 즉 하나의 변수에 대한 distribution. 예를 들어 컴퓨터 type과 성별에 따른 컴퓨터 가격 선호에 대해 연구를 할 때, 오직 type에 대한 distribution을 알고 싶다고 하자. 이게 marginal distribution이다. 즉, 한 변수 외의 관심 없는 변수들에 대해서는 "marginalize out(소외)"시키는 것이다.

marginal distribution example(https://statisticsbyjim.com/basics/marginal-distribution/)

 

이를 강화하는 방법 중 하나는, unlabeled data에 대해서 classifier가 낮은 entropy의 예측값을 생성하도록 하는 것이다. 

prediction

Entropy(of random variable)? 쉽게 말하자면 확률변수상 얼마나 무질서/불확실한지에 대한 척도. 무질서도이다. 원래의 정의는 확률분포 p를 binary digits으로 인코딩했을 때 필요한 channel capacity를 의미한다. (by Shannon's definition)

Entropy( -Average level of “information” or “uncertainty”)
https://en.wikipedia.org/wiki/Entropy_(information_theory)

예를 들어, 앞/뒷면이 나올 확률이 같은 동전 던지기의 정보량은 1이다.(0 또는 1) 한 비트만 있으면 동전 던지기의 distribution을 표현할 수 있다는 뜻이다. 만약 앞/뒷면의 확률이 다르다면 더 많은 정보량이 있어야 distribution을 설명할 수 있을 것이다.

동시에 동전의 수가 많아 질 수록, 정보량이 커진다. 즉, 엔트로피가 커진다. 예측하기 어렵고 확률이 낮을 수록 엔트로피가 커진다는 것이다.

 

따라서 본론으로 돌아가서, 낮은 entropy의 예측값을 생성하도록 한다는 것은 예측하기 쉽도록, 즉 높은 정확도 + 높은 confidence의 예측값을 생성하도록 만든다는 것이다. 단순하게 CEE가 낮다고 이해할 수도 있지만,. 정보 이론 측면에서 한번 더 이해하고 싶어 정리해보았다.

이렇게 낮은 entropy의 예측값으로 stronger result를 생성하는 것이 Entropy Minimization이다. "Pseudo-Label"에서 이 기법을 사용해서 unlabeled data에 대해 high-confidence prediction을 얻어 one-hot labeling을 수행하고, target으로 사용해서 일반적인 CEE를 계산하는 것으로 이어졌다.

MixMatch 또한 unlabeled data의 target distribution에 대해 "sharpening"을 사용해서 entropy Minimization을 도입했다.

 

Traditional Regularization: weight decay, Mixup

Regularization이란 모델에 제약을 주어서 training data를 기억하는 것을 방해하고, unseen data에 대한 generalization을 강화하는 일반적인 방법이다. MixMatch에서는 L2 norm을 방해하기 위해 weight decay를 사용했고, Labeled/Unlabeled data에 Mixup을 사용해서 데이터들 사이에 convex behavior를 향상시켰다. (convex behavior?)

 

 

MixMatch

드디어 본론! Semi-supervised learning method 방법인 MixMatch를 소개한다.

Labeled sample batch인 X와 unlabeled sample batch인 U가 주어졌을 때, 각각 augmented된 batch X'와 U'를 생성한다. (여기서 U'는 guessed label을 가진다고 함(?) 이 label guessing process는 뒤에서 더 자세히 설명한다.)  두 batch들은 labeld와 unlabeled loss를 계산하는데 사용된다.

 

lableled, unlabeled loss, 합쳐진 loss. q는 guessed label

H는 cross entropy를 의미한다.

 

Label Guessing

Unlabeled sample U에 대해서, MixMatch는 모델의 예측값을 사용해서 예상되는 레이블을 생성한다. 그리고 이 guess는 unsupervised loss를 계산하는데 사용된다. (위 식에서 q에 해당)

이를 위해서 unlabeled data에 대해 K회 이루어진 augmentation들에 대한 예측들의 평균을 계산해야 한다.

unlabeled 데이터에 대한 가짜 레이블을 얻기 위해서 augmentation을 사용하는 것은 consistency regularization method에서 매우 일반적인 방식이다.

 

Sharpening

Label guessing을 할 때, Sharpening이라는 추가적인 단계가 더 있다. Augmentation들에 걸친 평균 예측값이 계산되었을 때, label distribution의 entropy를 줄이기 위해서 sharpening function을 사용했다. 

Sharpening function으로는 catetorial distribution의 "temperature"'를 조절하는 일반적인 접근법을 사용했다.

T가 크면 t제곱근이 씌워져서 스무딩되고, T가 1 이하면 제곱이 되어서 차이가 극대화된다.

 

Temperature of Categorial probability distribution
아래 그림에서 C1~C6은 다른 클래스를 의미하고, softmax input logit data는 테이블의 마지막 행에 나열되어 있다. 일반적인 softmax는 T=1인 경우이며, 가장 높은 input logit class의 확률은 1, 나머지는 0을 출력한다. T(temperature)가 높아질 수록 categorial distribution은 큰 차이 없이 일정해진다.

https://www.researchgate.net/figure/An-example-of-categorical-probability-distributions-of-high-temperature-softmax-output_fig1_325016605
https://www.researchgate.net/figure/An-example-of-high-temperature-softmax-output-with-different-temperature_tbl1_325016605

 

MixUp

Labeled sample과 Unlabeled sample(with guessed labels)을 함께 MixUp을 사용했다. 두 이미지와 label 쌍을 (x1, p1), (x2, p2)라고 하고, MixUp 결과 이미지, label을 (x', p')라고 하자.이 (x', p')는 다음과 같이 계산된다.

vanilla MixUp은 (9)번 식을 제외한다고 한다. 람다'가 없는데, 본 논문에서 새로 도입한 개념이라고 한다. Beta distribution에서 얻은 확률값 람다와, 1-람다 중 더 큰 값을 취해서 MixUp에 사용하였다. 따라서 x'는 x2보다 x1에 더 가깝게 된다. (어디에 더 가깝게 되든 그게 중요한걸까>..?)

 

 

알고리즘 수도코드

알고리즘 한줄요약

Unlabeled image(U)에 augmentation + Sharpening 사용해서 label 생성 -> Labeled data(X) + Unlabeled data(+가짜 label)(U) = W 새로운 데이터셋 생성 -> W + X, W + U로 MixUp 진행 및 return

 

따~라~서~! MixMatch의 목적은 labeled, unlabeled data를 가지고 모두 labeling을 해주고, 그 데이터를 사용해서 mixup을 진행한 데이터셋을 리턴하는 것이다. 

"Data->augmentation->pseudo labeling(sharpening)-> MixUp"

Loss Function, Hyperparameters

Return된 데이터셋 X', U'에 대해서는 앞서 설명한 Lx, Lu를 사용한다. 이 때 label과 prediction 사이에는 CEE를 사용하고, prediction과 guessed label U' 사이에는 L2 loss를 사용한다. 또한 guessed label 계산 시에는 back propagation을 사용하지 않았다고 한다.

Temperature T는 0.5, unlabeled augmentation의 개수 K=2로 설정하였다.

 

Experiments

Implementation details

1. lr decay를 사용하는 대신, parameter EMA (decay of 0.999)를 사용했다.

2. weight decay 0.0004, Wide ResNet-28

3. Baseline model:  II-Model , Mean Teacher, Virtual Adversarial Training, and Pseudo-Label

매우 적은(250개) 양의 labeled data만으로도 supervised에 가까운 성능을 보였다.

 

  • Ablation study를 통해, temperature sharpening을 하지 않았을 때 성능이 확연히 줄은 것을 확인했다. Sharpening하는 경우는 T=0.5이고, 사용하지 않았을 때의 T=1이다.
  • MixUp을 사용하는 것이 약 27~8% 성능이 더 좋았고, labeled data에만 MixUp을 사용하는 것은 아예 사용하지 않는 것과 큰 차이가 나지 않았다.
  • Unlabeled data에 대해서 MixUp을 사용했을 때, 둘 다 사용했을때와 거의 비슷했다.
  • EMA의 사용은 드라마틱한 변화는 없었다. 약간의 성능 향상만 있었음
  • Augmentation의 개수(K, average 할 개수)도 1개만 아니면 큰 차이는 없었다.

 

Conclusion

논문의 conclusion은 결국엔 좋았다~였다. 개인적으로 아쉬웠던 점은 ablation study에서 왜 unlabeled only에 MixUp을 적용시켰을 때 드라마틱하게 좋아졌는가에 대한 분석이 없었던 것과, 더 큰 K에 대한 실험의 부재이다.

 

참고s

https://gooopy.tistory.com/63