深層学習モデルの軽量化: PyTorchによる知識蒸留の実践

近年、ディープラーニングは画像認識、自然言語処理など、様々な分野で目覚ましい成果を上げています。特に、大規模なニューラルネットワークは、大量のデータを学習し、高度なタスクをこなす能力を持っています。しかし、これらのモデルは、その巨大なサイズと計算量の多さから、スマートフォンやIoTデバイスなど、計算資源が限られた環境への導入が困難です。

そこで注目されているのが、知識蒸留という技術です。知識蒸留とは、大規模で高性能な教師モデルが学習した知識を、より小規模な生徒モデルに知識転送することで、高精度かつ軽量なモデルを構築する手法です。これにより、計算資源の少ないデバイスでも、高性能なモデルを活用できるようになります。

今回は、知識蒸留について簡単に解説し、PyTorchによる実装方法を紹介します。

知識蒸留とは?

近年、深層学習は画像認識や自然言語処理をはじめとする様々な分野で広く普及しています。しかしながら、深層学習モデルは、その大きなモデルサイズと計算量の多さから、スマートフォンやIoTデバイスははじめとする、計算資源が限られたエッジ環境への導入が困難です。

そこで、この問題に対処するため、様々なモデル圧縮・高速化技術が開発されてきました。代表的な手法としては、以下のものが挙げられます。

  • パラメータの刈り込みと共有: 不要なパラメータを削除することで、モデルを軽量化
  • 低ランク因数分解: 行列分解やテンソル分解を用いて、パラメータの冗長性を減らす
  • 畳み込みフィルタの圧縮: 畳み込みフィルタを圧縮することで、モデルを軽量化
  • 知識蒸留: 大規模な教師モデルの知識を、より小規模な生徒モデルに転送

今回は、これらの手法の中でも特に注目されている知識蒸留について解説していきます。

知識蒸留の仕組み

知識蒸留(Knowledge distillation)とは、大規模で高性能な教師モデルが学習した知識を、より小規模な生徒モデルに知識転送することで、高精度かつ軽量なモデルを構築する手法です。これにより、計算資源の少ないデバイスでも、高性能なモデルを活用できるようになります。

知識蒸留のメリット

  • モデルの軽量化: モデルサイズを削減し、メモリ使用量を減らすことができる
  • 高速化: 推論時間を短縮し、リアルタイム処理が可能になる
  • 低消費電力: 計算量を減らすことで、消費電力を抑えることができる。

知識蒸留の構成要素

下図に示すように、知識蒸留は、教師モデル、生徒モデル、蒸留アルゴリズムの3つの主要なコンポーネントから構成されます。

  • 教師モデル: 大規模で高性能なモデルであり、生徒モデルに教え込む役割を担う
  • 生徒モデル: 小規模なモデルであり、教師モデルから知識を学習する
  • 蒸留アルゴリズム: 教師モデルの知識を生徒モデルに転送する方法を定義
図1. 一般的な知識蒸留の仕組み

知識の種類

知識蒸留では、単にモデルの出力だけでなく、様々な種類の情報が「知識」として捉えられ、教師モデルから生徒モデルへ転送されます。知識蒸留における知識の種類は、主に応答ベース特徴ベース関係ベースの3つのカテゴリに分類できます。

図2. 知識ソースの概略図

応答ベースの知識 (Response-Based Knowledge)

応答ベースの知識は、最も直感的な知識の形式であり、教師モデルの最終的な出力(ロジットなど)を直接生徒モデルに伝達します。この方法はシンプルながら効果的で、多くのタスクで広く利用されています。

  • ソフトターゲット: 教師モデルの出力にソフトマックス関数をかけることで得られる確率分布を指す。教師モデルが各クラスに属する確率をより詳細に表すことができ、生徒モデルの学習に使用する。
  • 温度パラメータ: ソフトターゲットの分布を調整し、知識転送の効率を向上させるために使用する。
図3. 応答ベースの知識蒸留

長所:

  • 実装が簡単
  • 様々なタスクに適用可能

短所:

  • 中間層の情報を活用できない
  • 教師モデルの構造に依存
図4. 応答ベース知識蒸留の具体例 (Hinton et al., 2015)

特徴ベースの知識 (Feature-Based Knowledge)

特徴ベースの知識は、教師モデルの中間層の出力(特徴マップ)を知識として利用します。特徴マップには、入力データの高レベルな抽象表現が含まれており、生徒モデルの学習を改善できる情報を提供します。

  • 特徴マップの一致: 教師モデルと生徒モデルの特徴マップを直接比較し、その間の距離を最小化する。
  • 注意マップ: 特徴マップの中から重要な部分を抽出し、その情報を生徒モデルに伝える。
  • 因子: 特徴マップをより簡潔な表現に変換し、知識転送を効率化する。
図5. 特徴ベース知識蒸留

長所:

  • 中間層の情報を活用できる
  • より深いレベルでの知識転送が可能

短所:

  • 特徴マップの次元数が異なる場合、対応付けが難しい
  • ハイパーパラメータの調整が複雑
手法知識タイプ知識ソース蒸留損失
Fitnet (Romero et al., 2015)Feature representationHint layerL2損失
NST (Huang and Wang, 2017)Neuron selectivity patternsHint layer最大平均不一致損失
AT (Zagoruyko and Komodakis, 2017)Attention mapsMulti-layer groupL2損失
FT (Kim et al., 2018)ParaphraserMulti-layer groupL1損失
Rocket Launching (Zhou et al., 2018)Sharing parametersHint layerL2損失
KR (Liu et al., 2019c)Parameters distributionMulti-layer group交差エントロピー損失
AB (Heo et al., 2019c)Activation boundariesPre-ReLUL2損失
Shen et al. (2019a)Knowledge amalgamationHint layerL2損失
Heo et al. (2019a)Margin ReLUPre-ReLUL2損失
FN (Xu et al., 2020b)Feature representationFully-connected layer交差エントロピー損失
DFA (Guan et al., 2020)Feature aggregationHint layerL2損失
AdaIN (Yang et al., 2020a)Feature statisticsHint layerL2損失
FN (Xu et al., 2020b)Feature representationPenultimate layer交差エントロピー損失
EC-KD (Wang et al., 2020b)Feature representationHint layerL2損失
ALP-KD (Passban et al., 2021)Attention-based layer projectionHint layerL2損失
SemCKD (Chen et al., 2021)Feature mapsHint layerL2損失
表1. 特徴ベース知識蒸留の代表的手法

関係ベースの知識 (Relation-Based Knowledge)

関係ベースの知識は、教師モデルの異なる層間の関係や、データサンプル間の関係を知識として利用します。

  • FSP (Fisher Vector) 行列: 特徴マップ間の相関関係を表す行列。
  • インスタンス関係グラフ: データサンプル間の関係をグラフで表現。
  • 多様体学習: データを低次元空間に埋め込み、その構造を保持しながら知識を転送。
図6. 関係ベースの知識蒸留

長所:

  • データの構造的な情報を活用できる
  • より深いレベルでの知識転送が可能

短所:

  • 計算コストが高い
  • ハイパーパラメータの調整が複雑
知識タイプ知識ソース蒸留損失
FSP (Yim et al., 2017)FSP matrixEnd of multi-layer groupL2損失
You et al. (2017)Instance relationHint layersL2損失
Zhang and Peng (2018)Logits graph, Representation graphSoftmax layers, Hint layersEarth Mover距離, 最大平均不一致損失
DarkRank (Chen et al., 2018c)Similarity DarkRankFully-connected layersカルバック・ライブラー発散
MHGD (Lee and Song, 2019)Multi-head graphHint layersカルバック・ライブラー発散
RKD (Park et al., 2019)Instance relationFully-connected layersHuber損失, Angle-wise損失
IRG (Liu et al., 2019g)Instance relationship graphHint layersL2損失
SP (Tung and Mori, 2019)Similarity matrixHint layersFrobeniusノルム
CCKD (Peng et al., 2019a)Instance relationHint layersL2損失
MLKD (Yu et al., 2019)Instance relationHint layersFrobeniusノルム
PKT(Passalis et al., 2020a)Similarity probability distributionFully-connected layersカルバック・ライブラー発散
Passalis et al. (2020b)Mutual information flowHint layersカルバック・ライブラー発散
LP (Chen et al., 2021)Instance relationHint layersL2損失
表2. 関係ベース知識蒸留の代表的手法

蒸留スキーム

知識蒸留における学習スキームは、教師モデルと生徒モデルの更新方法によって大きく3つに分類されます。

蒸留手法特徴長所短所補足説明
オフライン蒸留教師モデルを事前に学習させ、その知識を固定して生徒モデルに転送する。シンプルで実装が容易、教師モデルの知識を事前に抽出し、再利用可能教師モデルの学習コストが高い、生徒モデルは教師モデルに大きく依存するバニラ知識蒸留などが代表的な手法。
教師モデルの知識を「蒸留」して、生徒モデルに「注入」するイメージ。
事前学習された大規模モデルの知識を、小型デバイスに適したモデルに転送する際に有効。
オンライン蒸留教師モデルと生徒モデルを同時に学習させる。教師モデルと生徒モデルが相互に学習することで、より高性能なモデルが得られる可能性がある。実装が複雑、ハイパーパラメータの調整が難しい深層相互学習などが代表的な手法。
教師モデルと生徒モデルが互いに教え合いながら学習を進める。
自己蒸留教師モデルと生徒モデルに同じアーキテクチャのネットワークを使用する。ネットワークの設計が複雑になる可能性があるスナップショット蒸留などが代表的な手法。
同じネットワークを異なる視点から見ることで、より深い学習を実現。
モデルの汎化性能を向上させる効果が期待できる。
表3. 知識蒸留における学習スキーム

Teacher-Studentアーキテクチャ

知識蒸留において、教師モデルと生徒モデルのアーキテクチャは、知識転移の成否を大きく左右します。教師モデルから生徒モデルへの知識の獲得と蒸留の質は、両者のネットワーク構造の設計に大きく依存します。

従来型アーキテクチャ

従来の知識蒸留では、一般的に以下の様な教師-生徒モデルの組み合わせが採用されてきました。

図7. 教師モデルと生徒モデルの関係
  • 教師モデル: 深く幅の広い大規模なネットワーク
  • 生徒モデル:
    • 教師モデルの簡略化バージョン(層数やチャネル数を減らす)
    • 教師モデルの量子化バージョン
    • パラメータ効率のよい計算をできる小規模ネットワーク
    • ネットワーク全体として最適化された小規模ネットワーク
    • 教師モデルと同じネットワーク

    モデル容量のギャップと対策

    教師モデルと生徒モデルの間には、モデル容量に大きな差があることが多く、これが知識転送の妨げとなることがあります。この問題に対処するために、以下の様な対策が提案されています。

    • 教師アシスタントの導入: 教師アシスタントを導入し、教師モデルと生徒モデル間のトレーニングギャップを軽減。
      • 残差学習: 教師モデルと生徒モデルの出力の差を学習することで、より正確な知識転移を実現する。
    • ネットワークの量子化: 教師モデルを量子化することで、生徒モデルとの構造的な差異を小さくする。
    • 構造圧縮: 教師モデルの複数の層の知識を、生徒モデルの単一の層に転送する。
    • ブロック単位の知識転送: 教師モデルのネットワークを、サブネットワークと呼ばれるブロック単位に分割し、それぞれのブロックが持つ知識を生徒モデルの各レイヤに段階的に転送する。

    最新の動向

    最近の研究では、以下の様な新しい方向性が注目されています。

    • ニューラルアーキテクチャ検索 (NAS): 効率的なニューラルネットワークを自動的に設計する技術を、知識蒸留に適用する。
    • 動的な教師-生徒学習アーキテクチャ: 学習中に教師と生徒の構造を適応的に変化させる。

    PyTorchによる知識蒸留実践

    では、PyTorchを使用して知識蒸留を試してみましょう。まずは必要なパッケージをインストールします。

    # PyTorch のインストール
    # バージョンの指定方法は公式ページを参照: https://pytorch.org/get-started/previous-versions/
    $ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    
    # 実験で使用するPyTorch関連のパッケージ
    $ pip install torcheval torchinfo
    
    # 進捗バーを表示するためのパッケージ
    $ pip install tqdm

    ここからコードを実装していきます。最初に、必要なパッケージを読み込んでおきましょう。

    from typing import OrderedDict
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    
    from torcheval.metrics import MulticlassAccuracy
    from torchinfo import summary
    from torchvision.models import resnet18
    from tqdm import tqdm
    
    # GPUが利用できればCUDAを使う
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    実験で使用するデータセットには、CIFAR10(10クラスの画像分類データセット)を使用します。次のようにして、訓練用のデータセットを読み込みます。

    # 学習データはオーグメンテーションの処理を考慮
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.50, 1.0), ratio=(1.0, 1.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.RandomHorizontalFlip(p=0.1),
        transforms.RandomVerticalFlip(p=0.1),
        transforms.RandomAutocontrast(p=0.1),
        transforms.RandomEqualize(p=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.RandomPerspective(distortion_scale=0.2, p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # CIFAR-10データセットをロード
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)

    続けて、テスト用のデータセットも読み込みます。

    # テストデータには、オーグメンテーションの処理は不要
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # CIFAR-10データセットをロード
    test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

    読み込んだデータセットからDataLoaderのインスタンスを準備します。

    # データローダーの作成
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

    次に、ネットワークの定義ですが、教師ネットワークには事前学習済みのResNet18を利用します。

    # 教師ネットワークには事前学習済みの ResNet18 を利用
    teacher_nn = resnet18(pretrained=True)
    
    # 出力層を10クラス分類に対応させる
    teacher_nn.fc = nn.Linear(in_features=teacher_nn.fc.in_features, out_features=10)
    
    # ネットワーク情報
    summary(teacher_nn, input_size=(2, 3, 32, 32))
    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    ResNet                                   [2, 10]                   --
    ├─Conv2d: 1-1                            [2, 64, 16, 16]           9,408
    ├─BatchNorm2d: 1-2                       [2, 64, 16, 16]           128
    ├─ReLU: 1-3                              [2, 64, 16, 16]           --
    ├─MaxPool2d: 1-4                         [2, 64, 8, 8]             --
    ├─Sequential: 1-5                        [2, 64, 8, 8]             --
    │    └─BasicBlock: 2-1                   [2, 64, 8, 8]             --
    │    │    └─Conv2d: 3-1                  [2, 64, 8, 8]             36,864
    │    │    └─BatchNorm2d: 3-2             [2, 64, 8, 8]             128
    │    │    └─ReLU: 3-3                    [2, 64, 8, 8]             --
    │    │    └─Conv2d: 3-4                  [2, 64, 8, 8]             36,864
    │    │    └─BatchNorm2d: 3-5             [2, 64, 8, 8]             128
    │    │    └─ReLU: 3-6                    [2, 64, 8, 8]             --
    │    └─BasicBlock: 2-2                   [2, 64, 8, 8]             --
    │    │    └─Conv2d: 3-7                  [2, 64, 8, 8]             36,864
    │    │    └─BatchNorm2d: 3-8             [2, 64, 8, 8]             128
    │    │    └─ReLU: 3-9                    [2, 64, 8, 8]             --
    │    │    └─Conv2d: 3-10                 [2, 64, 8, 8]             36,864
    │    │    └─BatchNorm2d: 3-11            [2, 64, 8, 8]             128
    │    │    └─ReLU: 3-12                   [2, 64, 8, 8]             --
    ├─Sequential: 1-6                        [2, 128, 4, 4]            --
    │    └─BasicBlock: 2-3                   [2, 128, 4, 4]            --
    │    │    └─Conv2d: 3-13                 [2, 128, 4, 4]            73,728
    │    │    └─BatchNorm2d: 3-14            [2, 128, 4, 4]            256
    │    │    └─ReLU: 3-15                   [2, 128, 4, 4]            --
    │    │    └─Conv2d: 3-16                 [2, 128, 4, 4]            147,456
    │    │    └─BatchNorm2d: 3-17            [2, 128, 4, 4]            256
    │    │    └─Sequential: 3-18             [2, 128, 4, 4]            8,448
    │    │    └─ReLU: 3-19                   [2, 128, 4, 4]            --
    │    └─BasicBlock: 2-4                   [2, 128, 4, 4]            --
    │    │    └─Conv2d: 3-20                 [2, 128, 4, 4]            147,456
    │    │    └─BatchNorm2d: 3-21            [2, 128, 4, 4]            256
    │    │    └─ReLU: 3-22                   [2, 128, 4, 4]            --
    │    │    └─Conv2d: 3-23                 [2, 128, 4, 4]            147,456
    │    │    └─BatchNorm2d: 3-24            [2, 128, 4, 4]            256
    │    │    └─ReLU: 3-25                   [2, 128, 4, 4]            --
    ├─Sequential: 1-7                        [2, 256, 2, 2]            --
    │    └─BasicBlock: 2-5                   [2, 256, 2, 2]            --
    │    │    └─Conv2d: 3-26                 [2, 256, 2, 2]            294,912
    │    │    └─BatchNorm2d: 3-27            [2, 256, 2, 2]            512
    │    │    └─ReLU: 3-28                   [2, 256, 2, 2]            --
    │    │    └─Conv2d: 3-29                 [2, 256, 2, 2]            589,824
    │    │    └─BatchNorm2d: 3-30            [2, 256, 2, 2]            512
    │    │    └─Sequential: 3-31             [2, 256, 2, 2]            33,280
    │    │    └─ReLU: 3-32                   [2, 256, 2, 2]            --
    │    └─BasicBlock: 2-6                   [2, 256, 2, 2]            --
    │    │    └─Conv2d: 3-33                 [2, 256, 2, 2]            589,824
    │    │    └─BatchNorm2d: 3-34            [2, 256, 2, 2]            512
    │    │    └─ReLU: 3-35                   [2, 256, 2, 2]            --
    │    │    └─Conv2d: 3-36                 [2, 256, 2, 2]            589,824
    │    │    └─BatchNorm2d: 3-37            [2, 256, 2, 2]            512
    │    │    └─ReLU: 3-38                   [2, 256, 2, 2]            --
    ├─Sequential: 1-8                        [2, 512, 1, 1]            --
    │    └─BasicBlock: 2-7                   [2, 512, 1, 1]            --
    │    │    └─Conv2d: 3-39                 [2, 512, 1, 1]            1,179,648
    │    │    └─BatchNorm2d: 3-40            [2, 512, 1, 1]            1,024
    │    │    └─ReLU: 3-41                   [2, 512, 1, 1]            --
    │    │    └─Conv2d: 3-42                 [2, 512, 1, 1]            2,359,296
    │    │    └─BatchNorm2d: 3-43            [2, 512, 1, 1]            1,024
    │    │    └─Sequential: 3-44             [2, 512, 1, 1]            132,096
    │    │    └─ReLU: 3-45                   [2, 512, 1, 1]            --
    │    └─BasicBlock: 2-8                   [2, 512, 1, 1]            --
    │    │    └─Conv2d: 3-46                 [2, 512, 1, 1]            2,359,296
    │    │    └─BatchNorm2d: 3-47            [2, 512, 1, 1]            1,024
    │    │    └─ReLU: 3-48                   [2, 512, 1, 1]            --
    │    │    └─Conv2d: 3-49                 [2, 512, 1, 1]            2,359,296
    │    │    └─BatchNorm2d: 3-50            [2, 512, 1, 1]            1,024
    │    │    └─ReLU: 3-51                   [2, 512, 1, 1]            --
    ├─AdaptiveAvgPool2d: 1-9                 [2, 512, 1, 1]            --
    ├─Linear: 1-10                           [2, 10]                   5,130
    ==========================================================================================
    Total params: 11,181,642
    Trainable params: 11,181,642
    Non-trainable params: 0
    Total mult-adds (Units.MEGABYTES): 74.05
    ==========================================================================================
    Input size (MB): 0.02
    Forward/backward pass size (MB): 1.62
    Params size (MB): 44.73
    Estimated Total Size (MB): 46.37
    ==========================================================================================

    合わせて生徒ネットワークを定義していきます。今回は、以下のように単純な畳込みネットワークとします。

    # 生徒ネットワークの定義
    class StudentNN(nn.Module):
        def __init__(self, num_classes=10):
            super(StudentNN, self).__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 16, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
    
                nn.Conv2d(16, 16, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
            )
    
            self.classifier = nn.Sequential(
                nn.Linear(1024, 256),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(256, num_classes)
            )
    
        def forward(self, x):
            x = self.features(x)
            x = torch.flatten(x, 1)
            x = self.classifier(x)
            return x
    
    
    # 生徒ネットワーク
    student_nn = StudentNN()
    summary(student_nn, input_size=(2, 3, 32, 32))
    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    StudentNN                                [2, 10]                   --
    ├─Sequential: 1-1                        [2, 16, 8, 8]             --
    │    └─Conv2d: 2-1                       [2, 16, 32, 32]           448
    │    └─ReLU: 2-2                         [2, 16, 32, 32]           --
    │    └─MaxPool2d: 2-3                    [2, 16, 16, 16]           --
    │    └─Conv2d: 2-4                       [2, 16, 16, 16]           2,320
    │    └─ReLU: 2-5                         [2, 16, 16, 16]           --
    │    └─MaxPool2d: 2-6                    [2, 16, 8, 8]             --
    ├─Sequential: 1-2                        [2, 10]                   --
    │    └─Linear: 2-7                       [2, 256]                  262,400
    │    └─ReLU: 2-8                         [2, 256]                  --
    │    └─Dropout: 2-9                      [2, 256]                  --
    │    └─Linear: 2-10                      [2, 10]                   2,570
    ==========================================================================================
    Total params: 267,738
    Trainable params: 267,738
    Non-trainable params: 0
    Total mult-adds (Units.MEGABYTES): 2.64
    ==========================================================================================
    Input size (MB): 0.02
    Forward/backward pass size (MB): 0.33
    Params size (MB): 1.07
    Estimated Total Size (MB): 1.43
    ==========================================================================================

    ネットワークが定義できたので、知識蒸留前に教師ネットワークと生徒ネットワークの精度を確認していきます。まずは、以下のように訓練用の関数を定義します。

    def train(model, train_loader, epochs, learning_rate, device):
        # 損失関数とオプティマイザを定義
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
        metric = MulticlassAccuracy(average="micro", num_classes=10)
    
        model.to(device)
        model.train()
    
        for epoch in range(epochs):
            with tqdm(train_loader) as pbar:
                pbar.set_description(f"[Epoch {epoch + 1}/{epochs}]")
    
                for inputs, labels in pbar:
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    # 順伝搬処理
                    optimizer.zero_grad()
                    outputs = model(inputs)
    
                    # バックプロパゲーション
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
    
                    # Accuracyをログ出力
                    metric.update(outputs, labels)
                    accuracy = metric.compute().item()
                    pbar.set_postfix(OrderedDict(Accuracy=accuracy))
    
                metric.reset()

    続けて、テスト用の関数です。

    def test(model, test_loader, device):
        model.to(device)
        model.eval()
    
        metric = MulticlassAccuracy(average="micro", num_classes=10)
    
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
    
                outputs = model(inputs)
                metric.update(outputs, labels)
    
        return metric.compute().item()

    精度確認の準備ができたので、まずは教師ネットワークを訓練して、精度を確認します。

    # 教師ネットワークの訓練
    train(teacher_nn, train_loader, epochs=20, learning_rate=1.0E-3, device=device)
    
    # 教師ネットワークの精度を確認
    accuracy = test(teacher_nn, test_loader, device)
    print(f"Teacher NN Accuracy: {accuracy}")
    Teacher NN Accuracy: 0.8389999866485596

    次に、知識蒸留前の生徒ネットワークの精度を確認します。

    # 生徒ネットワークの訓練
    train(student_nn, train_loader, epochs=20, learning_rate=0.001, device=device)
    
    # 生徒ネットワークの精度を確認
    accuracy = test(student_nn, test_loader, device)
    print(f"Student NN Accuracy: {accuracy}")
    Student NN Accuracy: 0.6978999972343445

    では、ここから知識蒸留の実装を進めて行きますが、まずは損失関数を定義します。今回は、Hinton et al., 2015 の論文中で定義されているものを使用します。この損失関数は交差エントロピー損失をベースにしたものとなっています。なお、詳細については、原論文を参照してください。

    以下、損失関数の実装になります。

    class SoftCrossEntropyLoss(nn.Module):
        def __init__(self, temperature: float = 1.0) -> None:
            super().__init__()
            if temperature <= 0.0:
                raise ValueError("温度パラメータが不正です。0より大きい値にしてください。")
            self.T = temperature
    
        def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            soft_prob = nn.functional.log_softmax(input / self.T, dim=-1)
            soft_targets = nn.functional.softmax(target / self.T, dim=-1)
    
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (self.T**2)
    
            return soft_targets_loss

    次に、この損失関数を利用した知識蒸留の関数を実装します。

    def knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, soft_target_weight, device):
        ce_loss = nn.CrossEntropyLoss()
        soft_loss = SoftCrossEntropyLoss(temperature=2.0)
        optimizer = optim.AdamW(student.parameters(), lr=learning_rate, betas=(0.9, 0.999))
        metric = MulticlassAccuracy(average="micro", num_classes=10)
    
        # 教師ネットワークの勾配計算は凍結
        teacher.to(device)
        teacher.eval()
    
        # 生徒ネットワークのみ学習させる
        student.to(device)
        student.train()
    
        for epoch in range(epochs):
            with tqdm(train_loader) as pbar:
                pbar.set_description(f"[Epoch {epoch + 1}/{epochs}]")
    
                for inputs, labels in pbar:
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    optimizer.zero_grad()
                    student_logits = student(inputs)
    
                    with torch.no_grad():
                        teacher_logits = teacher(inputs)
    
                    # 交差エントピー損失にMMD損失を加えたものをトータルの損失とする
                    hard_target_loss = (1.0 - soft_target_weight ) * ce_loss(student_logits, labels)
                    soft_target_loss = soft_target_weight * soft_loss (student_logits, teacher_logits)
                    loss = soft_target_loss + hard_target_loss
    
                    loss.backward()
                    optimizer.step()
    
                    metric.update(student_logits, labels)
                    accuracy = metric.compute().item()
                    pbar.set_postfix(OrderedDict(Accuracy=accuracy, HardLoss=hard_target_loss.item(), SoftLoss=soft_target_loss.item()))
    
                metric.reset()

    では、知識蒸留を実行して精度を確認してみましょう。

    student_nn = StudentNN()
    
    knowledge_distillation(
        teacher=teacher_nn,
        student=student_nn,
        train_loader=train_loader,
        epochs=20,
        learning_rate=1.0E-3,
        soft_target_weight=0.10,
        device=device
    )
    
    # 知識蒸留後の精度を確認
    accuracy = test(student_nn, test_loader, device)
    print(f"Knowledge Distillation Accuracy: {accuracy}")
    Knowledge Distillation Accuracy: 0.703499972820282

    上記のように、知識蒸留を行うことで生徒ネットワークの精度がベースラインからわずかに改善しました。今回は、ネットワークの出力であるLogitsを損失関数に利用したので、応答ベースの知識蒸留と言えます。また、学習済みの教師ネットワークを利用しているため、蒸留のスキーマとしてはオフライン蒸留に区分されます。

    この例では、交差エントロピーをベーズにした損失を使用しましたが、L1損失やL2損失に置き換えても問題ありません。また、複数の損失を組み合わせたハイブリッドな知識蒸留を行うのもよいかもしれません。

    さらに詳しいことを知りたい場合は、PyTorchが公開しているチュートリアルKnowledge Distillation Tutorialを参照してみてください。

    おわりに

    知識蒸留とは、大規模なモデルが学習した知識を、より小さなモデルに転送する技術です。

    今回は、知識蒸留の基礎的な概念から、具体的な手法、そして教師と生徒のモデルのアーキテクチャまで解説しました。また、PyTorchを使用した具体的な実装方法も紹介しました。

    知識蒸留は、深層学習モデルの軽量化や高速化に寄与し、様々な分野への応用が期待されています。

    More Information