JLOG

[Self-KD]Self-Knowledge Distillation with Progrssive Refinement of Targets 리뷰 본문

논문정리

[Self-KD]Self-Knowledge Distillation with Progrssive Refinement of Targets 리뷰

정정선선 2021. 10. 28. 11:59

 

리뷰 논문은 ICCV2021 oral로 발표된 논문인

Self-Knowledge Distillation with Progrssive Refinement of Targets입니다.

 

LG CNS AI 리서치, 서울대에서 작성한 논문으로서 2021 10 27일 기준으로 3회 인용되었습니다.

 

본 논문의 메인아이디어는 self knowledge distillation을 통하여 regularization의 역할과 hard example에 더 학습이 잘되도록 만들었다는 것입니다.

 

 


 

기본 개념

논문 설명을 하기 전 기본 개념에 대해 언급하고 가겠습니다.

Knowledge distillation 의 목적은 "미리 잘 학습된 큰 네트워크(Teacher network) 의 지식을 실제로 사용하고자 하는 작은 네트워크(Student network) 에게 전달하는 것"입니다.

사진과 같이 깊은 네트워크인 Teacher model이 존재하고, teacher model one hot를 적용하기 전 단계인 soft labelstudent modelsoft predictions을 이용해 로스를 구합니다.

 

이렇게 soft한 정보를 이용해서 loss구함으로서 더 많은 정보들을 student model에 넘겨줄 수 있습니다.

 

추가적으로 one hot이 적용된 hard prediction 값과 label 값을 이용하여 loss를 구합니다.

앞서 말씀드린 multi loss를 이용하여 학습이 진행되게 됩니다.

 

knowledge distillation을 이용함으로서 작은 모델로 더 큰 모델만큼의 성능을 얻을 수 있다는 장점이 있습니다.

 

 

 

self-knowledge distillation이란 teacher network와 student network가 동일한 네트워크 인것을 말합니다.

 

self knowledge distillation Approch

-approach 두 이미지 사이의 distance를 줄이는 것입니다.

 또한, 같은 class 다른 samplKL divergenc를 통해 학습해 분포를 맞춰줍니다.

 여러 sampl에 대해 general하게 학습의 결과를 가질 수 있습니다.

 하지만 overfitting이 일어날 가능성이 있고, 특정 classhigh densit를 가질 수 있다는 한계가 존재합니다.

 

-approach는 모델의 여러 block  auxiliary classifiers을 붙티처 네트워크로 사용하고, 최종 infer에 사용하는 output을 학생 네트워크로 사용하는 것입니다.

 얕은 layer의 정보들도 사용함으로서 표현력을 높이는 효과가 있습니다.

 

 

 


Proposed Method

본 논문에서 제안하는 PS-KD 기법에 대해서 말하겠습니다.

Knowledge Distillation as Softening Target

PS-KD 기법은 self-knowledge distillationregulization의 기능을 하도록 사용하였습니다.

맨 위의 수식는 기존 연구된 기법으로서, Soft softmax를 사용한 KD loss입니다.

알파는 hyperparameter로서 각 hard, soft label의 영향력을 조절하는 역할을 수행합니다.

H() CE Loss 입니다.

 

Softer Softmax

softer softmax, 지수함수를 씌우기 전에 타우를 나누어 보다 soft하게 softmax를 수행해줍니다.

(사실 softer softmax는 제가 만든 용어입니다.. 용어는 생략하고 봐주세요...)

기존 softmaxhard label과 비슷할 가능성이 큼으로, 정답이 아닌 다른 label에 대한 정보의 중요성을 증가시키기 위해서 새로운 softmax사용합니다.

 

아래에서 softmax 내용을 참고하시면 더 이해가 빠르실 것 같습니다.

https://light-tree.tistory.com/196

 

딥러닝 용어 정리, Knowledge distillation 설명과 이해

이 글은 제가 공부한 내용을 정리하는 글입니다. 따라서 잘못된 내용이 있을 수도 있습니다. 잘못된 내용을 발견하신다면 리플로 알려주시길 부탁드립니다. 감사합니다. Knowledge distillation 이란?

light-tree.tistory.com

 

타우가 1일 때 아래와 같은 형태로 나타낼 수 있습니다.

형태를 살펴보면 기존 CE loss term에 aP^T(x) term이 붙음으로서 마치 regularization처럼 동작이 수행됩니다.

 

 

논문에서는 past prediction를 teacher network로 사용합니다.

저자는 초반 epochs일 수록 target labeloverconfident 하지 않아 softer score를 가지고 있다고 주장합니다.

그래서, 기존 loss티처 네트워크 prediction 자리에 이전 prediction 값을 입력해줍니다.

 

그 중, 맨 마지막 layer가 제일 가치있는 정보를 담고 있기 때문에 직전의 레이어만 티처 네트워크로 사용하는 것이 최종 loss 입니다.

 

또한, 초기 단계에서는 학습 prediction믿을만할 정보를 가지고 있지 않기 때문에 초기부터 알파값을 늘려가면서 티처 네트워크의 영향력을 키우게 됩니다.

알파를 어떻게 키울 것인지에 대해 step-wise, linear growth, exponential 한 방법들을 실험해 보았는데, 실험적으로 linear growth 정보가 가장 좋은 성능을 가져서 선택했다고 합니다.

 

 

 

 

Hard example mining

논문에서는 알파가 적절하게 설정된 경우에는 hard example에 대해서 더 집중하게 된다고 주장합니다.

hard example에 대해서 더 집중해서 본다는 것을 증명하기 위해 Lz미분하여줍니다.

 

GT, i가 타겟 cls인 경우 또한 icls가 아닌 경우에 대해 아래와 같이 정리할 수 있습니다.

 

 

만약 igt가 아닌 항이 0보다 크다고 가정할 때, GT clsgradient는 반대가 되어야함으로 음수라고 예측할 수 있습니다.

 

 

softmax 식을 사용하였음으로, gt가 아닌 모든 i의 절대값들을 모두 더하면, GTgradient라는 것을 알 수 있습니다.

즉, 기존 KDgradienCEgradien로 나누어서, PS-KDrescaling factor를 구할 수 있습니다.

 

 

 

기존 KDgradienCEgradien로 나누어서, PS-KDrescaling factor를 구할 수 있습니다.

 

GTprobality는 이전 epoch보다 현재 epoch 크기 때문에, 올바르게 학습이 진행이 된다면 rt-1보다 rt이 항상 작습니다.

만약 학습하기가 쉬운 경우 rt분의 rt-1항이 커질 것이고, gradient rescaling factor는 줄어들 것입니다.

학습하기가 어렵다면 이전 epoch 현재 epoch probality 변화가 적을 것이기 때문에 rt분의 rt-1항이 분의 감마티는 작을 것이고, gradient rescaling factor는 커질 것입니다.

 

이렇게 앞서 정의한 losshardexample에 대해서 더 집중하게 되는 역할을 한다는 것을 알 수 있습니다.

 

 

 

 

위의 그림은 CIFAR-100100개의 hard sample data에 대해서 학습을 시킨 결과입니다.

첫번째 사진은, 각 학습이 진행될 수록, 정답 targetprobability를 나타냅니다. 다른 기법보다 target probability가 높은 것을 확인할 수 있습니다.

두번째 사진은, max probality를 나타내줍니다. 모두 targeprobabilty보다 큰 값을 가진 것으로 보아 incorrect predictio에 대해 overconfident한 결과를 보이는 것을 확인할 수 있습니다.

 

위의 결과를 바탕으로, PS-KDgradient rescapling scheme에 의해 hard 사례에 집중한다는 것을 그림으로 알 수 있습니다.

 

 

 


Implenmenation

이전 에폭의 prediction 결과가 학습에 필요한 PS-KD는 두가지 방법으로 구현될 수 있습니다.

 

첫번째 방법은, t번째 에폭이 시작시에 t-1 epochsforward 하여 결과를 가져옵니다.

기존 방법은 더 많은 GPU 메모리를 필요로 하지만, 더 많은 store space는 필요하지 않다고 합니다.

 

두번째 방법은, t-1 에폭의 결과를 disk에 저장하여 read하는 방법입니다.

더 많은 GPU 메모리를 필요를 하지만, 더 많은 store space 필요하지 않다고 합니다.

 

 

 

 

 


Experimental Results

 

 

CIFAR-100 Classification

성능 평가 지표로서 NLL, ECE, AURC이 사용되었는데요

 

NLL은 우리가 아는 Cross entroy와 동일하다고 합니다. , 낮을 수록 좋은 성능을 보여줍니다

ECEexpected calibration error로서 calibration이 얼마나 잘 되었는지를 나타내는 성능지표로서 낮을수록 좋은 값을 보여줍니다.

AURCarea under the risk-coverage curve로서 예측이 신뢰값에 의해 정렬되는 정도를 측정하는 성능지표로서 낮을 수록 예측이 신뢰성이 있다는 것을 보여줍니다.

 

evaluation metrix에 대하여 논문의 후반에 정리한 내용이 있는데요, 더 궁금하신 분은 논문을 참고해주시면 감사하겠습니다

 

Error와 여러 평가지표를 보면 다른 distillation 기법을 사용한 것보다 좋은 성능을 가진 것을 확인할 수 있습니다.

 

 

 

 

데이터 augmentation 기법과 같이 사용하였을 때도 Ps-kd 기법을 같이 사용한 결과 성능이 개선되는 것을 확인할 수 있었습니다.

 

 

 

또한 본 논문에서 PS-KD의 앙상블 효과를 검토하였습니다.
이전 실험의 세 가지 훈련된 모델을 활용하여, CS-KD, TF-KD PS-KD를 함께 사용함으로서 앙상블 성능을 보여주었습니다.
결과적으로
앙상블은 PS-KD의 이점도 얻을 수 있는 것으로 나타났는데, 이는 PS-KD가 독립적으로 훈련된 모델의 다양성을 저하시키지 않는다는 것을 의미한다고 합니다.

 

 

 

해당 그림은 Calibration plot 입니다.

Calibration이란, 실제 분포와 유사하도록 prediction 분포를 가지는 정도를 뜻합니다.

회색 점선에 가까울 수록 calibration plot이 잘되었다고 볼 수 있는데, ResNext-29를 제외하면 PS-KD는 다른 기법들보다 calibration이 잘 된 것을 확인할 수 있습니다.

 

 

 

 

 

ImageNet Classification

그 다음으로 ImageNet Classification에 대한 결과입니다.

CutMixPs-KD 결과를 함께 사용하였을때 가장 높은 성능을 보이는 것을 확인할 수 있습니다.

 

 

위의 그림을 통하여, PS-KDpredicted probability의 퀄리티를 향상시킨 결과를 확인할 수 있습니다.

 

그림에 여러가지 물체가 존재하는 경우 classification 결과를 나타내었는데요,

basline의 경우 incorrect prediction 결과가 가장 높게 나타나는데, PS-KD의 결과는 top1, top2가 해당 그림에 대한 cls인 것을 확인할 수 있습니다.

또한, 아래 사진에서도 basline 결과보다 overconfident하지 않은, softprediction 값을 확인할 수 있습니다.

 

 

Object Detection

또한 Object Detection에 대해서도 성능 향상을 보인 것을 확인할 수 있습니다.

 

 

 

 

Machine Translation

Machine Translation에서도 좋은 성능을 보인 것을 확인할 수 있었습니다.

 

 

 

 

 

 


Conclusion

이전 epochsprediction 결과를 teacher 모델로 이용해 사용한다는 아이디어가 기발한 논문이라고 생각했습니다.

단순히 calibratio을 위한 self-distillation이 아닌, hard example minin의 효과도 낼 수 있어 간단하지만 성능이 좋은 기법이라고 생각합니다.

실험 결과에서 알 수 있듯이, Classification, Detection, Machine Translation와 같이 다양한 task에 사용될 수 있기 때문에 확장성도 크다고 생각합니다.

확실히 oral 논문은 oral인 이유가 있는 것 같습니다

 

 

틀린 부분이나 같이 토론할 것들이 있으면 댓글 남겨주세요!

Comments