ディープラーニングモデルの量子化: PyTorchによる実践解説

近年、ディープラーニングは画像認識、自然言語処理、音声認識など、多岐にわたる分野で目覚ましい成果を上げています。しかし、これらの高精度なモデルは、しばしば膨大なパラメータ数を持ち、その結果として大きなメモリ消費量や計算コストを必要とします。特に、スマートフォンや組み込みシステムのようなリソースに制約のあるエッジデバイス上でこれらのモデルを動作させることは、依然として大きな課題です。

この課題に対する有力な解決策の一つが「量子化(Quantization)」です。量子化とは、ニューラルネットワークの重みや活性化関数などの値を、より少ないビット数で表現する技術です。通常、32ビットの浮動小数点数で表現されるこれらの値を、8ビット整数などに変換することで、モデルサイズの大幅な削減、推論速度の向上、さらには消費電力の低減を実現できます。

本記事では、このディープラーニングモデルの量子化技術に焦点を当て、その基本的な理論から、PyTorchを用いた具体的な実装方法までを解説します。

量子化の概要

ディープラーニングモデルの量子化は、モデルの性能を維持しつつ、そのサイズと計算コストを削減するための重要な技術です。ここでは、その理論的背景、必要性、具体的なユースケース、そしてメリットとデメリットについて詳しく見ていきましょう。

理論的背景と動作原理

量子化の基本的な考え方は、ニューラルネットワークで使用される数値の精度を意図的に下げることです。通常、ディープラーニングモデルの重みや活性化(ニューロンの出力)は、32ビット浮動小数点数(FP32)で表現されます。これは広いダイナミックレンジと高い精度を提供しますが、多くのメモリを消費し、計算にも時間がかかります。

量子化では、これらのFP32の値を、よりビット数の少ないデータ型、例えば8ビット整数(INT8)や、場合によってはさらに低いビット数(例:4ビット、2ビット、あるいは1ビットのバイナリ値)に変換します。この変換処理は、一般的に以下の式で表されます。

$$
\text{Q} (x) = \text{round} (\frac{x}{S})+Z
$$

ここで、\(x\)は元の浮動小数点数、\(S\)はスケールファクター、\(Z\)はゼロポイント(オフセット)です。スケールファクター\(S\)は、元の浮動小数点数の範囲を量子化後の整数の範囲にマッピングするための係数です。ゼロポイント\(Z\)は、元の浮動小数点数の0が量子化後のどの整数値に対応するかを示すオフセット値です。

この変換により、数値の表現に必要なビット数が減り、結果としてモデル全体のサイズが小さくなります。また、整数演算は浮動小数点演算に比べて計算ユニットが単純で済むため、推論速度の向上や消費電力の削減にも繋がります。

なぜ量子化が必要とされるのか

量子化が必要とされる主な理由は以下の通りです。

  • モデルの軽量化: 高度なディープラーニングモデルは、数百万から数十億ものパラメータを持つことがあります。これによりモデルファイルが非常に大きくなり、ストレージ容量を圧迫したり、ネットワーク経由での配信が困難になったりします。量子化によってモデルサイズを数分の一に削減できます。
  • 推論速度の向上: 整数演算は浮動小数点演算よりも高速に実行できます。特に、専用のハードウェアアクセラレータ(例:GPU、TPU、NPU)では、低ビット整数の並列演算に最適化されている場合があり、大幅な速度向上が期待できます。
  • エッジデバイスへの展開: スマートフォン、ドローン、IoTデバイスなどのエッジデバイスは、計算能力、メモリ容量、バッテリー容量に大きな制約があります。量子化された軽量なモデルは、これらのリソースが限られた環境でも効率的に動作させることが可能になります。
  • 省電力化: 計算量の削減とメモリフットプリントの縮小は、消費電力の低減に直接的に貢献します。これは、バッテリー駆動のデバイスや、大規模なデータセンターにおける運用コスト削減の観点から非常に重要です。

具体的なユースケース

量子化技術は、すでに様々な分野で活用されています。

  • モバイル端末でのCNNモデル推論: スマートフォンのカメラアプリにおけるリアルタイム物体認識、画像フィルタリング、拡張現実(AR)機能など。
  • 音声アシスタント: スマートスピーカーやスマートフォンに搭載される音声認識モデルの軽量化と応答速度の向上。
  • 自動運転: 車載システムにおける歩行者検知、車線認識などのためのモデルを、限られた計算資源でリアルタイムに実行。
  • 産業用IoT: 工場内のセンサーデータを処理し異常検知を行うエッジAIシステム。

これらのユースケースでは、レイテンシの低減、バッテリー消費の抑制、オフラインでの動作といった要件があり、量子化が重要な役割を果たします。

メリットとデメリットのバランス

量子化は多くの利点をもたらしますが、いくつかの潜在的な課題も存在します。

メリット:

  • モデルサイズの削減: ストレージコストの削減、配信効率の向上。
  • 推論速度の向上: レイテンシの短縮、スループットの向上。
  • メモリ使用量の削減: 特に実行時のRAM消費量を抑える。
  • 消費電力の削減: バッテリー寿命の延長、運用コストの削減。
  • ハードウェアサポートの活用: 低ビット演算に特化したプロセッサの性能を最大限に引き出す。

デメリット:

  • 精度の低下: 数値の精度を下げるため、元のモデルと比較して若干の精度低下が生じる可能性があります。特に、極端に低いビット数(例:1ビットや2ビット)に量子化すると、精度への影響が大きくなる傾向があります。
  • 量子化手法の選択と調整の複雑さ: モデルの特性やタスクに応じて、適切な量子化手法を選択し、パラメータ(スケールファクターやゼロポイント)を適切に調整する必要があります。これには試行錯誤が伴うことがあります。
  • 量子化に適さないモデルやレイヤーの存在: 一部のモデルアーキテクチャや特定のレイヤーは、量子化による精度低下の影響を受けやすい場合があります。

量子化を導入する際には、これらのメリットとデメリットを総合的に評価し、許容できる精度低下の範囲内で、最大限の効率化を目指すことが重要です。多くの場合、わずかな精度低下と引き換えに、モデルサイズや推論速度の大幅な改善を得ることができます。

代表的な量子化手法の比較と解説

ディープラーニングモデルの量子化には、いくつかの代表的なアプローチが存在します。それぞれの手法は、量子化を適用するタイミングや対象、そしてそれに伴う精度やパフォーマンスへの影響、導入の容易さが異なります。ここでは、主要な3つの手法である「Dynamic Quantization(動的量子化)」、「Static Quantization(静的量子化)」、そして「Quantization Aware Training(QAT、量子化を意識した学習)」について、その特徴とPyTorchにおける位置づけを解説します。

Dynamic Quantization(動的量子化)

概要: 動的量子化は、主にモデルの重み (weights) を事前に低ビット整数(例: INT8)に変換しておき、活性化 (activations) に関しては推論実行時に浮動小数点数から整数へ動的に量子化し、計算後すぐに浮動小数点数に戻す手法です。

動作原理:

  1. 重みの量子化: モデルの学習後、重みパラメータはINT8などの低ビット整数に変換され、保存されます。この際、各重みテンソル(またはその一部)に対してスケールファクターとゼロポイントが計算されます。
  2. 活性化の動的量子化: 推論時、ある層の入力となる活性化(前層の出力)が浮動小数点数で渡されると、その都度INT8に量子化されます。この量子化のパラメータ(スケールとゼロポイント)は、活性化テンソルの実際の最小値・最大値に基づいて実行時に計算されます。
  3. 整数演算: 量子化された重みと量子化された活性化を用いて、整数演算(主にINT8行列積)が実行されます。
  4. 結果の逆量子化: 演算結果(整数)は、次の層への入力として渡すために、再び浮動小数点数に戻されます(Dequantization)。

特徴:

  • 導入の容易さ: キャリブレーションデータセットを用意する必要がなく、学習済みのFP32モデルに対して比較的簡単に適用できます。PyTorchでは数行のコードで実装可能です。
  • モデルサイズ削減: 重みを低ビットで保存するため、モデルサイズの大幅な削減効果があります。
  • 推論速度: 整数演算による高速化が期待できますが、活性化の動的変換とデ量子化のオーバーヘッドがあるため、Static Quantizationほどの大幅な速度向上にはならない場合があります。
  • 適したモデル: LSTM、Transformer、RNNなど、活性化の分布が入力データによって大きく変動し、事前のキャリブレーションが難しいモデルに適しています。

Static Quantization(静的量子化)

概要: 静的量子化(Post-Training Static Quantizationとも呼ばれます)は、モデルの重みと活性化の両方を事前に低ビット整数に量子化する手法です。活性化の量子化パラメータ(スケールとゼロポイント)を決定するために、「キャリブレーション」というステップが必要になります。

動作原理:

  1. 重みの量子化: 動的量子化と同様に、学習後のモデルの重みパラメータを低ビット整数に変換します。
  2. キャリブレーション:
    • 代表的な入力データセット(キャリブレーションデータ)を少量用意します。
    • このキャリブレーションデータを用いて、FP32モデルを実際に数回推論実行(フォワードパス)します。
    • 各層の活性化の統計情報(例えば、最小値と最大値)を収集・観測します。
    • 収集した統計情報に基づいて、各活性化テンソルに対する固定のスケールファクターとゼロポイントを計算し、決定します。
  3. モデル変換: 量子化された重みと、キャリブレーションによって決定された活性化の量子化パラメータを持つ、完全に量子化されたモデルを生成します。
  4. 整数推論: 推論時は、入力データも最初の層で量子化され、以降の計算はすべて整数演算で行われます。活性化の量子化・逆量子化は実行時には行われません(または最小限に抑えられます)。

特徴:

  • 推論速度の大幅な向上: 重みと活性化の両方が低ビット整数で扱われ、演算も整数で行われるため、動的量子化よりも大きな推論速度の向上が期待できます。
  • キャリブレーションが必要: 活性化の量子化パラメータを決定するために、代表的な入力データ(検証データセットの一部など)を用いたキャリブレーション処理が必要です。このキャリブレーションデータの品質が、量子化後のモデルの精度に影響を与えることがあります。
  • 適したモデル: CNN(畳み込みニューラルネットワーク)など、活性化の分布が比較的入力データ間で安定しているモデルに適しています。
  • 精度: 一般的に、キャリブレーションを適切に行うことで、動的量子化よりも精度低下を抑えやすい傾向にあります。

Quantization Aware Training(QAT:量子化を意識した学習)

概要: Quantization Aware Training (QAT) は、モデルの学習プロセス自体に量子化処理を組み込む手法です。学習中に量子化による情報損失(丸め誤差など)をシミュレートし、モデルがその誤差に対して頑健になるようにパラメータを調整します。

動作原理:

  1. 偽量子化ノードの挿入: FP32モデルの学習中またはファインチューニング中に、重みと活性化の量子化・デ量子化をシミュレートする演算(偽量子化ノード、fake quantization nodes)をモデルグラフに挿入します。
    • フォワードパスでは、これらのノードは値を量子化し、すぐにデ量子化して浮動小数点数に戻します。これにより、量子化によって生じるであろう誤差がシミュレートされます。
    • バックワードパス(勾配計算)では、偽量子化ノードは通常、Straight-Through Estimator (STE) と呼ばれる手法を用いて勾配を近似的に計算し、量子化の非連続性による勾配消失問題を回避します。
  2. 量子化を意識した学習: この偽量子化されたモデルで学習を継続します。モデルは、量子化による情報損失が存在する状態で精度が最大になるように重みを調整していきます。
  3. 真の量子化モデルへの変換: 学習完了後、偽量子化ノードで学習された重みと量子化パラメータ(スケールとゼロポイント)を用いて、実際に低ビット整数で演算を行うモデルに変換します。

特徴:

  • 最高の精度維持: 量子化による精度低下を最小限に抑えることができ、多くの場合、元のFP32モデルに近い精度、あるいは同等の精度を達成できます。
  • 導入コストが高い: モデルの再学習またはファインチューニングが必要となるため、上記2つの手法(Post-Training Quantization: PTQ)と比較して、時間と計算リソースのコストが最も高くなります。
  • 複雑性: 学習パイプラインの変更が必要であり、実装の複雑性が増します。
  • 汎用性: 精度が非常に重要なタスクや、PTQでは精度低下が許容範囲を超えてしまう場合に有効な選択肢となります。
特徴Dynamic QuantizationStatic Quantization (PTQ)Quantization Aware Training (QAT)
主な量子化対象重み (アクティベーションは実行時)重み、アクティベーション重み、アクティベーション
キャリブレーション不要必要不要(学習データでパラメータ決定)
再学習不要不要必要
期待される精度△ (ベースラインより低下しやすい)〇 (比較的維持しやすい)◎ (最も維持しやすい)
期待される推論速度〇 (モデルサイズ削減効果大、速度向上は中程度)◎ (大幅な向上が期待できる)◎ (大幅な向上が期待できる)
導入コスト(容易さ)◎ (最も容易)〇 (キャリブレーションの手間)△ (再学習の手間と時間、実装の複雑性)
適したモデル例RNN, Transformerなどアクティベーションが変動するモデルCNNなどアクティベーションが比較的安定したモデル高い精度が求められるあらゆるモデル
備考手軽にモデルサイズを削減したい場合に有効キャリブレーションデータの代表性が重要精度低下を極力避けたい場合の最終手段

補足:

  • 精度: 一般的に、QAT > Static Quantization > Dynamic Quantization の順で精度が高くなる傾向があります。ただし、モデルの構造やタスク、データセットによって最適な手法は異なります。
  • 推論速度: Static QuantizationとQATは、重みと活性化の両方が事前に量子化されるため、同程度の高い推論速度が期待できます。Dynamic Quantizationは、実行時のアクティベーション変換コストがあるため、これら2つの手法よりは速度面で劣る可能性がありますが、それでもFP32よりは高速化が見込めます。
  • 導入コスト: 実装の手軽さや必要な時間で考えると、Dynamic Quantizationが最も低コストです。Static Quantizationはキャリブレーションのステップが加わります。QATは再学習を伴うため、最も時間と計算リソースを要します。

PyTorchによる実践例

このセクションでは、実際にPyTorchの torch.ao.quantization モジュール(PyTorch 1.3以降で導入され、継続的に改善されています。以前は torch.quantization)を用いて、MNIST手書き数字分類タスクを題材に、これまで解説した3つの主要な量子化手法、Dynamic Quantization、Static Quantization、Quantization Aware Training (QAT) を実装し、その効果を比較します。

準備

まず、データセットを準備する関数、モデル定義、評価関数など、各手法で共通で利用できる処理を準備します。

データセットの読み込み(MNIST)

# datasets.py

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


def load_mnist() -> tuple[DataLoader, DataLoader]:
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    train_dataset = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

    return train_loader, test_loader

モデル定義

# model.py

import torch
import torch.nn as nn
from torch.ao.quantization import DeQuantStub, QuantStub, fuse_modules


class SimpleCNN(nn.Module):
    def __init__(self, in_channels: int = 1, n_classes: int = 10) -> None:
        super(__class__, self).__init__()

        # CNN部分(順次構築)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(num_features=32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=1)
        self.relu = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(num_features=64)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(num_features=128)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # 全結合ヘッド
        self.fc1 = nn.Linear(in_features=128, out_features=64, bias=True)
        self.norm = nn.LayerNorm(normalized_shape=64)
        self.fc2 = nn.Linear(in_features=64, out_features=n_classes, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x


class QuantizableCNN(nn.Module):
    def __init__(self, in_channels: int = 1, n_classes: int = 10) -> None:
        super(__class__, self).__init__()
        self.quant = QuantStub()  # 入力を量子化

        # CNN部分(順次構築)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(num_features=32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=1)
        self.relu = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(num_features=64)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(num_features=128)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # 全結合ヘッド
        self.fc1 = nn.Linear(in_features=128, out_features=64, bias=True)
        self.norm = nn.LayerNorm(normalized_shape=64)
        self.fc2 = nn.Linear(in_features=64, out_features=n_classes, bias=True)

        self.dequant = DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.quant(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.fc2(x)

        return self.dequant(x)

    def fuse_model(self) -> None:
        self.eval()  # 融合の前に評価モードにする必要がある

        # Conv-ReLU/Conv-BN/Conv-BN-ReLU のような一部のシーケンスパターンのみサポートされている。
        # 参考: https://docs.pytorch.org/tutorials/recipes/fuse.html
        fuse_modules(self, ["conv2", "bn1"], inplace=True)
        fuse_modules(self, ["conv3", "bn2"], inplace=True)
        fuse_modules(self, ["conv4", "bn3"], inplace=True)

訓練用の関数

# trainer.py

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from model import QuantizableCNN, SimpleCNN


def train_model(
    model: SimpleCNN | QuantizableCNN,
    train_loader: DataLoader,
    num_epochs=100,
    learning_rate=0.001,
    device: str = "cuda:0",
) -> SimpleCNN | QuantizableCNN:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        losses = []
        for _, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            losses.append(float(loss))

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {sum(losses) / len(losses):.4f}")
    print("Finished Training FP32 model.")
    return model

評価用の関数

# helpers.py

import os

import torch


# 精度計算やモデルサイズ表示のためのヘルパー関数
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_model.p")
    size = os.path.getsize("temp_model.p")
    os.remove("temp_model.p")
    return size


def evaluate_model(model, data_loader, device="cpu"):
    model.to(device)
    model.eval()  # 評価モード
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

Dynamic Quantization (動的量子化) の実践

動的量子化は、学習済みモデルに対して最も簡単に適用できる量子化手法の一つです。重みのみを量子化し、活性化は実行時に動的に量子化されます。

import copy
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datasets import load_mnist
from helpers import evaluate_model, print_size_of_model
from model import SimpleCNN
from trainer import train_model


def train_base_model(train_loader: DataLoader) -> SimpleCNN:
    model = SimpleCNN()

    if os.path.exists("fp32_model.pth"):
        model.load_state_dict(torch.load("fp32_model.pth"))
        model.to("cuda:0")
    else:
        model = train_model(model, train_loader, num_epochs=10)
        torch.save(model.to("cpu").state_dict(), "fp32_model.pth")

    return model


def dynamically_quantize(fp32_model: SimpleCNN) -> SimpleCNN:
    # 1. モデルをコピー
    dynamic_quantized_model = copy.deepcopy(fp32_model)
    dynamic_quantized_model.to("cpu")  # 動的量子化は主にCPUでサポート
    dynamic_quantized_model.eval()

    # 2. 動的量子化を適用
    # torch.ao.quantization.quantize_dynamic() を使用
    # qconfig_spec で量子化するモジュールタイプを指定 (ここでは nn.Linear と nn.Conv2d)
    # dtype で重みのデータ型を指定 (torch.qint8)
    return torch.ao.quantization.quantize_dynamic(
        dynamic_quantized_model,
        qconfig_spec={nn.Linear, nn.Conv2d},  # 量子化対象のレイヤータイプ
        dtype=torch.qint8,
    )


def main() -> None:
    train_loader, test_loader = load_mnist()

    fp32_model = train_base_model(train_loader)
    size = print_size_of_model(fp32_model)
    accuracy = evaluate_model(fp32_model, test_loader, device="cuda:0")
    print(f"FP32 Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => FP32 Model Size=0.4381[MB], Accuracy=98.3%

    dynamic_quantized_model = dynamically_quantize(fp32_model)
    size = print_size_of_model(dynamic_quantized_model)
    accuracy = evaluate_model(dynamic_quantized_model, test_loader)
    print(f"Dynamic Quantized Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => Dynamic Quantized Model Size=0.4127[MB], Accuracy=98.3%


if __name__ == "__main__":
    main()

解説: torch.ao.quantization.quantize_dynamic 関数は、指定されたモジュール(ここでは nn.Linearnn.Conv2d)の重みを torch.qint8 (符号付き8ビット整数)に変換します。活性化は推論時に浮動小数点数から整数へ、そして再び浮動小数点数へと動的に変換されます。

Static Quantization (静的量子化) の実践

静的量子化では、重みと活性化の両方を事前に量子化します。活性化の量子化パラメータを決定するためにキャリブレーションが必要です。

import os

import torch
from torch.utils.data import DataLoader

from datasets import load_mnist
from helpers import evaluate_model, print_size_of_model
from model import QuantizableCNN, SimpleCNN
from trainer import train_model


def train_base_model(train_loader: DataLoader) -> SimpleCNN:
    model = SimpleCNN()

    if os.path.exists("fp32_model.pth"):
        model.load_state_dict(torch.load("fp32_model.pth"))
        model.to("cuda:0")
    else:
        model = train_model(model, train_loader, num_epochs=10)
        torch.save(model.to("cpu").state_dict(), "fp32_model.pth")

    return model


def statically_quantize(train_loader: DataLoader) -> QuantizableCNN:
    # 1. モデルの準備 (QuantizableSimpleCNNを使用)
    # FP32モデルの重みをロード
    static_quantized_model = QuantizableCNN()
    static_quantized_model.load_state_dict(torch.load("fp32_model.pth", map_location="cpu"))  # fp32_modelの重みをロード
    static_quantized_model.to("cpu")
    static_quantized_model.eval()

    # 2. モジュールの融合
    # これにより精度が向上し、実行速度も速くなる場合がある
    static_quantized_model.fuse_model()

    # 3. QConfig (量子化設定) の指定
    # デフォルトのqconfigを使用 (fbgemmバックエンド向け8ビット対称量子化)
    # static_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
    # より明示的に
    static_quantized_model.qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.default_observer.with_args(dtype=torch.quint8),  # 符号なし8ビット
        weight=torch.ao.quantization.default_per_channel_weight_observer.with_args(
            dtype=torch.qint8
        ),  # 符号あり8ビット、チャネルごと
    )

    # 4. 量子化の準備 (オブザーバーを挿入)
    # torch.ao.quantization.prepare() はモデルにオブザーバーを挿入し、
    # キャリブレーション中に活性化の統計情報を収集できるようにする
    static_model_prepared = torch.ao.quantization.prepare(
        static_quantized_model, inplace=False
    )  # inplace=Trueにすると元のモデルが変更される

    # 5. キャリブレーション
    print("Running calibration for static quantization...")
    with torch.no_grad():
        for images, _ in train_loader:  # ラベルは不要
            static_model_prepared(images.to("cpu"))
    print("Calibration done.")

    # 6. モデルの変換 (実際に量子化)
    # torch.ao.quantization.convert() は収集された統計情報に基づいてモデルを量子化
    return torch.ao.quantization.convert(static_model_prepared, inplace=False)


def main() -> None:
    train_loader, test_loader = load_mnist()

    fp32_model = train_base_model(train_loader)
    size = print_size_of_model(fp32_model)
    accuracy = evaluate_model(fp32_model, test_loader, device="cuda:0")
    print(f"FP32 Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => FP32 Model Size=0.4381[MB], Accuracy=98.3%

    static_quantized_model = statically_quantize(train_loader)
    size = print_size_of_model(static_quantized_model)
    accuracy = evaluate_model(static_quantized_model, test_loader)
    print(f"Static Quantized Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => Static Quantized Model Size=0.1257[MB], Accuracy=98.11%


if __name__ == "__main__":
    main()

解説: 静的量子化の主なステップは以下の通りです。

  1. モデルの準備と融合: QuantStubDeQuantStub を持つモデルを使用し、可能であれば fuse_modules で演算子を融合します。これにより、計算効率と量子化の精度が向上することがあります。
  2. QConfigの設定: 量子化の方法(オブザーバーの種類、データ型など)を定義します。get_default_qconfig('fbgemm')get_default_qconfig('qnnpack') でバックエンドに応じたデフォルト設定を取得できます。
  3. prepare: モデルにオブザーバー(統計情報を収集するモジュール)を挿入します。
  4. キャリブレーション: 訓練用データを使用して、オブザーバーに活性化の範囲を学習させます。
  5. convert: 学習したオブザーバーの情報を使って、モデルの重みと活性化を実際に量子化します。

Quantization Aware Training (QAT) の実践

QATは、学習プロセス中に量子化の影響をシミュレートし、精度低下を最小限に抑える手法です。ファインチューニングが必要になります。

import os

import torch
from torch.utils.data import DataLoader

from datasets import load_mnist
from helpers import evaluate_model, print_size_of_model
from model import QuantizableCNN, SimpleCNN
from trainer import train_model


def train_base_model(train_loader: DataLoader) -> SimpleCNN:
    model = SimpleCNN()

    if os.path.exists("fp32_model.pth"):
        model.load_state_dict(torch.load("fp32_model.pth"))
        model.to("cuda:0")
    else:
        model = train_model(model, train_loader, num_epochs=10)
        torch.save(model.to("cpu").state_dict(), "fp32_model.pth")

    return model


def quantization_aware_training(train_loader: DataLoader) -> QuantizableCNN:
    # 1. モデルの準備 (QuantizableSimpleCNNを使用)
    qat_model = QuantizableCNN()
    qat_model.load_state_dict(torch.load("fp32_model.pth", map_location="cpu"))  # FP32の学習済み重みから開始
    qat_model.to("cpu")  # QATはCPUでもGPUでも実行可能だが、ここではCPUで統一

    # 2. モジュールの融合
    qat_model.fuse_model()

    # 3. QConfig の指定 (QAT用)
    # QATでは学習も行うため、学習可能なオブザーバーを使用する
    # get_default_qat_qconfig() を使用
    # qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
    qat_model.qconfig = torch.ao.quantization.QConfig(
        activation=torch.ao.quantization.FakeQuantize.with_args(
            observer=torch.ao.quantization.MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
            reduce_range=False,
        ),
        weight=torch.ao.quantization.FakeQuantize.with_args(
            observer=torch.ao.quantization.MovingAveragePerChannelMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8,
            qscheme=torch.per_channel_symmetric,
        ),
    )

    # 4. QATの準備 (偽量子化ノードを挿入)
    # torch.ao.quantization.prepare_qat() を使用
    qat_model.train()
    qat_model_prepared = torch.ao.quantization.prepare_qat(qat_model, inplace=False)

    # 5. ファインチューニング (再学習)
    print("Running QAT fine-tuning...")
    qat_model_prepared = train_model(qat_model_prepared, train_loader, num_epochs=5, learning_rate=0.0001)
    print("QAT fine-tuning done.")

    # 6. 量子化されたモデルに変換
    qat_model_prepared.to("cpu")  # 変換前にCPUに戻す
    qat_model_prepared.eval()
    return torch.ao.quantization.convert(qat_model_prepared, inplace=False)


def main() -> None:
    train_loader, test_loader = load_mnist()

    fp32_model = train_base_model(train_loader)
    size = print_size_of_model(fp32_model)
    accuracy = evaluate_model(fp32_model, test_loader, device="cuda:0")
    print(f"FP32 Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => FP32 Model Size=0.4381[MB], Accuracy=98.3%

    qat_model = quantization_aware_training(train_loader)
    size = print_size_of_model(qat_model)
    accuracy = evaluate_model(qat_model, test_loader)
    print(f"QAT Quantized Model Size={size / 1e6:.4f}[MB], Accuracy={accuracy}%")
    # => QAT Quantized Model Size=0.1257[MB], Accuracy=99.09%


if __name__ == "__main__":
    main()

解説: QATの主なステップは以下の通りです。

  1. モデルの準備と融合: Static Quantizationと同様です。
  2. QConfigの設定: QATでは、学習中に量子化の範囲を調整するため、FakeQuantize モジュールを活性化と重みの両方に設定します。get_default_qat_qconfig() を使うと便利です。
  3. prepare_qat: モデルに偽量子化ノード(FakeQuantizeモジュール)を挿入します。これにより、フォワードパスでは量子化・デ量子化がシミュレートされ、バックワードパスでは勾配が適切に流れるようになります。
  4. ファインチューニング: 偽量子化ノードが挿入されたモデルを、通常の学習と同様に数エポック学習させます。学習率はFP32の初期学習時よりも低く設定するのが一般的です。
  5. convert: ファインチューニング後、偽量子化情報を使って真の量子化モデルに変換します。

結果の比較と考察

  • モデルサイズ: Static, QATで、FP32モデルと比較して大幅なモデルサイズ削減(INT8なら約1/4)が実現できました。Dynamic Quantizationは最も手軽ですが、今回は大きなモデルサイズの削減には至りませんでした(コードの実装に問題があるのかもしれません)。
  • 精度:
    • Dynamic Quantizationは最も手軽で、今回はもとのモデルと精度が変わりませんでした。通常は、ベースモデルよりも大きな精度悪化が見られるので、やはりコードの実装に問題があるのかもしれません。
    • Static Quantizationは、適切なキャリブレーションにより、高い精度を維持できることができました。
    • QATは、学習プロセスで量子化を考慮するため、3つの手法の中で最も高い精度を維持する(あるいはFP32に匹敵する精度を出す)ことが期待されます。
  • 推論速度: (このコード例では直接測定していませんが) 一般的に、Static QuantizationとQATは、活性化も事前に量子化されているため、Dynamic Quantizationよりも高速な推論が期待できます。整数演算器を効率的に利用できるためです。
手法モデルサイズ (MB)精度 (%)FP32比サイズ (%)備考
FP32 (ベースライン)0.4381 [MB]98.3%元のモデル
Dynamic Quantization0.4127 [MB]98.3%94.20%手軽、重みのみ量子化
Static Quantization0.1257 [MB]98.11%28.69%キャリブレーション要、活性化も量子化
Quantization Aware Training0.1257[MB]99.09%28.69%再学習要、高精度維持が期待できる

この実践例を通して、PyTorchを用いた各量子化手法の基本的な実装フローと、それぞれの特徴を理解いただけたかと思います。モデルの特性、許容できる精度低下、開発コストなどを考慮して、最適な量子化戦略を選択することが重要です。

おわりに

本記事では、ディープラーニングモデルの運用効率を飛躍的に高める「量子化」技術について、その基本原理から代表的な手法(動的量子化、静的量子化、QAT)、そしてPyTorchを用いた具体的な実装例までを解説しました。ご覧いただいたように、量子化はモデルサイズの大幅な削減、推論速度の向上、そしてエッジデバイスへの展開を可能にする強力な手段です。

PyTorchが提供するtorch.ao.quantizationモジュールを活用することで、これらの量子化手法を比較的容易に既存のワークフローに組み込むことができます。もちろん、導入する手法によって精度や実装コストにトレードオフが存在するため、対象となるモデルやタスクの要件、許容できる精度低下の範囲を慎重に評価することが肝心です。

量子化技術は日々進化しており、今後のさらなる発展にも期待が寄せられます。ぜひ、ご自身のプロジェクトで量子化の導入を検討してみてください。

More Information: