LoRAによる大規模言語モデルのファインチューニング

ChatGPTをはじめとする大規模言語モデル(LLM)は、その高度な言語処理能力で注目を集めています。しかし、LLMベンダーが提供するモデルを利用する場合、利用料金の問題や、機密情報を含むデータの外部への漏洩リスクが懸念されます。

そこで注目されているのが、自社のデータでオープンソースのLLMをファインチューニングする手法です。

今回は、中でも効率的なファインチューニング手法であるLoRAに焦点を当て、LLMをカスタマイズする方法をコードを交えて解説します。

LoRAとは?

LoRA (Low-Rank Adaptation) は、大規模の深層学習モデルを効率的にファインチューニングするための手法です。従来のファインチューニングでは、すべてのモデルパラメータを更新するため、大量のメモリと計算資源が必要でした。しかし、LoRAは、更新するパラメータの数を大幅に削減することで、これらの課題を解決します。

LoRAの肝となる考え方は、低ランク行列分解によるパラメータの更新にあります。元の重み行列 \(W_0\)​ に対して、低ランク行列 \(B \in \mathbb{R}^{d \times r}\) と \(A \in \mathbb{R}^{r \times k}\) を用いて、更新後の重み行列 \(W\) を次のように表します。

$$
h = W_0 x + \Delta W x = W_0 x + B A x
$$

ここで、\(r\) は低ランクの次元であり、\(r \ll \min(d,k)\) (\(d\) は入力次元、\(k\) は出力次元) を満たします。この手法により、更新するパラメータの数は大幅に減少し、メモリ消費と計算コストが削減されます。

LoRAは、自然言語処理だけでなく、拡散モデルによる画像生成など、他の深層学習モデルにも適用することができます。例えば、Stable Diffusionなどの画像生成モデルにおいても、LoRAを用いて、特定のスタイルや概念に特化したモデルを効率的に作成することができます。

図1. (a) Full Fine-tuning、(b) LoRA、(c) 低ランクのボトルネックの解消、(d) 動的ランク割り当て (https://arxiv.org/pdf/2407.11046 より引用)

LoRAの主な利点は以下の通りです:

  • メモリ消費の削減: 更新するパラメータ数が大幅に減るため、メモリ消費を大幅に削減
  • 計算コストの削減: 更新するパラメータ数が少ないため、計算コストを大幅に削減
  • 高速なファインチューニング: 少ないパラメータを更新するため、ファインチューニングの速度が向上
  • 柔軟なモデル切り替え: 推論時に、異なるタスク用に異なるパラメータ \(B\) と \(A\) をロードすることで、簡単にモデルを切り替えることができる。

LoRAの実践

では、ここからはPythonコードを書いてLLMをファインチューニングしていきます。まずは、次の通り、必要なパッケージをインストールします。

# PyTorchのインストール
# インストールするバージョンの指定については、公式ページを参照のこと。
# 参照: https://pytorch.org/get-started/previous-versions/
$ pip install torch torchvision torchaudio

# Transformersと、その関連パッケージをインストール
$ pip install --upgrade transformers accelerate datasets

# PEFTにLoRA関連の機能が実装されている
$ pip install --upgrade peft

ここからPythonコードを書いていきますが、まずは必要なパッケージを読み込んでおきます。

import datasets
import torch
import wandb

from peft import get_peft_model, LoraConfig, PeftModel, TaskType
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

LLMのトレーニングには、WandBへのログインが必要になるので、ここで実行しておきます。APIキーの入力が求められるので、ご自身で取得したものを入力して、Enterキーを押してください。

# WandB へのログイン
wandb.login(relogin=True)

次は、ファインチューニング対象のモデルをダウンロードします。今回は、Llama 2 に 日本語データで追加事前学習をさせた ELYZA-japanese-Llama-2-7b を使用します。

model_path = "elyza/ELYZA-japanese-Llama-2-7b-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")

model.eval()
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

ファインチューニングを実施する前に、適当な入力でこのモデルの応答を確認してみましょう。

request = "ソフトウェアエンジニアが優先的に学ぶべきプログラミング言語を教えてください。"

input_ids = tokenizer.encode(request, return_tensors="pt").to(device=model.device)
tokens = model.generate(
    input_ids,
    max_new_tokens=1024,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
)

start_pos = len(input_ids[0])
response = tokenizer.decode(tokens[0][start_pos:], skip_special_tokens=True)
print(response)
現在プログラミングを学ぼうとお考えの方、今後ソフトウェアエンジニアとしてキャリアを向上させたい方、その学習の第一歩として今一度プログラミング言語について考える必要があります。
学習サイクルの組み立て方、効率の良い学習の進め方など、その道のプロであるエンジニアの方々からのアドバイスをもとに自身に最適な学習計画を立てていきましょう。

現在のソフトウェア業界ではさまざまなプログラミング言語が用いられています。様々な企業で用いられている言語の特徴や違いについて理解し、自身の強みや最終目標と照らし合わせながら、自分に最も合ったプログラミング言語を見つけましょう。
Java、C、Pythonなど、今一般的に用いられている言語は様々な場面で活用されており、その多くはオブジェクト指向的な設計思想を持っています。オブジェクト指向とは、プログラムを小さな部品で構成し、それを組み合わせて最終的な機能を作り上げることを指します。今後ソフトウェア業界でキャリアをスタートさせる上で非常に役に立つ思想なので、今後学習計画の中に組み込んでいくことをおすすめします。
また機械学習やアプリケーションサーバー関連の言語としてGo、Ruby、Swiftがあり、 IoTやAIサービスの分野ではKotlinなどが注目されています。その他これから注目される言語としてPHPなどがあります。技術進化に伴い、絶えず言語が入れ替わっており、今後もまだまだ新しい言語が誕生していくでしょう。今一度自分の遅れていないのか、言語の最新トレンドが押さえられているかを見つめ直し、自身のレベルに合わせて言語を選択していきましょう。
PHPを学ぶメリットについてご紹介します。
PHPはさまざまなシステムが動くインフラとして広く使われています。また、世界中の多くのWEBサイトにPHPは採用されており、PHPのスキルを持っていることでWEBエンジニアとして幅広い就職先が選択できる可能性があります。PHPの高い需要に伴い、求人数も多く、PHPエンジニアの需要は今後も増加すると予想されています。 PHPに関心を持ち、エンジニアとしてキャリアを向上させたいとお考えの方は、学習を始めてみましょう。
また、PHPはデータベースにも対応しているため、フレームワーク

最大トークンの指定があるため、出力が途中で切れていますが、自然な日本語のテキストが生成されていることがわかります。

では、これをファインチューニングしてきますが、今回使用するデータセットは HuggingFace で公開されている databricks-dolly-15k-ja-gozaru にします。早速、ダウンロードして、データの中身を確認します。

# データセットのダウンロード
dolly_ja = datasets.load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozaru")

print(f"データ数: {len(dolly_ja['train'])}")
print(f"データサンプル:\n{dolly_ja['train'][0]}")
データ数: 15015
データサンプル:
{'category': 'closed_qa', 'instruction': 'ヴァージン・オーストラリア航空はいつから運航を開始したのですか?', 'input': 'ヴァージン・オーストラリア航空(Virgin Australia Airlines Pty Ltd)はオーストラリアを拠点とするヴァージン・ブランドを冠する最大の船団規模を持つ航空会社です。2000年8月31日に、ヴァージン・ブルー空港として、2機の航空機、1つの空路を運行してサービスを開始しました。2001年9月のアンセット・オーストラリア空港の崩壊後、オーストラリアの国内市場で急速に地位を確立しました。その後はブリスベン、メルボルン、シドニーをハブとして、オーストラリア国内の32都市に直接乗り入れるまでに成長しました。', 'output': 'ヴァージン・オーストラリア航空は、2000年8月31日にヴァージン・ブルー航空として、2機の航空機で単一路線の運航を開始しましたでござる。', 'index': '0'}

データはdict形式になっていますが、outputが出力に相当します。見るとわかりますが、末尾が「~ござる」になっていてサムライ口調です。今回は、モデルをファインチューニングすることで、回答を「ござる」口調にしていきます。

モデルへの指示に相当する箇所であるinstructionと、出力のoutputからデータセットを準備していきます。

template = "ユーザー: {instruction}\nシステム: {output}"

datalist = []
for i in range(len(dolly_ja["train"])):
    d = dolly_ja["train"][i]
    if d["input"] == "":
        ptext = template.format_map(d)
        # 今回は、GPUのメモリ不足でエラーにならないように長文を学習対象から除外
        if (len(ptext) < 200):
            datalist.append(ptext)

上記で作成したデータをもとに、torch.utils.data.Datasetを継承してデータセットのクラスを作成します。

class DollyJaDataset(Dataset):
    def __init__(self, datalist, tokenizer):
        self.tokenizer = tokenizer
        self.features = []
        for ptext in datalist:
            input_ids = self.tokenizer.encode(ptext)
            input_ids = input_ids + [ self.tokenizer.eos_token_id ]
            input_ids = torch.LongTensor(input_ids)
            self.features.append({'input_ids': input_ids})

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx]


train_dataset = DollyJaDataset(datalist, tokenizer)

以上でデータセットの準備は完了です。次は、LoRAを利用するのでもとのモデルのどの層にアダプタを追加していくか決めていきます。対象となるのは、線形変換の層(Linear層)です。上記で表示したモデルの構成を確認して、対象の線形変換の層を決定します。では、今回はLlamaMLPの中で使用されている線形変換に対してアダプタを追加することにしましょう。

lora_config = LoraConfig(
    r=4,
    lora_alpha=8,
    
    # アダプタを追加する線形変換の層を指定
    target_modules=["gate_proj", "up_proj", "down_proj"],

    lora_dropout=0.05,
    bias="none",
    fan_in_fan_out=False,
    task_type=TaskType.CAUSAL_LM
)

lora_model = get_peft_model(model, lora_config)

LoRAモデルができたので、あとはこれを学習させていくだけです。早速、次のようにして学習を開始します(Google Colaboratory T4 GPU の環境であれば90分程度で学習が完了します)。

training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=3,
    save_steps=200,
    fp16=True,
    save_strategy='epoch',
    per_device_train_batch_size=1,
    logging_steps=20,
)

trainer = Trainer(
    model=lora_model,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    args=training_args,
    train_dataset=train_dataset
)
trainer.train()

学習が完了したら、モデルに適当な指示を与えて「ござる」口調になっているか確認しましょう。

ptext = template.format_map({"instruction": "東京の魅力的なところは?", "output": ""})
input_ids = tokenizer.encode(ptext, add_special_tokens=False, return_tensors="pt").to(model.device)

with torch.no_grad():
    tokens = lora_model.generate(
        input_ids=input_ids,
        max_new_tokens=200,
        temperature=1.0,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

start_pos = len(input_ids[0])
output = tokenizer.decode(tokens[0][start_pos:], skip_special_tokens=True)
print(output)
東京は、日本の首都であり、日本の経済の中心地でもありますでござる。東京には、日本の歴史を垣間見ることができるような場所もありますでござる。東京には、日本の文化を体験することができるような場所もありますでござる。東京には、日本の食文化を味わうことができるような場所もありますでござる。東京には、日本の芸術を体験することができるような場所もありますでご

指定した最大トークン数で出力が途切れていますが、期待通り、「ござる」口調になっていることが確認できました。

なお、学習したパラメータはTrainingArgumentsに指定したディレクトリに保存されています。今回の場合であれば./outputにチェックポイントのデータが保存されていることを確認できます。これの学習済みパラメータをもとにLoRAモデルを構築する場合は、次のようにしてください。

# ベースモデル
model_path = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")

# LoRAの適用
lora_name = "output/checkpoint-3115"   # 環境によってファイル名は異なります
lora_model = PeftModel.from_pretrained(base_model, lora_name)

おわりに

今回は、大規模言語モデル(LLM)のファインチューニング手法の一つであるLoRAについて、その仕組みやメリット、そしてPythonによる実装例を紹介しました。

LoRAは、従来のファインチューニング手法に比べて、少ない計算資源で、既存のモデルを微調整し、特定のタスクに特化したモデルを構築できるという大きなメリットがあります。このため、LLMを自社でカスタマイズしたいと考えている企業や研究機関にとって、非常に魅力的な技術です。

ただし、LoRAは万能ではありません。バッチ処理の複雑化など、いくつかの課題も存在します。実際、LoRAの問題を指摘し、それを改善する方法を提案する研究論文はいくつか存在します。興味ある方は、LoRAの派生手法などを調べてみるとよいかと思います。

More Information: