This example shows how to easily deploy JAX large language models in a multi-node environment using PyTriton. In this tutorial we will be working with HuggingFace OPT with up to 530B parameters.
To run JAX in multi-GPU and/or multi-node environment we are
using jax.distributed
and jax.experimental.pjit modules. To learn
more about using pjit
and jax.distributed
for running multi-node models please visit JAX docs.
Example consists of following scripts:
- server.py - this file runs the Triton server (with
--rank 0
) or JAX worker (with--host_idx
greater than 0) on each node. It contains the code that distributes the inputs from the server to the workers. - client.py - example of a simple client that calls the server with a single sample.
- opt_utils.py - lower level code used by server.py. In this file we define functions that create a sharding strategy, copy model parameters from the cpu into multiple devices and run inference.
- modeling_flax_opt.py - slightly modified HuggingFace file with OPT model definition. The main difference is that in the HuggingFace repository the model is initialized with FP32 weights even when the operations are in FP16. In our file we use FP16 for both storing parameters and performing operations.
And configurations:
kubernetes
- example Helm Charts for serving and test inference in Kubernetes cluster
Below you can find a list of available models:
model name | pretrained | source |
---|---|---|
facebook/opt-125m | True | HuggingFace |
facebook/opt-350m | True | HuggingFace |
facebook/opt-1.3b | True | HuggingFace |
facebook/opt-2.7b | True | HuggingFace |
facebook/opt-6.7b | True | HuggingFace |
facebook/opt-13b | True | HuggingFace |
facebook/opt-30b | True | HuggingFace |
facebook/opt-66b | True | HuggingFace |
random/125M | False | |
random/350M | False | |
random/1.3B | False | |
random/2.7B | False | |
random/5B | False | |
random/6.7B | False | |
random/13B | False | |
random/20B | False | |
random/30B | False | |
random/66B | False | |
random/89B | False | |
random/17B | False | |
random/310B | False | |
random/530B | False |
In this section we describe running the JAX on multi-GPU and/or multi-node environment where manual setup of environment is done on each node.
Each node must meet following requirements:
- NVIDIA Container Toolkit.
- NVIDIA Driver based on chosen version of framework container
The easiest way of running this example is inside a nvcr.io TensorFlow2
container. Example Dockerfile
that can be used to run the server:
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:24.10-tf2-py3
FROM ${FROM_IMAGE_NAME}
ENV XLA_PYTHON_CLIENT_PREALLOCATE=false
ENV NCCL_LAUNCH_MODE="PARALLEL"
WORKDIR /workdir
COPY install.sh .
RUN ./install.sh
RUN pip install <pytriton package>
COPY . .
On each node we have to build the image (or download it from a registry).
docker build -t jax-llm:latest .
On each node run:
docker run --net host --rm --gpus all jax-llm python server.py \
--head-url "<head node IP>:<port (e.g. 1234)>" \
--number-of-nodes <number of nodes> \
--rank <current node index, head node has index 0> \
--model-name <model_name> \
--number-of-gpus <number of GPUs per node>
The server expects two inputs:
input
- string array of shape (batch_size
, 1),output_length
- int64 array of shape (batch-size
, 1).
It returns a sing output:
output
- string array of shape (batch_size
, 1).
To read more about Triton server please visit Triton docs.
To use our example client run on any machine:
docker run --net host jax-llm python client.py \
--server-url "http://<head node IP>:8000" \
--input "<input text>" \
--output-length <output length>
This section describe how to server the JAX model on Kubernetes cluster. The following prerequisites must be matched to run the example:
- Kubernetes cluster with NVIDIA GPU node
- NVIDIA Device Plugin installed in Kubernetes cluster
- Docker Containers Registry accessible from Kubernetes cluster
- Installed Helm for creating the deployment and test job
Optionally you may install NVIDIA Container Toolkit and NVIDIA GPU Operator which enable more features like MIG or Time Slicing support in the cluster. To learn more how to set up Kubernetes cluster with NVIDIA GPU you can review NVIDIA Cloud Native Documentation
Below, we present a step-by-step guide assuming that all the commands are executed from the root of repository.
Follow these steps to run and test example in the cluster:
- [Optional] Build PyTriton wheel following the build instruction
- Prepare the tag under which image is going to be pushed to your Docker Containers Registry accessible from Kubernetes cluster. Example for local cluster (minikube, k3s) with registry hosted inside the cluster:
export DOCKER_IMAGE_NAME_WITH_TAG=localhost:5000/jax-example:latest
- Build and push the Docker container image to your registry:
# Export the base image used for build. We use TensorFlow image for JAX
export FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:24.10-tf2-py3
./examples/huggingface_opt_multinode_jax/kubernetes/build_and_push.sh
Note: By default the container is built using pytriton
package from pypi.org. To build container with wheel built
locally use export BUILD_FROM=dist
before executing script.
- At this point there are 2 options to deploy the model depending on the size of the model: a) Install the Helm Chart with deployment and service for single-node:
helm upgrade -i --set deployment.image=${DOCKER_IMAGE_NAME_WITH_TAG} \
--set deployment.numOfGPUs=1 \
jax-example \
./examples/huggingface_opt_multinode_jax/kubernetes/single-node
b) Install the Helm Chart with deployment and service for multi-node:
Important: Running multi-node requires to create Persistent Volume Claim in the cluster shared between PODs. You can pass name as argument to Helm Chart during installation. Read more how to create Persistent Volume Claim.
Please note: The multi-node deployment for scaling requires improved configuration of services and load balancing.
helm upgrade -i --set statefulset.image=${DOCKER_IMAGE_NAME_WITH_TAG} \
--set statefulset.persistentVolumeClaim=llm-cache-pvc \
--set statefulset.numOfNodes=3 \
--set statefulset.numOfGPUs=1 \
jax-example \
./examples/huggingface_opt_multinode_jax/kubernetes/multi-node
- Install the Helm Chart with client test
helm install --set image=${DOCKER_IMAGE_NAME_WITH_TAG} \
jax-example-test \
./examples/huggingface_opt_multinode_jax/kubernetes/test
Now, you can review the logs from the running PODs to verify the inference is running. To show the logs from cluster for given POD first list all running pods:
kubectl get pods
Next show logs from server or client:
kubectl logs {NAME}
To remove the installed charts simply run:
helm uninstall jax-example-test
helm uninstall jax-example
This section describe how to create Persistent Volume Claim in Kuberenetes cluster using CSI or NFS drive.
When you are running on local machine (ex. Minikube or k3s) you can use CSI host path to create a persistent volume claim. Make sure that appropriate extension for your cluster has been installed and run:
kubectl apply -f ./examples/huggingface_opt_multinode_jax/kubernetes/persistent-volume-claim-csi.yaml
When you are running Kubernetes cluster in Cloud Service Provider you can create persistent volume claim using NFS disk.
First, create the NFS disk and obtain its IP address. Make sure the disk is in the same network as Kubernetes cluster.
The pre-defined file share name for the NFS storage is llm
.
Next modify the ./examples/huggingface_opt_multinode_jax/kubernetes/persistent-volume-claim-nfs.yaml
file and update the
{IP}
value. Then run:
kubectl apply -f ./examples/huggingface_opt_multinode_jax/kubernetes/persistent-volume-nfs.yaml
Once the persistent volume is ready the claim can be created using:
kubectl apply -f ./examples/huggingface_opt_multinode_jax/kubernetes/persistent-volume-claim-nfs.yaml