TabNet: 表形式データ向け深層学習モデル

ディープラーニングは画像やテキストなどの分野で大きな成功を収めていますが、表形式データにおいては未だに決定木をベースにしたブースティング手法が主流です。しかし、表形式データは実世界において最も一般的なデータであり、ディープラーニングの適用が期待されています。

そこで今回は、表形式データ向けに提案されたDNNアーキテクチャである「TabNet」について紹介します。TabNetは、注意機構を用いて特徴量を選択するなど、解釈性や性能に優れたモデルです。また、TabNetの実装であるPythonパッケージ「pytorch-tabnet」の使い方についても解説します。

TabNetの仕組み

TabNetは、表形式のデータを扱うための新しい種類のニューラルネットワークです。従来のニューラルネットワークとは異なり、データの中から、その時に最も重要な情報(特徴量)を自動的に選び出し、学習するという特徴を持っています。

アルゴリズム

  1. 特徴量の選択:
    • TabNetは、データ中に含まれる複数の特徴量(例えば、年齢、性別、職業など)を処理する。
    • この特徴量の中から、最も重要な特徴量を自動的に選択する(注意機構による)。
    • 重要な特徴量だけを重点的に学習することで、より正確な予測が可能になる。
  2. シーケンシャルな処理:
    • TabNetは、一度に全ての情報を見るのではなく、段階的に情報を見る
    • 各段階で、異なる特徴量に注目することで、より複雑なパターンを捉えることが可能。
  3. スパース性:
    • TabNetは、重要な特徴量だけに注目するため、不要な情報に振り回されることがない。
    • このため、学習が効率的に行われ、よりシンプルなモデルを作ることが可能。
TabNetアーキテクチャ: TabNetは、EncoderDecoder、そしてFeature Transfomerと呼ばれるブロックから構成されるニューラルネットワークモデルです。(a)Encoderは、EncoderAttentive Transformer、そしてFeature Maskingの3つの主要な要素から成り立っています。Feature Transfomerは、入力データをより高次の表現に変換し、Attentive Transformerは、入力データのどの部分が重要かを学習します。Feature Maskingは、モデルがどの特徴に注目しているかを明らかにし、モデルの解釈性を高める役割を果たします。また、Splitブロックは、処理された表現を分割し、後続のステップのAttentive Transformerや全体の出力に利用されます。各ステップで生成される特徴選択マスクは、モデルの学習過程に関する解釈可能な情報を提供し、マスクを集約することで、グローバルな特徴の重要度を把握することができます。(b)Decoderは、各ステップでFeature Transfomerブロックを繰り返し、最終的な予測を行います。(c) Feature Transfomerブロックは、通常4層のネットワークで構成されます。このうち2層はすべてのDecision Stepで共有され、残りの2層はDecision Stepに応じて変化します。各層は、全結合層、バッチ正規化、そしてGLU非線形処理から構成されています。(d)Attentive Transformerブロックは、単一層のマッピングによって構成され、現在のDecision Stepの前に各特徴がどの程度使用されたかを表す事前スケール情報で調整されます。係数の正規化にはsparsemaxが用いられ、重要な特徴がスパースに選択されるように設計されています。

TabNetのメリット

  • 高い予測精度: 重要な特徴量にだけ注目するため、より正確な予測が可能。
  • 解釈性の高さ: モデルがどのような特徴量に注目して予測しているのかを可視化できるため、モデルの信頼性が高まる。
  • 教師なし学習: ラベルなしのデータでも学習できるため、データが少ない場合でも活用できる。

TabNetの性能

TabNetは、決定木ベースの手法や従来の深層ニューラルネットワークを上回る精度を達成すると報告されています。以下に紹介するように、様々なデータセットにおいて高い汎化性能を示すことが実証されており、実用的な場面で幅広く活用できることが証明されています。

データセット①: Forest Cover Type

TabNetは、決定木ベースのアンサンブル手法(XGBoost、LightGBM、CatBoost)やAutoMLフレームワークであるAutoML Tablesを上回る高い精度(テスト精度96.99%)を達成しています。

ModelTest accuracy (%)
XGBoost89.34
LightGBM89.28
CatBoost85.14
AutoMLTables94.95
TabNet96.99

データセット②: Poker Hand

このタスクは、ポーカーのカード情報(マークと数字)をもとに、その手が何役になるのかを分類するもので、データ不均衡や複雑なルールにより、従来のDNN、決定木、ハイブリッドモデルは低い精度しか得られませんでした。TabNetは、インスタンスごとの特徴量選択によって過学習を抑え、高い非線形処理能力を発揮することで、99.2%の精度を達成し、他の手法を大幅に上回っています。

ModelTest accuracy (%)
DT50
MLP50
Deep neural DT65.1
XGBoost71.1
LightGBM70
CatBoost66.6
TabNet99.2
Rule-based100

データセット③: Sarcos

ロボットアームの逆動力学を回帰するタスクです。 モデルサイズが小さい場合、TabNetは既存手法と同等の性能を、はるかに少ないパラメータ数で達成しています。 モデルサイズを大きくすると、TabNetの性能が大幅に向上し、テスト誤差を大幅に低減できています。

ModelTest MSEModel size
Random forest2.3916.7K
Stochastic DT2.1128K
MLP2.130.14M
Adaptive neural tree1.230.60M
Gradient boosted tree1.440.99M
TabNet-S1.256.3K
TabNet-M0.280.59M
TabNet-L0.141.75M

データセット④: Higgs Boson

ヒッグス粒子の検出と背景事象の区別を行うタスクです。 TabNetは、MLPよりも小規模なサイズで高い精度を達成しています。 スパース化手法同士の比較では、TabNetは同程度の性能を達成することが実証されています。

ModelTest acc. (%)Model size
Sparse evolutionary MLP78.4781K
Gradient boosted tree-S74.220.12M
Gradient boosted tree-M75.970.69M
MLP78.442.04M
Gradient boosted tree-L76.986.96M
TabNet-S78.2581K
TabNet-M78.840.66M

データセット⑤: Rossmann Store Sales

小売店の売上を予測するタスクです。 TabNetは、MLP、XGBoost、LightGBM、CatBoostなどの手法を上回る性能を達成しています。 特に、祝日などの特殊な場合において、インスタンスごとの特徴量選択が有効であることが確認されたようです。

ModelTest MSE
MLP512.62
XGBoost490.83
LightGBM504.76
CatBoost489.75
TabNet485.12

pytorch-tabnet を試してみる

TabNetには、PyPIで管理されているPythonパッケージがあるので、簡単に試すことができます。このパッケージは、PyTorchベースで実装されています。まずは、PyTorchをインストールします(インストールバージョンの指定については公式ページを参照してください)。

$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

続けて、pytorch-tabnet をインストールします。

$ pip install pytorch-tabnet

今回は、pytorch-tabnet のお試しとして、California Housing データセット(カリフォルニアの住宅価格を予測するデータセット)を使用します。このデータセットは、8つの説明変数(すべて連続変数でカテゴリ変数なし)と目的変数で構成されています。

まずは、必要なパッケージをインポートします。

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from pytorch_tabnet.augmentations import RegressionSMOTE
from pytorch_tabnet.tab_model import TabNetRegressor

from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

データセットについては、scikit-learn に内包されているものを使用します。

california_housing_data = fetch_california_housing()

# 説明変数については標準化しておく
data_x = StandardScaler().fit_transform(california_housing_data.data)
data_y = california_housing_data.target

# データを、訓練データ、検証データ、テストデータに分割
train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.1)
train_x, valid_x, train_y, valid_y = train_test_split(train_x, train_y, test_size=0.2)

では、TabNetの回帰タスク向けのインスタンスを生成します。各パラメータについては、コメントを参照してください。

model = TabNetRegressor(
    n_d=8,  # Decision Prediction Layer の幅: 値が大きいとモデルの容量が大きくなり、オーバーフィッティングのリスクがある。値は通常 8~64 の範囲をとる
    n_a=8,  # 各マスクのAttention Embedding の幅。論文によると、通常は n_d = n_a が良い選択である。
    n_steps=4,  # アーキテクチャーのステップ数(通常 3~10)
    gamma=1.3,  # マスクの特徴再利用の係数である。値に 1.0 を近づけると、レイヤー間のマスク選択の相関が最も低くなります。値の範囲は 1.0 ~ 2.0。
    cat_idxs=[],  # カテゴリ変数のインデックス
    cat_dims=[],  # カテゴリ変数に対するユニークな値の数
    cat_emb_dim=[],  # 各カテゴリ変数に対するエンベディングサイズ
    n_independent=2,  # 各ステップにおける Gated Linear Unit (GLU) のレイヤー数。通常の値は 1 ~ 5。
    n_shared=3,  # 各ステップにおける Gated Linear Unit (GLU) の共有数。通常は 1 ~ 5
    momentum=0.02,  # バッチ正規化のモメンタム。通常 0.01~ 0.4
    clip_value=None,  # floatが与えられた場合、この値で勾配をクリップする。
    lambda_sparse=1.0E-3,  # スパース性損失係数: この係数が大きいほど、特徴選択においてモデルがスパースになる
    mask_type="sparsemax",  # "sparsemax" or "entmax": 特徴選択するために使用するマスキング関数の指定
    grouped_features=[],  # 類似した特徴量をグループ化する際に指定する
    n_shared_decoder=1,  # デコーダーの共有GLUブロックの数。TabNetPretrainerにのみ有効
    n_indep_decoder=1,  # デコーダーの独立したGLUブロックの数。TabNetPretrainerにのみ有効。
    optimizer_fn=torch.optim.Adam,  # オプティマイザの指定
    optimizer_params={"lr": 0.02},  # オプティマイザのパラメータ
    scheduler_fn=None,  # 学習率のスケジューラー (Ex: torch.optim.lr_scheduler.LambdaLR)
    scheduler_params={},  # 学習率のスケジューラーに指定するパラメータ
    device_name="auto",  # デバイス名の指定
    verbose=1,  # 学習中のログを出力するかどうか (0 or 1)
    seed=12345,  # 乱数のシード
)

ここまでできたら、次のようにすることで学習を行うことができます。いくつかパラメータがありますが、コード中のコメントを参照してください。

model.fit(
    X_train=train_x,  # 訓練データの説明変数を指定
    y_train=train_y.reshape(-1, 1),  # 訓練データの目的変数を指定: 入力は、データ数 x 出力次元 の形状にする
    eval_set=[(train_x, train_y.reshape(-1, 1)), (valid_x, valid_y.reshape(-1, 1))],  # 評価用のデータセットを指定: 訓練データと検証データに対して評価する
    eval_name=["train", "valid"],  # 評価対象に名前をつける
    eval_metric=["rmsle", "mae", "rmse", "mse"],  # 評価指標を指定 (リストの最後に指定したものが Early Stopping の評価指標に使われる)
    max_epochs=200,  # 最大学習エポック数
    patience=50,  # Early Stoppingの条件とする連続未改善エポック数
    loss_fn=nn.MSELoss(),  # 損失関数の指定
    batch_size=256,  # バッチサイズの指定
    virtual_batch_size=64,  # "Ghost Batch Normalization" に使用されるバッチのサイズ
    num_workers=0,  # torch.utils.data.Dataloader で使用されるワーカー数
    drop_last=False,  # 最後のバッチをドロップするかどうか
    warm_start=False,  # 訓練でWarm Startをできるようにするかどうか
    compute_importance=True,  # 特徴量の重要度を計算するかどうか
    augmentations=RegressionSMOTE(p=0.2),  # オーグメンテーションの手法を指定: 不要であれば None を指定する
)

学習が完了したら、次のようにすることでモデルの精度を確認できます。

preds = model.predict(test_x)

# テストデータについて、MSEとR2スコアを計算
test_mse = mean_squared_error(y_pred=preds, y_true=test_y.reshape(-1, 1))
test_r2 = r2_score(y_pred=preds, y_true=test_y.reshape(-1, 1))

print(f"Valid MSE for California Housing : {model.best_cost}")
print(f"Test MSE for California Housing : {test_mse}")
print(f"Test R2 Score for California Housing : {test_r2}")

さらにfit()メソッドに、compute_importance=Trueを指定していれば、説明変数の重要度を確認することができます。この重要度は、XAI(説明可能なAI)における大域的説明に相当します。

importances = sorted(
    zip(model.feature_importances_, california_housing_data.feature_names),
    reverse=True,
    key=lambda x: x[0]
)
print(importances)

各データについての局所的説明も確認することができます。

explain_matrix, masks = model.explain(test_x)

# 説明対象のデータに対するインデックス
target_index = 10

importances = sorted(
    zip(explain_matrix[target_index , :], california_housing_data.feature_names),
    reverse=True,
    key=lambda x: x[0]
)
print(importances)

さらに、TabNetエンコーダーの各ステップ内のマスクの情報を、次のようにすることで可視化することができます。図中の縦方向はデータのインデックス、横方向が説明変数のインデックスを指しています。色の濃いところほど、モデルが重要視しているデータとなります。

fig, axs = plt.subplots(1, len(masks.keys()), figsize=(5, 5))

# テストデータの先頭30個を可視化
for i in masks.keys():
    axs[i].imshow(masks[i][:30])
    axs[i].set_title(f"mask {i}")

plt.tight_layout()
plt.show()

以上、pytorch-tabnet の基本的な使い方を紹介しました。今回は回帰タスクについて説明しましたが、pytorch-tabnet には、この他にも様々な使い方ができます。いくつかサンプルコードも用意されているので、興味があれば確認してみてください。

まとめ

今回は、表形式データのためのディープラーニングモデルであるTabNetを紹介しました。TabNetは、注意機構を用いた特徴量選択により、高い性能と解釈性を両立しています。様々なデータセットにおける実験結果から、TabNetが既存の手法を上回る性能を示すことも報告されています。

また、TabNetのPythonパッケージである pytorch-tabnet も紹介しました。このパッケージを使用することで、高精度のモデルを簡単に構築することができます。ブースティング系のモデルの性能に満足できない場合、TabNetを代替モデルに考えてみては如何でしょう?

More Informations