LitServe: 機械学習モデルの効率的なデプロイ

機械学習モデルは、FlaskFlastAPIなどのWebフレームワークを使用して、WebAPIとしてデプロイされることが一般的です。これらのフレームワークは、WebAPIを構築するための便利な機能が豊富に含まれています。しかしながら、機械学習モデルに特化したWebAPをリリースしたいケースでは、機能過多である場合があります。

そこで、最近注目されているPythonパッケージとして「LitServe」というものがあります。LitServeは、機械学習モデルを効率的にデプロイできることが特徴です。今回は、このLitServeについて、入門的な解説をしたいと思います。

LitServeとは?

以下に、LitServeの特徴を簡潔にまとめます。

  • 高性能な推論エンジンFastAPIをベースに、デプロイ速度を従来の2倍以上に向上。
  • 柔軟なマルチモデル対応PyTorchTensorFlowHugging Faceモデルの統合をサポート。
  • GPU自動スケーリング:負荷に応じたリソースの動的調整を実現。
  • 簡易性:インストールから設定まで数行のコードで開始可能。
  • クラウド対応AWSGCPを利用したスケーラブルなクラウド展開をサポート。

LitServeの応用事例

LitServe ExamplesというGitHubリポジトリで、LitServeの実装例が紹介されています。

NameDescription
Whisper Speech to Text APIOpenAI Whisper を搭載した音声テキスト変換 API
Chat with Llama 3.2-Vision Multimodal LLM画像入力にも対応した Llama 3.2 を使用した会話型 AI
Speech Generation API using Parler TTSParler TTS を使用したテキスト読み上げ API
Chat with Qwen2-VLQwen2-VL を使用したマルチモーダル チャットボットの実装
Jina CLIP v2 Embeddings APIテキストと画像用の多言語マルチモーダルのエンベディング API

LitServeのインストール方法と使い方

LitServeは、PyPIから簡単にインストール可能です。

$ pip install litserve

# Linux環境では uvloop をインストールすることでパフォーマンスが向上する。
# (この記事を書いている段階では、Windows環境で未サポート)
$ pip install uvloop

では、LitServeのサンプルコードを次に紹介します。

# -*- coding: utf-8 -*-

import litserve as ls
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


class TextGenAPI(ls.LitAPI):
    def setup(self, device: str) -> None:
        self.model = AutoModelForCausalLM.from_pretrained(
            "cyberagent/open-calm-small",
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-small")

        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=128,
            do_sample=True,
            temperature=0.01,
        )

    def decode_request(self, request: dict) -> str:
        return request["input"]

    def predict(self, x: str) -> list[dict]:
        return self.pipeline(x)

    def encode_response(self, output: list[dict]) -> dict:
        return {"text": output[0]["generated_text"]}


if __name__ == "__main__":
    server = ls.LitServer(TextGenAPI(), accelerator="auto", max_batch_size=1)
    server.run(port=8000)

まず、LitAPIクラスを継承して新しいクラスを定義します。このクラスには、以下の4つのメソッドを実装します。

  • setup: サーバーの起動時に、1 回だけ呼び出され、主に、次の目的で使用します。
    • データの読み込み
    • モデルの読み込みと初期化
    • データベース接続
  • decode_request: 受信したリクエスト ペイロードをモデル対応の入力に変換します。
  • predict: decode_request の出力を使用して、モデルで推論を実行します。
  • encode_response: predict の出力をレスポンスのフォーマットに変換します。

次に、上記のメソッドを実装したクラスをLitServerクラスに指定してインスタンスを生成し、runメソッドでサーバーを起動します。

あとは、ここまで実装したファイルをスクリプトとして実行するだけで、サーバーが立ち上がります。

$ python <スクリプトファイル名>

サーバーを立ち上げると、以下の内容が実装された client.py というファイルが自動で生成されるので、サーバーの動作確認などに使用しましょう。

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests

response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")

まとめ

今回は、LitServeの基本的な使い方を紹介しました。FlaskFastAPIなどと比べると、はるかに簡単にWebAPIを構築できることが分かったかと思います。機械学習モデルの推論機能のみをWebAPIとしてデプロイしたい場合は、LitServeを利用することで開発コストを削減できます。

さらに詳しく知りたい場合は、公式のドキュメントGitHubリポジトリなどを参照してください。