KAN: Kolmogorov–Arnold Networks

ディープラーニングモデルの多くは、多層パーセプトロン(MLP)に大きく依存していますが、MLPには、解釈が難しさや、Transformerなどのモデルでは埋め込みパラメータ以外のほぼすべてのパラメータを消費してしまうという欠点があります。

そこで、Ziming Liu等は、MLPに代わる有望なモデルとして、コルモゴロフ・アーノルド・ネットワーク(KAN)を提案しました。 MLPは普遍性近似定理により、表現力の高さを保証されているに対し、KANはコルモゴロフ・アーノルド表現定理に触発されています。 MLPと同様に、KANも全結合層を持ちますが、MLPがノード(「ニューロン」)に固定活性化関数を配置するのに対し、KANはエッジ(「重み」)に学習可能な活性化関数を配置します。

結果として、KANには線形重み行列は一切なく、代わりに各重みパラメータは学習可能な1次元関数に置き換えられ、スプライン関数としてパラメータ化されます。 KANのノードは、非線形性を適用することなく、入力信号を加算するだけです。

コルモゴロフ・アーノルド表現定理とは?

コルモゴロフ・アーノルド表現定理は、多次元の連続関数を、1変数の連続関数の合成と加算によって表現できることを示した数学の定理です。つまり、任意の多次元関数は、適切な1変数関数を組み合わせることで表現できることを意味します。

参考までに、数学的な表現を示すと、以下のようになります。

$$
f(\mathbf{x}) = f (x_1, \cdots, x_n) = \sum_{q=0}^{2n} \Phi_q \left(\sum_{p=1}^n \phi_{q,p}(x_p)\right) \hspace{5mm} \text{where} \hspace{2mm} \phi_{q,p} :[0, 1]\rightarrow \mathbb{R} \hspace{2mm} \text{and} \hspace{2mm} \Phi_q \rightarrow \mathbb{R}
$$

KANのアーキテクチャ

KANを定式化すると、以下のようになります。

$$
\text{KAN}(\mathbf{x}) = \left(\Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_1 \circ \Phi_0 \right)\mathbf{x}
$$

ここで、\( \Phi_l \)は関数行列を表しており、以下のような構造を持ちます。

$$
\Phi_l =
\begin{pmatrix}
\phi_{l,1,1}(\cdot) & \phi_{l,1,2}(\cdot) & \cdots & \phi_{l,1,n_l}(\cdot) \\
\phi_{l,2,1}(\cdot) & \phi_{l,2,2}(\cdot) & \cdots & \phi_{l,2,n_l}(\cdot) \\
\vdots & \vdots & \ddots & \vdots \\
\phi_{l,n_l+1,1}(\cdot) & \phi_{l,n_l+1,2}(\cdot) & \cdots & \phi_{l,n_l+1,n_l}(\cdot) \\
\end{pmatrix}
$$

さらに、要素関数\( \phi_l \)は次の通り定義されます。

$$
\begin{eqnarray}
\phi(x) &=& w_b b(x) + w_s \text{spline}(x) \\
b(x) &=& \text{silu}(x) = \frac{x}{1 + e^{-x}}
\end{eqnarray}
$$

また、スプライン関数\( \text{spline}(x) \)はBスプラインの線形結合で定義されます。

$$
\text{spline}(x) = \sum_{i} c_i B_i(x)
$$

KANの解釈性

KANは、その構造上、解釈性の高いモデルとなっています。KANの解釈性を高めるための技術として、以下のものがあります。

  1. スパース化: L1正則化とエントロピー正則化を用いることで、KANのスパース化を促進します。これにより、重要でない活性化関数を特定し、モデルの解釈性を向上させることができます。
  2. 可視化: 活性化関数の透明度をその大きさによって変化させることで、どの入力変数が重要であるかを視覚的に把握することができます。
  3. シンボル化: 活性化関数が特定のシンボル関数 \( \cos \) や \(\log \) などで表現できる可能性がある場合、その関数を指定して固定することができます。

KANの性能

KANは、MLPと比較して、いくつかのタスクにおいてより優れた性能を示すことが示されています。例えば、合成データセットを用いた実験では、KANはMLPよりも少ないパラメータ数で同等以上の精度を達成しています。

また、偏微分方程式(PDE)の解を求めるタスクにおいても、KANはMLPよりも高速に収束し、より低い誤差を達成することが示されています。

特に、KANはデータが組成的な構造を持つ場合に、次元の呪いを克服できる可能性があります。これは、KANが各変数を個別に扱うため、高次元データに対しても効率的に学習できるためです。

しかし、KANの学習は、MLPと比較して一般的に低速であるという欠点も存在します。これは、KANの活性化関数がそれぞれ異なるため、バッチ計算を利用できないことが原因の一つとして考えられます。

KANの実装

KANのPython実装はpykanと呼ばれるパッケージに実装されています。pipコマンドによるインストールは以下の通りです。

$ pip install pykan

pykanを利用したサンプルコードを以下に掲載します。

import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from kan import *


def load_datasets():
    dataset = {}

    data = load_iris()
    x_train, x_test, y_train, y_test = train_test_split(data["data"], data["target"], test_size=0.25)

    dataset["train_input"] = torch.from_numpy(x_train)
    dataset["test_input"] = torch.from_numpy(x_test)
    dataset["train_label"] = torch.from_numpy(y_train[:, None])
    dataset["test_label"] = torch.from_numpy(y_test[:, None])

    return dataset


def print_accuracy(model, dataset):
    train_acc = torch.mean(
        (torch.round(model(dataset["train_input"])[:, 0]) == dataset["train_label"][:, 0]).float()
    )

    test_acc = torch.mean(
        (torch.round(model(dataset["test_input"])[:, 0]) == dataset["test_label"][:, 0]).float()
    )

    print(f"train_acc={train_acc}, test_acc={test_acc}")


if __name__ == "__main__":
    dataset = load_datasets()

    model = KAN(width=[4, 2, 1], grid=3, k=3)

    model.train(dataset, opt="LBFGS", steps=20)
    print_accuracy(model, dataset)

    model.auto_symbolic(lib=['x','x^2', 'tanh','sin','tan','abs'], verbose=0)
    formula = model.symbolic_formula()[0][0]
    print(formula)
More Information: 「KAN: Kolmogorov-Arnold Networks」, arXiv:2404.19756, Ziming Liu et el, https://arxiv.org/abs/2404.19756