diff --git a/jax-inference-offloading/examples/example-standalone.sh b/jax-inference-offloading/examples/example-standalone.sh new file mode 100755 index 000000000..9c4a79822 --- /dev/null +++ b/jax-inference-offloading/examples/example-standalone.sh @@ -0,0 +1,265 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standalone example: weight transfer + rollout generation without Tunix. +# This script launches the gateway, vLLM rollout worker, and trainer_standalone.py. + +set -euo pipefail + +DIR="$(dirname "$0")" +JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR:-/tmp/jax-compilation-cache} +mkdir -p ${JAX_COMPILATION_CACHE_DIR} + +# Set default values +DEBUG="false" +OUTPUT_DIR=${OUTPUT_DIR:-$(mktemp -d)} + +# Model configuration +MODEL_NAME="" +MODEL_PATH="" +PARAM_MAPPING_PATH="" + +# Transfer mode +TRANSFER_MODE="" + +# vLLM runtime +VLLM_ENFORCE_EAGER="0" +VLLM_GPU_MEMORY_UTILIZATION="0.9" + +# Debug-only: use dummy weights for JAX model +USE_DUMMY_WEIGHT="0" + +# Device assignment +N_GPUS_VLLM="4" +N_GPUS_JAX="4" + +# Gateway +GATEWAY_PORT="50051" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case "$1" in + # General + --debug) + DEBUG="true" + shift + ;; + --output-dir=*) + OUTPUT_DIR="${1#*=}" + shift + ;; + # Model configuration + --model-name=*) + MODEL_NAME="${1#*=}" + shift + ;; + --model-path=*) + MODEL_PATH="${1#*=}" + shift + ;; + --param-mapping-path=*) + PARAM_MAPPING_PATH="${1#*=}" + shift + ;; + # Transfer mode + --transfer-mode=*) + TRANSFER_MODE="${1#*=}" + shift + ;; + # vLLM runtime + --vllm-enforce-eager) + VLLM_ENFORCE_EAGER="1" + shift + ;; + --no-vllm-enforce-eager) + VLLM_ENFORCE_EAGER="0" + shift + ;; + --vllm-gpu-memory-utilization=*) + VLLM_GPU_MEMORY_UTILIZATION="${1#*=}" + shift + ;; + --use-dummy-weight) + USE_DUMMY_WEIGHT="1" + shift + ;; + # Device assignment + --n-gpus-vllm=*) + N_GPUS_VLLM="${1#*=}" + shift + ;; + --n-gpus-jax=*) + N_GPUS_JAX="${1#*=}" + shift + ;; + # Gateway + --gateway-port=*) + GATEWAY_PORT="${1#*=}" + shift + ;; + --help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Standalone example: weight transfer + rollout generation." + echo "" + echo "This example uses Tunix for model loading, but trainer_standalone.py" + echo "demonstrates how to integrate with custom RL frameworks. See the comments" + echo "in trainer_standalone.py for guidance on replacing the model loading code." + echo "" + echo "Options:" + echo " --debug Enable debug mode with verbose logging." + echo " --output-dir=DIR Directory to save logs and outputs. Default is a temporary directory." + echo "" + echo " --model-name=NAME HF model name (required by Tunix's loader for architecture selection)." + echo " --model-path=PATH HF snapshot directory containing model weights." + echo " --param-mapping-path=PATH Path to JSON param mapping file (optional, uses hardcoded if not set)." + echo "" + echo " --transfer-mode=MODE Transfer mode for trainer->vLLM weights (grouped/stacked/fused/unfused)." + echo "" + echo " --vllm-enforce-eager Force vLLM eager mode (sets VLLM_ENFORCE_EAGER=1)." + echo " --no-vllm-enforce-eager Disable vLLM eager mode (sets VLLM_ENFORCE_EAGER=0)." + echo " --vllm-gpu-memory-utilization=FLOAT vLLM GPU memory utilization (e.g., 0.7)." + echo " --use-dummy-weight Use randomly initialized JAX weights (DEBUG ONLY)." + echo "" + echo " --n-gpus-vllm=N Number of GPUs for vLLM (default: 4)." + echo " --n-gpus-jax=N Number of GPUs for JAX (default: 4)." + echo "" + echo " --gateway-port=PORT gRPC gateway port (default: 50051)." + echo " --help Show this help message and exit." + exit 0 + ;; + *) + echo "Unknown argument: $1" + shift + ;; + esac +done + +# Model selection default +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-1B-Instruct"} + +# ------------------------------------------------------------------------------ +# Kill all processes when done. +# ------------------------------------------------------------------------------ +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT + +# ------------------------------------------------------------------------------ +# load environment variables from .env file +# ------------------------------------------------------------------------------ +if [[ -f "${PWD}/.env" ]]; then + echo "Loading ${PWD}/.env" + set -a && source "${PWD}/.env" && set +a +else + echo ".env not found in ${PWD}, skipping" +fi + +# ------------------------------------------------------------------------------ +# Ensure model is already present on disk (download only when using real weights) +# ------------------------------------------------------------------------------ + +if [[ -z "${HF_TOKEN:-}" ]]; then + echo "HF_TOKEN is not set. Please set it in the .env file or export it." +fi + +if [[ "${USE_DUMMY_WEIGHT}" == "1" ]]; then + echo "Using dummy weights for JAX model (DEBUG ONLY)." + MODEL_PATH= +else + if [[ -n "${MODEL_PATH:-}" ]]; then + echo "Using provided MODEL_PATH: ${MODEL_PATH}" + else + echo "MODEL_PATH not provided, downloading HF snapshot..." + MODEL_PATH=$(python "${DIR}/download_model.py" --hub=hf --model="${MODEL_NAME}" --ignore="*.pth") + fi +fi + +# ------------------------------------------------------------------------------ +# assign GPUs to vLLM and JAX +# ------------------------------------------------------------------------------ +N_GPUS=$((N_GPUS_VLLM + N_GPUS_JAX)) + +# Derive CUDA_VISIBLE_DEVICES_ARRAY +if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then + CUDA_VISIBLE_DEVICES_ARRAY=($(seq 0 $((N_GPUS - 1)))) +else + IFS=',' read -r -a CUDA_VISIBLE_DEVICES_ARRAY <<< "$CUDA_VISIBLE_DEVICES" +fi + +VLLM_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:0:N_GPUS_VLLM}") +JAX_GPU_ARRAY=("${CUDA_VISIBLE_DEVICES_ARRAY[@]:N_GPUS_VLLM:N_GPUS}") + +# ------------------------------------------------------------------------------ +# common environment +# ------------------------------------------------------------------------------ +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export CUDA_DEVICE_MAX_CONNECTIONS=16 +export NCCL_BUFFSIZE=16777216 +export GATEWAY_PORT +export GATEWAY_URL="localhost:${GATEWAY_PORT}" +export MODEL_NAME +export MODEL_PATH +export PARAM_MAPPING_PATH +export USE_DUMMY_WEIGHT +export VLLM_ENFORCE_EAGER +export VLLM_GPU_MEMORY_UTILIZATION +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL + --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 + --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 + --xla_gpu_all_gather_combine_threshold_bytes=8589934592 + --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" +if [[ -n "${TRANSFER_MODE:-}" ]]; then + export TRANSFER_MODE +fi + +if [ "$DEBUG" == "true" ]; then + set -x + export TF_CPP_MIN_LOG_LEVEL=0 + export NCCL_DEBUG=INFO # Enable NCCL debug logs +else + export TF_CPP_MIN_LOG_LEVEL=2 # Suppress TensorFlow debug logs + export VLLM_CONFIGURE_LOGGING=0 # Suppress vLLM logging +fi + +PIDS=() + +mkdir -p "${OUTPUT_DIR}" +echo "Logs will be saved to: ${OUTPUT_DIR}" + +# ------------------------------------------------------------------------------ +# Launch components +# ------------------------------------------------------------------------------ + +# Gateway server (no GPU) +CUDA_VISIBLE_DEVICES= \ +python "${DIR}/../jax_inference_offloading/controller/gateway.py" 2>&1 | tee "${OUTPUT_DIR}/gateway.log" & +PIDS+=($!) + +# vLLM rollout worker +CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${VLLM_GPU_ARRAY[*]}") \ +MODEL_NAME=${MODEL_PATH:-$MODEL_NAME} \ +python "${DIR}/rollout.py" 2>&1 | tee "${OUTPUT_DIR}/rollout.log" & +PIDS+=($!) + +# Standalone trainer (weight transfer + generation demo) +CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${JAX_GPU_ARRAY[*]}") \ +JAX_COMPILATION_CACHE_DIR=${JAX_COMPILATION_CACHE_DIR} \ +JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=0.1 \ +python "${DIR}/trainer_standalone.py" 2>&1 | tee "${OUTPUT_DIR}/trainer.log" & +PIDS+=($!) + +wait "${PIDS[@]}" diff --git a/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json b/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json new file mode 100644 index 000000000..c13aeb9fc --- /dev/null +++ b/jax-inference-offloading/examples/mappings/llama3_1b_param_mapping.json @@ -0,0 +1,131 @@ +{ + "num_layers": 16, + "mappings": [ + { + "jax_param": { + "name": "model.embedder.input_embedding" + }, + "vllm_param": { + "name": "model.embed_tokens.weight", + "shape": [128256, 2048] + } + }, + { + "jax_param": { + "name": "model.final_norm.w" + }, + "vllm_param": { + "name": "model.norm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.input_layernorm.w" + }, + "vllm_param": { + "name": "model.layers.{layer}.input_layernorm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.post_attention_layernorm.w" + }, + "vllm_param": { + "name": "model.layers.{layer}.post_attention_layernorm.weight", + "shape": [2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.mlp.gate_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.gate_proj.weight", + "shape": [8192, 2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.mlp.up_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.up_proj.weight", + "shape": [8192, 2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.mlp.down_proj.kernel", + "transform": { + "transpose": [1, 0] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.mlp.down_proj.weight", + "shape": [2048, 8192] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.attn.q_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.q_proj.weight", + "shape": [2048, 2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.attn.k_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048], + "replication_axis": 2 + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.k_proj.weight", + "shape": [512, 2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.attn.v_proj.w", + "transform": { + "transpose": [1, 2, 0], + "reshape": [-1, 2048], + "replication_axis": 2 + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.v_proj.weight", + "shape": [512, 2048] + } + }, + { + "jax_param": { + "name": "model.layers.{layer}.attn.o_proj.w", + "transform": { + "transpose": [2, 0, 1], + "reshape": [2048, -1] + } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.o_proj.weight", + "shape": [2048, 2048] + } + } + ] +} diff --git a/jax-inference-offloading/examples/rollout.py b/jax-inference-offloading/examples/rollout.py index 61e91df50..00b069cf4 100644 --- a/jax-inference-offloading/examples/rollout.py +++ b/jax-inference-offloading/examples/rollout.py @@ -40,6 +40,8 @@ def main(): model_name = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") model_path = os.environ.get("MODEL_PATH", None) model = model_path or model_name + # Optional: path to custom param_mapping.json for JAX-to-vLLM parameter mapping + param_mapping_path = os.environ.get("PARAM_MAPPING_PATH", None) logging.basicConfig(level=logging.INFO) @@ -58,7 +60,7 @@ def main(): # subscribe to control messages from the gateway rollout_client = make_rollout_client(gateway_url) - rollout_client.subscribe_to_control_messages(llm) + rollout_client.subscribe_to_control_messages(llm, mapping_json_path=param_mapping_path) if __name__ == "__main__": diff --git a/jax-inference-offloading/examples/trainer.py b/jax-inference-offloading/examples/trainer.py index 63565e7bd..0687bfff6 100644 --- a/jax-inference-offloading/examples/trainer.py +++ b/jax-inference-offloading/examples/trainer.py @@ -29,7 +29,7 @@ from jax_inference_offloading.jax import OffloadingBridge from jax_inference_offloading.sharding import PolymorphicMesh from jax_inference_offloading.timer import Timer -from jax_inference_offloading.tunix.load_model import load_model +from jax_inference_offloading.integrations.tunix.load_model import load_model from jax_inference_offloading.models import get_named_parameters # logging.basicConfig(level=logging.INFO) diff --git a/jax-inference-offloading/examples/trainer_grpo.py b/jax-inference-offloading/examples/trainer_grpo.py index c7dca3811..ff36eb66b 100644 --- a/jax-inference-offloading/examples/trainer_grpo.py +++ b/jax-inference-offloading/examples/trainer_grpo.py @@ -35,8 +35,8 @@ from tunix.rl.rollout import base_rollout from jax_inference_offloading.timer import Timer -from jax_inference_offloading.tunix.load_model import load_model -from jax_inference_offloading.tunix.rollout import VllmGPURollout +from jax_inference_offloading.integrations.tunix.load_model import load_model +from jax_inference_offloading.integrations.tunix.rollout import VllmGPURollout logger = logging.getLogger(__name__) timer = Timer() diff --git a/jax-inference-offloading/examples/trainer_standalone.py b/jax-inference-offloading/examples/trainer_standalone.py new file mode 100644 index 000000000..a93291c92 --- /dev/null +++ b/jax-inference-offloading/examples/trainer_standalone.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Standalone example: Using jax-inference-offloading without Tunix. + +This example demonstrates how to use the VLLMRolloutEngine directly +without depending on the Tunix RL framework. This is useful for: +- Custom RL training loops +- Integration with other frameworks (OpenRLHF, TRL, etc.) +- Testing and benchmarking + +Prerequisites: +- Gateway server running (python -m jax_inference_offloading.controller.gateway) +- vLLM rollout worker running (python examples/rollout.py) + +Environment variables: +- GATEWAY_URL: URL of the gateway server (e.g., "localhost:50051") +- MODEL_NAME: HuggingFace model name (e.g., "meta-llama/Llama-3.2-1B-Instruct") +- MODEL_PATH: Path to HuggingFace model checkpoint (optional, uses dummy weights if not set) +- PARAM_MAPPING_PATH: Path to JSON parameter mapping file (optional, uses hardcoded mappings if not set) +""" + +import os + +import jax +import jax.numpy as jnp + +# Framework-agnostic imports from jax-inference-offloading +from jax_inference_offloading import ( + InferenceConfig, + VLLMRolloutEngine, +) +from jax_inference_offloading.timer import Timer +from jax_inference_offloading.models import get_named_parameters + +from transformers import AutoTokenizer + +# ============================================================================= +# CUSTOM MODEL LOADING +# ============================================================================= +# IMPORTANT: This example uses the model implementation in Tunix, Thus +# the model loading below uses Tunix's load_model function. +# When integrating with your own RL framework (OpenRLHF, TRL, custom, etc.), +# you MUST replace this with your own model loading code that interfaces with +# your own model implementation. +# +# NOTE: Tunix's load_model requires MODEL_NAME because it uses regex matching +# on the model name to determine which architecture/config to use. A custom +# loader could instead read the model type from checkpoint_path/config.json. +# +# Your custom load_model function should: +# 1. Load the model architecture specific to your framework +# 2. Load checkpoint weights from MODEL_PATH (HuggingFace safetensors format) +# 3. Return a model object from which parameters can be extracted +# +# The key requirement is that parameters must have JAX-compatible shapes that +# match the parameter mapping (see examples/mappings/ for JSON mapping format). +# The mapping defines how JAX parameter shapes are transformed to vLLM shapes +# during weight transfer. +# +# Example custom implementation: +# +# def load_model(checkpoint_path, mesh, dtype): +# # 1. Create your model architecture (could infer from config.json) +# config = json.load(open(f"{checkpoint_path}/config.json")) +# model = MyCustomModel(config) +# +# # 2. Load weights from checkpoint +# model = load_weights_from_safetensors(model, checkpoint_path) +# +# # 3. Shard across mesh if needed +# model = shard_model(model, mesh) +# +# return model +# +# ============================================================================= +from jax_inference_offloading.integrations.tunix.load_model import load_model + +timer = Timer() + +# --- Configuration --- +model_name = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.2-1B-Instruct") +model_path = os.environ.get("MODEL_PATH", None) +gateway_url = os.environ.get("GATEWAY_URL", "localhost:50051") +transfer_mode = os.environ.get("TRANSFER_MODE", "grouped") + +# Load tokenizer for pad_id +tokenizer = AutoTokenizer.from_pretrained(model_path or model_name) +if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + +# Create mesh +mesh = jax.make_mesh((jax.process_count(), jax.local_device_count()), ("fsdp", "tp")) + +# --- Load Model --- +# NOTE: Replace this section with your own model loading code. +# The load_model function must return a model compatible with get_named_parameters(). +with timer.section("load_model"): + model = load_model( + model_name, + mesh, + checkpoint_path=model_path, + dtype=jnp.bfloat16, + random_seed=42, + ) + +# Extract named parameters for transfer +# This flattens the model's parameter tree into a dict with dot-separated keys +# e.g., {"model.layers.0.attn.q_proj.w": array(...), ...} +params = get_named_parameters(model) + +# --- Create VLLMRolloutEngine (framework-agnostic) --- +if jax.process_index() == 0: + print(f"Creating VLLMRolloutEngine with gateway_url={gateway_url}") + +with timer.section("create_engine"): + engine = VLLMRolloutEngine( + gateway_url=gateway_url, + mesh=mesh, + model_name=model_name, + transfer_mode=transfer_mode, + timer=timer, + ) + +# --- Transfer Weights --- +if jax.process_index() == 0: + print("Transferring weights to vLLM...") + +with timer.section("warmup_transfer"): + engine.update_weights(params) + +if jax.process_index() == 0: + print("Weights transferred successfully!") + +# --- Benchmark weight transfer --- +for r in range(3): + with timer.section(f"transfer.run{r}"): + engine.update_weights(params) + +# --- Generate Completions --- +if jax.process_index() == 0: + print("\n" + "=" * 80) + print("Generating completions...") + print("=" * 80) + + # Example 1: Simple text prompt + config = InferenceConfig( + max_tokens=256, + temperature=0.7, + top_p=0.95, + ) + output = engine.generate(["Quick facts about the moon:"], config) + + print("\n--- Text Prompt ---") + print(f"Prompt: Quick facts about the moon:") + print(f"Response: {output.texts[0]}") + + # Example 2: Multiple prompts with multiple outputs per prompt + config_multi = InferenceConfig( + max_tokens=100, + temperature=0.9, + top_p=0.95, + n=2, # Generate 2 completions per prompt + ) + prompts = [ + "What is 2 + 2?", + "Name a color:", + ] + output = engine.generate(prompts, config_multi) + + print("\n--- Multiple Prompts (n=2) ---") + for i, completion in enumerate(output.completions): + print(f"\nCompletion {i + 1}:") + print(f" Text: {completion.text[:100]}...") + print(f" Token count: {len(completion.token_ids)}") + + # Example 3: Using to_arrays() for training + arrays = output.to_arrays( + max_prompt_length=64, + max_completion_length=100, + pad_id=tokenizer.pad_token_id, + ) + print("\n--- Arrays for Training ---") + print(f" prompt_tokens shape: {arrays['prompt_tokens'].shape}") + print(f" completion_tokens shape: {arrays['completion_tokens'].shape}") + +# --- Print timing summary --- +if jax.process_index() == 0: + print("\n" + "=" * 80) + print("Timing Summary") + print("=" * 80) + timer.summary(sort_by="name", precision=3) + +# --- Shutdown --- +if jax.process_index() == 0: + engine.shutdown() + print("\nEngine shutdown complete. Exiting.") diff --git a/jax-inference-offloading/jax_inference_offloading/__init__.py b/jax-inference-offloading/jax_inference_offloading/__init__.py new file mode 100644 index 000000000..4e61684e1 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/__init__.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX-vLLM Inference Offloading Bridge. + +This package provides infrastructure for offloading inference/rollout +generation from JAX training to vLLM, enabling efficient RL post-training. + +Quick Start: + >>> from jax_inference_offloading import VLLMRolloutEngine, InferenceConfig + >>> + >>> engine = VLLMRolloutEngine( + ... gateway_url="localhost:50051", + ... model_name="meta-llama/Llama-3.1-8B-Instruct", + ... mesh=jax.make_mesh((8,), ("tp",)), + ... ) + >>> + >>> engine.update_weights(my_params) + >>> output = engine.generate(prompts, InferenceConfig(max_tokens=128)) +""" + +# Core API types +from jax_inference_offloading.api import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, +) + +# Engine implementations +from jax_inference_offloading.engines import VLLMRolloutEngine + +# Low-level access (for advanced users) +from jax_inference_offloading.jax import OffloadingBridge + +__all__ = [ + # Core API types + "CompletionOutput", + "InferenceConfig", + "InferenceOutput", + # Engines + "VLLMRolloutEngine", + # Advanced + "OffloadingBridge", +] diff --git a/jax-inference-offloading/jax_inference_offloading/api/__init__.py b/jax-inference-offloading/jax_inference_offloading/api/__init__.py new file mode 100644 index 000000000..4a8387cab --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/api/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API types for jax-inference-offloading.""" + +from jax_inference_offloading.api.types import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, + pad_left, + pad_right, +) + +__all__ = [ + "CompletionOutput", + "InferenceConfig", + "InferenceOutput", + "pad_left", + "pad_right", +] diff --git a/jax-inference-offloading/jax_inference_offloading/api/types.py b/jax-inference-offloading/jax_inference_offloading/api/types.py new file mode 100644 index 000000000..54412215a --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/api/types.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Framework-agnostic types for inference offloading.""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import numpy as np + + +def pad_left(seq: list, length: int, pad_value) -> list: + """Left-pad a sequence to the specified length. + + Args: + seq: Sequence to pad. + length: Target length. + pad_value: Value to use for padding. + + Returns: + Left-padded sequence. + + Raises: + AssertionError: If sequence is longer than target length. + """ + assert len(seq) <= length, f"Sequence too long: {len(seq)} > {length}" + return [pad_value] * (length - len(seq)) + list(seq) + + +def pad_right(seq: list, length: int, pad_value) -> list: + """Right-pad a sequence to the specified length. + + Args: + seq: Sequence to pad. + length: Target length. + pad_value: Value to use for padding. + + Returns: + Right-padded sequence. + + Raises: + AssertionError: If sequence is longer than target length. + """ + assert len(seq) <= length, f"Sequence too long: {len(seq)} > {length}" + return list(seq) + [pad_value] * (length - len(seq)) + + +@dataclass(frozen=True) +class InferenceConfig: + """Framework-agnostic inference configuration. + + Maps to vLLM SamplingParams. + + Attributes: + max_tokens: Maximum number of tokens to generate per output sequence. + temperature: Temperature for sampling. 0.0 = greedy, higher = more random. + top_p: Top-p (nucleus) sampling. 1.0 = no filtering. + top_k: Top-k sampling. -1 = no filtering. + n: Number of output sequences per prompt (for best-of-n, GRPO groups, etc.). + seed: Random seed for reproducibility. + stop_token_ids: Stop token IDs (e.g., EOS tokens). + """ + + max_tokens: int = 64 + temperature: float = 0.9 + top_p: float = 1.0 + top_k: int = -1 + n: int = 1 + seed: Optional[int] = None + stop_token_ids: List[int] = field(default_factory=list) + + +@dataclass +class CompletionOutput: + """Single completion/output from the model. + + Attributes: + text: Generated text. + token_ids: Generated token IDs. + logprobs: Log probabilities per generated token (optional). + prompt_token_ids: Prompt token IDs (useful for log-prob calculations). + """ + + text: str + token_ids: List[int] + logprobs: Optional[List[float]] = None + prompt_token_ids: Optional[List[int]] = None + + +@dataclass +class InferenceOutput: + """Output from inference/rollout generation. + + Contains one or more CompletionOutput per prompt (based on config.n). + + Attributes: + completions: List of completions, flattened across all prompts. + Length = num_prompts * config.n + """ + + completions: List[CompletionOutput] + + @property + def texts(self) -> List[str]: + """Get all generated texts.""" + return [c.text for c in self.completions] + + @property + def token_ids(self) -> List[List[int]]: + """Get all generated token ID sequences.""" + return [c.token_ids for c in self.completions] + + def to_arrays( + self, + max_prompt_length: int, + max_completion_length: int, + pad_id: int, + ) -> Dict[str, np.ndarray]: + """Convert to padded numpy arrays for training. + + Args: + max_prompt_length: Maximum prompt length (for left-padding). + max_completion_length: Maximum completion length (for right-padding). + pad_id: Padding token ID. + + Returns: + dict with keys: + - 'prompt_tokens': [batch, max_prompt_length] left-padded + - 'completion_tokens': [batch, max_completion_length] right-padded + - 'completion_logprobs': [batch, max_completion_length] if available + + Raises: + AssertionError: If any sequence exceeds the specified max length. + """ + result: Dict[str, np.ndarray] = { + "prompt_tokens": np.array( + [ + pad_left(c.prompt_token_ids or [], max_prompt_length, pad_id) + for c in self.completions + ], + dtype=np.int32, + ), + "completion_tokens": np.array( + [ + pad_right(c.token_ids, max_completion_length, pad_id) + for c in self.completions + ], + dtype=np.int32, + ), + } + + if all(c.logprobs is not None for c in self.completions): + result["completion_logprobs"] = np.array( + [ + pad_right(c.logprobs or [], max_completion_length, 0.0) + for c in self.completions + ], + dtype=np.float32, + ) + + return result diff --git a/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py b/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py index 2f7900491..4904be217 100644 --- a/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py +++ b/jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py @@ -36,11 +36,12 @@ class RolloutServicer: - def __init__(self, llm): + def __init__(self, llm, mapping_json_path=None): llm.collective_rpc("set_sharding") self._llm = llm self._tok = llm.get_tokenizer() + self._mapping_json_path = mapping_json_path @staticmethod def as_proto(vllm_response) -> ctrl.InferenceResponse: @@ -62,7 +63,8 @@ def from_vllm_output(vllm_output) -> ctrl.InferenceResponse.Output: return response_proto def handshake(self, request): - mapping_specs = get_tp_model_mapping(request.model_name or self._llm.llm_engine.model_config.model) + model_name = request.model_name or self._llm.llm_engine.model_config.model + mapping_specs = get_tp_model_mapping(model_name, mapping_json_path=self._mapping_json_path) mapping_specs, vllm_tp_size = add_sharding_specs(mapping_specs, self._llm, request.jax_parallelism.tp) self._mapping_specs = mapping_specs @@ -160,10 +162,10 @@ def __init__(self, executor, controller_stub, broker_stub, channel=None): super().__init__(executor, controller_stub, broker_stub, channel) self._update_future = None - def subscribe_to_control_messages(self, llm): + def subscribe_to_control_messages(self, llm, mapping_json_path=None): assert self._update_future is None - servicer = RolloutServicer(llm) + servicer = RolloutServicer(llm, mapping_json_path=mapping_json_path) def call(): try: diff --git a/jax-inference-offloading/jax_inference_offloading/engines/__init__.py b/jax-inference-offloading/jax_inference_offloading/engines/__init__.py new file mode 100644 index 000000000..62540adf4 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/engines/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference engine implementations.""" + +from jax_inference_offloading.engines.vllm_engine import VLLMRolloutEngine + +__all__ = ["VLLMRolloutEngine"] diff --git a/jax-inference-offloading/jax_inference_offloading/engines/vllm_engine.py b/jax-inference-offloading/jax_inference_offloading/engines/vllm_engine.py new file mode 100644 index 000000000..5e8b86fc4 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/engines/vllm_engine.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""vLLM-based rollout engine using the JAX-vLLM offloading bridge.""" + +from typing import Dict, List, Optional, Union + +import jax + +from jax_inference_offloading.api.types import ( + CompletionOutput, + InferenceConfig, + InferenceOutput, +) +from jax_inference_offloading.jax import OffloadingBridge +from jax_inference_offloading.models import flatten_state, get_named_parameters +from jax_inference_offloading.timer import Timer +import jax_inference_offloading.api.controller_pb2 as ctrl + + +class VLLMRolloutEngine: + """vLLM-based rollout engine for inference offloading. + + This is the main entry point for users who want to use vLLM + for inference offloading without depending on Tunix or other + RL frameworks. + + Example: + >>> engine = VLLMRolloutEngine( + ... gateway_url="localhost:50051", + ... model_name="meta-llama/Llama-3.1-8B-Instruct", + ... mesh=jax.make_mesh((8,), ("tp",)), + ... ) + >>> + >>> # Transfer weights from JAX model + >>> engine.update_weights(my_jax_params) + >>> + >>> # Generate completions + >>> config = InferenceConfig(max_tokens=128, temperature=0.9) + >>> output = engine.generate(["What is 2+2?"], config) + >>> print(output.texts[0]) + """ + + def __init__( + self, + gateway_url: str, + mesh: jax.sharding.Mesh, + *, + model_name: Optional[str] = None, + transfer_mode: str = "grouped", + timer: Optional[Timer] = None, + ): + """Initialize the vLLM rollout engine. + + Args: + gateway_url: URL of the gateway server (e.g., "localhost:50051"). + mesh: JAX device mesh for sharded parameter handling. + model_name: HuggingFace model name for tensor mapping resolution. + Optional if PARAM_MAPPING_PATH is set on the vLLM side. + transfer_mode: Weight transfer mode ('fused', 'unfused', 'grouped'). + timer: Optional timer for performance profiling. + """ + self._timer = timer or Timer() + self._bridge = OffloadingBridge( + gateway_url=gateway_url, + model_name=model_name, + mesh=mesh, + transfer_mode=transfer_mode, + timer=self._timer, + ) + + def generate( + self, + prompts: Union[List[str], List[List[int]]], + config: InferenceConfig, + ) -> InferenceOutput: + """Generate completions using vLLM. + + Args: + prompts: Text prompts or pre-tokenized prompts. + config: Inference configuration. + + Returns: + InferenceOutput with generated completions. + """ + # Build protobuf config + proto_config = ctrl.RolloutConfig( + max_tokens=config.max_tokens, + temperature=config.temperature, + top_p=config.top_p, + top_k=config.top_k, + num_outputs=config.n, + seed=config.seed or 42, + ) + proto_config.stop_token_ids.extend(config.stop_token_ids) + + # Call gateway + with self._timer.section("inference"): + response = self._bridge.gateway.inference(prompts, config=proto_config) + + # Convert response to framework-agnostic output + completions = [] + for output in response.outputs: + completions.append( + CompletionOutput( + text=output.generated_text, + token_ids=list(output.generated_tokens.ids), + logprobs=( + list(output.generated_token_logps) + if output.generated_token_logps + else None + ), + prompt_token_ids=( + list(output.tokenized_prompt.ids) + if output.tokenized_prompt.ids + else None + ), + ) + ) + + return InferenceOutput(completions=completions) + + def update_weights( + self, + params: Union[Dict[str, jax.Array], "nnx.State", "nnx.Module"], # noqa: F821 + ) -> None: + """Transfer model weights to vLLM. + + Args: + params: Model parameters in various formats: + - Dict[str, jax.Array]: Direct flattened params + - flax.nnx.State: Flax state object + - flax.nnx.Module: Flax module (state extracted automatically) + """ + + with self._timer.section("update_weights"): + # Handle different input formats + if isinstance(params, dict): + named_params = params + else: + # Try flax.nnx formats + try: + from flax import nnx + + if isinstance(params, nnx.Module): + named_params = get_named_parameters(params) + elif isinstance(params, nnx.State): + named_params = flatten_state(params) + else: + raise TypeError(f"Unsupported params type: {type(params)}") + except ImportError: + raise TypeError( + f"Unsupported params type: {type(params)}. " + "Expected Dict[str, jax.Array] or install flax for nnx support." + ) + + # Transfer via bridge + self._bridge.transfer(named_params) + + def shutdown(self) -> None: + """Shutdown the gateway connection.""" + try: + self._bridge.gateway.shutdown() + except Exception: + pass # Ignore shutdown errors + + @property + def gateway(self): + """Access the underlying gateway client for advanced usage.""" + return self._bridge.gateway + + @property + def timer(self) -> Timer: + """Access the timer for performance analysis.""" + return self._timer + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures shutdown is called.""" + self.shutdown() + return False diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py b/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py new file mode 100644 index 000000000..dcd32cf2c --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Framework-specific integrations for jax-inference-offloading.""" diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/__init__.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/__init__.py similarity index 100% rename from jax-inference-offloading/jax_inference_offloading/tunix/__init__.py rename to jax-inference-offloading/jax_inference_offloading/integrations/tunix/__init__.py diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/load_model.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py similarity index 94% rename from jax-inference-offloading/jax_inference_offloading/tunix/load_model.py rename to jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py index 3fef74854..6cfd7766c 100644 --- a/jax-inference-offloading/jax_inference_offloading/tunix/load_model.py +++ b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/load_model.py @@ -30,11 +30,11 @@ def load_model(name, mesh: jax.sharding.Mesh = None, checkpoint_path: str = None from tunix.models.llama3.model import Llama3, ModelConfig from tunix.models.llama3.params import create_model_from_safe_tensors config_factory = { - '1B': ModelConfig.llama3_2_1b, - '3B': ModelConfig.llama3_2_3b, - '8B': ModelConfig.llama3_1_8b, - '70B': ModelConfig.llama3_70b, - '405B': ModelConfig.llama3_405b, + '1B': ModelConfig.llama3p2_1b, + '3B': ModelConfig.llama3p2_3b, + '8B': ModelConfig.llama3p1_8b, + '70B': ModelConfig.llama3p1_70b, + '405B': ModelConfig.llama3p1_405b, } try: config = config_factory[m.group('size')]() diff --git a/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py new file mode 100644 index 000000000..c77b792a5 --- /dev/null +++ b/jax-inference-offloading/jax_inference_offloading/integrations/tunix/rollout.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tunix adapter for vLLM rollout offloading. + +This module provides VllmGPURollout, a Tunix BaseRollout implementation +that delegates to the framework-agnostic VLLMRolloutEngine. +""" +from typing import Any, List, Optional, Tuple + +import jax +import jax.numpy as jnp +import jaxtyping + +from jax_inference_offloading.api import InferenceConfig, pad_left, pad_right +from jax_inference_offloading.engines import VLLMRolloutEngine +from jax_inference_offloading.timer import Timer +from tunix.rl.rollout.base_rollout import BaseRollout, RolloutConfig, RolloutOutput + + +class VllmGPURollout(BaseRollout): + """Tunix adapter wrapping VLLMRolloutEngine. + + This class implements Tunix's BaseRollout interface by delegating + to the framework-agnostic VLLMRolloutEngine. It handles the conversion + between Tunix's RolloutConfig/RolloutOutput and the bridge's + InferenceConfig/InferenceOutput. + """ + + def __init__( + self, + gateway_url: str, + model_name: str, + *, + rollout_actor, # AKA rollout model (unused for remote engine) + tokenizer, + mesh: jax.sharding.Mesh, + rollout_config: RolloutConfig, # Initial config (unused, passed per-call) + extra_stop_tokens: List[str] | None = None, + transfer_mode: str = "fused", + timer: Any | None = None, + ): + """Initialize the Tunix vLLM rollout adapter. + + Args: + gateway_url: URL of the gateway server. + model_name: HuggingFace model name for tensor mapping. + rollout_actor: The rollout model (unused for remote engine). + tokenizer: Tunix tokenizer for encoding/decoding. + mesh: JAX device mesh. + rollout_config: Initial rollout config (unused, provided per generate call). + extra_stop_tokens: Additional stop tokens as strings. + transfer_mode: Weight transfer mode ('fused', 'unfused', 'grouped'). + timer: Optional timer for profiling. + """ + del rollout_actor # Not used for remote engine + del rollout_config # Config passed per generate() call + + self._timer = timer or Timer() + self._tokenizer = tokenizer + + # Resolve extra stop tokens to IDs + self._extra_stop_token_ids: List[int] = [] + for t in extra_stop_tokens or []: + token_ids = self._tokenizer.encode(t) + assert len(token_ids) == 1, f"Stop token {t} must be a single token, got {token_ids}" + self._extra_stop_token_ids.extend(token_ids) + + # Delegate to the framework-agnostic engine + self._engine = VLLMRolloutEngine( + gateway_url=gateway_url, + model_name=model_name, + mesh=mesh, + transfer_mode=transfer_mode, + timer=self._timer, + ) + + def generate( + self, + prompts: List[str], + rollout_config: RolloutConfig, + ) -> RolloutOutput: + """Generate completions for the given prompts. + + Args: + prompts: List of text prompts. + rollout_config: Tunix rollout configuration. + + Returns: + Tunix RolloutOutput with generated samples. + """ + with self._timer.section("rollout.generate"): + # Convert Tunix RolloutConfig -> InferenceConfig + stop_token_ids = list(self._extra_stop_token_ids) + if rollout_config.eos_tokens is not None: + stop_token_ids = list(rollout_config.eos_tokens) + stop_token_ids + else: + stop_token_ids = [self._tokenizer.eos_id()] + stop_token_ids + + config = InferenceConfig( + max_tokens=rollout_config.max_tokens_to_generate, + temperature=rollout_config.temperature, + top_p=rollout_config.top_p if rollout_config.top_p is not None else 1.0, + top_k=rollout_config.top_k if rollout_config.top_k is not None else -1, + seed=rollout_config.seed, + stop_token_ids=stop_token_ids, + ) + + # Call engine + with self._timer.section("inference"): + output = self._engine.generate([str(p) for p in prompts], config) + + # Convert InferenceOutput -> Tunix RolloutOutput + with self._timer.section("process_outputs"): + generated_text = [] + input_tokens = [] + output_tokens = [] + + for i, completion in enumerate(output.completions): + generated_text.append(completion.text) + input_tokens.append( + pad_left( + completion.prompt_token_ids or [], + rollout_config.max_prompt_length, + self._tokenizer.pad_id(), + ) + ) + output_tokens.append( + pad_right( + completion.token_ids, + rollout_config.max_tokens_to_generate, + self._tokenizer.pad_id(), + ) + ) + + return RolloutOutput( + text=generated_text, + logits=[], # Not needed for GRPO + tokens=jnp.array(output_tokens, dtype=jnp.int32), + left_padded_prompt_tokens=jnp.array(input_tokens, dtype=jnp.int32), + logprobs=None, # GRPOLearner will recalculate + ) + + def get_per_token_logps( + self, + prompt_tokens: jax.Array, + completion_tokens: jax.Array, + completion_mask: jax.Array | None = None, + ) -> jax.Array: + """Get per-token log probabilities. + + Not implemented for remote engine - use GRPOLearner's recalculation. + """ + raise NotImplementedError( + "get_per_token_logps is not supported for remote vLLM engine. " + "Use GRPOLearner which recalculates logprobs locally." + ) + + def update_params( + self, + params: jaxtyping.PyTree, + filter_types: Optional[Tuple[Any, ...]] = None, + ) -> None: + """Update the rollout model parameters. + + Args: + params: Model parameters to transfer. + filter_types: Unused for remote engine. + """ + del filter_types # Not used for remote engine + with self._timer.section("rollout.update_params"): + self._engine.update_weights(params) + + def pad_id(self) -> int: + """Return the padding token ID.""" + return self._tokenizer.pad_id() + + def eos_id(self) -> int: + """Return the end-of-sequence token ID.""" + return self._tokenizer.eos_id() + + def model(self): + """Return the local model (None for remote engine).""" + return None + + def shutdown(self) -> None: + """Gracefully shutdown the remote gateway.""" + self._engine.shutdown() + + def __del__(self): + """Destructor - attempt graceful shutdown.""" + try: + self.shutdown() + except Exception: + # Suppress destructor-time errors during interpreter shutdown. + pass diff --git a/jax-inference-offloading/jax_inference_offloading/models/auto.py b/jax-inference-offloading/jax_inference_offloading/models/auto.py index c6902d970..e3bc3d3dd 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/auto.py +++ b/jax-inference-offloading/jax_inference_offloading/models/auto.py @@ -14,14 +14,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import jax_inference_offloading.api.param_mapping_pb2 as mapping from .gemma import get_gemma_2b_mapping, get_gemma_7b_mapping from .gemma3 import get_gemma3_1b_mapping -from .llama3 import get_llama3_8b_mapping, get_llama3_70b_mapping, get_llama3_405b_mapping +from .llama3 import get_llama3_1b_mapping, get_llama3_8b_mapping, get_llama3_70b_mapping, get_llama3_405b_mapping +from .mapping_util import load_mapping_from_json + + +def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model", mapping_json_path=None) -> mapping.TpModelMappingSpecs: + """Get the parameter mapping for a model. + Args: + model_name: HuggingFace model name (e.g., "meta-llama/Llama-3.1-8B"). + jax_prefix: Prefix for JAX parameter names. + vllm_prefix: Prefix for vLLM parameter names. + mapping_json_path: Optional path to a custom param_mapping.json file. + If provided and the file exists, it will be used instead of hardcoded mappings. -def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model") -> mapping.TpModelMappingSpecs: + Returns: + TpModelMappingSpecs protobuf with parameter mappings. + """ + # Check for custom JSON mapping file first + if mapping_json_path is not None and os.path.exists(mapping_json_path): + return load_mapping_from_json(mapping_json_path) if model_name in ("google/gemma-2b", "google/gemma-2b-it"): return get_gemma_2b_mapping(jax_prefix, vllm_prefix) elif model_name in ("google/gemma-7b", "google/gemma-7b-it"): @@ -57,4 +75,9 @@ def get_tp_model_mapping(model_name, jax_prefix="model", vllm_prefix="model") -> "meta-llama/Llama-3.1-405B-Instruct-FP8", ): return get_llama3_405b_mapping(jax_prefix, vllm_prefix) + elif model_name in ( + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + ): + return get_llama3_1b_mapping(jax_prefix, vllm_prefix) raise Exception(f"Unknown model {model_name}.") diff --git a/jax-inference-offloading/jax_inference_offloading/models/llama3.py b/jax-inference-offloading/jax_inference_offloading/models/llama3.py index 2346c9090..9c4eca7c3 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/llama3.py +++ b/jax-inference-offloading/jax_inference_offloading/models/llama3.py @@ -30,6 +30,7 @@ def _get_llama3_mapping( ffn_size, jax_prefix: str = "model", vllm_prefix: str = "model", + tie_word_embeddings: bool = False, ) -> mapping.TpModelMappingSpecs: param_mapping = partial(make_mapping, jax_prefix=jax_prefix, vllm_prefix=vllm_prefix) @@ -41,13 +42,18 @@ def _get_llama3_mapping( [vocab_size, hidden_size], ), param_mapping("final_norm.w", "norm.weight", [hidden_size]), - make_mapping( - "lm_head.w", "lm_head.weight", [vocab_size, hidden_size], - transform=make_transform(transpose=[1, 0]), - jax_prefix=jax_prefix, vllm_prefix='' - ), ] + # Only add lm_head mapping if embeddings are not tied + if not tie_word_embeddings: + params.append( + make_mapping( + "lm_head.w", "lm_head.weight", [vocab_size, hidden_size], + transform=make_transform(transpose=[1, 0]), + jax_prefix=jax_prefix, vllm_prefix='' + ) + ) + # per-layer for layer_id in range(n_layers): params.extend( @@ -116,6 +122,20 @@ def _get_llama3_mapping( model_mapping.mappings.extend(params) return model_mapping +def get_llama3_1b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: + return _get_llama3_mapping( + vocab_size=128_256, + n_layers=16, + hidden_size=2048, + q_heads=32, + kv_heads=8, + head_dim=64, + ffn_size=8192, + jax_prefix=jax_prefix, + vllm_prefix=vllm_prefix, + tie_word_embeddings=True, + ) + def get_llama3_8b_mapping(jax_prefix: str = "model", vllm_prefix: str = "model") -> mapping.TpModelMappingSpecs: return _get_llama3_mapping( vocab_size=128_256, diff --git a/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py b/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py index 5030a455d..0dc83e28e 100644 --- a/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py +++ b/jax-inference-offloading/jax_inference_offloading/models/mapping_util.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from copy import deepcopy import google.protobuf.text_format as text_format @@ -73,6 +74,154 @@ def load_mapping_spec(filename: str) -> mapping.TpModelMappingSpecs: return text_format.Parse(file.read(), mapping.TpModelMappingSpecs()) +def _parse_slice_from_json(slice_list: list) -> mapping.TensorSlice: + """Parse a JSON slice specification into a TensorSlice proto. + + Args: + slice_list: A list of slice specifications. Each element can be: + - "..." for ellipsis + - An integer for index + - A list [start, stop] for slice (use null for None) + + Returns: + A TensorSlice protobuf message. + """ + result = mapping.TensorSlice() + for dim in slice_list: + slice_dim = mapping.TensorSlice.Dim() + if dim == "...": + slice_dim.ellipsis.SetInParent() + elif isinstance(dim, int): + slice_dim.index.index = dim + elif isinstance(dim, list): + slice_dim.slice.SetInParent() + if len(dim) >= 1 and dim[0] is not None: + slice_dim.slice.start = dim[0] + if len(dim) >= 2 and dim[1] is not None: + slice_dim.slice.stop = dim[1] + else: + raise ValueError(f"Invalid slice specification: {dim}") + result.dims.append(slice_dim) + return result + + +def _parse_transform_from_json(transform_dict: dict) -> mapping.JaxParam.Transform: + """Parse a JSON transform specification into a Transform proto. + + Args: + transform_dict: A dictionary with optional keys: + - "transpose": list of ints (axis permutation) + - "reshape": list of ints (new shape, -1 for inferred) + - "slice": list of slice specs + - "replication_axis": int + - "replication_count": int + + Returns: + A JaxParam.Transform protobuf message. + """ + result = mapping.JaxParam.Transform() + + if "slice" in transform_dict: + result.slice.CopyFrom(_parse_slice_from_json(transform_dict["slice"])) + + if "transpose" in transform_dict: + result.transpose.extend(transform_dict["transpose"]) + + if "reshape" in transform_dict: + result.reshape.extend(transform_dict["reshape"]) + + if "replication_axis" in transform_dict: + result.replication_axis = int(transform_dict["replication_axis"]) + + if "replication_count" in transform_dict: + result.replication_count = int(transform_dict["replication_count"]) + + return result + + +def load_mapping_from_json(json_path: str) -> mapping.TpModelMappingSpecs: + """Load parameter mapping from a JSON configuration file. + + The JSON schema is: + { + "num_layers": 32, + "mappings": [ + { + "jax_param": { + "name": "model.layers.{layer}.attn.q_proj.w", + "transform": { "transpose": [1, 2, 0], "reshape": [-1, 4096] } + }, + "vllm_param": { + "name": "model.layers.{layer}.self_attn.q_proj.weight", + "shape": [4096, 4096] + } + }, + ... + ] + } + + - Mappings with `{layer}` placeholder are expanded into `num_layers` copies + - Mappings without `{layer}` are kept as singletons + - Transform fields (transpose, reshape, slice, replication_axis) are optional + + Args: + json_path: Path to the JSON configuration file. + + Returns: + A TpModelMappingSpecs protobuf message with all mappings expanded. + """ + with open(json_path, "r") as f: + config = json.load(f) + + num_layers = config.get("num_layers", 0) + json_mappings = config.get("mappings", []) + + model_mapping = mapping.TpModelMappingSpecs() + + for json_mapping in json_mappings: + jax_param_spec = json_mapping["jax_param"] + vllm_param_spec = json_mapping["vllm_param"] + + jax_name = jax_param_spec["name"] + vllm_name = vllm_param_spec["name"] + vllm_shape = vllm_param_spec["shape"] + + # Check if this is a templated per-layer mapping + if "{layer}" in jax_name or "{layer}" in vllm_name: + # Expand for all layers + layer_indices = range(num_layers) + else: + # Singleton mapping - use None as sentinel + layer_indices = [None] + + for layer_idx in layer_indices: + param_mapping = mapping.ParamMapping() + + # Expand layer placeholder if present + if layer_idx is not None: + expanded_jax_name = jax_name.replace("{layer}", str(layer_idx)) + expanded_vllm_name = vllm_name.replace("{layer}", str(layer_idx)) + else: + expanded_jax_name = jax_name + expanded_vllm_name = vllm_name + + # Set JAX param + param_mapping.jax_param.name = expanded_jax_name + + # Parse and set transform if present + if "transform" in jax_param_spec: + transform = _parse_transform_from_json(jax_param_spec["transform"]) + param_mapping.jax_param.transform.CopyFrom(transform) + + # Set vLLM param + param_mapping.vllm_param.name = expanded_vllm_name + param_mapping.vllm_param.shape.extend(vllm_shape) + + model_mapping.mappings.append(param_mapping) + + return model_mapping + + def add_sharding_specs(model_mapping: mapping.TpModelMappingSpecs, llm: LLM, jax_tp_size: int): per_rank_sharding_specs = llm.collective_rpc("get_tp_sharding_specs") vllm_tp_size = len(per_rank_sharding_specs) diff --git a/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py b/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py deleted file mode 100644 index af8dc9adf..000000000 --- a/jax-inference-offloading/jax_inference_offloading/tunix/rollout.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rollout worker with offloading to vLLM.""" -from typing import Any, Optional, Tuple - -import jax -import jax.numpy as jnp -import jaxtyping - -import jax_inference_offloading.api.controller_pb2 as ctrl -from jax_inference_offloading.jax import OffloadingBridge -from jax_inference_offloading.timer import Timer -from tunix.rl.rollout.base_rollout import BaseRollout, RolloutConfig, RolloutOutput - - -class VllmGPURollout(BaseRollout): - - def __init__( - self, - gateway_url, - model_name, - *, - rollout_actor, # AKA rollout model - tokenizer, - mesh, - rollout_config, - extra_stop_tokens: list[str] | None = None, - transfer_mode: str = 'fused', - timer: Any | None = None, - ): - self._timer = timer or Timer() - self._tokenizer = tokenizer - self._bridge = OffloadingBridge( - gateway_url=gateway_url, - model_name=model_name, - mesh=mesh, - transfer_mode=transfer_mode, - timer=self._timer, - ) - self._extra_stop_token_ids = [] - for t in extra_stop_tokens or []: - i = self._tokenizer.encode(t) - assert len(i) == 1, f"Stop token {t} must be a single token, got {i}" - self._extra_stop_token_ids.extend(i) - - def generate( - self, - prompts: list[str], - rollout_config: RolloutConfig, - ): - """Generates samples from the model.""" - with self._timer.section("rollout.generate"): - remote_rollout_config = ctrl.RolloutConfig( - top_p=rollout_config.top_p, - top_k=rollout_config.top_k, - temperature=rollout_config.temperature, - max_tokens=rollout_config.max_tokens_to_generate, - seed=rollout_config.seed, - ) - if rollout_config.eos_tokens is not None: - remote_rollout_config.stop_token_ids.extend(rollout_config.eos_tokens) - else: - remote_rollout_config.stop_token_ids.extend([self._tokenizer.eos_id()]) - remote_rollout_config.stop_token_ids.extend(self._extra_stop_token_ids) - - with self._timer.section("inference"): - response = self._bridge.gateway.inference([str(p) for p in prompts], config=remote_rollout_config) - - with self._timer.section("process_outputs"): - generated_text = [] - input_tokens = [] - output_tokens = [] - - def pad_to_left(original, length, pad_value): - assert len(original) <= length - return [pad_value] * (length - len(original)) + original - - def pad_to_right(original, length, pad_value): - assert len(original) <= length - return original + [pad_value] * (length - len(original)) - - for i, output in enumerate(response.outputs): - if i < 1: - print(f"# Rollout {i} of {len(prompts)}") - print(f"## Prompt:\n{prompts[i]}") - print(f"## Response:\n{output.generated_text}") - print("-" * 80) - generated_text.append(output.generated_text) - input_tokens.append( - pad_to_left(list(output.tokenized_prompt.ids), rollout_config.max_prompt_length, self._tokenizer.pad_id()) - ) - output_tokens.append( - pad_to_right(list(output.generated_tokens.ids), rollout_config.max_tokens_to_generate, self._tokenizer.pad_id()) - ) - - return RolloutOutput( - text=generated_text, - logits=[], # not needed for GRPO - tokens=jnp.array(output_tokens, dtype=jnp.int32), - left_padded_prompt_tokens=jnp.array(input_tokens, dtype=jnp.int32), - logprobs=None, # needed for GRPO, GRPOLearner will recalc - ) - - def get_per_token_logps( - self, - prompt_tokens: jax.Array, - completion_tokens: jax.Array, - completion_mask: jax.Array | None = None, - ) -> jax.Array: - raise NotImplementedError() - - def update_params( - self, - params: jaxtyping.PyTree, - filter_types: Optional[Tuple[Any, ...]] = None, - ) -> None: - """Updates the rollout model parameters.""" - with self._timer.section("rollout.update_params"): - self._bridge.transfer(params) - - def pad_id(self) -> int: - return self._tokenizer.pad_id() - - def eos_id(self) -> int: - return self._tokenizer.eos_id() - - def model(self): - return None - - def shutdown(self) -> None: - """Gracefully shutdown the remote gateway if available.""" - try: - self._bridge.gateway.shutdown() - except Exception: - # Ignore shutdown errors; process teardown or remote unavailability is expected. - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # Suppress destructor-time errors during interpreter shutdown. - pass diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index 6bd052756..030934ac7 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -66,7 +66,7 @@ def run(self): 'jax==0.8.1', 'jaxtyping', 'kagglehub', - 'vllm==0.11.2', + 'vllm==0.14.0', ], extras_require={ 'test': ['pytest>=7.0'],