Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions jax-inference-offloading/examples/example-standalone.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standalone example: weight transfer + rollout generation without Tunix.
# This script launches the gateway, vLLM rollout worker, and trainer_standalone.py.

set -euo pipefail

DIR="$(dirname "$0")"
JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR:-/tmp/jax-compilation-cache}
mkdir -p ${JAX_COMPILATION_CACHE_DIR}

# Set default values
DEBUG="false"
OUTPUT_DIR=${OUTPUT_DIR:-$(mktemp -d)}

# Model configuration
MODEL_NAME=""
MODEL_PATH=""
PARAM_MAPPING_PATH=""

# Transfer mode
TRANSFER_MODE=""

# vLLM runtime
VLLM_ENFORCE_EAGER="0"
VLLM_GPU_MEMORY_UTILIZATION="0.9"

# Debug-only: use dummy weights for JAX model
USE_DUMMY_WEIGHT="0"

# Device assignment
N_GPUS_VLLM="4"
N_GPUS_JAX="4"

# Gateway
GATEWAY_PORT="50051"

# Parse command line arguments
while [[ $# -gt 0 ]]; do
case "$1" in
# General
--debug)
DEBUG="true"
shift
;;
--output-dir=*)
OUTPUT_DIR="${1#*=}"
shift
;;
# Model configuration
--model-name=*)
MODEL_NAME="${1#*=}"
shift
;;
--model-path=*)
MODEL_PATH="${1#*=}"
shift
;;
--param-mapping-path=*)
PARAM_MAPPING_PATH="${1#*=}"
shift
;;
# Transfer mode
--transfer-mode=*)
TRANSFER_MODE="${1#*=}"
shift
;;
# vLLM runtime
--vllm-enforce-eager)
VLLM_ENFORCE_EAGER="1"
shift
;;
--no-vllm-enforce-eager)
VLLM_ENFORCE_EAGER="0"
shift
;;
--vllm-gpu-memory-utilization=*)
VLLM_GPU_MEMORY_UTILIZATION="${1#*=}"
shift
;;
--use-dummy-weight)
USE_DUMMY_WEIGHT="1"
shift
;;
# Device assignment
--n-gpus-vllm=*)
N_GPUS_VLLM="${1#*=}"
shift
;;
--n-gpus-jax=*)
N_GPUS_JAX="${1#*=}"
shift
;;
# Gateway
--gateway-port=*)
GATEWAY_PORT="${1#*=}"
shift
;;
--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Standalone example: weight transfer + rollout generation."
echo ""
echo "This example uses Tunix for model loading, but trainer_standalone.py"
echo "demonstrates how to integrate with custom RL frameworks. See the comments"
echo "in trainer_standalone.py for guidance on replacing the model loading code."
echo ""
echo "Options:"
echo " --debug Enable debug mode with verbose logging."
echo " --output-dir=DIR Directory to save logs and outputs. Default is a temporary directory."
echo ""
echo " --model-name=NAME HF model name (required by Tunix's loader for architecture selection)."
echo " --model-path=PATH HF snapshot directory containing model weights."
echo " --param-mapping-path=PATH Path to JSON param mapping file (optional, uses hardcoded if not set)."
echo ""
echo " --transfer-mode=MODE Transfer mode for trainer->vLLM weights (grouped/stacked/fused/unfused)."
echo ""
echo " --vllm-enforce-eager Force vLLM eager mode (sets VLLM_ENFORCE_EAGER=1)."
echo " --no-vllm-enforce-eager Disable vLLM eager mode (sets VLLM_ENFORCE_EAGER=0)."
echo " --vllm-gpu-memory-utilization=FLOAT vLLM GPU memory utilization (e.g., 0.7)."
echo " --use-dummy-weight Use randomly initialized JAX weights (DEBUG ONLY)."
echo ""
echo " --n-gpus-vllm=N Number of GPUs for vLLM (default: 4)."
echo " --n-gpus-jax=N Number of GPUs for JAX (default: 4)."
echo ""
echo " --gateway-port=PORT gRPC gateway port (default: 50051)."
echo " --help Show this help message and exit."
exit 0
;;
*)
echo "Unknown argument: $1"
shift
;;
esac
done

# Model selection default
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-1B-Instruct"}

# ------------------------------------------------------------------------------
# Kill all processes when done.
# ------------------------------------------------------------------------------
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT

# ------------------------------------------------------------------------------
# load environment variables from .env file
# ------------------------------------------------------------------------------
if [[ -f "${PWD}/.env" ]]; then
echo "Loading ${PWD}/.env"
set -a && source "${PWD}/.env" && set +a
else
echo ".env not found in ${PWD}, skipping"
fi

# ------------------------------------------------------------------------------
# Ensure model is already present on disk (download only when using real weights)
# ------------------------------------------------------------------------------

if [[ -z "${HF_TOKEN:-}" ]]; then
echo "HF_TOKEN is not set. Please set it in the .env file or export it."
fi

if [[ "${USE_DUMMY_WEIGHT}" == "1" ]]; then
echo "Using dummy weights for JAX model (DEBUG ONLY)."
MODEL_PATH=
else
if [[ -n "${MODEL_PATH:-}" ]]; then
echo "Using provided MODEL_PATH: ${MODEL_PATH}"
else
echo "MODEL_PATH not provided, downloading HF snapshot..."
MODEL_PATH=$(python "${DIR}/download_model.py" --hub=hf --model="${MODEL_NAME}" --ignore="*.pth")
fi
fi

# ------------------------------------------------------------------------------
# assign GPUs to vLLM and JAX
# ------------------------------------------------------------------------------
N_GPUS=$((N_GPUS_VLLM + N_GPUS_JAX))

# Derive CUDA_VISIBLE_DEVICES_ARRAY
if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then
CUDA_VISIBLE_DEVICES_ARRAY=($(seq 0 $((N_GPUS - 1))))
else
IFS=',' read -r -a CUDA_VISIBLE_DEVICES_ARRAY <<< "$CUDA_VISIBLE_DEVICES"
fi

VLLM_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:0:N_GPUS_VLLM}")
JAX_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:N_GPUS_VLLM:N_GPUS}")

# ------------------------------------------------------------------------------
# common environment
# ------------------------------------------------------------------------------
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_DEVICE_MAX_CONNECTIONS=16
export NCCL_BUFFSIZE=16777216
export GATEWAY_PORT
export GATEWAY_URL="localhost:${GATEWAY_PORT}"
export MODEL_NAME
export MODEL_PATH
export PARAM_MAPPING_PATH
export USE_DUMMY_WEIGHT
export VLLM_ENFORCE_EAGER
export VLLM_GPU_MEMORY_UTILIZATION
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL
--xla_gpu_collective_permute_combine_threshold_bytes=8589934592
--xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
--xla_gpu_all_gather_combine_threshold_bytes=8589934592
--xla_gpu_all_reduce_combine_threshold_bytes=8589934592"
if [[ -n "${TRANSFER_MODE:-}" ]]; then
export TRANSFER_MODE
fi

if [ "$DEBUG" == "true" ]; then
set -x
export TF_CPP_MIN_LOG_LEVEL=0
export NCCL_DEBUG=INFO # Enable NCCL debug logs
else
export TF_CPP_MIN_LOG_LEVEL=2 # Suppress TensorFlow debug logs
export VLLM_CONFIGURE_LOGGING=0 # Suppress vLLM logging
fi

PIDS=()

mkdir -p "${OUTPUT_DIR}"
echo "Logs will be saved to: ${OUTPUT_DIR}"

# ------------------------------------------------------------------------------
# Launch components
# ------------------------------------------------------------------------------

# Gateway server (no GPU)
CUDA_VISIBLE_DEVICES= \
python "${DIR}/../jax_inference_offloading/controller/gateway.py" 2>&1 | tee "${OUTPUT_DIR}/gateway.log" &
PIDS+=($!)

# vLLM rollout worker
CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${VLLM_GPU_ARRAY[*]}") \
MODEL_NAME=${MODEL_PATH:-$MODEL_NAME} \
python "${DIR}/rollout.py" 2>&1 | tee "${OUTPUT_DIR}/rollout.log" &
PIDS+=($!)

# Standalone trainer (weight transfer + generation demo)
CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${JAX_GPU_ARRAY[*]}") \
JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR} \
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=0.1 \
python "${DIR}/trainer_standalone.py" 2>&1 | tee "${OUTPUT_DIR}/trainer.log" &
PIDS+=($!)

wait "${PIDS[@]}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"num_layers": 16,
"mappings": [
{
"jax_param": {
"name": "model.embedder.input_embedding"
},
"vllm_param": {
"name": "model.embed_tokens.weight",
"shape": [128256, 2048]
}
},
{
"jax_param": {
"name": "model.final_norm.w"
},
"vllm_param": {
"name": "model.norm.weight",
"shape": [2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.input_layernorm.w"
},
"vllm_param": {
"name": "model.layers.{layer}.input_layernorm.weight",
"shape": [2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.post_attention_layernorm.w"
},
"vllm_param": {
"name": "model.layers.{layer}.post_attention_layernorm.weight",
"shape": [2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.mlp.gate_proj.kernel",
"transform": {
"transpose": [1, 0]
}
},
"vllm_param": {
"name": "model.layers.{layer}.mlp.gate_proj.weight",
"shape": [8192, 2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.mlp.up_proj.kernel",
"transform": {
"transpose": [1, 0]
}
},
"vllm_param": {
"name": "model.layers.{layer}.mlp.up_proj.weight",
"shape": [8192, 2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.mlp.down_proj.kernel",
"transform": {
"transpose": [1, 0]
}
},
"vllm_param": {
"name": "model.layers.{layer}.mlp.down_proj.weight",
"shape": [2048, 8192]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.attn.q_proj.w",
"transform": {
"transpose": [1, 2, 0],
"reshape": [-1, 2048]
}
},
"vllm_param": {
"name": "model.layers.{layer}.self_attn.q_proj.weight",
"shape": [2048, 2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.attn.k_proj.w",
"transform": {
"transpose": [1, 2, 0],
"reshape": [-1, 2048],
"replication_axis": 2
}
},
"vllm_param": {
"name": "model.layers.{layer}.self_attn.k_proj.weight",
"shape": [512, 2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.attn.v_proj.w",
"transform": {
"transpose": [1, 2, 0],
"reshape": [-1, 2048],
"replication_axis": 2
}
},
"vllm_param": {
"name": "model.layers.{layer}.self_attn.v_proj.weight",
"shape": [512, 2048]
}
},
{
"jax_param": {
"name": "model.layers.{layer}.attn.o_proj.w",
"transform": {
"transpose": [2, 0, 1],
"reshape": [2048, -1]
}
},
"vllm_param": {
"name": "model.layers.{layer}.self_attn.o_proj.weight",
"shape": [2048, 2048]
}
}
]
}
Loading