Paper Review/Deep Learning

[OUTTA Alpha팀 논문 리뷰 요약] Part 9-5. Distilling the Knowledge in a Neural Network

YeonJuJeon 2025. 2. 1. 23:12

논문 링크: 1503.02531

 

OUTTA 논문 리뷰 링크: [2025-1] 박서형 - Distilling the Knowledge in a Neural Network

 

[2025-1] 박서형 - Distilling the Knowledge in a Neural Network

https://arxiv.org/abs/1503.02531 Distilling the Knowledge in a Neural NetworkA very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Un

blog.outta.ai


1. Introduction

  • 일반적으로 머신러닝 모델의 성능을 향상시키는 방법 중 하나는 앙상블(ensemble) 기법을 사용하는 것.
  • 그러나 앙상블 모델은 연산량이 많아 배포 및 실시간 사용이 어려움.
  • 이를 해결하기 위해 Knowledge Distillation (지식 증류) 기법을 제안함.
  • 이 기법은 복잡한 모델(Teacher)의 지식을 작은 모델(Student)로 전이시켜 성능을 유지하면서도 연산량을 줄이는 것을 목표로 함.

2. Knowledge Distillation

  • Teacher Model (큰 모델)에서 학습한 지식을 Student Model (작은 모델)에게 전달.
  • 단순히 정답 라벨(hard target)을 사용하는 것이 아니라, 출력 확률 분포(soft target)를 활용.
  • Soft Target을 이용하면 클래스 간 관계 및 일반화 성능을 보존하면서 더 작은 모델을 훈련 가능.

Distillation 과정

  1. Teacher Model $T(x)$ 학습 $$\mathcal{L}_{\text{teacher}} = \sum -y \log T(x)$$
  2. Teacher Model의 출력을 Softmax 변환하여 확률 분포 생성
    • Softmax의 온도 $T$ 조절하여 더 부드러운 분포 생성
    $$q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$
  3. Student Model $S(x)$ 훈련
    • 두 가지 손실 함수 사용
      • Soft Target Loss: Soft target $q_i$를 이용한 cross-entropy loss
      • Hard Target Loss: 실제 정답 라벨을 이용한 cross-entropy loss
    • 최종 손실 함수:
    $$\mathcal{L} = (1 - \lambda) \mathcal{L}_{\text{hard}} + \lambda T^2 \mathcal{L}_{\text{soft}}$$
    • 여기서 $T^2$를 곱해 gradient scaling을 조절하여 안정적인 학습이 가능하도록 함.

3. Temperature 의 역할

  • $T > 1$ → 확률 분포를 부드럽게 함 (Soft Targets) → 약한 클래스 정보까지 고려 가능
  • $T = 1$ → 일반적인 소프트맥스 분포와 동일
  • $T < 1$ → 원-핫(one-hot) 레이블과 가까워짐
  • Knowledge Distillation에서는 주로 $T > 1$을 사용하여 다양한 클래스 관계를 학습하도록 함.

4. Logit Matching: Distillation의 특수한 경우

  • 일반적인 Distillation은 Soft Target을 이용한 loss minimization을 수행하지만, Logit Matching은 logit 자체를 맞추는 방법.
  • 손실 함수: $$\mathcal{L} = \frac{1}{2} || z_{\text{teacher}} - z_{\text{student}} ||^2$$
  • 특징:
    • Softmax 이후 확률 분포가 아니라, logit 자체를 맞추므로 더 직접적인 학습 가능.
    • 수식 유도 과정:
      • Softmax 적용 후의 cross-entropy loss의 편미분은 다음과 같이 나타남:
      $$\frac{\partial C}{\partial z_i} = \frac{1}{T} \left( q_i - p_i \right)$$
      • $T$가 충분히 크고 logit을 zero-mean 처리하면, loss 함수는 MSE 형태와 유사해짐:
      $$\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2} (z_i - v_i)$$
      • 즉, logit matching은 distillation의 특수한 경우로 볼 수 있음.

5. Experiments & Results

1) MNIST 실험

  • Teacher Model: 1200 hidden units, dropout 적용 → 67 test errors
  • Student Model: 800 hidden units, regularization 없이 학습 → 146 test errors
  • Distilled Student Model (Soft Targets 사용) → 74 test errors
    • Soft Targets만으로 지식 전이가 가능하며, 약 50%의 성능 향상을 보여줌.

2) 음성 인식 실험 (Speech Recognition, ASR)

  • Baseline 모델: 8 hidden layers, 2560 units → Frame Accuracy 58.9%, WER 10.9%
  • 10개 앙상블 모델Frame Accuracy 61.1%, WER 10.7%
  • Distilled 모델Frame Accuracy 60.8%, WER 10.7%
    • 단일 모델이 앙상블 성능을 거의 복제할 수 있음.

6. Specialist Models (특화 모델)

  • 큰 데이터셋에서 성능을 높이기 위해, 전체 클래스를 학습하는 Generalist Model과 특정 클래스에 특화된 Specialist Models을 조합하는 방식을 제안.
  • 예시: Google JFT 데이터셋 (100M 이미지, 15,000 클래스)
    • Generalist Model: 전체 클래스 학습
    • Specialist Model: 특정 혼동되는 클래스 그룹을 전문적으로 학습 (예: 자동차 브랜드 구분)
    • Specialist Model은 overfitting 문제가 있지만, Soft Targets을 활용하면 이를 방지 가능.

7. Soft Targets의 Regularization 효과

  • 데이터가 부족한 경우에도 Soft Targets을 사용하면 일반화 성능을 향상시킬 수 있음.
  • 예시:
    • 3% 데이터만 사용한 ASR 모델
      • Hard Target 사용 → Test Accuracy 44.5% (심각한 과적합)
      • Soft Target 사용 → Test Accuracy 57.0% (거의 전체 데이터 사용과 유사한 성능)
  • 이는 Soft Targets이 모델이 학습한 구조적 정보를 보존한다는 것을 의미함.

8. Conclusion

  • Knowledge Distillation은 작은 모델에서도 강력한 성능을 유지하면서 연산 비용을 절감하는 방법.
  • Soft Targets은 단순한 hard label보다 훨씬 많은 정보를 포함하고 있으며, 작은 데이터셋에서도 강력한 regularization 효과를 가짐.
  • Logit Matching은 Distillation의 특수한 경우로, 직접적인 지식 전이가 가능.
  • 앙상블의 성능을 단일 모델로 복제할 수 있어 배포 및 실시간 처리에서 큰 이점을 가짐.