인공지능 AI/자연어처리

[논문리뷰/NLP] Distilling the Knowledge in a Neural Network

Distilling the Knowledge in a Neural Network

Geoffrey HintonOriol VinyalsJeff Dean

#NIPS 2014 Deep Learning Workshop

논문선정이유

모델링 경량화 작업을 공부하기 위해 읽었던 사전의 DistilBERT 논문이 차용한 논문이다.

Knowledge Distillation에 대해 처음으로 소개하는 논문이다. Teacher model과 Student model이라는 개념을 제시했다.

Abstract

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions [3]. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. Caruana and his collaborators [1] have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel.

Introduction

머신러닝의 앙상블은 대부분 한 데이터셋을 여러 다른 모델에서 훈련한 후, 그 결과의 평균을 내어 예측한다. 하지만 이는 비용이 매우 비싸다.

해당 논문은 대규모 모델을 하나의 작은 모델로 압축해 동등한 수준의 성능을 내는 방법론을 제시한다.

  • soft target은 high entropy라서 일반적인 학습에 사용하는 hard target보다 information이 많다.
  • training gradient 간에 gradient의 variance가 작아서, small model이 적은 data로도 효율적으로 학습이 가능해진다.

2. Distillation

Neural networks typically produce class probabilities by using a “softmax” output layer that converts the logit, zi, computed for each class into a probability, qi, by comparing zi with the other logits.

Distillation은 잘 학습된 large model 이 주는 결과를 바탕으로 small model 역시 좋은 성능을 내도록 하는 과정이라 설명할 수 있을 것이다.

  • 기존의 softmax: where T is a temperature that is normally set to 1. Using a higher value for T produces a softer probability distribution over classes.
  • 본 논문의 softmax 함수:
    1. training set (x, hard target)을 사용해 large model을 학습한다.
    2. large model이 충분히 학습된 뒤에, large model의 output을 soft target으로 하는 transfer set(x, soft target)을 생성해낸다. 이 때 soft target의 는 1이 아닌 높은 값을 사용한다.
    3. transfer set을 사용해 small model을 학습한다. 는 soft target을 생성할 때와 같은 값을 사용한다.
    4. training set을 사용해 small model을 학습한다. 는 1로 고정한다.
    각각의 loss function은 모두 Cross-Entropy-Loss를 사용한다. 결국, small model의 최종 loss function은 soft target과의 Cross-Entropy-Loss + hard target과의 Cross-Entropy-Loss가 된다.

2.1 Matching logits is a special case of distillation

Cross entropy 식을 distilled model에서 나오는 logit (z) 으로 미분하면 식은 다음과 같다.

Cross entropy 식을 distilled model에서 나오는 logit (z) 으로 미분하면 식은 다음과 같다.

이때 temperature T가 logit 값보다 더 크다면, 지수가 0에 가까워지고 exp 함수의 특성상 1에 가까워진다. 따라서 다음과 같은 근사가 가능해진다.

logit 값이 zero-mean이라고 가정한다면, 위와 같이 단순화가 가능해진다.

We show that when the distilled model is much too small to capture all of the knowledge in the cumbersome model, intermediate temperatures work best which strongly suggests that ignoring the large negative logits can be helpful.

3. Preliminary experiments on MNIST

To see how well distillation works, we trained a single large neural net with two hidden layers of 1200 rectified linear hidden units on all 60,000 training cases.

This net achieved 67 test errors whereas a smaller net with two hidden layers of 800 rectified linear hidden units and no regularization achieved 146 errors. But if the smaller net was regularized solely by adding the additional task of matching the soft targets produced by the large net at a temperature of 20, it achieved 74 test errors.

⇒ This shows that soft targets can transfer a great deal of knowledge to the distilled model, including the knowledge about how to generalize that is learned from translated training data even though the transfer set does not contain any translations.

soft target이 번역 데이터가 전송 세트에 포함되어있지 않더라도 번역된 데이터에서 학습한 일반화 방법에 대한 지식을 포함한 많은 지식을 전송할 수 있음을 보여준다.

5. Training ensembles of specialists on very big datasets

model ensemble의 단점을 한 번 더 언급. KD를 사용하여 처리하자.

앙상블은 매우 쉽게 과적합한다. soft target을 사용하여 이런 과적합을 방지할 수 있다.

5.2 Specialist Models

When the number of classes is very large, it makes sense for the cumbersome model to be an ensemble that contains one generalist model trained on all the data and many “specialist” models, each of which is trained on data that is highly enriched in examples from a very confusable subset of the classes (like different types of mushroom). The softmax of this type of specialist can be made much smaller by combining all of the classes it does not care about into a single dustbin class.

클래스 수가 매우 많을 때는 모든 데이터에 대해 훈련된 하나의 일반화 모델과 많은 ‘전문가’ 모델을 포함하는 앙상블로 처리.

5.3 Assigning classes to specialists

모델 돌리고, confusion matrix에서 오분류율 높은 클래스에 대해 집중하여 개선

5.4 Performing inference with ensembles of specialists

6. Soft Targets as Regularizers

soft target에는 hard target에는 담을 수 없는 유용한 정보들이 포함되어 있는데, 이 정보들이 overfitting을 방지하는 효과를 가져다준다. 위 table은 660M example을 포함하는 dataset에 대해 학습을 할 때, hard target과 soft target으로 성능을 측정한 것이다.

hard target으로 학습했을 때는 overfitting이 심했지만, soft target으로 했을 때는 정확도가 full training set을 사용한 것처럼 유사하게 나왔다.

7. Relations to Mixtures of Experts

It is much easier to parallelize the training of multiple specialists. We first train a generalist model and then use the confusion matrix to define the subsets that the specialists are trained on. Once these subsets have been defined the specialists can be trained entirely independently. At test time we can use the predictions from the generalist model to decide which specialists are relevant and only these specialists need to be run.

specialist model들을 여러 개 학습하는 것에 대한 내용.

specialist model 을 여러 개 학습할 때도 모델 하나 당 전체 데이터의 일부분을 주고 학습하는 것이 좋다.

요약

  • Knowledge Distillation에 대해 처음으로 소개하는 논문이다. Teacher model과 Student model이라는 개념을 제시했다.
  • MNIST, 음성 인식 등 딥러닝 전반적으로 적용 가능한 방법론임을 제시.
  • 앙상블보다 soft target이 더 성능이 높음을 제시.
  • 데이터 전체를 학습하는 비용이 비쌀 때, 경량화 서비스를 구현할 때 유용할 듯.
  • Knowledge Distillation 개념에 대해 알 수 있었지만, 중복된 내용으로 실험을 반복한 느낌.
  • 현실적으로는 이것보다 transformer와 attention에 더 많은 주목과 성과가 나오고 있음