TabNet: 表形式データ向け深層学習モデル

ディープラーニングは画像やテキストなどの分野で大きな成功を収めていますが、表形式データにおいては未だに決定木をベースにしたブースティング手法が主流です。しかし、表形式データは実世界において最も一般的なデータであり、ディープラーニングの適用が期待されています。
そこで今回は、表形式データ向けに提案されたDNNアーキテクチャである「TabNet」について紹介します。TabNetは、注意機構を用いて特徴量を選択するなど、解釈性や性能に優れたモデルです。また、TabNetの実装であるPythonパッケージ「pytorch-tabnet」の使い方についても解説します。
TabNetの仕組み
TabNetは、表形式のデータを扱うための新しい種類のニューラルネットワークです。従来のニューラルネットワークとは異なり、データの中から、その時に最も重要な情報(特徴量)を自動的に選び出し、学習するという特徴を持っています。
アルゴリズム
- 特徴量の選択:
- TabNetは、データ中に含まれる複数の特徴量(例えば、年齢、性別、職業など)を処理する。
- この特徴量の中から、最も重要な特徴量を自動的に選択する(注意機構による)。
- 重要な特徴量だけを重点的に学習することで、より正確な予測が可能になる。
- シーケンシャルな処理:
- TabNetは、一度に全ての情報を見るのではなく、段階的に情報を見る。
- 各段階で、異なる特徴量に注目することで、より複雑なパターンを捉えることが可能。
- スパース性:
- TabNetは、重要な特徴量だけに注目するため、不要な情報に振り回されることがない。
- このため、学習が効率的に行われ、よりシンプルなモデルを作ることが可能。

TabNetのメリット
- 高い予測精度: 重要な特徴量にだけ注目するため、より正確な予測が可能。
- 解釈性の高さ: モデルがどのような特徴量に注目して予測しているのかを可視化できるため、モデルの信頼性が高まる。
- 教師なし学習: ラベルなしのデータでも学習できるため、データが少ない場合でも活用できる。
TabNetの性能
TabNetは、決定木ベースの手法や従来の深層ニューラルネットワークを上回る精度を達成すると報告されています。以下に紹介するように、様々なデータセットにおいて高い汎化性能を示すことが実証されており、実用的な場面で幅広く活用できることが証明されています。
データセット①: Forest Cover Type
TabNetは、決定木ベースのアンサンブル手法(XGBoost、LightGBM、CatBoost)やAutoMLフレームワークであるAutoML Tablesを上回る高い精度(テスト精度96.99%)を達成しています。
Model | Test accuracy (%) |
---|---|
XGBoost | 89.34 |
LightGBM | 89.28 |
CatBoost | 85.14 |
AutoMLTables | 94.95 |
TabNet | 96.99 |
データセット②: Poker Hand
このタスクは、ポーカーのカード情報(マークと数字)をもとに、その手が何役になるのかを分類するもので、データ不均衡や複雑なルールにより、従来のDNN、決定木、ハイブリッドモデルは低い精度しか得られませんでした。TabNetは、インスタンスごとの特徴量選択によって過学習を抑え、高い非線形処理能力を発揮することで、99.2%の精度を達成し、他の手法を大幅に上回っています。
Model | Test accuracy (%) |
---|---|
DT | 50 |
MLP | 50 |
Deep neural DT | 65.1 |
XGBoost | 71.1 |
LightGBM | 70 |
CatBoost | 66.6 |
TabNet | 99.2 |
Rule-based | 100 |
データセット③: Sarcos
ロボットアームの逆動力学を回帰するタスクです。 モデルサイズが小さい場合、TabNetは既存手法と同等の性能を、はるかに少ないパラメータ数で達成しています。 モデルサイズを大きくすると、TabNetの性能が大幅に向上し、テスト誤差を大幅に低減できています。
Model | Test MSE | Model size |
---|---|---|
Random forest | 2.39 | 16.7K |
Stochastic DT | 2.11 | 28K |
MLP | 2.13 | 0.14M |
Adaptive neural tree | 1.23 | 0.60M |
Gradient boosted tree | 1.44 | 0.99M |
TabNet-S | 1.25 | 6.3K |
TabNet-M | 0.28 | 0.59M |
TabNet-L | 0.14 | 1.75M |
データセット④: Higgs Boson
ヒッグス粒子の検出と背景事象の区別を行うタスクです。 TabNetは、MLPよりも小規模なサイズで高い精度を達成しています。 スパース化手法同士の比較では、TabNetは同程度の性能を達成することが実証されています。
Model | Test acc. (%) | Model size |
---|---|---|
Sparse evolutionary MLP | 78.47 | 81K |
Gradient boosted tree-S | 74.22 | 0.12M |
Gradient boosted tree-M | 75.97 | 0.69M |
MLP | 78.44 | 2.04M |
Gradient boosted tree-L | 76.98 | 6.96M |
TabNet-S | 78.25 | 81K |
TabNet-M | 78.84 | 0.66M |
データセット⑤: Rossmann Store Sales
小売店の売上を予測するタスクです。 TabNetは、MLP、XGBoost、LightGBM、CatBoostなどの手法を上回る性能を達成しています。 特に、祝日などの特殊な場合において、インスタンスごとの特徴量選択が有効であることが確認されたようです。
Model | Test MSE |
---|---|
MLP | 512.62 |
XGBoost | 490.83 |
LightGBM | 504.76 |
CatBoost | 489.75 |
TabNet | 485.12 |
pytorch-tabnet を試してみる
TabNetには、PyPIで管理されているPythonパッケージがあるので、簡単に試すことができます。このパッケージは、PyTorchベースで実装されています。まずは、PyTorchをインストールします(インストールバージョンの指定については公式ページを参照してください)。
$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
続けて、pytorch-tabnet をインストールします。
$ pip install pytorch-tabnet
今回は、pytorch-tabnet のお試しとして、California Housing データセット(カリフォルニアの住宅価格を予測するデータセット)を使用します。このデータセットは、8つの説明変数(すべて連続変数でカテゴリ変数なし)と目的変数で構成されています。
まずは、必要なパッケージをインポートします。
import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from pytorch_tabnet.augmentations import RegressionSMOTE from pytorch_tabnet.tab_model import TabNetRegressor from sklearn.datasets import fetch_california_housing from sklearn.metrics import mean_squared_error, r2_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler
データセットについては、scikit-learn に内包されているものを使用します。
california_housing_data = fetch_california_housing() # 説明変数については標準化しておく data_x = StandardScaler().fit_transform(california_housing_data.data) data_y = california_housing_data.target # データを、訓練データ、検証データ、テストデータに分割 train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.1) train_x, valid_x, train_y, valid_y = train_test_split(train_x, train_y, test_size=0.2)
では、TabNetの回帰タスク向けのインスタンスを生成します。各パラメータについては、コメントを参照してください。
model = TabNetRegressor( n_d=8, # Decision Prediction Layer の幅: 値が大きいとモデルの容量が大きくなり、オーバーフィッティングのリスクがある。値は通常 8~64 の範囲をとる n_a=8, # 各マスクのAttention Embedding の幅。論文によると、通常は n_d = n_a が良い選択である。 n_steps=4, # アーキテクチャーのステップ数(通常 3~10) gamma=1.3, # マスクの特徴再利用の係数である。値に 1.0 を近づけると、レイヤー間のマスク選択の相関が最も低くなります。値の範囲は 1.0 ~ 2.0。 cat_idxs=[], # カテゴリ変数のインデックス cat_dims=[], # カテゴリ変数に対するユニークな値の数 cat_emb_dim=[], # 各カテゴリ変数に対するエンベディングサイズ n_independent=2, # 各ステップにおける Gated Linear Unit (GLU) のレイヤー数。通常の値は 1 ~ 5。 n_shared=3, # 各ステップにおける Gated Linear Unit (GLU) の共有数。通常は 1 ~ 5 momentum=0.02, # バッチ正規化のモメンタム。通常 0.01~ 0.4 clip_value=None, # floatが与えられた場合、この値で勾配をクリップする。 lambda_sparse=1.0E-3, # スパース性損失係数: この係数が大きいほど、特徴選択においてモデルがスパースになる mask_type="sparsemax", # "sparsemax" or "entmax": 特徴選択するために使用するマスキング関数の指定 grouped_features=[], # 類似した特徴量をグループ化する際に指定する n_shared_decoder=1, # デコーダーの共有GLUブロックの数。TabNetPretrainerにのみ有効 n_indep_decoder=1, # デコーダーの独立したGLUブロックの数。TabNetPretrainerにのみ有効。 optimizer_fn=torch.optim.Adam, # オプティマイザの指定 optimizer_params={"lr": 0.02}, # オプティマイザのパラメータ scheduler_fn=None, # 学習率のスケジューラー (Ex: torch.optim.lr_scheduler.LambdaLR) scheduler_params={}, # 学習率のスケジューラーに指定するパラメータ device_name="auto", # デバイス名の指定 verbose=1, # 学習中のログを出力するかどうか (0 or 1) seed=12345, # 乱数のシード )
ここまでできたら、次のようにすることで学習を行うことができます。いくつかパラメータがありますが、コード中のコメントを参照してください。
model.fit( X_train=train_x, # 訓練データの説明変数を指定 y_train=train_y.reshape(-1, 1), # 訓練データの目的変数を指定: 入力は、データ数 x 出力次元 の形状にする eval_set=[(train_x, train_y.reshape(-1, 1)), (valid_x, valid_y.reshape(-1, 1))], # 評価用のデータセットを指定: 訓練データと検証データに対して評価する eval_name=["train", "valid"], # 評価対象に名前をつける eval_metric=["rmsle", "mae", "rmse", "mse"], # 評価指標を指定 (リストの最後に指定したものが Early Stopping の評価指標に使われる) max_epochs=200, # 最大学習エポック数 patience=50, # Early Stoppingの条件とする連続未改善エポック数 loss_fn=nn.MSELoss(), # 損失関数の指定 batch_size=256, # バッチサイズの指定 virtual_batch_size=64, # "Ghost Batch Normalization" に使用されるバッチのサイズ num_workers=0, # torch.utils.data.Dataloader で使用されるワーカー数 drop_last=False, # 最後のバッチをドロップするかどうか warm_start=False, # 訓練でWarm Startをできるようにするかどうか compute_importance=True, # 特徴量の重要度を計算するかどうか augmentations=RegressionSMOTE(p=0.2), # オーグメンテーションの手法を指定: 不要であれば None を指定する )
学習が完了したら、次のようにすることでモデルの精度を確認できます。
preds = model.predict(test_x) # テストデータについて、MSEとR2スコアを計算 test_mse = mean_squared_error(y_pred=preds, y_true=test_y.reshape(-1, 1)) test_r2 = r2_score(y_pred=preds, y_true=test_y.reshape(-1, 1)) print(f"Valid MSE for California Housing : {model.best_cost}") print(f"Test MSE for California Housing : {test_mse}") print(f"Test R2 Score for California Housing : {test_r2}")
さらにfit()
メソッドに、compute_importance=True
を指定していれば、説明変数の重要度を確認することができます。この重要度は、XAI(説明可能なAI)における大域的説明に相当します。
importances = sorted( zip(model.feature_importances_, california_housing_data.feature_names), reverse=True, key=lambda x: x[0] ) print(importances)
各データについての局所的説明も確認することができます。
explain_matrix, masks = model.explain(test_x) # 説明対象のデータに対するインデックス target_index = 10 importances = sorted( zip(explain_matrix[target_index , :], california_housing_data.feature_names), reverse=True, key=lambda x: x[0] ) print(importances)
さらに、TabNetエンコーダーの各ステップ内のマスクの情報を、次のようにすることで可視化することができます。図中の縦方向はデータのインデックス、横方向が説明変数のインデックスを指しています。色の濃いところほど、モデルが重要視しているデータとなります。
fig, axs = plt.subplots(1, len(masks.keys()), figsize=(5, 5)) # テストデータの先頭30個を可視化 for i in masks.keys(): axs[i].imshow(masks[i][:30]) axs[i].set_title(f"mask {i}") plt.tight_layout() plt.show()

以上、pytorch-tabnet の基本的な使い方を紹介しました。今回は回帰タスクについて説明しましたが、pytorch-tabnet には、この他にも様々な使い方ができます。いくつかサンプルコードも用意されているので、興味があれば確認してみてください。
- binary classification examples (2値分類タスク)
- multi-class classification examples (多クラス分類タスク)
- regression examples (回帰タスク)
- multi-task regression examples (マルチ回帰タスク)
- multi-task multi-class classification examples (マルチタスク・多クラス分類)
- kaggle moa 1st place solution using tabnet (Kaggleコンペ への応用)
まとめ
今回は、表形式データのためのディープラーニングモデルであるTabNetを紹介しました。TabNetは、注意機構を用いた特徴量選択により、高い性能と解釈性を両立しています。様々なデータセットにおける実験結果から、TabNetが既存の手法を上回る性能を示すことも報告されています。
また、TabNetのPythonパッケージである pytorch-tabnet も紹介しました。このパッケージを使用することで、高精度のモデルを簡単に構築することができます。ブースティング系のモデルの性能に満足できない場合、TabNetを代替モデルに考えてみては如何でしょう?
More Informations
- arXiv:1908.07442, Sercan O. Arik, Tomas Pfister, 「TabNet: Attentive Interpretable Tabular Learning」, https://arxiv.org/abs/1908.07442