JLOG
[Self-KD]Self-Knowledge Distillation with Progrssive Refinement of Targets 리뷰 본문
[Self-KD]Self-Knowledge Distillation with Progrssive Refinement of Targets 리뷰
정정선선 2021. 10. 28. 11:59
리뷰 논문은 ICCV2021 oral로 발표된 논문인
Self-Knowledge Distillation with Progrssive Refinement of Targets입니다.
LG CNS AI 리서치, 서울대에서 작성한 논문으로서 2021년 10월 27일 기준으로 3회 인용되었습니다.
본 논문의 메인아이디어는 self knowledge distillation을 통하여 regularization의 역할과 hard example에 더 학습이 잘되도록 만들었다는 것입니다.
기본 개념
논문 설명을 하기 전 기본 개념에 대해 언급하고 가겠습니다.
Knowledge distillation 의 목적은 "미리 잘 학습된 큰 네트워크(Teacher network) 의 지식을 실제로 사용하고자 하는 작은 네트워크(Student network) 에게 전달하는 것"입니다.
사진과 같이 깊은 네트워크인 Teacher model이 존재하고, teacher model의 one hot를 적용하기 전 단계인 soft label과 student model의 soft predictions을 이용해 로스를 구합니다.
이렇게 soft한 정보를 이용해서 loss를 구함으로서 더 많은 정보들을 student model에 넘겨줄 수 있습니다.
추가적으로 one hot이 적용된 hard prediction 값과 label 값을 이용하여 loss를 구합니다.
앞서 말씀드린 multi loss를 이용하여 학습이 진행되게 됩니다.
knowledge distillation을 이용함으로서 작은 모델로 더 큰 모델만큼의 성능을 얻을 수 있다는 장점이 있습니다.
self-knowledge distillation이란 teacher network와 student network가 동일한 네트워크 인것을 말합니다.
self knowledge distillation Approch
-approach는 두 이미지 사이의 distance를 줄이는 것입니다.
또한, 같은 class 다른 sampl을 KL divergenc를 통해 학습해 분포를 맞춰줍니다.
여러 sampl에 대해 general하게 학습의 결과를 가질 수 있습니다.
하지만 overfitting이 일어날 가능성이 있고, 특정 class는 high densit를 가질 수 있다는 한계가 존재합니다.
-approach는 모델의 여러 block에 auxiliary classifiers을 붙여 티처 네트워크로 사용하고, 최종 infer에 사용하는 output을 학생 네트워크로 사용하는 것입니다.
얕은 layer의 정보들도 사용함으로서 표현력을 높이는 효과가 있습니다.
Proposed Method
본 논문에서 제안하는 PS-KD 기법에 대해서 말하겠습니다.
Knowledge Distillation as Softening Target
PS-KD 기법은 self-knowledge distillation이 regulization의 기능을 하도록 사용하였습니다.
맨 위의 수식는 기존 연구된 기법으로서, Soft softmax를 사용한 KD loss입니다.
알파는 hyperparameter로서 각 hard, soft label의 영향력을 조절하는 역할을 수행합니다.
H()는 CE Loss 입니다.
Softer Softmax
softer softmax란, 지수함수를 씌우기 전에 타우를 나누어 보다 soft하게 softmax를 수행해줍니다.
(사실 softer softmax는 제가 만든 용어입니다.. 용어는 생략하고 봐주세요...)
기존 softmax는 hard label과 비슷할 가능성이 큼으로, 정답이 아닌 다른 label에 대한 정보의 중요성을 증가시키기 위해서 새로운 softmax를 사용합니다.
아래에서 softmax 내용을 참고하시면 더 이해가 빠르실 것 같습니다.
https://light-tree.tistory.com/196
타우가 1일 때 아래와 같은 형태로 나타낼 수 있습니다.
형태를 살펴보면 기존 CE loss term에 aP^T(x) term이 붙음으로서 마치 regularization처럼 동작이 수행됩니다.
논문에서는 past prediction를 teacher network로 사용합니다.
저자는 초반 epochs일 수록 target label에 overconfident 하지 않아 softer score를 가지고 있다고 주장합니다.
그래서, 기존 loss의 티처 네트워크 prediction 자리에 이전 prediction 값을 입력해줍니다.
그 중, 맨 마지막 layer가 제일 가치있는 정보를 담고 있기 때문에 직전의 레이어만 티처 네트워크로 사용하는 것이 최종 loss 입니다.
또한, 초기 단계에서는 학습 prediction이 믿을만할 정보를 가지고 있지 않기 때문에 초기부터 알파값을 늘려가면서 티처 네트워크의 영향력을 키우게 됩니다.
알파를 어떻게 키울 것인지에 대해 step-wise, linear growth, exponential 한 방법들을 실험해 보았는데, 실험적으로 linear growth 정보가 가장 좋은 성능을 가져서 선택했다고 합니다.
Hard example mining
본 논문에서는 알파가 적절하게 설정된 경우에는 hard example에 대해서 더 집중하게 된다고 주장합니다.
hard example에 대해서 더 집중해서 본다는 것을 증명하기 위해 L를 z로 미분하여줍니다.
GT, 즉 i가 타겟 cls인 경우 또한 i가 cls가 아닌 경우에 대해 아래와 같이 정리할 수 있습니다.
만약 i가 gt가 아닌 항이 0보다 크다고 가정할 때, GT cls의 gradient는 반대가 되어야함으로 음수라고 예측할 수 있습니다.
softmax 식을 사용하였음으로, gt가 아닌 모든 i의 절대값들을 모두 더하면, GT의 gradient라는 것을 알 수 있습니다.
즉, 기존 KD의 gradien를 CE의 gradien로 나누어서, PS-KD의 rescaling factor를 구할 수 있습니다.
기존 KD의 gradien를 CE의 gradien로 나누어서, PS-KD의 rescaling factor를 구할 수 있습니다.
GT의 probality는 이전 epoch보다 현재 epoch이 크기 때문에, 올바르게 학습이 진행이 된다면 rt-1보다 rt이 항상 작습니다.
만약 학습하기가 쉬운 경우 rt분의 rt-1항이 커질 것이고, gradient rescaling factor는 줄어들 것입니다.
학습하기가 어렵다면 이전 epoch과 현재 epoch의 probality의 변화가 적을 것이기 때문에 rt분의 rt-1항이 분의 감마티는 작을 것이고, gradient rescaling factor는 커질 것입니다.
이렇게 앞서 정의한 loss가 hard한 example에 대해서 더 집중하게 되는 역할을 한다는 것을 알 수 있습니다.
위의 그림은 CIFAR-100의 100개의 hard sample data에 대해서 학습을 시킨 결과입니다.
첫번째 사진은, 각 학습이 진행될 수록, 정답 target의 probability를 나타냅니다. 다른 기법보다 target probability가 높은 것을 확인할 수 있습니다.
두번째 사진은, max probality를 나타내줍니다. 모두 targe의 probabilty보다 큰 값을 가진 것으로 보아 incorrect predictio에 대해 overconfident한 결과를 보이는 것을 확인할 수 있습니다.
위의 결과를 바탕으로, PS-KD가 gradient rescapling scheme에 의해 hard 사례에 집중한다는 것을 그림으로 알 수 있습니다.
Implenmenation
이전 에폭의 prediction 결과가 학습에 필요한 PS-KD는 두가지 방법으로 구현될 수 있습니다.
첫번째 방법은, t번째 에폭이 시작시에 t-1 epochs를 forward 하여 결과를 가져옵니다.
기존 방법은 더 많은 GPU 메모리를 필요로 하지만, 더 많은 store space는 필요하지 않다고 합니다.
두번째 방법은, t-1 에폭의 결과를 disk에 저장하여 read하는 방법입니다.
더 많은 GPU 메모리를 필요를 하지만, 더 많은 store space 필요하지 않다고 합니다.
Experimental Results
CIFAR-100 Classification
성능 평가 지표로서 NLL, ECE, AURC이 사용되었는데요
NLL은 우리가 아는 Cross entroy와 동일하다고 합니다. 즉, 낮을 수록 좋은 성능을 보여줍니다
ECE는 expected calibration error로서 calibration이 얼마나 잘 되었는지를 나타내는 성능지표로서 낮을수록 좋은 값을 보여줍니다.
AURC는 area under the risk-coverage curve로서 예측이 신뢰값에 의해 정렬되는 정도를 측정하는 성능지표로서 낮을 수록 예측이 신뢰성이 있다는 것을 보여줍니다.
evaluation metrix에 대하여 논문의 후반에 정리한 내용이 있는데요, 더 궁금하신 분은 논문을 참고해주시면 감사하겠습니다
Error와 여러 평가지표를 보면 다른 distillation 기법을 사용한 것보다 좋은 성능을 가진 것을 확인할 수 있습니다.
데이터 augmentation 기법과 같이 사용하였을 때도 Ps-kd 기법을 같이 사용한 결과 성능이 개선되는 것을 확인할 수 있었습니다.
또한 본 논문에서 PS-KD의 앙상블 효과를 검토하였습니다.
이전 실험의 세 가지 훈련된 모델을 활용하여, CS-KD, TF-KD 및 PS-KD를 함께 사용함으로서 앙상블 성능을 보여주었습니다.
결과적으로 앙상블은 PS-KD의 이점도 얻을 수 있는 것으로 나타났는데, 이는 PS-KD가 독립적으로 훈련된 모델의 다양성을 저하시키지 않는다는 것을 의미한다고 합니다.
해당 그림은 Calibration plot 입니다.
Calibration이란, 실제 분포와 유사하도록 prediction 분포를 가지는 정도를 뜻합니다.
회색 점선에 가까울 수록 calibration plot이 잘되었다고 볼 수 있는데, ResNext-29를 제외하면 PS-KD는 다른 기법들보다 calibration이 잘 된 것을 확인할 수 있습니다.
ImageNet Classification
그 다음으로 ImageNet Classification에 대한 결과입니다.
CutMix와 Ps-KD 결과를 함께 사용하였을때 가장 높은 성능을 보이는 것을 확인할 수 있습니다.
위의 그림을 통하여, PS-KD가 predicted probability의 퀄리티를 향상시킨 결과를 확인할 수 있습니다.
그림에 여러가지 물체가 존재하는 경우 classification 결과를 나타내었는데요,
basline의 경우 incorrect prediction 결과가 가장 높게 나타나는데, PS-KD의 결과는 top1, top2가 해당 그림에 대한 cls인 것을 확인할 수 있습니다.
또한, 아래 사진에서도 basline 결과보다 overconfident하지 않은, soft한 prediction 값을 확인할 수 있습니다.
Object Detection
또한 Object Detection에 대해서도 성능 향상을 보인 것을 확인할 수 있습니다.
Machine Translation
Machine Translation에서도 좋은 성능을 보인 것을 확인할 수 있었습니다.
Conclusion
이전 epochs의 prediction 결과를 teacher 모델로 이용해 사용한다는 아이디어가 기발한 논문이라고 생각했습니다.
단순히 calibratio을 위한 self-distillation이 아닌, hard example minin의 효과도 낼 수 있어 간단하지만 성능이 좋은 기법이라고 생각합니다.
실험 결과에서 알 수 있듯이, Classification, Detection, Machine Translation와 같이 다양한 task에 사용될 수 있기 때문에 확장성도 크다고 생각합니다.
확실히 oral 논문은 oral인 이유가 있는 것 같습니다…
틀린 부분이나 같이 토론할 것들이 있으면 댓글 남겨주세요!
'논문정리' 카테고리의 다른 글
[논문리뷰]SEAM-Self-supervised Equivariant Attention Mechanism for Weakly Supervised Semantic Segmentation (0) | 2022.02.08 |
---|