GradCAM: 深層学習モデルの判断根拠を可視化してみる

深層学習は、画像認識、自然言語処理など、様々な分野で応用されています。しかし、深層学習モデルは、非常に複雑な構造のため、なぜそのような判断を下すのか、その根拠を人間が理解することが困難です。

このブラックボックスである機械学習モデルの判断根拠を理解するために、近年XAI(Explainable AI: 説明可能なAI)と呼ばれる技術が注目されています。XAIは、機械学習モデルの判断根拠を可視化し、人間が理解できるようにする技術の総称です。

XAIの中でも特に画像分類モデルの可視化に用いられる代表的な手法がGradCAMです。GradCAMを用いることで、モデルが画像のどの部分に着目して分類を行っているのか可視化して理解できます。

今回は、GradCAMの仕組みを解説し、実際にPythonで実装してみたいと思います。

XAIとは?

XAI は、「Explainable Artificial Intelligence」の略で、説明可能な人工知能 と訳されます。

近年、高度なAIモデルが複雑なタスクをこなせるようになってきましたが、その判断過程がブラックボックス化し、人間が理解するのが難しくなっています。XAIは、AIがどのようにしてその結論に至ったのか、人間が理解できる形式で説明することを目指す技術の総称です。

XAIが必要な理由

  • 信頼性の向上: AIの判断理由が透明化することで、人間はAIをより信頼できるようになります。
  • 責任性の確保: AIが誤った判断をした場合、その原因を特定し、改善することができます。
  • 規制への対応: AIの利用に関する規制が厳しくなる中で、XAIはコンプライアンス確保に役立ちます。

XAIの活用例

  • 医療診断: AIが診断を下した理由を医師に説明することで、診断の精度向上や誤診の防止に繋げることができます。
  • 金融: 融資審査の判断理由を顧客に説明することで、透明性を高め、顧客満足度向上に貢献できます。
  • 自動運転: 自律運転車がなぜ特定の行動を取ったのかを説明することで、安全性向上に繋がります。
様々なXAI技術: https://arxiv.org/pdf/2006.11371 より引用

GradCAMの仕組み

まずCAM(Class Activation Map)について解説します。CAMは、GradCAMのベースになったアルゴリズムで、画像分類タスクの畳み込みニューラルネットワークに特化したXAI技術です。

CAMの具体的な動作は、以下のようになります。

  1. 特徴の抽出: 入力された画像は、畳み込み層で処理され、様々な特徴が抽出されます。その後、Global Average Pooling 層 (GAP) を経由して全結合層に各特徴量が渡ります。
  2. 重要度の計算: それぞれ特徴量の重要度は、全結合層と呼ばれる部分の重み係数によって決まります。この重み係数は、学習中に調整され、ある特徴がどのくらい重要なのかを示す指標となります。
  3. ヒートマップの作成: 各特徴の重要度を元に、元の画像と同じサイズのヒートマップを作成します。ヒートマップの明るい部分は、その部分が最終的な判断に大きく影響していることを示します。

上記のように、CAMGlobal Average Pooling (GAP) を用いた特定のタイプの畳み込みニューラルネットワークにしか適用できません。そこで、GradCAMでは全結合層の重み係数の代わりに勾配情報を利用して、より一般的な畳み込みニューラルネットワークに適用できるように拡張されています。

GradCAMの具体的な動作は、以下のようになります。

  1. 順伝播: 入力画像を畳み込み層に入力し、最終的な出力(例えば、あるクラスに属する確率)を得ます。
  2. 逆伝播: 最終出力に対して、特定のクラス(例えば、「猫」)に関する勾配を計算します。この勾配は、そのクラスの出力に最も貢献した特徴マップ(畳み込み層の出力)を特定するために使用されます。
  3. 重み付け: 各特徴マップの勾配の平均を計算し、それを重みとして使用します。この重みは、各特徴マップが最終的な出力にどの程度貢献しているかを示します。
  4. ヒートマップ生成: 重みと特徴マップを掛け合わせ、要素ごとの総和をとることで、ヒートマップを生成します。ヒートマップの明るい部分は、そのクラスの予測に大きく貢献した画像の部分を示します。

PythonでGradCAMを試してみる

PyTorchで実装されたモデルであれば、pytorch-gradcamというPythonパッケージを使用することで、簡単にGradCAMを試すことができます。

pytorch-gradcamの特徴を以下に紹介します。

  • 包括的な手法をカバー: コンピュータビジョン向けのXAIアルゴリズムを幅広く網羅
  • 幅広いネットワーク対応: 多くの一般的なCNNネットワークやVision Transformerに対応
  • 様々なユースケース: 分類、物体検出、セマンティックセグメンテーション、埋め込み類似度などのユースケースに対応
  • スムージング手法: CAM (Class Activation Map) を視覚的にキレイにするためのスムージング手法を含む
  • 高性能: すべてのメソッドにおいてバッチ処理をサポートし、高いパフォーマンスを実現
  • 信頼性評価とチューニング: 説明の信頼性をチェックするためのメトリクスや、最適なパフォーマンスのためのチューニング機能を含む

また、pytorch-gradcamに実装されているアルゴリズムには、GradCAMのほかに、次のものが実装されています。

アルゴリズム説明
GradCAM2D活性化マップを平均勾配で重み付け
HiResCAMGradCAMと同様に活性化マップと勾配を使用しますが、特定のモデルにおいて忠実性を保証
GradCAMElementWiseGradCAMと同様に活性化マップと勾配を使用し、その後、ReLU関数を適用してから合計
GradCAM++GradCAMの勾配に2階微分を使用
XGradCAMGradCAMの勾配を正規化された活性化マップでスケーリング
AblationCAMある特定の活性化マップをゼロにし、出力の変化を測定(高速なバッチ処理の実装あり)
ScoreCAMスケーリングされた活性化マップに基づいて画像を摂動させ、出力の変化を測定
EigenCAM2D活性化マップの第一主成分を抽出 (クラス識別なしで優れた結果を示すことが知られている)
EigenGradCAMクラス識別ありのEigenCAM: 活性化マップと勾配の積の第一主成分を抽出する。GradCAMに似ているが、より明確な結果が得られる。
LayerCAM正の勾配によって活性化マップを空間的に重み付けする
FullGradネットワーク全体のバイアスの勾配を計算し、それらを合計
Deep Feature Factorizations2D活性化マップに対して非負値行列因子分解を行う

ではまずは、PyTorchをインストールしましょう。使用するPyTorchのバージョンの指定は、公式ページを参考にしてください。

$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

次に、pytorch-gradcamをインストールします。

$ pip install grad-cam

公式のドキュメントを参考に、GradCAMを実行するコードを書くと、次のようになります。各処理の内容については、コード中のコメントを参考にしてください。

# -*- coding: utf-8 -*-

import cv2
import matplotlib.pyplot as plt
import numpy as np

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
from torchvision.models import resnet18

# 学習済みのモデルを読み込む
model = resnet18(pretrained=True)

# GradCAMの対象とする層を指定
# CNNの場合は、全結合に繋ぐプーリング層の直前のレイヤーを指定すればよい
target_layers = [model.layer4[-1]]

# 説明に使用する画像を読み込む
# 必要な前処理があれば、ここで実行する
rgb_img = cv2.imread("both.png", 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).to("cuda:0")

# CAMの対象とするターゲットを指定する
# ここでは、ImageNetの「tabby cat」を対象とするので、そのラベル番号(=281)を指定
targets = [ClassifierOutputTarget(281)]

# GradCAMのインスタンスを生成
cam = GradCAM(model=model, target_layers=target_layers)

# CAM (Class Activation Map) を計算する
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]

# 入力画像にCAMのヒートマップを重畳表示する
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
plt.imshow(visualization)
plt.show()

上記のコードをベースにして、GradCAM系のアルゴリズムを比較したものを、参考までに以下に添付します。

まとめ

今回は、画像分類の畳込みニューラルネットワークを説明するためのXAIアルゴリズムであるGradCAMを紹介しました。所見の場合、アルゴリズムが分かりにくいかもしれませんが、落ち着いて処理のステップを追えば理解できると思います。また、pytorch-gradcamというパッケージを利用することで、簡単にGradCAMを試すことができます。より詳しい内容については、GitHubリポジトリ公式ドキュメントを参考にしてください。

More Information:

  • arXiv:1910.10045, Alejandro Barredo Arrieta et al., 「Explainable Artificial Intelligence (XAI): Concepts, Taxonomies, Opportunities and Challenges toward Responsible AI」, https://arxiv.org/abs/1910.10045
  • arXiv:2006.11371, Arun Das, Paul Rad, 「Opportunities and Challenges in Explainable Artificial Intelligence (XAI): A Survey」, https://arxiv.org/abs/2006.11371
  • arXiv:2111.06420, Waddah Saeed, Christian Omlin, 「Explainable AI (XAI): A Systematic Meta-Survey of Current Challenges and Future Opportunities」, https://arxiv.org/abs/2111.06420