PyTorch Metric Learningではじめる深層距離学習

近年のAI技術の発展において、画像認識、自然言語処理、推薦システムなど、様々なタスクでデータ間の「類似性」を理解し、活用することが重要となっています。このような背景から、入力データを効果的な特徴空間にマッピングし、類似するデータは近くに、類似しないデータは遠くに配置するよう学習する深層距離学習(Deep Metric Learning)が注目を集めています。

深層距離学習を実現するためには、様々な損失関数、サンプリング戦略、評価指標などが提案されており、その実装は多岐にわたります。このような状況において、PyTorchユーザーが効率的に深層距離学習の研究開発を進めることを目的として開発されたのが、今回ご紹介するライブラリ PyTorch Metric Learning です。

本記事では、この PyTorch Metric Learning ライブラリに焦点を当て、その全体像を把握するための概要から、距離学習タスクの根幹をなす主要なコンポーネントについて詳しく解説します。

深層距離学習とは

深層距離学習は、ディープラーニングモデルを用いて、入力データ間の「距離」や「類似度」を学習する機械学習の手法です。従来の分類問題が「入力がどのカテゴリに属するか」を学習するのに対し、距離学習は「入力同士がどれだけ似ているか・似ていないか」を学習することに主眼を置きます。

具体的には、ニューラルネットワークを使って入力データ(画像、テキスト、音声など)を、低次元のベクトル空間(特徴空間、埋め込み空間)にマッピングします。この際、似ているデータは特徴空間内で近くに配置され、似ていないデータは遠くに配置されるようにネットワークのパラメータを学習します。学習された特徴ベクトル間の距離(ユークリッド距離、コサイン類似度など)が、元のデータ間の類似度を表現することになります。

なぜ深層距離学習が必要か

  • 分類タスクの補完・強化: クラス間の微妙な違いを捉えるのに役立ちます。特に、クラス数が非常に多い場合や、未知のクラスが出現する可能性のあるタスクに適しています。
  • 未知のカテゴリへの対応: 学習時に見たことのないカテゴリのデータに対しても、その類似度に基づいて適切な処理(ゼロショット学習や少数ショット学習など)を行うことが可能になります。
  • 検索・レコメンデーションシステム: クエリデータに類似するデータを膨大なデータベースから効率的に検索したり、ユーザーが過去に興味を示したアイテムに類似するアイテムを推薦したりする基盤となります。
  • データの可視化と構造理解: 学習された特徴空間を可視化することで、データの潜在的な構造やクラス間の関係性を理解するのに役立ちます。

学習のメカニズム(基本的な考え方)

深層距離学習の学習は、データ間の関係性に基づいた特殊な損失関数を用いて行われます。代表的なアプローチとして、以下のようなサンプリング手法とそれに付随する損失関数があります。

  • ペアベース (Pair-based):
    • ポジティブペア: 似ているデータ同士のペア(例: 同じ人物の顔画像、同じカテゴリの物体画像)。これらの間の距離を小さくするように学習します。
    • ネガティブペア: 似ていないデータ同士のペア(例: 異なる人物の顔画像、異なるカテゴリの物体画像)。これらの間の距離を大きくするように学習します。
    • 代表的な損失関数: Contrastive Loss
  • トリプレットベース (Triplet-based):
    • アンカー: 基準となるデータ。
    • ポジティブ: アンカーと似ているデータ。
    • ネガティブ: アンカーと似ていないデータ。
    • アンカーとポジティブの距離が、アンカーとネガティブの距離よりも一定のマージンをもって小さくなるように学習します。
    • 代表的な損失関数: Triplet Loss
  • Proxyベース (Proxy-based) / Angularベース (Angular-based):
    • クラスごとに代表点(Proxy)を設けたり、特徴ベクトルの角度に着目したりして学習を進めます。
    • 代表的な損失関数: N-pair Loss, ArcFace, CosFace など

これらの損失関数は、ネットワークが出力する特徴ベクトル間の距離や角度に基づいて勾配を計算し、ネットワークのパラメータを更新することで、特徴空間の構造を目的の形(類似データが近く、非類似データが遠い)に近づけていきます。

深層距離学習の応用例

  • 顔認証・顔照合: 未知の人物の顔画像が登録済みのデータベース中の誰であるか、あるいは同一人物であるか判定するタスク。
  • 画像検索: 入力画像に写っている物体やシーンに類似する画像をデータベースから探し出すタスク(例: ファッションアイテム検索、風景検索)。
  • レコメンデーションシステム: ユーザーの行動履歴や嗜好に基づいて、類似するアイテムやコンテンツを推薦するタスク。
  • 異常検知: 正常なデータの分布から大きく外れるデータを検出するタスク。
  • 生体認証: 指紋、声紋などの個人固有の特徴の照合。

深層距離学習は、これらの応用分野において、高精度な類似度計算や検索、識別を実現するための重要な技術となっています。PyTorch Metric Learningのようなライブラリは、これらの複雑な損失関数やサンプリング手法の実装を提供することで、研究開発を効率化します。

PyTorch Metric Learning の概要

PyTorch Metric Learning は、深層距離学習(Deep Metric Learning)の研究開発を効率化するために設計された、PyTorchベースのオープンソースライブラリです。このライブラリは、距離学習の様々な側面に対応するモジュール化されたコンポーネントを提供しており、ユーザーはこれらの部品を組み合わせて柔軟に実験を行うことができます。

主要なコンポーネントとしては、データ間の距離を計算する Distance、学習の目的関数を定義する Loss、学習に用いるポジティブ・ネガティブペアやトリプレットなどを選択する Miner、損失関数に正則化項を追加する Regularizer、そして損失値を集約する Reducer などがあります。

これらの主要コンポーネントに加え、学習プロセス全体を管理する Trainer、バッチ内のサンプリングを行う Sampler、モデルの評価を行う Tester、そして学習中の特定のイベントで処理を実行する HookContainer といった補助的なコンポーネントも豊富に備えており、End-to-Endでの距離学習パイプライン構築を強力にサポートします。これにより、多様な距離学習手法や最新の研究成果を容易に試すことが可能となっています。

Distance コンポーネント

Distance コンポーネントは、深層距離学習において、ネットワークが出力した特徴ベクトル間の「距離」または「類似度」を計算する役割を担います。距離学習の多くは、この特徴空間上の距離に基づいてデータ間の類似性を判断し、損失を計算するため、Distance コンポーネントは非常に基本的な要素となります。

このコンポーネントは、様々な距離・類似度の定義に対応しており、ユークリッド距離のような一般的な距離から、コサイン類似度のような角度に基づく類似度まで、多様な計算方法を提供します。これにより、損失関数やモデルの設計思想に合わせて適切な距離尺度を選択することが可能です。

クラス名説明種類(計算する値)
CosineSimilarity2つのベクトルの方向の類似度を計算します。ベクトルのノルムに影響されにくい性質があります。類似度(コサイン類似度)
DotProductSimilarity2つのベクトルの内積を計算します。コサイン類似度とは異なり、ベクトルのノルムも考慮されます。類似度(内積)
LpDistance2つのベクトル間のLp距離(ミンコフスキー距離)を計算します。特に p=2 の場合はユークリッド距離となります。距離(Lp距離、ユークリッド距離など)
SNRDistance特徴ベクトルのシグナル対ノイズ比(SNR)に基づいた距離を計算します。特徴のばらつきを考慮した距離尺度です。距離(SNRベースの距離)

Loss コンポーネント

Loss コンポーネントは、深層距離学習モデルの学習において最も核となる部分であり、特徴空間の構造を望ましい形に近づけるための目的関数(損失関数)を定義します。このコンポーネントは、モデルから出力された特徴ベクトルと、それに対応するラベル情報を受け取り、データ間の距離や類似度、あるいはそれらの関係性に基づいて損失値を計算します。

学習プロセスでは、この損失値を最小化するようにネットワークのパラメータが更新されます。PyTorch Metric Learningは、距離学習分野で提案されてきた非常に多岐にわたる損失関数を実装しており、研究者や開発者は様々な手法を容易に試すことができます。損失関数は、データのペア間の距離を調整するもの、トリプレット間の距離関係を強制するもの、特徴空間上での角度に着目するもの、クラスごとの代表点を用いるものなど、多様なアプローチに基づいています。

クラス名説明タイプ/カテゴリ
AngularLossアンカー、ポジティブ、ネガティブのトリプレットに対して、特徴ベクトルの角度関係に基づいて損失を計算します。トリプレット/角度ベース
ArcFaceLoss分類層を距離学習に応用し、特徴ベクトルとクラス重みベクトルの角度の差に加算マージンを導入することで識別能力を高めます。顔認証で広く用いられます。角度/マージンベース
CircleLossポジティブペアとネガティブペアの類似度(内積またはコサイン類似度)に対して、最適化すべき目標値とマージンを設けて損失を計算します。ペア/マージンベース
ContrastiveLossポジティブペアの距離を小さく、ネガティブペアの距離を大きくするように学習します。最も基本的なペアベースの損失関数の一つです。ペアベース
CosFaceLoss分類層を距離学習に応用し、特徴ベクトルとクラス重みベクトルの角度のコサイン値から減算マージンを導入することで識別能力を高めます。顔認証で広く用いられます。角度/マージンベース
CrossBatchMemory損失計算のために、現在のバッチだけでなく過去のバッチの特徴ベクトルをメモリバンクに保存して利用します。大規模データでの効率的なサンプリングに寄与します。メカニック/サンプリング補助
DynamicSoftMarginLossマージンを固定せず、学習中に動的に調整するソフトマージンベースの損失関数です。マージンベース
FastAPLossAverage Precision (AP) を直接最適化することを目的とした損失関数です。検索性能との相関が高いとされます。リストワイズ/ランクベース
GeneralizedLiftedStructureLossバッチ内の全てのポジティブペアに対して、最も近いネガティブペアとの距離に基づいて損失を計算します。バッチ構造ベース
InstanceLossデータインスタンスそのものに着目し、データ間の類似度に基づいて損失を計算します。インスタンスベース
HistogramLossポジティブペアとネガティブペアの距離分布をヒストグラムで捉え、それらの分布を分離するように損失を計算します。分布ベース
IntraPairVarianceLoss同じクラス内のペア間の距離の分散を小さくするように促す損失関数です。ペアベース/バリアンス抑制
LargeMarginSoftmaxLoss標準的なSoftmax損失にマージンを導入し、より識別性の高い特徴空間を学習することを目的とします。分類/マージンベース
LiftedStructureLossバッチ内のポジティブペアと、それに関連する最も近いネガティブペアとの距離に基づいて損失を計算します。Generalized版の原型です。バッチ構造ベース
ManifoldLoss特徴空間上のデータの局所的な構造(マニホールド)を保存するように損失を計算します。構造保存ベース
MarginLossポジティブペアとネガティブペアの距離に対して、固定または適応的なマージンを設けて損失を計算します。Tripliet Lossの派生形を含む場合があります。ペア/マージンベース
MultiSimilarityLossポジティブペアとネガティブペアの類似度(内積またはコサイン類似度)の両方を考慮し、それらの相対的な関係に基づいて損失を計算します。ペアベース
NCALossNeighborhood Component Analysis (NCA) に基づく損失関数で、確率的に近傍を維持するように学習します。確率/近傍ベース
NormalizedSoftmaxLoss標準的なSoftmax損失において、特徴ベクトルとクラス重みベクトルを正規化して計算します。角度情報に焦点を当てやすくなります。分類/正規化ベース
NPairsLossN個のアンカーと、それぞれのポジティブペア、そして他の全てのネガティブペアを用いて損失を計算します。効率的なバッチ処理を可能にします。バッチ構造ベース
NTXentLoss自己教師あり学習で用いられるSimCLRの損失関数であり、データ拡張されたペアをポジティブとし、バッチ内の他のサンプルをネガティブとして対比的に学習します。自己教師あり/対比学習
P2SGradLossPull to Sphere Grad Loss の略で、特徴ベクトルを単位超球面に近づけつつ、識別性を高めるように学習します。球面埋め込み/勾配ベース
PNPLossProxy NCA and Proxy Anchor Loss の略で、NCAとProxy Anchor Lossの考え方を組み合わせた損失関数です。Proxy/近傍ベース
ProxyAnchorLoss各クラスの代表点(Proxy)とデータのペア間の距離に基づいて損失を計算します。Proxyをアンカーとして機能させます。Proxyベース
ProxyNCALossNCAの考え方をProxyに拡張し、Proxyとデータ間の確率的な関係に基づいて損失を計算します。Proxy/確率/近傍ベース
RankedListLossクエリに対する検索結果のランクリストの評価指標を最適化することを目的とした損失関数です。検索性能に特化しています。リストワイズ/ランクベース
SelfSupervisedLoss自己教師あり学習タスクのための汎用的なラッパーまたは基底クラス。特定の自己教師あり損失の実装を含みます。自己教師あり
SignalToNoiseRatioContrastiveLossSNRDistanceとContrastiveLossの考え方を組み合わせ、SNRに基づいて対比的に学習します。ペアベース/SNRベース
SoftTripleLoss各クラスを複数の代表点(Triple)で表現し、データとこれらの代表点との距離に基づいてソフトマージン付きで学習します。クラス内の多様性を考慮できます。Proxy/ソフトマージンベース
SphereFaceLoss分類層を距離学習に応用し、特徴ベクトルとクラス重みベクトルに対して角度マージンを導入する初期の手法の一つです。角度/マージンベース
SubCenterArcFaceLossArcFaceLossを拡張し、各クラスを複数のサブセンターで表現することで、クラス内のばらつきに対応できるようにした損失関数です。角度/マージンベース/サブセンター
SupConLossSupervised Contrastive Loss の略で、同一クラスのサンプルをポジティブペア、異なるクラスのサンプルをネガティブペアとして、バッチ内の全てのペアに対して対比的に学習します。NTXentLossの教師あり版とも言えます。対照学習/教師あり
ThresholdConsistentMarginLossマージンを一貫性のある閾値に基づいて設定する損失関数です。マージンベース
TripletMarginLossアンカー、ポジティブ、ネガティブのトリプレットに対して、アンカー-ポジティブ間の距離がアンカー-ネガティブ間の距離よりもマージン以上小さくなるように学習します。最も基本的なトリプレットベースの損失関数です。トリプレットベース
TupletMarginLossTriplet Lossを拡張し、ポジティブとネガティブのタプル(トリプレットよりも多い要素数)に対してマージンベースの損失を計算します。タプルベース
WeightRegularizerMixin損失関数ではなく、重み(特に分類層の重み)に正則化項(L2正則化など)を追加するためのMixinクラスです。複数の損失関数と組み合わせて使用できます。正則化
VICRegLoss自己教師あり学習手法であるVICRegの損失関数であり、特徴ベクトルの分散、共分散、不変性(ペア間の距離)を同時に最適化します。深層距離学習的な側面も持ちます。自己教師あり/共同埋め込み

Miner コンポーネント

Miner コンポーネントは、深層距離学習のバッチ学習において非常に重要な役割を果たします。特に、大きなバッチサイズで学習を行う際に、バッチ内の全てのデータ間の組み合わせ(ペアやトリプレット)を評価するのは計算コストが高いだけでなく、学習にあまり寄与しない「イージー」(既に距離が適切になっている)なサンプルが多数含まれることがあります。

Miner は、このような非効率性を解消し、学習効率と性能を向上させるために導入されました。その主な役割は、現在のバッチ内のデータから、損失関数の計算に特に有用であると考えられるサンプルペアやトリプレットを選択(「マイニング」)することです。これは通常、まだ距離関係が適切に学習されていない「ハード」なサンプルを中心に選択することで実現されます。

例えば、Triplet Lossの場合、アンカーに対して「ハードポジティブ」(アンカーから遠いポジティブ)や「ハードネガティブ」(アンカーに近いネガティブ)を重点的に選択することで、モデルはより困難なケースを克服し、識別能力を高めることができます。Miner は様々な戦略に基づいてこれらのサンプルを効率的に見つけ出し、損失計算部にその情報(通常はサンプルのインデックス)を渡します。

クラス名説明マイニング対象
AngularMiner特徴ベクトルの角度に基づいて、トリプレット学習に有用なサンプルをマイニングします。角度マージンベースの損失関数と組み合わせて使用されることがあります。トリプレット
BatchEasyHardMinerバッチ内の全てのペアまたはトリプレットを評価し、イージー、セミハード、ハードなサンプルに分類して、指定された割合でサンプリングします。ペアまたはトリプレット
BatchHardMinerバッチ内の各アンカーに対して、最も距離が遠いポジティブ(ハードポジティブ)と、最も距離が近いネガティブ(ハードネガティブ)を選択します。トリプレット
DistanceWeightedMinerデータ間の距離に基づいて、サンプリング確率を重み付けしてペアを選択します。特定の距離範囲のサンプルを重点的にサンプリングするのに有用です。ペア
HDCMinerHard-aware Deep Contrastive Miner の略。ハードなポジティブペアとネガティブペアを効率的に選択する手法に基づいています。ペアまたはトリプレット
MultiSimilarityMinerMultiSimilarity Loss の考え方に基づき、データ間の類似度(特にポジティブとネガティブの類似度)を考慮して、学習に有用なペアをマイニングします。ペア
PairMarginMinerペア間の距離が特定のマージンを満たしているかどうかを基準にペアをマイニングします。マージンベースの損失関数と組み合わせて使用されます。ペア
TripletMarginMinerトリプレットのアンカー-ポジティブ間距離とアンカー-ネガティブ間距離の関係が特定のマージンを満たしているかどうかを基準にトリプレットをマイニングします。トリプレット
UniformHistogramMinerポジティブペアとネガティブペアの距離ヒストグラムを均一化するようにサンプルをマイニングします。距離分布を平滑化することを目的とします。ペア

Reducer コンポーネント

Reducer コンポーネントは、Loss コンポーネントによって計算された複数の損失値や、損失計算の過程で得られる様々な値を、最終的にバックプロパゲーションで利用可能な単一のスカラー値(または勾配計算が可能な集約された値)に「縮小」(Reduce)する役割を担います。

多くの損失関数は、バッチ内の複数のサンプルペアやトリプレットに対して損失を計算します。その結果は、個々のサンプルやペア/トリプレットに対応する損失値の集合となります。しかし、モデルのパラメータを更新するための勾配計算(バックプロパゲーション)では、通常、単一のスカラー損失値が必要です。

Reducer は、これらの個別の損失値をどのように集約するかを定義します。単純な平均や合計だけでなく、非ゼロの損失値のみの平均を取る、クラスごとに重み付けを行うなど、様々な集約戦略を提供します。これにより、損失の性質や学習の目的に応じて、勾配の伝播方法を調整することが可能になります。

クラス名説明集約方法の例
AvgNonZeroReducer計算された損失値のうち、ゼロでない値のみを対象として平均を計算します。損失が発生しているサンプルに焦点を当てたい場合に有用です。非ゼロ要素の平均
ClassWeightedReducer各クラスの損失値に対して、指定された重み付けを行ってから合計または平均を計算します。クラス間のサンプル数に偏りがある場合などに利用されます。クラスごとの重み付き合計/平均
DivisorReducer損失値を特定の除数(例えば、マイニングされたペアやトリプレットの数など)で割ることで正規化します。特定の値で除算(正規化)
DoNothingReducer受け取った損失値をそのまま返します。損失関数が既にスカラー値を返す場合や、後段で別の処理を行う場合に利用されます。何もしない(Identity)
MeanReducer計算された全ての損失値の単純平均を計算します。最も一般的な集約方法の一つです。平均
MultipleReducers複数のReducerを組み合わせて使用できます。例えば、異なる種類の損失(複数のタスクの損失など)に対して異なるReducerを適用する場合に便利です。複数の集約処理を適用
PerAnchorReducerTriplet Lossなどで得られる損失値を、アンカーごとに集約します。アンカー単位での勾配計算を行いたい場合に有用です。アンカー単位での集約(例: 平均)
SumReducer計算された全ての損失値の合計を計算します。合計
ThresholdReducer特定の閾値を超える損失値のみを対象として集約します。損失が小さい(イージーな)サンプルを無視したい場合に利用されます。閾値を超える値のみ集約

Regularizer コンポーネント

Regularizer コンポーネントは、深層距離学習モデルの学習プロセスにおいて、過学習を抑制したり、学習される特徴空間やモデルのパラメータに特定の望ましい性質を導入したりするために使用されます。これは、通常、メインとなる損失関数(Loss コンポーネントで計算される損失)に追加される項として機能します。

正則化項を加えることで、モデルは訓練データに対して過度に適合することを避け、未知のデータに対する汎化性能を高めることができます。また、距離学習の文脈では、学習された特徴ベクトルが特定の分布に従うように促したり、特定の制約(例えば、特徴ベクトルのノルムを制限するなど)を満たすようにしたりするためにも利用されます。

Regularizer は、モデルの出力である特徴ベクトル、あるいはネットワークの特定の層(特に最後の分類層や全結合層)の重みに対して作用することが一般的です。

クラス名説明正則化対象/種類
CenterInvariantRegularizer各クラスの中心(平均特徴ベクトルなど)に対する特徴ベクトルの相対的な位置関係を保つように促す正則化です。特徴ベクトル/クラス中心
LpRegularizer特徴ベクトルまたはモデルパラメータのLpノルムに対する正則化です。特に p=2 のL2正則化は、パラメータの値が大きくなりすぎるのを抑制します。特徴ベクトルまたはパラメータ/Lpノルム
RegularFaceRegularizerSphereFaceなどの角度マージンベースの損失関数と組み合わせて使用され、学習された特徴ベクトルや重みベクトルの分布を正規化・調整する正則化です。特徴ベクトルまたは重み/正規化
SparseCentersRegularizerSoftTripleLossなどで使用されるクラスごとの代表点(Triple)がスパースになるように促す正則化です。不要な代表点の数を減らすことを目的とします。代表点(Proxy)/スパース性
ZeroMeanRegularizer学習された特徴ベクトルの平均がゼロに近づくように促す正則化です。特徴空間の中心を原点に近づけることを目的とします。特徴ベクトル/平均ゼロ

PyTorch Metric Learning を使ってみる

ここでは、PyTorch Metric Learningライブラリを使用して、MNISTデータセットで距離学習を実装する方法を解説します。

まず、必要なPythonパッケージをインストールします。

# PyTorchのインストール
$ pip install torch torchvision torchaudio

# PyTorch Metric Learning のインストール
$ pip install pytorch-metric-learning

# その他、デモコードで使用するパッケージをインストール
$ pip install pytorch_lightning torchinfo matplotlib scikit-learn

では、ここからコードの実装方法について説明しますが、まずは必要なパッケージをインポートしておきます。

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from matplotlib.colors import TABLEAU_COLORS
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import MeanReducer
from pytorch_metric_learning import trainers
from sklearn.manifold import TSNE
from torchinfo import summary

次に、GPUを利用できるかどうかを確認します。深層距離学習は計算コストが高いため、可能であればGPUの使用をお勧めします。

# デバイスの設定 (GPUが利用可能ならGPU, なければCPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

今回はMNISTの手書き数字データセットを使用します。このデータセットは28×28ピクセルのグレースケール画像で、0から9までの数字が含まれています。PyTorchのDataLoaderを使用してデータを読み込めるようにします。なお、訓練データセットを訓練用と検証用に8:2の割合で分割しています。これにより、モデルの過学習を防ぎ、汎化性能を評価できます。

# MNISTデータセットのロード
# transforms.Composeを使って複数の前処理をシーケンシャルに適用
# ToTensor()は画像をPyTorchのTensorに変換し、画素値を[0, 1]の範囲に正規化
# Normalize()は指定された平均と標準偏差でテンソルを正規化 (MNISTの標準的な値を使用)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# バッチサイズ
batch_size = 256
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transform,
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transform)

# 訓練データセットを訓練用と検証用に分割
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# データローダー
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False) 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

では、次にディープラーニングのモデルを作成していきますが、今回はシンプルな畳み込みニューラルネットワーク(CNN)を定義します。このモデルは特徴抽出器(バックボーン)と埋め込み層(ヘッド)で構成されています。最終的に32次元の埋め込みベクトルを出力します。

# 畳込みニューラルネットワーク (CNN) モデルの定義
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=0),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(num_features=32),
            nn.MaxPool2d(kernel_size=2, stride=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),

            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.head = nn.Linear(in_features=64, out_features=32, bias=True)

    def forward(self, x):
        z = self.backbone(x)
        return self.head(z.squeeze())

モデルの構造を確認すると、パラメータ総数は約25,568で、入力画像(28×28ピクセル)を32次元の埋め込みベクトルに変換します。これは比較的軽量なモデルですが、MNISTのような単純なデータセットに対しては十分な表現力を持っています。

summary(
    SimpleCNN(),
    input_size=(2, 1, 28, 28),
    col_names=["output_size", "num_params"],
)
==========================================================================================
SimpleCNN                                [2, 32]                   --
├─Sequential: 1-1                        [2, 64, 1, 1]             --
│    └─Conv2d: 2-1                       [2, 16, 26, 26]           160
│    └─Conv2d: 2-2                       [2, 32, 24, 24]           4,640
│    └─BatchNorm2d: 2-3                  [2, 32, 24, 24]           64
│    └─MaxPool2d: 2-4                    [2, 32, 23, 23]           --
│    └─ReLU: 2-5                         [2, 32, 23, 23]           --
│    └─Conv2d: 2-6                       [2, 64, 21, 21]           18,496
│    └─BatchNorm2d: 2-7                  [2, 64, 21, 21]           128
│    └─ReLU: 2-8                         [2, 64, 21, 21]           --
│    └─AdaptiveAvgPool2d: 2-9            [2, 64, 1, 1]             --
├─Linear: 1-2                            [2, 32]                   2,080
==========================================================================================
Total params: 25,568
Trainable params: 25,568
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 21.88
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 1.67
Params size (MB): 0.10
Estimated Total Size (MB): 1.78
==========================================================================================

PyTorch Lightningを使用して学習プロセスを整理します。今回は、MultiSimilarityLossMultiSimilarityMinerを使用します。これらは距離学習のための重要なコンポーネントです。

通常の分類学習と距離学習の大きな違いは、「マイニング」と呼ばれるプロセスにあります。マイニングでは、バッチ内のすべてのサンプルペアではなく、学習にとって重要な「ハードペア」を選択します。ハードペアとは、異なるクラスなのに類似している(ハードネガティブ)、または同じクラスなのに類似していない(ハードポジティブ)サンプルペアのことです。

MultiSimilarityMinerは、コサイン類似度に基づいて、各バッチから効率的にハードペアを抽出します。これにより、学習が効率化され、より良い埋め込み空間を獲得できます。

class MnistEncorder(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = SimpleCNN()
        self.criterion = losses.MultiSimilarityLoss(alpha=2, beta=50, base=0.5, distance=CosineSimilarity(), reducer=MeanReducer())
        self.miner = miners.MultiSimilarityMiner(distance=self.criterion.distance)
    
        # 学習時の損失 (エポックごと)
        self.training_step_outputs: list[torch.Tensor] = []
        self.validation_step_outputs: list[torch.Tensor] = []
        self.train_losses: list[np.float32] = []
        self.val_losses: list[np.float32] = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def configure_optimizers(self) -> tuple:
        optimizer = optim.Adamax(self.model.parameters(), lr=1.0E-4, betas=(0.9, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.95)
        
        return [optimizer, ], [scheduler, ]
    
    def configure_callbacks(self) -> list:
        checkpoint = ModelCheckpoint(
            monitor="val_loss",
            dirpath="checkpoint",
            filename="mnist_metric_learn",
            every_n_epochs=1,
            save_weights_only=False,
        )
        early_stopping = EarlyStopping("val_loss", mode="min", patience=5)

        return [checkpoint, early_stopping]

    def training_step(self, train_batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor:
        loss = self._common_step(train_batch, batch_idx)
        self.log("loss", loss, on_step=True, on_epoch=True, batch_size=len(train_batch[0]), prog_bar=True)
        self.training_step_outputs.append(loss)
        return loss
    
    def on_train_epoch_end(self) -> None:
        losses = np.array([item.detach().cpu() for item in self.training_step_outputs])
        self.train_losses.append(np.mean(losses))
        self.training_step_outputs.clear()
    
    def validation_step(self, val_batch: list[torch.Tensor], batch_idx: int) -> torch.Tensor:
        loss = self._common_step(val_batch, batch_idx)
        self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=len(val_batch[0]), prog_bar=True)
        self.validation_step_outputs.append(loss)
        return loss
    
    def on_validation_epoch_end(self) -> None:
        losses = np.array([item.cpu() for item in self.validation_step_outputs])
        self.val_losses.append(np.mean(losses))
        self.validation_step_outputs.clear()

    def save_weights(self) -> None:
        torch.save(self.model.cpu().state_dict(), "mnist_metric_learn.pth")

    def plot_train_history(self) -> None:
        x = [i + 1 for i in range(len(self.train_losses))]

        plt.figure(figsize=(8, 5))
        plt.plot(x, self.train_losses, label="loss")
        plt.plot(x, self.val_losses, label="val_loss")
        plt.title("Training Loss vs Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid()
        plt.legend(loc="best")
        plt.show()

    def _common_step(self, batch: list[torch.Tensor], _batch_idx: int) -> torch.Tensor:
        x, y = batch
        pred_y = self.forward(x)
        
        hard_pairs = self.miner(pred_y, y)
        loss = self.criterion(pred_y, y, hard_pairs)
        
        return loss

では、PyTorch Lightningの Trainer を使用してモデルを学習させます。

module = MnistEncorder()
trainer = pl.Trainer(max_epochs=50, accelerator="gpu", devices=1, num_sanity_val_steps=0)
trainer.fit(model=module, train_dataloaders=train_loader, val_dataloaders=val_loader)

module.save_weights()
module.plot_train_history()

学習が完了すると、訓練損失と検証損失の推移を確認できます。損失関数が順調に減少し、検証損失も同様に減少していれば、モデルは適切に学習できています。過学習の兆候がなければ、モデルは汎化性能を獲得していると判断できます。

次に、学習したモデルを使用して、テストデータセットの埋め込みベクトルを計算します。

predicts = []
labels = []

module.model.eval()
module.model.to(device)

with torch.no_grad():
    for x, y in test_loader:
        z = module.model(x.to(device))
        predicts.append(z.squeeze().cpu())
        labels.append(y)

predicts = torch.cat(predicts).numpy()
labels = torch.cat(labels).numpy()

計算した32次元ベクトルを直接可視化することは難しいため、t-SNE(t-distributed Stochastic Neighbor Embedding)アルゴリズムを使用して2次元に次元削減します。t-SNEは高次元データの局所的構造を保持しながら低次元に写像する非線形次元削減技術です。これにより、32次元の埋め込みベクトルを2次元平面上に視覚化でき、クラスタリングの品質を評価できます。

def show_tsne(x: np.ndarray, y: np.ndarray) -> None:
    x_tsne = TSNE(random_state=123).fit_transform(x)
    labels = np.unique(y)

    # 凡例をグラフ外に表示したいので、横長に指定しておく
    plt.figure(figsize=(10, 6))
    plt.xlim(x_tsne[:, 0].min() - 1, x_tsne[:, 0].max() + 1)
    plt.ylim(x_tsne[:, 1].min() - 1, x_tsne[:, 1].max() + 1)

    # ラベルごとに色を変える
    color_names = list(TABLEAU_COLORS.keys())
    marker_size = matplotlib.rcParams["lines.markersize"] ** 2

    # ラベルごとに散布図に表示しておく
    for ii, label in enumerate(labels):
        index = np.where(y == label)
        plt.scatter(x_tsne[index, 0], x_tsne[index, 1], label=f"{label}", color=color_names[ii], s=marker_size, marker=".")

    # グラフ保存
    plt.title("t-SNE visualization of MNIST features")
    plt.xlabel("t-SNE feature-0")
    plt.ylabel("t-SNE feature-1")
    plt.grid()
    plt.legend(bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0)
    plt.tight_layout()
    plt.show()


# 可視化
show_tsne(predicts, labels)

t-SNEによる可視化結果を確認すると、各数字(0〜9)のクラスタが明確に形成されていることがわかります。これは、モデルが各数字の特徴を適切に学習し、同じクラスのサンプルを埋め込み空間内で近づけ、異なるクラスのサンプルを遠ざけることに成功していることを示しています。

いくつかのデータポイントが他のクラスタの近くにプロットされていますが、これは以下の理由によるものと考えられます。

  1. モデルの表現力の限界(今回は小規模なCNNを使用)
  2. 一部の手書き数字が形状的に類似している(例:4と9、3と8など)
  3. 個々の筆跡の多様性

より大きなモデルや異なる損失関数を使用することで、さらに改善できる可能性があります。

おわりに

ここまで、深層距離学習のための強力なライブラリである PyTorch Metric Learning の主要なコンポーネントについて概観してきました。データ間の距離を定義する Distance、学習の根幹となる Loss、効率的なサンプル選択を行う Miner、損失を集約する Reducer、そして学習を安定させる Regularizer といったモジュール化された要素が、本ライブラリの柔軟性と機能性を支えています。

これらの豊富なコンポーネントを組み合わせることで、ユーザーは多様な深層距離学習手法を容易に実装・比較したり、自身のタスクに最適なカスタム設定を追求したりすることが可能です。顔認証、画像検索、レコメンデーションなど、データの類似性に着目する多くの現代AIアプリケーションにおいて、深層距離学習は欠かせない技術となっています。

深層学習が登場する以前から、距離学習は機械学習の重要な分野として研究されてきました。ディープラーニングを使用しない古典的な距離学習の手法にご興味のある方は、以前執筆した記事「距離学習入門 ~様々なタスクに応用できる機械学習手法~」もぜひご参照ください。本記事と合わせてお読みいただくことで、距離学習の背景とディープラーニングによる発展について、より深くご理解いただけるかと思います。

More Information