深層学習モデルの軽量化: PyTorchによる知識蒸留の実践

近年、ディープラーニングは画像認識、自然言語処理など、様々な分野で目覚ましい成果を上げています。特に、大規模なニューラルネットワークは、大量のデータを学習し、高度なタスクをこなす能力を持っています。しかし、これらのモデルは、その巨大なサイズと計算量の多さから、スマートフォンやIoTデバイスなど、計算資源が限られた環境への導入が困難です。
そこで注目されているのが、知識蒸留という技術です。知識蒸留とは、大規模で高性能な教師モデルが学習した知識を、より小規模な生徒モデルに知識転送することで、高精度かつ軽量なモデルを構築する手法です。これにより、計算資源の少ないデバイスでも、高性能なモデルを活用できるようになります。
今回は、知識蒸留について簡単に解説し、PyTorchによる実装方法を紹介します。
知識蒸留とは?
近年、深層学習は画像認識や自然言語処理をはじめとする様々な分野で広く普及しています。しかしながら、深層学習モデルは、その大きなモデルサイズと計算量の多さから、スマートフォンやIoTデバイスははじめとする、計算資源が限られたエッジ環境への導入が困難です。
そこで、この問題に対処するため、様々なモデル圧縮・高速化技術が開発されてきました。代表的な手法としては、以下のものが挙げられます。
- パラメータの刈り込みと共有: 不要なパラメータを削除することで、モデルを軽量化
- 低ランク因数分解: 行列分解やテンソル分解を用いて、パラメータの冗長性を減らす
- 畳み込みフィルタの圧縮: 畳み込みフィルタを圧縮することで、モデルを軽量化
- 知識蒸留: 大規模な教師モデルの知識を、より小規模な生徒モデルに転送
今回は、これらの手法の中でも特に注目されている知識蒸留について解説していきます。
知識蒸留の仕組み
知識蒸留(Knowledge distillation)とは、大規模で高性能な教師モデルが学習した知識を、より小規模な生徒モデルに知識転送することで、高精度かつ軽量なモデルを構築する手法です。これにより、計算資源の少ないデバイスでも、高性能なモデルを活用できるようになります。
知識蒸留のメリット
- モデルの軽量化: モデルサイズを削減し、メモリ使用量を減らすことができる
- 高速化: 推論時間を短縮し、リアルタイム処理が可能になる
- 低消費電力: 計算量を減らすことで、消費電力を抑えることができる。
知識蒸留の構成要素
下図に示すように、知識蒸留は、教師モデル、生徒モデル、蒸留アルゴリズムの3つの主要なコンポーネントから構成されます。
- 教師モデル: 大規模で高性能なモデルであり、生徒モデルに教え込む役割を担う
- 生徒モデル: 小規模なモデルであり、教師モデルから知識を学習する
- 蒸留アルゴリズム: 教師モデルの知識を生徒モデルに転送する方法を定義

知識の種類
知識蒸留では、単にモデルの出力だけでなく、様々な種類の情報が「知識」として捉えられ、教師モデルから生徒モデルへ転送されます。知識蒸留における知識の種類は、主に応答ベース、特徴ベース、関係ベースの3つのカテゴリに分類できます。

応答ベースの知識 (Response-Based Knowledge)
応答ベースの知識は、最も直感的な知識の形式であり、教師モデルの最終的な出力(ロジットなど)を直接生徒モデルに伝達します。この方法はシンプルながら効果的で、多くのタスクで広く利用されています。
- ソフトターゲット: 教師モデルの出力にソフトマックス関数をかけることで得られる確率分布を指す。教師モデルが各クラスに属する確率をより詳細に表すことができ、生徒モデルの学習に使用する。
- 温度パラメータ: ソフトターゲットの分布を調整し、知識転送の効率を向上させるために使用する。

長所:
- 実装が簡単
- 様々なタスクに適用可能
短所:
- 中間層の情報を活用できない
- 教師モデルの構造に依存

特徴ベースの知識 (Feature-Based Knowledge)
特徴ベースの知識は、教師モデルの中間層の出力(特徴マップ)を知識として利用します。特徴マップには、入力データの高レベルな抽象表現が含まれており、生徒モデルの学習を改善できる情報を提供します。
- 特徴マップの一致: 教師モデルと生徒モデルの特徴マップを直接比較し、その間の距離を最小化する。
- 注意マップ: 特徴マップの中から重要な部分を抽出し、その情報を生徒モデルに伝える。
- 因子: 特徴マップをより簡潔な表現に変換し、知識転送を効率化する。

長所:
- 中間層の情報を活用できる
- より深いレベルでの知識転送が可能
短所:
- 特徴マップの次元数が異なる場合、対応付けが難しい
- ハイパーパラメータの調整が複雑
手法 | 知識タイプ | 知識ソース | 蒸留損失 |
---|---|---|---|
Fitnet (Romero et al., 2015) | Feature representation | Hint layer | L2損失 |
NST (Huang and Wang, 2017) | Neuron selectivity patterns | Hint layer | 最大平均不一致損失 |
AT (Zagoruyko and Komodakis, 2017) | Attention maps | Multi-layer group | L2損失 |
FT (Kim et al., 2018) | Paraphraser | Multi-layer group | L1損失 |
Rocket Launching (Zhou et al., 2018) | Sharing parameters | Hint layer | L2損失 |
KR (Liu et al., 2019c) | Parameters distribution | Multi-layer group | 交差エントロピー損失 |
AB (Heo et al., 2019c) | Activation boundaries | Pre-ReLU | L2損失 |
Shen et al. (2019a) | Knowledge amalgamation | Hint layer | L2損失 |
Heo et al. (2019a) | Margin ReLU | Pre-ReLU | L2損失 |
FN (Xu et al., 2020b) | Feature representation | Fully-connected layer | 交差エントロピー損失 |
DFA (Guan et al., 2020) | Feature aggregation | Hint layer | L2損失 |
AdaIN (Yang et al., 2020a) | Feature statistics | Hint layer | L2損失 |
FN (Xu et al., 2020b) | Feature representation | Penultimate layer | 交差エントロピー損失 |
EC-KD (Wang et al., 2020b) | Feature representation | Hint layer | L2損失 |
ALP-KD (Passban et al., 2021) | Attention-based layer projection | Hint layer | L2損失 |
SemCKD (Chen et al., 2021) | Feature maps | Hint layer | L2損失 |
関係ベースの知識 (Relation-Based Knowledge)
関係ベースの知識は、教師モデルの異なる層間の関係や、データサンプル間の関係を知識として利用します。
- FSP (Fisher Vector) 行列: 特徴マップ間の相関関係を表す行列。
- インスタンス関係グラフ: データサンプル間の関係をグラフで表現。
- 多様体学習: データを低次元空間に埋め込み、その構造を保持しながら知識を転送。

長所:
- データの構造的な情報を活用できる
- より深いレベルでの知識転送が可能
短所:
- 計算コストが高い
- ハイパーパラメータの調整が複雑
法 | 知識タイプ | 知識ソース | 蒸留損失 |
---|---|---|---|
FSP (Yim et al., 2017) | FSP matrix | End of multi-layer group | L2損失 |
You et al. (2017) | Instance relation | Hint layers | L2損失 |
Zhang and Peng (2018) | Logits graph, Representation graph | Softmax layers, Hint layers | Earth Mover距離, 最大平均不一致損失 |
DarkRank (Chen et al., 2018c) | Similarity DarkRank | Fully-connected layers | カルバック・ライブラー発散 |
MHGD (Lee and Song, 2019) | Multi-head graph | Hint layers | カルバック・ライブラー発散 |
RKD (Park et al., 2019) | Instance relation | Fully-connected layers | Huber損失, Angle-wise損失 |
IRG (Liu et al., 2019g) | Instance relationship graph | Hint layers | L2損失 |
SP (Tung and Mori, 2019) | Similarity matrix | Hint layers | Frobeniusノルム |
CCKD (Peng et al., 2019a) | Instance relation | Hint layers | L2損失 |
MLKD (Yu et al., 2019) | Instance relation | Hint layers | Frobeniusノルム |
PKT(Passalis et al., 2020a) | Similarity probability distribution | Fully-connected layers | カルバック・ライブラー発散 |
Passalis et al. (2020b) | Mutual information flow | Hint layers | カルバック・ライブラー発散 |
LP (Chen et al., 2021) | Instance relation | Hint layers | L2損失 |
蒸留スキーム
知識蒸留における学習スキームは、教師モデルと生徒モデルの更新方法によって大きく3つに分類されます。
蒸留手法 | 特徴 | 長所 | 短所 | 補足説明 |
---|---|---|---|---|
オフライン蒸留 | 教師モデルを事前に学習させ、その知識を固定して生徒モデルに転送する。 | シンプルで実装が容易、教師モデルの知識を事前に抽出し、再利用可能 | 教師モデルの学習コストが高い、生徒モデルは教師モデルに大きく依存する | バニラ知識蒸留などが代表的な手法。 教師モデルの知識を「蒸留」して、生徒モデルに「注入」するイメージ。 事前学習された大規模モデルの知識を、小型デバイスに適したモデルに転送する際に有効。 |
オンライン蒸留 | 教師モデルと生徒モデルを同時に学習させる。 | 教師モデルと生徒モデルが相互に学習することで、より高性能なモデルが得られる可能性がある。 | 実装が複雑、ハイパーパラメータの調整が難しい | 深層相互学習などが代表的な手法。 教師モデルと生徒モデルが互いに教え合いながら学習を進める。 |
自己蒸留 | 教師モデルと生徒モデルに同じアーキテクチャのネットワークを使用する。 | – | ネットワークの設計が複雑になる可能性がある | スナップショット蒸留などが代表的な手法。 同じネットワークを異なる視点から見ることで、より深い学習を実現。 モデルの汎化性能を向上させる効果が期待できる。 |
Teacher-Studentアーキテクチャ
知識蒸留において、教師モデルと生徒モデルのアーキテクチャは、知識転移の成否を大きく左右します。教師モデルから生徒モデルへの知識の獲得と蒸留の質は、両者のネットワーク構造の設計に大きく依存します。
従来型アーキテクチャ
従来の知識蒸留では、一般的に以下の様な教師-生徒モデルの組み合わせが採用されてきました。

- 教師モデル: 深く幅の広い大規模なネットワーク
- 生徒モデル:
- 教師モデルの簡略化バージョン(層数やチャネル数を減らす)
- 教師モデルの量子化バージョン
- パラメータ効率のよい計算をできる小規模ネットワーク
- ネットワーク全体として最適化された小規模ネットワーク
- 教師モデルと同じネットワーク
モデル容量のギャップと対策
教師モデルと生徒モデルの間には、モデル容量に大きな差があることが多く、これが知識転送の妨げとなることがあります。この問題に対処するために、以下の様な対策が提案されています。
- 教師アシスタントの導入: 教師アシスタントを導入し、教師モデルと生徒モデル間のトレーニングギャップを軽減。
- 残差学習: 教師モデルと生徒モデルの出力の差を学習することで、より正確な知識転移を実現する。
- ネットワークの量子化: 教師モデルを量子化することで、生徒モデルとの構造的な差異を小さくする。
- 構造圧縮: 教師モデルの複数の層の知識を、生徒モデルの単一の層に転送する。
- ブロック単位の知識転送: 教師モデルのネットワークを、サブネットワークと呼ばれるブロック単位に分割し、それぞれのブロックが持つ知識を生徒モデルの各レイヤに段階的に転送する。
最新の動向
最近の研究では、以下の様な新しい方向性が注目されています。
- ニューラルアーキテクチャ検索 (NAS): 効率的なニューラルネットワークを自動的に設計する技術を、知識蒸留に適用する。
- 動的な教師-生徒学習アーキテクチャ: 学習中に教師と生徒の構造を適応的に変化させる。
PyTorchによる知識蒸留実践
では、PyTorchを使用して知識蒸留を試してみましょう。まずは必要なパッケージをインストールします。
# PyTorch のインストール
# バージョンの指定方法は公式ページを参照: https://pytorch.org/get-started/previous-versions/
$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 実験で使用するPyTorch関連のパッケージ
$ pip install torcheval torchinfo
# 進捗バーを表示するためのパッケージ
$ pip install tqdm
ここからコードを実装していきます。最初に、必要なパッケージを読み込んでおきましょう。
from typing import OrderedDict import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.datasets as datasets from torcheval.metrics import MulticlassAccuracy from torchinfo import summary from torchvision.models import resnet18 from tqdm import tqdm # GPUが利用できればCUDAを使う device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
実験で使用するデータセットには、CIFAR10(10クラスの画像分類データセット)を使用します。次のようにして、訓練用のデータセットを読み込みます。
# 学習データはオーグメンテーションの処理を考慮 train_transform = transforms.Compose([ transforms.RandomResizedCrop(32, scale=(0.50, 1.0), ratio=(1.0, 1.0)), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.RandomHorizontalFlip(p=0.1), transforms.RandomVerticalFlip(p=0.1), transforms.RandomAutocontrast(p=0.1), transforms.RandomEqualize(p=0.1), transforms.RandomGrayscale(p=0.1), transforms.RandomPerspective(distortion_scale=0.2, p=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # CIFAR-10データセットをロード train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
続けて、テスト用のデータセットも読み込みます。
# テストデータには、オーグメンテーションの処理は不要 test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # CIFAR-10データセットをロード test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
読み込んだデータセットからDataLoaderのインスタンスを準備します。
# データローダーの作成 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
次に、ネットワークの定義ですが、教師ネットワークには事前学習済みのResNet18を利用します。
# 教師ネットワークには事前学習済みの ResNet18 を利用 teacher_nn = resnet18(pretrained=True) # 出力層を10クラス分類に対応させる teacher_nn.fc = nn.Linear(in_features=teacher_nn.fc.in_features, out_features=10) # ネットワーク情報 summary(teacher_nn, input_size=(2, 3, 32, 32))
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet [2, 10] --
├─Conv2d: 1-1 [2, 64, 16, 16] 9,408
├─BatchNorm2d: 1-2 [2, 64, 16, 16] 128
├─ReLU: 1-3 [2, 64, 16, 16] --
├─MaxPool2d: 1-4 [2, 64, 8, 8] --
├─Sequential: 1-5 [2, 64, 8, 8] --
│ └─BasicBlock: 2-1 [2, 64, 8, 8] --
│ │ └─Conv2d: 3-1 [2, 64, 8, 8] 36,864
│ │ └─BatchNorm2d: 3-2 [2, 64, 8, 8] 128
│ │ └─ReLU: 3-3 [2, 64, 8, 8] --
│ │ └─Conv2d: 3-4 [2, 64, 8, 8] 36,864
│ │ └─BatchNorm2d: 3-5 [2, 64, 8, 8] 128
│ │ └─ReLU: 3-6 [2, 64, 8, 8] --
│ └─BasicBlock: 2-2 [2, 64, 8, 8] --
│ │ └─Conv2d: 3-7 [2, 64, 8, 8] 36,864
│ │ └─BatchNorm2d: 3-8 [2, 64, 8, 8] 128
│ │ └─ReLU: 3-9 [2, 64, 8, 8] --
│ │ └─Conv2d: 3-10 [2, 64, 8, 8] 36,864
│ │ └─BatchNorm2d: 3-11 [2, 64, 8, 8] 128
│ │ └─ReLU: 3-12 [2, 64, 8, 8] --
├─Sequential: 1-6 [2, 128, 4, 4] --
│ └─BasicBlock: 2-3 [2, 128, 4, 4] --
│ │ └─Conv2d: 3-13 [2, 128, 4, 4] 73,728
│ │ └─BatchNorm2d: 3-14 [2, 128, 4, 4] 256
│ │ └─ReLU: 3-15 [2, 128, 4, 4] --
│ │ └─Conv2d: 3-16 [2, 128, 4, 4] 147,456
│ │ └─BatchNorm2d: 3-17 [2, 128, 4, 4] 256
│ │ └─Sequential: 3-18 [2, 128, 4, 4] 8,448
│ │ └─ReLU: 3-19 [2, 128, 4, 4] --
│ └─BasicBlock: 2-4 [2, 128, 4, 4] --
│ │ └─Conv2d: 3-20 [2, 128, 4, 4] 147,456
│ │ └─BatchNorm2d: 3-21 [2, 128, 4, 4] 256
│ │ └─ReLU: 3-22 [2, 128, 4, 4] --
│ │ └─Conv2d: 3-23 [2, 128, 4, 4] 147,456
│ │ └─BatchNorm2d: 3-24 [2, 128, 4, 4] 256
│ │ └─ReLU: 3-25 [2, 128, 4, 4] --
├─Sequential: 1-7 [2, 256, 2, 2] --
│ └─BasicBlock: 2-5 [2, 256, 2, 2] --
│ │ └─Conv2d: 3-26 [2, 256, 2, 2] 294,912
│ │ └─BatchNorm2d: 3-27 [2, 256, 2, 2] 512
│ │ └─ReLU: 3-28 [2, 256, 2, 2] --
│ │ └─Conv2d: 3-29 [2, 256, 2, 2] 589,824
│ │ └─BatchNorm2d: 3-30 [2, 256, 2, 2] 512
│ │ └─Sequential: 3-31 [2, 256, 2, 2] 33,280
│ │ └─ReLU: 3-32 [2, 256, 2, 2] --
│ └─BasicBlock: 2-6 [2, 256, 2, 2] --
│ │ └─Conv2d: 3-33 [2, 256, 2, 2] 589,824
│ │ └─BatchNorm2d: 3-34 [2, 256, 2, 2] 512
│ │ └─ReLU: 3-35 [2, 256, 2, 2] --
│ │ └─Conv2d: 3-36 [2, 256, 2, 2] 589,824
│ │ └─BatchNorm2d: 3-37 [2, 256, 2, 2] 512
│ │ └─ReLU: 3-38 [2, 256, 2, 2] --
├─Sequential: 1-8 [2, 512, 1, 1] --
│ └─BasicBlock: 2-7 [2, 512, 1, 1] --
│ │ └─Conv2d: 3-39 [2, 512, 1, 1] 1,179,648
│ │ └─BatchNorm2d: 3-40 [2, 512, 1, 1] 1,024
│ │ └─ReLU: 3-41 [2, 512, 1, 1] --
│ │ └─Conv2d: 3-42 [2, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-43 [2, 512, 1, 1] 1,024
│ │ └─Sequential: 3-44 [2, 512, 1, 1] 132,096
│ │ └─ReLU: 3-45 [2, 512, 1, 1] --
│ └─BasicBlock: 2-8 [2, 512, 1, 1] --
│ │ └─Conv2d: 3-46 [2, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-47 [2, 512, 1, 1] 1,024
│ │ └─ReLU: 3-48 [2, 512, 1, 1] --
│ │ └─Conv2d: 3-49 [2, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-50 [2, 512, 1, 1] 1,024
│ │ └─ReLU: 3-51 [2, 512, 1, 1] --
├─AdaptiveAvgPool2d: 1-9 [2, 512, 1, 1] --
├─Linear: 1-10 [2, 10] 5,130
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 74.05
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 1.62
Params size (MB): 44.73
Estimated Total Size (MB): 46.37
==========================================================================================
合わせて生徒ネットワークを定義していきます。今回は、以下のように単純な畳込みネットワークとします。
# 生徒ネットワークの定義 class StudentNN(nn.Module): def __init__(self, num_classes=10): super(StudentNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Linear(1024, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # 生徒ネットワーク student_nn = StudentNN() summary(student_nn, input_size=(2, 3, 32, 32))
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
StudentNN [2, 10] --
├─Sequential: 1-1 [2, 16, 8, 8] --
│ └─Conv2d: 2-1 [2, 16, 32, 32] 448
│ └─ReLU: 2-2 [2, 16, 32, 32] --
│ └─MaxPool2d: 2-3 [2, 16, 16, 16] --
│ └─Conv2d: 2-4 [2, 16, 16, 16] 2,320
│ └─ReLU: 2-5 [2, 16, 16, 16] --
│ └─MaxPool2d: 2-6 [2, 16, 8, 8] --
├─Sequential: 1-2 [2, 10] --
│ └─Linear: 2-7 [2, 256] 262,400
│ └─ReLU: 2-8 [2, 256] --
│ └─Dropout: 2-9 [2, 256] --
│ └─Linear: 2-10 [2, 10] 2,570
==========================================================================================
Total params: 267,738
Trainable params: 267,738
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 2.64
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.33
Params size (MB): 1.07
Estimated Total Size (MB): 1.43
==========================================================================================
ネットワークが定義できたので、知識蒸留前に教師ネットワークと生徒ネットワークの精度を確認していきます。まずは、以下のように訓練用の関数を定義します。
def train(model, train_loader, epochs, learning_rate, device): # 損失関数とオプティマイザを定義 criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999)) metric = MulticlassAccuracy(average="micro", num_classes=10) model.to(device) model.train() for epoch in range(epochs): with tqdm(train_loader) as pbar: pbar.set_description(f"[Epoch {epoch + 1}/{epochs}]") for inputs, labels in pbar: inputs, labels = inputs.to(device), labels.to(device) # 順伝搬処理 optimizer.zero_grad() outputs = model(inputs) # バックプロパゲーション loss = criterion(outputs, labels) loss.backward() optimizer.step() # Accuracyをログ出力 metric.update(outputs, labels) accuracy = metric.compute().item() pbar.set_postfix(OrderedDict(Accuracy=accuracy)) metric.reset()
続けて、テスト用の関数です。
def test(model, test_loader, device): model.to(device) model.eval() metric = MulticlassAccuracy(average="micro", num_classes=10) with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) metric.update(outputs, labels) return metric.compute().item()
精度確認の準備ができたので、まずは教師ネットワークを訓練して、精度を確認します。
# 教師ネットワークの訓練 train(teacher_nn, train_loader, epochs=20, learning_rate=1.0E-3, device=device) # 教師ネットワークの精度を確認 accuracy = test(teacher_nn, test_loader, device) print(f"Teacher NN Accuracy: {accuracy}")
Teacher NN Accuracy: 0.8389999866485596
次に、知識蒸留前の生徒ネットワークの精度を確認します。
# 生徒ネットワークの訓練 train(student_nn, train_loader, epochs=20, learning_rate=0.001, device=device) # 生徒ネットワークの精度を確認 accuracy = test(student_nn, test_loader, device) print(f"Student NN Accuracy: {accuracy}")
Student NN Accuracy: 0.6978999972343445
では、ここから知識蒸留の実装を進めて行きますが、まずは損失関数を定義します。今回は、Hinton et al., 2015 の論文中で定義されているものを使用します。この損失関数は交差エントロピー損失をベースにしたものとなっています。なお、詳細については、原論文を参照してください。
以下、損失関数の実装になります。
class SoftCrossEntropyLoss(nn.Module): def __init__(self, temperature: float = 1.0) -> None: super().__init__() if temperature <= 0.0: raise ValueError("温度パラメータが不正です。0より大きい値にしてください。") self.T = temperature def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: soft_prob = nn.functional.log_softmax(input / self.T, dim=-1) soft_targets = nn.functional.softmax(target / self.T, dim=-1) soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (self.T**2) return soft_targets_loss
次に、この損失関数を利用した知識蒸留の関数を実装します。
def knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, soft_target_weight, device): ce_loss = nn.CrossEntropyLoss() soft_loss = SoftCrossEntropyLoss(temperature=2.0) optimizer = optim.AdamW(student.parameters(), lr=learning_rate, betas=(0.9, 0.999)) metric = MulticlassAccuracy(average="micro", num_classes=10) # 教師ネットワークの勾配計算は凍結 teacher.to(device) teacher.eval() # 生徒ネットワークのみ学習させる student.to(device) student.train() for epoch in range(epochs): with tqdm(train_loader) as pbar: pbar.set_description(f"[Epoch {epoch + 1}/{epochs}]") for inputs, labels in pbar: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() student_logits = student(inputs) with torch.no_grad(): teacher_logits = teacher(inputs) # 交差エントピー損失にMMD損失を加えたものをトータルの損失とする hard_target_loss = (1.0 - soft_target_weight ) * ce_loss(student_logits, labels) soft_target_loss = soft_target_weight * soft_loss (student_logits, teacher_logits) loss = soft_target_loss + hard_target_loss loss.backward() optimizer.step() metric.update(student_logits, labels) accuracy = metric.compute().item() pbar.set_postfix(OrderedDict(Accuracy=accuracy, HardLoss=hard_target_loss.item(), SoftLoss=soft_target_loss.item())) metric.reset()
では、知識蒸留を実行して精度を確認してみましょう。
student_nn = StudentNN() knowledge_distillation( teacher=teacher_nn, student=student_nn, train_loader=train_loader, epochs=20, learning_rate=1.0E-3, soft_target_weight=0.10, device=device ) # 知識蒸留後の精度を確認 accuracy = test(student_nn, test_loader, device) print(f"Knowledge Distillation Accuracy: {accuracy}")
Knowledge Distillation Accuracy: 0.703499972820282
上記のように、知識蒸留を行うことで生徒ネットワークの精度がベースラインからわずかに改善しました。今回は、ネットワークの出力であるLogitsを損失関数に利用したので、応答ベースの知識蒸留と言えます。また、学習済みの教師ネットワークを利用しているため、蒸留のスキーマとしてはオフライン蒸留に区分されます。
この例では、交差エントロピーをベーズにした損失を使用しましたが、L1損失やL2損失に置き換えても問題ありません。また、複数の損失を組み合わせたハイブリッドな知識蒸留を行うのもよいかもしれません。
さらに詳しいことを知りたい場合は、PyTorchが公開しているチュートリアルKnowledge Distillation Tutorialを参照してみてください。
おわりに
知識蒸留とは、大規模なモデルが学習した知識を、より小さなモデルに転送する技術です。
今回は、知識蒸留の基礎的な概念から、具体的な手法、そして教師と生徒のモデルのアーキテクチャまで解説しました。また、PyTorchを使用した具体的な実装方法も紹介しました。
知識蒸留は、深層学習モデルの軽量化や高速化に寄与し、様々な分野への応用が期待されています。
More Information
- arXiv:2006.05525, Jianping Gou et al., 「Knowledge Distillation: A Survey」, https://arxiv.org/abs/2006.05525
- arXiv:2304.04262, Weijian Luo, 「A Comprehensive Survey on Knowledge Distillation of Diffusion Models」, https://arxiv.org/abs/2304.04262
- arXiv:2308.04268, Chengming Hu et al., 「Teacher-Student Architecture for Knowledge Distillation: A Survey」, https://arxiv.org/abs/2308.04268
- arXiv:2407.01885, Chuanpeng Yang et al., 「Survey on Knowledge Distillation for Large Language Models: Methods, Evaluation, and Application」, https://arxiv.org/abs/2407.01885