GKE で Saxml を実行してマルチホスト TPU を使用して LLM を提供する


このチュートリアルでは、Saxml を使用して、Google Kubernetes Engine(GKE)でマルチホスト TPU スライス ノードプールを利用し、大規模言語モデル(LLM)をデプロイして提供する方法について説明します。これにより、効率的でスケーラブルなアーキテクチャを実現できます。

背景

Saxml は、PaxmlJAXPyTorch の各フレームワークを提供する試験運用版のシステムです。TPU を使用すると、これらのフレームワークでデータ処理を高速化できます。GKE で TPU のデプロイのデモを行うため、このチュートリアルでは 175B の LmCloudSpmd175B32Test テストモデルを使用します。GKE は、このテストモデルをそれぞれ 4x8 トポロジの 2 つの v5e TPU スライス ノードプールにデプロイします。

テストモデルを適切にデプロイするために、TPU トポロジはモデルのサイズに基づいて定義されています。N x 10 億の 16 ビットモデルには約 2 倍(2 x N)の GB 数のメモリが必要ですが、175B LmCloudSpmd175B32Test モデルには約 350 GB のメモリが必要です。TPU v5e シングル TPU チップの容量は 16 GB です。350 GB をサポートするには、GKE に 21 個の v5e TPU チップが必要です(350÷16= 21)。TPU 構成のマッピングに基づいて、このチュートリアルの適切な TPU 構成は次のようになります。

  • マシンタイプ: ct5lp-hightpu-4t
  • トポロジ: 4x8(32 個の TPU チップ)

GKE に TPU をデプロイする場合は、モデルの提供に適した TPU トポロジを選択することが重要です。詳細については、TPU 構成の計画をご覧ください。

目標

このチュートリアルは、データモデルを提供するために GKE オーケストレーション機能を使用する MLOps または DevOps エンジニア、プラットフォーム管理者を対象としています。

このチュートリアルでは、次の手順について説明します。

  1. GKE Standard クラスタで環境を準備します。クラスタには、4x8 トポロジの 2 つの v5e TPU スライス ノードプールがあります。
  2. Saxml をデプロイします。Saxml には、管理者サーバー、モデルサーバーとして機能する Pod のグループ、事前に構築された HTTP サーバー、ロードバランサが必要です。
  3. Saxml を使用して LLM を提供します。

次の図は、このチュートリアルで実装するアーキテクチャを示しています。

GKE 上のマルチホスト TPU のアーキテクチャ。
図: GKE 上のマルチホスト TPU のアーキテクチャ例。

始める前に

  • In the Trusted Cloud console, on the project selector page, select or create a Trusted Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Trusted Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin, roles/iam.policyAdmin

    Check for the roles

    1. In the Trusted Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Trusted Cloud console, go to the IAM page.

      IAM に移動
    2. プロジェクトを選択します。
    3. [ アクセスを許可] をクリックします。
    4. [新しいプリンシパル] フィールドに、ユーザー ID を入力します。 これは通常、Workforce Identity プール内のユーザーの ID です。詳細については、IAM ポリシーで Workforce プールユーザーを表すをご覧いただくか、管理者にお問い合わせください。

    5. [ロールを選択] リストでロールを選択します。
    6. 追加のロールを付与するには、 [別のロールを追加] をクリックして各ロールを追加します。
    7. [保存] をクリックします。
    8. 環境を準備する

      1. Trusted Cloud コンソールで、Cloud Shell インスタンスを起動します。
        Cloud Shell を開く

      2. デフォルトの環境変数を設定します。

          gcloud config set project PROJECT_ID
          export PROJECT_ID=$(gcloud config get project)
          export ZONE=COMPUTE_ZONE
          export BUCKET_NAME=PROJECT_ID-gke-bucket
        

        次の値を置き換えます。

        このコマンドの BUCKET_NAME で、Saxml 管理者サーバーの構成を保存する Trusted CloudStorage バケットの名前を指定します。

      GKE Standard クラスタを作成する

      Cloud Shell で以下の操作を行います。

      1. Workload Identity Federation for GKE を使用する Standard クラスタを作成します。

        gcloud container clusters create saxml \
            --zone=${ZONE} \
            --workload-pool=${PROJECT_ID}.svc.id.goog \
            --cluster-version=VERSION \
            --num-nodes=4
        

        VERSION は、GKE のバージョン番号に置き換えます。GKE は、バージョン 1.27.2-gke.2100 以降で TPU v5e をサポートしています。詳細については、GKE での TPU の可用性をご覧ください。

        クラスタの作成には数分かかることもあります。

      2. tpu1 という名前で 1 つ目のノードプールを作成します。

        gcloud container node-pools create tpu1 \
            --zone=${ZONE} \
            --machine-type=ct5lp-hightpu-4t \
            --tpu-topology=4x8 \
            --num-nodes=8 \
            --cluster=saxml
        

        --num-nodes フラグの値は、TPU トポロジを TPU スライスあたりの TPU チップ数で除算して計算されます。この例の場合は、(4 * 8)/4 の計算になります。

      3. tpu2 という名前で 2 つ目のノードプールを作成します。

        gcloud container node-pools create tpu2 \
            --zone=${ZONE} \
            --machine-type=ct5lp-hightpu-4t \
            --tpu-topology=4x8 \
            --num-nodes=8 \
            --cluster=saxml
        

        --num-nodes フラグの値は、TPU トポロジを TPU スライスあたりの TPU チップ数で除算して計算されます。この例の場合は、(4 * 8)/4 の計算になります。

      次のリソースを作成しました。

      • 4 つの CPU ノードを持つ Standard クラスタ。
      • 4x8 トポロジを持つ 2 つの v5e TPU スライス ノードプール。各ノードプールは、それぞれ 4 つの TPU チップを持つ 8 つの TPU スライスノードを表します。

      175B モデルは、少なくとも 4x8 トポロジ スライス(32 個の v5e TPU チップ)を持つマルチホスト v5e TPU スライスで提供する必要があります。

      Cloud Storage バケットを作成する

      Saxml 管理者サーバーの構成を保存する Cloud Storage バケットを作成します。実行中の管理者サーバーは、その状態と公開モデルの詳細を定期的に保存します。

      Cloud Shell で次のコマンドを実行します。

      gcloud storage buckets create gs://${BUCKET_NAME}
      

      Workload Identity Federation for GKE を使用してワークロード アクセスを構成する

      アプリケーションに Kubernetes ServiceAccount を割り当て、IAM サービス アカウントとして機能するようにその Kubernetes ServiceAccount を構成します。

      1. クラスタと通信を行うように kubectl を構成します。

        gcloud container clusters get-credentials saxml --zone=${ZONE}
        
      2. アプリケーションで使用する Kubernetes ServiceAccount を作成します。

        kubectl create serviceaccount sax-sa --namespace default
        
      3. アプリケーションの IAM サービス アカウントを作成します。

        gcloud iam service-accounts create sax-iam-sa
        
      4. IAM サービス アカウントの IAM ポリシー バインディングを追加して、Cloud Storage に対する読み取りと書き込みを行います。

        gcloud projects add-iam-policy-binding ${PROJECT_ID} \
          --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.s3ns-system.iam.gserviceaccount.com" \
          --role roles/storage.admin
        
      5. 2 つのサービス アカウントの間に IAM ポリシー バインディングを追加して、Kubernetes ServiceAccount が IAM サービス アカウントの権限を借用できるようにします。このバインドで、Kubernetes ServiceAccount が IAM サービス アカウントとして機能するようになるため、Kubernetes ServiceAccount が Cloud Storage に対して読み書きを行うことができます。

        gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.s3ns-system.iam.gserviceaccount.com \
          --role roles/iam.workloadIdentityUser \
          --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
        
      6. Kubernetes サービス アカウントに IAM サービス アカウントのメールアドレスでアノテーションを付けます。これにより、サンプルアプリが Trusted Cloud サービスへのアクセスに使用するサービス アカウントを認識できます。そのため、アプリが標準の Google API クライアント ライブラリを使用して Trusted Cloud サービスにアクセスする場合は、その IAM サービス アカウントを使用します。

        kubectl annotate serviceaccount sax-sa \
          iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.s3ns-system.iam.gserviceaccount.com
        

      Saxml をデプロイする

      このセクションでは、Saxml 管理者サーバーと Saxml モデルサーバーをデプロイします。

      Saxml 管理者サーバーをデプロイする

      1. 次の sax-admin-server.yaml マニフェストを作成します。

        apiVersion: apps/v1
        kind: Deployment
        metadata:
          name: sax-admin-server
        spec:
          replicas: 1
          selector:
            matchLabels:
              app: sax-admin-server
          template:
            metadata:
              labels:
                app: sax-admin-server
            spec:
              hostNetwork: false
              serviceAccountName: sax-sa
              containers:
              - name: sax-admin-server
                image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
                securityContext:
                  privileged: true
                ports:
                - containerPort: 10000
                env:
                - name: GSBUCKET
                  value: BUCKET_NAME
      2. BUCKET_NAME を、前に作成した Cloud Storage に置き換えます。

        perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-admin-server.yaml
        
      3. 次のようにマニフェストを適用します。

        kubectl apply -f sax-admin-server.yaml
        
      4. 管理者サーバーの Pod が稼働していることを確認します。

        kubectl get deployment
        

        出力は次のようになります。

        NAME               READY   UP-TO-DATE   AVAILABLE   AGE
        sax-admin-server   1/1     1            1           52s
        

      Saxml モデルサーバーをデプロイする

      マルチホスト TPU スライスで実行されるワークロードでは、同じ TPU スライス内のピアを検出するために、各 Pod に安定したネットワーク識別子が必要です。これらの識別子を定義するには、IndexedJobStatefulSet ヘッドレス Service、または JobSet を使用します。JobSet を使用すると、それに属するすべての Job に対してヘッドレス Service が自動的に作成されます。Jobset は、Kubernetes Job のグループをユニットとして管理できるワークロード API です。JobSet の最も一般的なユースケースは分散トレーニングですが、バッチ ワークロードの実行にも使用できます。

      次のセクションでは、JobSet を使用してモデルサーバー Pod の複数のグループを管理する方法について説明します。

      1. v0.2.3 以降の JobSet をインストールします。

        kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
        

        JOBSET_VERSION は、JobSet のバージョンに置き換えます。例: v0.2.3

      2. JobSet コントローラが jobset-system Namespace で実行されていることを確認します。

        kubectl get pod -n jobset-system
        

        出力は次のようになります。

        NAME                                        READY   STATUS    RESTARTS   AGE
        jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
        
      3. 2 つの TPU スライス ノードプールに 2 つのモデルサーバーをデプロイします。次の sax-model-server-set マニフェストを保存します。

        apiVersion: jobset.x-k8s.io/v1alpha2
        kind: JobSet
        metadata:
          name: sax-model-server-set
          annotations:
            alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          failurePolicy:
            maxRestarts: 4
          replicatedJobs:
            - name: sax-model-server
              replicas: 2
              template:
                spec:
                  parallelism: 8
                  completions: 8
                  backoffLimit: 0
                  template:
                    spec:
                      serviceAccountName: sax-sa
                      hostNetwork: true
                      dnsPolicy: ClusterFirstWithHostNet
                      nodeSelector:
                        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                        cloud.google.com/gke-tpu-topology: 4x8
                      containers:
                      - name: sax-model-server
                        image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                        args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                        ports:
                        - containerPort: 10001
                        - containerPort: 8471
                        securityContext:
                          privileged: true
                        env:
                        - name: SAX_ROOT
                          value: "gs://BUCKET_NAME/sax-root"
                        - name: MEGASCALE_NUM_SLICES
                          value: ""
                        resources:
                          requests:
                            google.com/tpu: 4
                          limits:
                            google.com/tpu: 4
      4. BUCKET_NAME を、前に作成した Cloud Storage に置き換えます。

        perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-model-server-set.yaml
        

        このマニフェストの内容:

        • replicas: 2 は、Job のレプリカの数です。各ジョブはモデルサーバーを表します。したがって、8 つの Pod のグループになります。
        • parallelism: 8completions: 8 は、各ノードプール内のノード数と等しくなります。
        • Pod が失敗した場合に Job を失敗としてマークするには、backoffLimit: 0 を 0 にする必要があります。
        • ports.containerPort: 8471 は、VM 通信用のデフォルト ポートです。
        • GKE はマルチスライス トレーニングを実行していないため、name: MEGASCALE_NUM_SLICES は環境変数の設定を解除します。
      5. 次のようにマニフェストを適用します。

        kubectl apply -f sax-model-server-set.yaml
        
      6. Saxml 管理サーバーと Model Server Pod のステータスを確認します。

        kubectl get pods
        

        出力は次のようになります。

        NAME                                              READY   STATUS    RESTARTS   AGE
        sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
        sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
        sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
        

      この例では、16 個のモデルサーバー コンテナがあります。sax-model-server-set-sax-model-server-0-0-nj4smsax-model-server-set-sax-model-server-1-0-8mlf5 は、各グループの 2 つのプライマリ モデルサーバーです。

      Saxml クラスタには、それぞれ 4x8 トポロジを持つ 2 つの v5e TPU スライス ノードプールにデプロイされた 2 つのモデルサーバーがあります。

      Saxml HTTP Server とロードバランサをデプロイする

      1. 次のビルド済みイメージの HTTP サーバー イメージを使用します。次の sax-http.yaml マニフェストを保存します。

        apiVersion: apps/v1
        kind: Deployment
        metadata:
          name: sax-http
        spec:
          replicas: 1
          selector:
            matchLabels:
              app: sax-http
          template:
            metadata:
              labels:
                app: sax-http
            spec:
              hostNetwork: false
              serviceAccountName: sax-sa
              containers:
              - name: sax-http
                image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
                ports:
                - containerPort: 8888
                env:
                - name: SAX_ROOT
                  value: "gs://BUCKET_NAME/sax-root"
        ---
        apiVersion: v1
        kind: Service
        metadata:
          name: sax-http-lb
        spec:
          selector:
            app: sax-http
          ports:
          - protocol: TCP
            port: 8888
            targetPort: 8888
          type: LoadBalancer
      2. BUCKET_NAME を、前に作成した Cloud Storage に置き換えます。

        perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-http.yaml
        
      3. sax-http.yaml マニフェストを適用します。

        kubectl apply -f sax-http.yaml
        
      4. HTTP サーバー コンテナの作成が完了するまで待ちます。

        kubectl get pods
        

        出力は次のようになります。

        NAME                                              READY   STATUS    RESTARTS   AGE
        sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
        sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
        sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
        ...
        
      5. Service に外部 IP アドレスが割り当てられるまで待ちます。

        kubectl get svc
        

        出力は次のようになります。

        NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
        sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
        

      Saxml を使用する

      v5e TPU マルチホスト スライスの Saxml でモデルを読み込んでデプロイし、提供します。

      モデルを読み込む

      1. Saxml のロードバランサの IP アドレスを取得します。

        LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
        PORT="8888"
        
      2. 2 つの v5e TPU スライス ノードプールに LmCloudSpmd175B テストモデルを読み込みます。

        curl --request POST \
        --header "Content-type: application/json" \
        -s ${LB_IP}:${PORT}/publish --data \
        '{
            "model": "/sax/test/spmd",
            "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
            "checkpoint": "None",
            "replicas": 2
        }'
        

        テストモデルにはファインチューニングされたチェックポイントがなく、重みはランダムに生成されます。モデルの読み込みには最大 10 分かかります。

        出力は次のようになります。

        {
            "model": "/sax/test/spmd",
            "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
            "checkpoint": "None",
            "replicas": 2
        }
        
      3. モデルの準備状況を確認します。

        kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
        

        出力は次のようになります。

        ...
        loading completed.
        Successfully loaded model for key: /sax/test/spmd
        

        モデルが完全に読み込まれました。

      4. モデルに関する情報を取得します。

        curl --request GET \
        --header "Content-type: application/json" \
        -s ${LB_IP}:${PORT}/listcell --data \
        '{
            "model": "/sax/test/spmd"
        }'
        

        出力は次のようになります。

        {
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "max_replicas": 2,
        "active_replicas": 2
        }
        

      モデルを提供する

      プロンプト リクエストを処理します。

      curl --request POST \
      --header "Content-type: application/json" \
      -s ${LB_IP}:${PORT}/generate --data \
      '{
        "model": "/sax/test/spmd",
        "query": "How many days are in a week?"
      }'
      

      出力には、モデルのレスポンスの例が表示されます。テストモデルにはランダムな重みがあるため、このレスポンスは意味をなさない可能性があります。

      モデルの公開を停止する

      次のコマンドを実行して、モデルを非公開にします。

      curl --request POST \
      --header "Content-type: application/json" \
      -s ${LB_IP}:${PORT}/unpublish --data \
      '{
          "model": "/sax/test/spmd"
      }'
      

      出力は次のようになります。

      {
        "model": "/sax/test/spmd"
      }
      

      クリーンアップ

      このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

      デプロイされたリソースを削除する

      1. このチュートリアル用に作成したクラスタを削除します。

        gcloud container clusters delete saxml --zone ${ZONE}
        
      2. サービス アカウントを削除します。

        gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.s3ns-system.iam.gserviceaccount.com
        
      3. Cloud Storage バケットを削除します。

        gcloud storage rm -r gs://${BUCKET_NAME}
        

      次のステップ