深層学習×状態空間モデル: Mambaアーキテクチャの概要

近年、普段の生活やビジネスでは欠かせないAI技術ですが、ChatGPTをはじめとする大規模言語モデル(LLM)の登場で、Transformerと呼ばれるアーキテクチャが注目を集めています。Transformerは、文章や画像など、様々なデータを高度に扱うことができる強力なアーキテクチャです。しかし、Transformerには計算量が多いという課題があります。

そこで登場したのが、Mamba2023.12, Gu & Dao)と呼ばれる新しいモデルです。Mambaは、Transformerの優れた部分を継承しつつ、計算量を大幅に削減できるという特徴を持っています。また、自然言語や時系列データだけでなく、画像処理など様々なタスクに応用できる汎用性も兼ね備えています。

今回は、このMambaについて解説し、Pythonによる実装例も紹介しようと思います。

Mambaとは?

従来の代表的な時系列処理モデルであるTransformerは、画像や音声、言語など様々なデータを扱うことができますが、長いシーケンスを処理する際に計算量が増大してしまうという課題がありました。Mambaは、この課題を克服するために開発されたモデルで、Transformerと同等の性能を持ちながら、計算量を大幅に削減できることが特徴です。

Mambaは大きく分けて、以下の2つのバージョンがあります。

  • Mamba-1
  • Mamba-2

Mamba-1

Mamba-1は、従来の状態空間モデル (SSM)と呼ばれるモデルの枠組みを利用して作られています。状態空間モデルは、系列データの時間的な依存関係を表現することに長けていますが、複雑な情報を扱うタスクには性能が十分ではありませんでした。Mamba-1は、状態空間モデルに以下の3つの新しい技術を導入することにより、この課題を解決しています。

  • HiPPOに基づく初期化
  • 選択メカニズム
  • ハードウェアに配慮した計算
図1. Mamba-1のアーキテクチャ

HiPPOに基づく初期化

HiPPO (High-order Polynomial Projection Operator) は、Mamba-1の初期化に用いられる手法です。状態空間モデルは、過去の情報に基づいて現在の状態を予測するモデルですが、HiPPOを用いることで、より長期的な依存関係を捉えることができるようになります。

HiPPOは、過去の全ての情報に均一な重みを割り当てることで、長期的な記憶を保持します。これにより、Mamba-1は、長い文章や複雑な時系列データをより正確にモデル化できるようになります。

選択メカニズム

選択メカニズムは、Mamba-1入力データに合わせた処理を行うための仕組みです。従来の状態空間モデルは、全ての入力に対して同じ処理を行っていましたが、選択メカニズムにより、入力データの内容に応じて、異なる処理を行うことができるようになりました。

このメカニズムは、TransformerAttention機構と類似しており、Mamba-1が重要な情報に注目し、不要な情報を無視することを可能にします。

ハードウェアに配慮した計算

ハードウェアに配慮した計算は、Mamba-1を効率的に実行するための工夫です。Mamba-1は、Parallel Associative Scanと呼ばれるアルゴリズムを用いて、計算を並列化し、GPUなどのハードウェアを最大限に活用します。

また、メモリ再計算という手法を用いることで、メモリ使用量を削減し、より大規模なモデルを学習できるようになります。

Mamba-2

Mamba-2は、Mamba-1をさらに発展させたモデルです。Mamba-1状態空間モデルの枠組みを利用していましたが、Mamba-2Transformerと同じようなAttention機構を利用することで、計算効率を向上させています。

また、Mamba-2は理論的な裏付けも強化されており、Mamba-1よりも高速で、かつ、精度も向上しています。

図2. Mamba-1 と Mamba-2 の比較

Mamba アーキテクチャの進化

Mambaは、近年、基盤モデルのバックボーンとして有望な代替手段として注目されています。大規模なMambaベースのモデルは、学術研究だけでなく、Falcon Mamba 7BMistral 7B などに代表されるように、ビジネスにおける応用も進んでおり、GPU上での効果的なトレーニングが実証されています。

しかしながら、Mambaアーキテクチャには、記憶損失(Memory Loss)、多様なタスクへの汎化性、Transformerベースの言語モデルに比べて複雑なパターンを捉える能力が劣る、といった課題があります。

これらの課題を克服するために、Mambaアーキテクチャを改善するための多くの取り組みが行われています。既存の研究は主に、ブロック設計(Block Design)、スキャンモード(Scan Mode)、記憶管理(Memory Management)に焦点を当てています。このセクションでは、これらの3つの側面の概要を紹介します。

ブロック設計(Block Design)

Mambaブロックの構築方法について、既存の研究は以下の3つのカテゴリに分類できます。

  • 統合法 (Integration methods): Mambaと他アーキテクチャ(TransformerGNNなど)を統合することで、有効性と効率性のバランスを図る。
  • 置換法 (Substitution methods): 先進的なモデルフレームワーク(U-Net拡散モデルなど)の主要な層をMambaブロックで置き換える。
  • 修正法 (Modification methods): 従来のMambaブロック内のコンポーネントを修正する。
図3. Mambaブロック構築方法の3つのカテゴリ

統合法 (Integration methods)

長期的なダイナミクスを効率よく処理できるMambaの特徴は、様々なアーキテクチャと相性がよく、特定のシナリオに適した堅牢なフレームワークを構築できます。特に、Transformerグラフニューラルネットワーク(GNN)リカレントニューラルネットワーク(RNN)畳み込みニューラルネットワーク(CNN)スパイクニューラルネットワーク(SNN)といった、代表的なアーキテクチャとの統合が進んでいます。

モデルの組み合わせ特徴利点
Transformer + Mamba長距離依存関係を効率的に捉える生成タスクにおける性能向上、スループットの向上
GNN + Mamba高次隣接信号のキャプチャグラフ表現学習能力の向上
RNN + Mamba時系列ダイナミクスを効果的に捉える時空間予測における性能向上、FLOPsの削減
CNN + Mambaグローバルおよび長距離文脈のキャプチャ画像解析タスクにおける性能向上
SNN + Mamba低消費電力、長距離依存関係のキャプチャ時間的ビデオグラウンディングタスクの有効性向上

置換法 (Substitution methods)

Mambaにおける選択的状態空間モデルは、効率的な計算と長いシーケンスの学習において優れた能力を持っています。このことが、U-Net拡散モデルなどの従来のフレームワークの重要なコンポーネントをMambaモジュールで置き換えることを可能にしています。

  • U-Net: Mamba-UNetは、Vision-Mambaブロックのみを使用して、U-Netのようなモデルを構築します。
  • 拡散モデル: DiSは、典型的なバックボーン(CNNAttentionU-Netなど)を状態空間バックボーンに置き換えることで、拡散モデルを使用してより長いシーケンスを生成します。

修正法 (Modification methods)

Mambaブロックを直接使用する統合法置換法とは別に、様々なシナリオでMambaブロックの性能を向上させることを目的として、Mambaブロック自体を修正する取り組みも行われています。

  • MoE (Mix-of-Experts)の導入: Jambaは、MoEの概念を利用することで、ハイブリッド(Transformer-Mamba)デコーダのみのモデルをはるかに少ない計算量で事前学習できるようにし、柔軟な目的固有の設定を可能にします。
  • SSM層のK-way構造: Sigmaは、並列SSM層を使用して多様な入力に対応する新しいMambaベースの視覚エンコーダを開発しています。
  • レジスタの導入: Mamba®は、視覚入力トークン内に均等にレジスタを組み込み、その後、入力トークンをSSM層に渡すという新しいアプローチを導入しています。

スキャンモード(Scan Mode)

Mambaモデルでは、Parallel Associative Scan が重要な役割を果たし、計算問題への対処、トレーニングの高速化、メモリ要件の削減に寄与しています。これは、時変状態空間モデルの線形特性を活用し、ハードウェアレベルでカーネルの融合や再計算を実現するものです。

しかしながら、Mambaにおける単方向シーケンスモデリングは、画像やビデオといった多様なデータについて、効率的な学習ができないという課題があります。この問題を解決するため、モデルの性能向上やトレーニングプロセスの効率化を目指したスキャン方法の設計が研究されています。

  • フラットスキャン (Flatten Scan): モデル入力をトークンシーケンスに平坦化し、それに応じてさまざまな方向からスキャンします。
  • ステレオスキャン (Stereo Scan): 次元、チャネル、スケールにわたってモデル入力をスキャンします。
図4. Mambaにおけるスキャンモードの種類

フラットスキャン (Flatten Scan)

スキャン方法特徴・説明具体例・効果
双方向スキャン (Bi-Scan)Bi-RNNの概念を応用し、順方向と逆方向のSSMを同時進行で使用して入力トークンを処理。空間認識能力を強化。視覚データ処理において、モデルの学習能力を向上。Mambaベースモデルで広く研究されている。
スイープスキャン (Sweeping Scan)モデル入力を掃除機のように特定方向に処理。Cross Scanでは、入力画像をパッチ分割し、4つの異なるパスで平坦化。各パッチが隣接パッチから効率的に情報を統合可能。情報豊富な受容野を確立し、空間情報の活用を向上。
連続スキャン (Continuous Scan)入力シーケンスの連続性を確保。列または行間の隣接トークンをスキャン。ヒルベルト行列を用いたHilbert Scanも含む。PlainMambaでは2D空間入力に適用し、Cross Scanと異なり、隣接トークンに焦点を当てた処理を実現。
効率的なスキャン (Efficient Scan)入力をいくつかの部分に分割して並列処理。トレーニング・推論プロセスを高速化。計算要求を削減しつつグローバルな特徴を保持。Efficient-2D Scanは、パッチをスキップしつつ画像を処理。計算時間を短縮(要求を4分の1に削減)し、効率的な学習を可能に。

ステレオスキャン (Stereo Scan)

スキャン方法特徴・説明具体例・効果
階層型スキャン (Hierarchical Scan)グローバルからローカル、またはマクロからミクロまでの視点でさまざまなカーネルサイズを走査し、セマンティック知識をキャプチャ。Mamba-in-Mamba 階層型エンコーダは、ローカルパターン抽出(内部ブロック)とグローバル特徴キャプチャ(外部ブロック)を組み合わせた手法を提供。
時空間スキャン (Spatiotemporal Scan)動的システム向けにMambaブロックを強化。元の2Dスキャンを、空間優先スキャンと時間優先スキャンという2つの3Dスキャンに拡張。VideoMambaは、長く高解像度なビデオ処理において優れた効率性を示し、2Dスキャンに比べて時空間情報の活用を向上。
ハイブリッドスキャン (Hybrid Scan)多様なスキャン方法を統合して包括的な機能モデリングを追求。複数のスキャン方法の利点を動的に利用するSwitch of Scanを採用。Mambamixerは、Cross-Scan、Zigzag Scan、Local Scanを組み合わせ、トークンとチャネル間で情報を混合。Pan-Mambaは、チャネル交換スキャンとクロスモーダルスキャンで画像パンシャープニングを効率化。

記憶管理(Memory Management)

Mambaにおける状態空間モデル(SSM)の隠れ状態に関するメモリは、直前の演算ステップからの情報を効果的に保持するため、SSM全体の機能において重要な役割を担っています。Mambaでは、メモリ初期化にHiPPOベースの手法を導入しましたが、層間の隠れ情報の伝達やロスレスなメモリ圧縮といった課題が依然として残されています。

  • 初期化プロセスの改善: モデルの再トレーニング時に、バランスの取れた切り捨て手法を用いることで、選択的SSMの初期化プロセスを改良。
  • 隠れ状態の抑制: DGMambaは、状態空間モデル内の隠れ状態のドメイン汎化能力を向上させるため、隠れ状態抑制手法を導入。この手法により、隠れ状態が引き起こす悪影響を軽減し、異なるドメイン間の隠れ状態のギャップを縮小。
  • 層間の隠れ情報の伝播の強化: DenseMambaでは、SSMの層間における隠れ情報の伝播を強化するため、密結合方式を提案。この戦略は、浅い層の隠れ状態を選択的に深い層へ統合することで、メモリ劣化を軽減し、出力生成に必要な詳細な情報を保持することを目的としている。
  • “State”の概念の再考: 広範な実験を通じて、SSMにおける”State”の概念に潜む制約を明らかにした。例えば、チェスの動きを監視するような実際的な状態追跡課題において、この制約が顕著に現れる。この知見に基づき、入力依存の遷移行列を導入し、SSMの機能を強化することで、効果的な状態追跡と順列合成を可能にする。

Mamba の 実践

では、ここからは Mamba の使い方の例として、時系列データの予測を行ってみます。まずは、必要なPythonパッケージをインストールします。

# PyTorch のインストール
$ pip install torch

# PyTorch 関連のパッケージ
$ pip install torcheval torchinfo

# Mamba のインストール
# CUDAがインストールされていないとエラーになります
$ pip install mamba-ssm

次に使用するデータを準備します。今回は、ETD(Electricity Transformer Dataset)を使用します。このデータセットは、電力変圧器の油温予測と過剰な負荷容量の調査を目的として構築されたデータセットです。

背景: 電力需要は時間帯、曜日、休日、季節、天候、気温などによって変動するため、正確な予測が困難です。予測の失敗は変圧器の損傷につながる可能性があり、現状では経験に基づいた過剰な電力供給が行われ、電力と設備の無駄が生じています。そこで、変圧器の油温を予測することで安全性を確保し、無駄を削減する戦略が有効となります。

データ収集: 北京国旺富達科技発展有限公司との共同で、中国の一省の2つの地域から2年間の実データが収集されました。データは1分ごと(m)と1時間ごと(h)に記録されており、それぞれETT-small-m1/m2、ETT-small-h1/h2と名付けられています。

データ内容: 各データポイントは、日付、予測対象である「油温」、および6種類の外部電力負荷の特徴量を含む、合計7個の特徴量で構成されています。

データの特徴: データには、短期的な周期的パターン、長期的な周期的パターン、長期的なトレンド、および多くの不規則なパターンが含まれています。データには明らかな季節変動が示されており、自己相関グラフでは、油温は短期的な連続性を示し、他の変数(電力負荷)は短期的な日周パターン(24時間ごと)と長期的な週パターン(7日ごと)を示しています。

なお、データセットの各特徴量の意味は次の通りです。予測対象は OT (Oil Temperature) となります。

FielddateHUFLHULLMUFLMULLLUFLLULLOT
DescriptionThe recorded dateHigh UseFul LoadHigh UseLess LoadMiddle UseFul LoadMiddle UseLess LoadLow UseFul LoadLow UseLess LoadOil Temperature (target)

データセットはGitHubで公開されているので、リポジトリをクローンしてデータを利用することにします。

$ git clone https://github.com/zhouhaoyi/ETDataset

これで準備が整ったので、コードの実装に移ります。まずは、必要なパッケージの読み込みです。

import math

from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from mamba_ssm import Mamba
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from torcheval.metrics import R2Score
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader

次にデータセットを読み込みましょう。今回は4つあるデータセットの中から、ETTh1.csv を使用することにします。

ETTh1 = pd.read_csv("ETDataset/ETT-small/ETTh1.csv", index_col=0,  parse_dates=True)
# ETTh2 = pd.read_csv("ETDataset/ETT-small/ETTh2.csv", index_col=0,  parse_dates=True)
# ETTm1 = pd.read_csv("ETDataset/ETT-small/ETTm1.csv", index_col=0,  parse_dates=True)
# TTm2 = pd.read_csv("ETDataset/ETT-small/ETTm2.csv", index_col=0,  parse_dates=True)

print(ETTh1.head())
                      HUFL   HULL   MUFL   MULL   LUFL   LULL         OT
date
2016-07-01 00:00:00  5.827  2.009  1.599  0.462  4.203  1.340  30.531000
2016-07-01 01:00:00  5.693  2.076  1.492  0.426  4.142  1.371  27.787001
2016-07-01 02:00:00  5.157  1.741  1.279  0.355  3.777  1.218  27.787001
2016-07-01 03:00:00  5.090  1.942  1.279  0.391  3.807  1.279  25.044001
2016-07-01 04:00:00  5.358  1.942  1.492  0.462  3.868  1.279  21.948000

データの統計情報も確認しておきましょう。

ETTh1.describe()
               HUFL          HULL          MUFL          MULL          LUFL          LULL            OT
count  17420.000000  17420.000000  17420.000000  17420.000000  17420.000000  17420.000000  17420.000000
mean       7.375141      2.242242      4.300239      0.881568      3.066062      0.856932     13.324672
std        7.067744      2.042342      6.826978      1.809293      1.164506      0.599552      8.566946
min      -22.705999     -4.756000    -25.087999     -5.934000     -1.188000     -1.371000     -4.080000
25%        5.827000      0.737000      3.296000     -0.284000      2.315000      0.670000      6.964000
50%        8.774000      2.210000      5.970000      0.959000      2.833000      0.975000     11.396000
75%       11.788000      3.684000      8.635000      2.203000      3.625000      1.218000     18.079000
max       23.643999     10.114000     17.341000      7.747000      8.498000      3.046000     46.007000

次に、データをグラフに描画して可視化してみます。

def plot_time_series(df: pd.DataFrame, n_split: int = 2) -> None:
    plt.style.use("ggplot")

    n_cols = len(df.columns)  # 列数を取得
    n_rows = math.ceil(n_cols / n_split) # 行数を計算
    fig = plt.figure(figsize=(20, n_rows * 3))

    x_index = df.index  # x軸のインデックスを取得

    for i, column_name in enumerate(df.columns):
        ax = fig.add_subplot(n_rows, n_split, i + 1)
        ax.set_title(column_name)
        ax.plot(x_index, df[column_name], color="blue")
        ax.set_xlabel(df.index.name)
        ax.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.show()


plot_time_series(ETTh1)

このデータをもとに、DataLoaderを準備していきます。その前に、データを正規化しておきましょう。

ETTh1_scaled = MinMaxScaler().fit_transform(ETTh1)
ETTh1_scaled = pd.DataFrame(ETTh1_scaled, columns=ETTh1.columns, index=ETTh1.index)
print(ETTh1_scaled.head())
                         HUFL      HULL      MUFL      MULL      LUFL      LULL        OT
date
2016-07-01 00:00:00  0.615599  0.454943  0.628980  0.467510  0.556576  0.613765  0.691018
2016-07-01 01:00:00  0.612708  0.459449  0.626458  0.464878  0.550279  0.620783  0.636233
2016-07-01 02:00:00  0.601143  0.436920  0.621438  0.459689  0.512595  0.586144  0.636233
2016-07-01 03:00:00  0.599698  0.450437  0.621438  0.462320  0.515693  0.599955  0.581468
2016-07-01 04:00:00  0.605480  0.450437  0.626458  0.467510  0.521990  0.599955  0.519656

では、Datasetクラスを作成して、DataLoaderを準備していきます。今回は、過去24時間のデータをもとに、1時間後のOT (Oil Temperature) を予測するようにデータを準備します(データには季節性や、週単位でのパターンも見られますが、今回は簡単な実験のため無視します)。また、通常は、訓練/検証/テスト用とデータを分割するの通例ですが、今回はすべて訓練データに使用します。

class TimeSeriesDataset(Dataset):
    def __init__(self, df: pd.DataFrame, seq_length: int) -> None:
        self.df = df
        self.seq_length = seq_length
        self.data = torch.tensor(df.values, dtype=torch.float32)

    def __len__(self) -> int:
        return len(self.df) - self.seq_length

    def __getitem__(self, index: int) -> tuple[torch.Tensor, float]:
        x = self.data[index:index + self.seq_length]

        # 1時間語のOTをターゲット
        y = self.data[index + self.seq_length, -1]

        return x, y


dataset = TimeSeriesDataset(ETTh1_scaled, seq_length=24)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

データの準備は、以上でOKです。次は、予測モデルを定義していきます。以下のように、ネットワークのバックボーンをMambaアーキテクチャにします。

class TimeSeriesMamba(nn.Module):
    def __init__(self, in_features: int, out_features: int, n_dim: int) -> None:
        super(TimeSeriesMamba, self).__init__()
        # パラメータ数は、ざっくり 3 * expand * d_model^2 程度になります。
        mamba_config = dict(
            d_model=n_dim,  # モデルの特徴量の次元
            d_state=2,      # 状態空間モデルの拡張パラメータ
            d_conv=5,       # 局所畳み込みの幅
            expand=4,       # ブロック拡張因子
        )

        self.backbone = nn.Sequential(
            *[Mamba(**mamba_config) for _ in range(1)],
        )

        self.head = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features // 2, bias=True),
            nn.LayerNorm(normalized_shape=in_features // 2),

            nn.Linear(in_features=in_features // 2, out_features=out_features, bias=True)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.backbone(x)
        z = torch.adaptive_avg_pool1d(z, 1)
        z = self.head(z.squeeze())
        return z

上記モデルのパラメータ数も確認しておきましょう。結果を確認すると分かりますが、パラメータ数は、1,373と非常に小さいネットワークとしました。

batch, length, dim = 2, 24, 7
x = torch.randn(batch, length, dim).to("cuda")

model = TimeSeriesMamba(in_features=length, out_features=1, n_dim=dim).to("cuda")
summary(model, input_size=(batch, length, dim))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
TimeSeriesMamba                          [2, 1]                    --
├─Sequential: 1-1                        [2, 24, 7]                --
│    └─Mamba: 2-1                        [2, 24, 7]                532
│    │    └─Conv1d: 3-1                  [2, 28, 28]               168
│    │    └─SiLU: 3-2                    [2, 28, 24]               --
│    │    └─Linear: 3-3                  [48, 5]                   140
│    │    └─Linear: 3-4                  [2, 24, 7]                196
├─Sequential: 1-2                        [2, 1]                    --
│    └─Linear: 2-2                       [2, 12]                   300
│    └─LayerNorm: 2-3                    [2, 12]                   24
│    └─Linear: 2-4                       [2, 1]                    13
==========================================================================================
Total params: 1,373
Trainable params: 1,373
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.02
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.02
==========================================================================================

あとは、モデルを訓練していくだけなので、そのための関数を定義します。

def train(model, train_loader, epochs, learning_rate, device):
    # 損失関数とオプティマイザを定義
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
    metric = R2Score()

    model.to(device)
    model.train()

    for epoch in range(epochs):
        with tqdm(train_loader) as pbar:
            pbar.set_description(f"[Epoch {epoch + 1}/{epochs}]")
            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # 順伝搬処理
                optimizer.zero_grad()
                outputs = model(inputs).squeeze()

                # バックプロパゲーション
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # R2スコアをログ出力
                metric.update(outputs.cpu(), labels.cpu())
                r2_score = metric.compute().item()
                pbar.set_postfix(OrderedDict(R2=r2_score))

            metric.reset()

上記の関数を使用して、モデルを学習させてみましょう。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train(model, dataloader, epochs=50, learning_rate=1.0E-3, device=device)
[Epoch 1/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 153.20it/s, R2=0.384]
[Epoch 2/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 226.75it/s, R2=0.593]
[Epoch 3/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 177.35it/s, R2=0.547]
[Epoch 4/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 189.88it/s, R2=0.531]
[Epoch 5/50]: 100%|██████████████████████████████████████████████████| 272/272 [00:01<00:00, 154.09it/s, R2=0.52]
[Epoch 6/50]: 100%|███████████████████████████████████████████████████| 272/272 [00:01<00:00, 192.10it/s, R2=0.5]
[Epoch 7/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 184.46it/s, R2=0.481]
[Epoch 8/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 151.82it/s, R2=0.465]
[Epoch 9/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 201.42it/s, R2=0.479]
[Epoch 10/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 186.28it/s, R2=0.522]
[Epoch 11/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 173.24it/s, R2=0.584]
[Epoch 12/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 162.90it/s, R2=0.656]
[Epoch 13/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 185.51it/s, R2=0.706]
[Epoch 14/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 176.87it/s, R2=0.733]
[Epoch 15/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 157.12it/s, R2=0.761]
[Epoch 16/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 184.83it/s, R2=0.787]
[Epoch 17/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 191.92it/s, R2=0.813]
[Epoch 18/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 160.83it/s, R2=0.827]
[Epoch 19/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 167.03it/s, R2=0.844]
[Epoch 20/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 182.19it/s, R2=0.857]
[Epoch 21/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 178.36it/s, R2=0.871]
[Epoch 22/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 163.23it/s, R2=0.879]
[Epoch 23/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 169.40it/s, R2=0.89]
[Epoch 24/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 193.27it/s, R2=0.894]
[Epoch 25/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 159.22it/s, R2=0.902]
[Epoch 26/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 193.59it/s, R2=0.907]
[Epoch 27/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 177.95it/s, R2=0.912]
[Epoch 28/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 162.18it/s, R2=0.915]
[Epoch 29/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 174.27it/s, R2=0.919]
[Epoch 30/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 174.64it/s, R2=0.921]
[Epoch 31/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 200.30it/s, R2=0.925]
[Epoch 32/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 157.71it/s, R2=0.924]
[Epoch 33/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 189.44it/s, R2=0.927]
[Epoch 34/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 178.47it/s, R2=0.927]
[Epoch 35/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 156.92it/s, R2=0.93]
[Epoch 36/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 200.91it/s, R2=0.93]
[Epoch 37/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 194.81it/s, R2=0.933]
[Epoch 38/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 153.43it/s, R2=0.933]
[Epoch 39/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 180.30it/s, R2=0.936]
[Epoch 40/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 180.05it/s, R2=0.934]
[Epoch 41/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 190.87it/s, R2=0.937]
[Epoch 42/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 156.98it/s, R2=0.936]
[Epoch 43/50]: 100%|█████████████████████████████████████████████████| 272/272 [00:01<00:00, 196.15it/s, R2=0.94]
[Epoch 44/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 177.40it/s, R2=0.939]
[Epoch 45/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 172.27it/s, R2=0.943]
[Epoch 46/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 176.18it/s, R2=0.947]
[Epoch 47/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 179.32it/s, R2=0.948]
[Epoch 48/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 169.89it/s, R2=0.947]
[Epoch 49/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 167.25it/s, R2=0.948]
[Epoch 50/50]: 100%|████████████████████████████████████████████████| 272/272 [00:01<00:00, 171.74it/s, R2=0.948]

訓練データに対するR2スコアを見ると0.948と、まずまずの精度になっていることが分かります。今回は、検証データやテストデータを準備していないので、モデルが過学習している可能性があるものの、非常に少ないパラメータで、Mambaが高い精度を実現できるのか確認できました。

おわりに

というわけで、今回は深層学習アーキテクチャの1つである Mamba について解説してきました。このアーキテクチャは、自然言語処理から画像認識、レコメンド、さらには創薬まで、幅広い分野で応用されています。その背景には、優れたモデリング能力と計算効率の良さがあります。最近はさらに高性能・低コストなモデル開発の研究も進んでいます。

今回の記事では、Mambaのアーキテクチャの進化や、Pythonによる実装例などをまとめてみました。Mambaの研究はまだまだ発展途上で、課題もあれば、これからが楽しみなポイントもたくさんあります。今後のMambaの進化に期待したいです。

More Information

  • arXiv:2312.00752, Albert Gu, Tri Dao, 「Mamba: Linear-Time Sequence Modeling with Selective State Spaces」, https://arxiv.org/abs/2312.00752
  • arXiv:2405.21060, Tri Dao, Albert Gu, 「Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality」, https://arxiv.org/abs/2405.21060
  • arXiv:2408.01129, Haohao Qu, Liangbo Ning, Rui An, Wenqi Fan, Tyler Derr, Hui Liu, Xin Xu, Qing Li, 「A Survey of Mamba」, https://arxiv.org/abs/2408.01129