diff --git a/.github/workflows/unit-tests-recipes.yml b/.github/workflows/unit-tests-recipes.yml index c0cb3da83c..eb5d40333f 100644 --- a/.github/workflows/unit-tests-recipes.yml +++ b/.github/workflows/unit-tests-recipes.yml @@ -64,7 +64,7 @@ jobs: CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }} run: | # Get all recipe and model directories - ALL_DIRS=$(ls -d bionemo-recipes/models/*/ bionemo-recipes/recipes/*/ 2>/dev/null | jq -R -s -c 'split("\n")[:-1] | map(rtrimstr("/"))') + ALL_DIRS=$(ls -d bionemo-recipes/models/*/ bionemo-recipes/recipes/*/ bionemo-recipes/recipes/vllm_inference/*/ 2>/dev/null | grep -v 'recipes/vllm_inference/$' | jq -R -s -c 'split("\n")[:-1] | map(rtrimstr("/"))') # Helper to check for a PR label has_label() { diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ba59732d7..8c8c788f94 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,17 @@ repos: - id: ruff # 1. Attempt to automatically fix any lint issues. args: ["--fix"] + # Exclude check_copied_files destinations; they are verbatim copies + # of source files and must not be reformatted independently. + exclude: | + (?x)^( + bionemo-recipes/recipes/vllm_inference/llama3/(modeling_llama_te|convert|state)\.py + )$ - id: ruff-format + exclude: | + (?x)^( + bionemo-recipes/recipes/vllm_inference/llama3/(modeling_llama_te|convert|state)\.py + )$ - repo: https://github.com/executablebooks/mdformat rev: 0.7.22 # Use the latest stable version hooks: diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/.ci_build.sh b/bionemo-recipes/recipes/vllm_inference/llama3/.ci_build.sh new file mode 100755 index 0000000000..d41dd81522 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/.ci_build.sh @@ -0,0 +1,3 @@ +#!/bin/bash -x +PIP_CONSTRAINT= pip install -r requirements.txt +./install_vllm.sh diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/Dockerfile b/bionemo-recipes/recipes/vllm_inference/llama3/Dockerfile new file mode 100644 index 0000000000..d5697592f4 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/Dockerfile @@ -0,0 +1,24 @@ +FROM nvcr.io/nvidia/pytorch:26.02-py3 +WORKDIR /workspace/bionemo +COPY . . +RUN --mount=type=cache,target=/root/.cache/pip \ + PIP_CONSTRAINT= pip install -r requirements.txt + +WORKDIR /workspace +ARG INSTALL_VLLM=false +ARG TORCH_CUDA_ARCH_LIST="" +ARG MAX_JOBS=8 +ARG UV_BREAK_SYSTEM_PACKAGES=1 +RUN if [ "$INSTALL_VLLM" = "true" ]; then \ + if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then \ + echo "ERROR: TORCH_CUDA_ARCH_LIST must be set when INSTALL_VLLM=true" && exit 1; \ + fi && \ + git clone --branch v0.15.1 --depth 1 https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + python use_existing_torch.py && \ + uv pip install -r requirements/build.txt --system && \ + uv pip install --no-build-isolation -e . --system && \ + pip install --upgrade "transformers[torch]"; \ + fi + +WORKDIR /workspace/bionemo diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/README.md b/bionemo-recipes/recipes/vllm_inference/llama3/README.md new file mode 100644 index 0000000000..96726ed4f3 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/README.md @@ -0,0 +1,52 @@ +# Llama-3 vLLM Inference + +This recipe demonstrates serving a round-tripped +[Llama-3 TE checkpoint](../../../models/llama3/) via +[vLLM](https://github.com/vllm-project/vllm) (>= 0.14). + +The workflow is: + +1. Convert an HF checkpoint to TE format, then back to HF + (`export_llama3.py`). +2. Serve the round-tripped checkpoint with vLLM. + +See [tests/test_vllm.py](tests/test_vllm.py) for golden-value validation +confirming the round-tripped model matches the original. + +## Installing vLLM in the container + +There are two ways to get vLLM installed in the Docker image. + +**Option 1: Build-time installation via Dockerfile build arg** + +Pass `--build-arg INSTALL_VLLM=true` and `--build-arg TORCH_CUDA_ARCH_LIST=` when +building the image. `TORCH_CUDA_ARCH_LIST` is required when `INSTALL_VLLM=true` (the +Dockerfile will error if it is not set): + +```bash +docker build -t llama3-vllm \ + --build-arg INSTALL_VLLM=true \ + --build-arg TORCH_CUDA_ARCH_LIST="9.0" . +``` + +**Option 2: Post-build installation via `install_vllm.sh`** + +Build the base image normally, then run `install_vllm.sh` inside the container. The script +auto-detects the GPU architecture, or you can pass an explicit arch argument: + +```bash +docker build -t llama3 . +docker run --rm -it --gpus all llama3 bash -c "./install_vllm.sh" +# or with an explicit architecture: +docker run --rm -it --gpus all llama3 bash -c "./install_vllm.sh 9.0" +``` + +## Benchmarking + +The recipe includes benchmark scripts for comparing HuggingFace native and vLLM +inference: + +```bash +python benchmark_hf.py +python benchmark_vllm.py +``` diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_common.py b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_common.py new file mode 100644 index 0000000000..d48ebf328c --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_common.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared configuration, prompt generation, timing helpers, and reporting for benchmarks. + +Both benchmark_hf.py and benchmark_vllm.py import from this module so that +the sweep grid, synthetic inputs, metric computation, and output format are +identical -- guaranteeing an apples-to-apples comparison. +""" + +import argparse +import csv +import statistics +import time +from dataclasses import dataclass, fields + +import torch +from transformers import PreTrainedTokenizerBase + + +DEFAULT_BATCH_SIZES = [1, 4, 8] +DEFAULT_PROMPT_LENS = [64, 256, 512] +DEFAULT_OUTPUT_LENS = [16, 64, 128] +DEFAULT_WARMUP = 2 +DEFAULT_REPEATS = 5 + +STOCK_TEXT = ( + "The quick brown fox jumps over the lazy dog. " + "Pack my box with five dozen liquor jugs. " + "How vexingly quick daft zebras jump. " + "The five boxing wizards jump quickly. " + "Bright vixens jump; dozy fowl quack. " +) + + +@dataclass +class BenchmarkConfig: + """Holds all parameters for a benchmark run.""" + + model: str + batch_sizes: list[int] + prompt_lens: list[int] + output_lens: list[int] + warmup: int + repeats: int + csv_path: str | None + + +@dataclass +class BenchmarkResult: + """One row of benchmark output.""" + + batch_size: int + prompt_len: int + output_len: int + e2e_ms: float + ttft_ms: float + tpot_ms: float + throughput_tok_s: float + + +def build_prompts( + tokenizer: PreTrainedTokenizerBase, + batch_size: int, + prompt_length: int, +) -> tuple[list[str], torch.Tensor]: + """Generate deterministic synthetic prompts of exactly *prompt_length* tokens. + + Returns: + A tuple of (prompt_strings, input_ids_tensor). + *prompt_strings* is a list[str] of length *batch_size* (for vLLM). + *input_ids_tensor* is a (batch_size, prompt_length) int64 tensor (for HF). + Both represent byte-identical inputs. + """ + repeated = STOCK_TEXT * ((prompt_length // 10) + 2) + token_ids = tokenizer.encode(repeated, add_special_tokens=False)[:prompt_length] + prompt_str = tokenizer.decode(token_ids) + + prompt_strings = [prompt_str] * batch_size + input_ids = torch.tensor([token_ids] * batch_size, dtype=torch.long) + return prompt_strings, input_ids + + +def compute_metrics( + e2e_seconds: float, + ttft_seconds: float, + batch_size: int, + output_len: int, +) -> BenchmarkResult: + """Derive TPOT and throughput from raw wall-clock timings.""" + e2e_ms = e2e_seconds * 1000.0 + ttft_ms = ttft_seconds * 1000.0 + tpot_ms = ((e2e_seconds - ttft_seconds) / max(output_len - 1, 1)) * 1000.0 + total_output_tokens = batch_size * output_len + throughput = total_output_tokens / e2e_seconds if e2e_seconds > 0 else 0.0 + return BenchmarkResult( + batch_size=batch_size, + prompt_len=0, + output_len=output_len, + e2e_ms=e2e_ms, + ttft_ms=ttft_ms, + tpot_ms=tpot_ms, + throughput_tok_s=throughput, + ) + + +def median_timing(fn, repeats: int) -> float: + """Run *fn* multiple times and return the median wall-clock duration in seconds.""" + times = [] + for _ in range(repeats): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + return statistics.median(times) + + +_HEADER = ["batch_size", "prompt_len", "output_len", "e2e_ms", "ttft_ms", "tpot_ms", "throughput_tok_s"] + + +def print_results(results: list[BenchmarkResult]) -> None: + """Pretty-print a results table to stdout.""" + col_widths = [max(len(h), 12) for h in _HEADER] + header_line = " ".join(h.rjust(w) for h, w in zip(_HEADER, col_widths)) + print(header_line) + print("-" * len(header_line)) + for r in results: + vals = [ + str(r.batch_size), + str(r.prompt_len), + str(r.output_len), + f"{r.e2e_ms:.1f}", + f"{r.ttft_ms:.1f}", + f"{r.tpot_ms:.2f}", + f"{r.throughput_tok_s:.1f}", + ] + print(" ".join(v.rjust(w) for v, w in zip(vals, col_widths))) + + +def write_csv(results: list[BenchmarkResult], path: str) -> None: + """Write results to a CSV file.""" + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow([field.name for field in fields(BenchmarkResult)]) + for r in results: + writer.writerow( + [r.batch_size, r.prompt_len, r.output_len, r.e2e_ms, r.ttft_ms, r.tpot_ms, r.throughput_tok_s] + ) + print(f"Results written to {path}") + + +def add_common_args(parser: argparse.ArgumentParser, default_model: str) -> None: + """Register the shared CLI flags on *parser*.""" + parser.add_argument("--model", type=str, default=default_model, help="Model ID or checkpoint path.") + parser.add_argument( + "--batch-sizes", + type=str, + default=",".join(str(x) for x in DEFAULT_BATCH_SIZES), + help="Comma-separated batch sizes.", + ) + parser.add_argument( + "--prompt-lens", + type=str, + default=",".join(str(x) for x in DEFAULT_PROMPT_LENS), + help="Comma-separated prompt lengths (tokens).", + ) + parser.add_argument( + "--output-lens", + type=str, + default=",".join(str(x) for x in DEFAULT_OUTPUT_LENS), + help="Comma-separated output lengths (tokens).", + ) + parser.add_argument("--warmup", type=int, default=DEFAULT_WARMUP, help="Warmup iterations per grid point.") + parser.add_argument("--repeats", type=int, default=DEFAULT_REPEATS, help="Timed iterations per grid point.") + parser.add_argument("--csv", type=str, default=None, dest="csv_path", help="Optional CSV output path.") + + +def parse_config(args: argparse.Namespace) -> BenchmarkConfig: + """Convert parsed CLI args into a BenchmarkConfig.""" + return BenchmarkConfig( + model=args.model, + batch_sizes=[int(x) for x in args.batch_sizes.split(",")], + prompt_lens=[int(x) for x in args.prompt_lens.split(",")], + output_lens=[int(x) for x in args.output_lens.split(",")], + warmup=args.warmup, + repeats=args.repeats, + csv_path=args.csv_path, + ) diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_hf.py b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_hf.py new file mode 100644 index 0000000000..a8787eb5ac --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_hf.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark HuggingFace native inference on the reference Llama-3 model. + +Sweeps over a grid of (batch_size, prompt_length, output_length) and reports +end-to-end latency, time-to-first-token, time-per-output-token, and throughput. + +Usage: + python benchmark_hf.py + python benchmark_hf.py --model meta-llama/Llama-3.2-1B-Instruct --csv hf_results.csv +""" + +import argparse +import itertools + +import torch +from benchmark_common import ( + add_common_args, + build_prompts, + compute_metrics, + median_timing, + parse_config, + print_results, + write_csv, +) +from transformers import AutoModelForCausalLM, AutoTokenizer + + +DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct" + + +def main() -> None: + """Run the HF benchmark sweep.""" + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + add_common_args(parser, default_model=DEFAULT_MODEL) + config = parse_config(parser.parse_args()) + + print(f"Loading model: {config.model}") + model = AutoModelForCausalLM.from_pretrained(config.model, torch_dtype=torch.bfloat16).to("cuda").eval() + tokenizer = AutoTokenizer.from_pretrained(config.model) + tokenizer.pad_token = tokenizer.eos_token + + results = [] + grid = list(itertools.product(config.batch_sizes, config.prompt_lens, config.output_lens)) + + for batch_size, prompt_len, output_len in grid: + label = f"batch={batch_size} prompt={prompt_len} output={output_len}" + print(f"\n[{label}]") + + _, input_ids = build_prompts(tokenizer, batch_size, prompt_len) + input_ids = input_ids.to("cuda") + + def _generate(max_new: int) -> None: + with torch.no_grad(): + model.generate(input_ids, max_new_tokens=max_new, do_sample=False, use_cache=True) + torch.cuda.synchronize() + + for _ in range(config.warmup): + _generate(output_len) + + ttft_s = median_timing(lambda: _generate(1), config.repeats) + e2e_s = median_timing(lambda: _generate(output_len), config.repeats) + + result = compute_metrics(e2e_s, ttft_s, batch_size, output_len) + result.prompt_len = prompt_len + results.append(result) + print( + f" e2e={result.e2e_ms:.1f}ms ttft={result.ttft_ms:.1f}ms " + f"tpot={result.tpot_ms:.2f}ms throughput={result.throughput_tok_s:.1f} tok/s" + ) + + print("\n" + "=" * 60) + print_results(results) + if config.csv_path: + write_csv(results, config.csv_path) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_vllm.py b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_vllm.py new file mode 100644 index 0000000000..01fc652cb9 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/benchmark_vllm.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark vLLM inference on the round-tripped Llama-3 checkpoint. + +Sweeps over a grid of (batch_size, prompt_length, output_length) and reports +end-to-end latency, time-to-first-token, time-per-output-token, and throughput. + +The checkpoint is produced by export_llama3.py (HF -> TE -> HF round-trip). + +Usage: + python benchmark_vllm.py + python benchmark_vllm.py --model ./llama3_hf_roundtrip_checkpoint --csv vllm_results.csv +""" + +import argparse +import itertools + +from benchmark_common import ( + add_common_args, + build_prompts, + compute_metrics, + median_timing, + parse_config, + print_results, + write_csv, +) +from vllm import LLM, SamplingParams + + +DEFAULT_MODEL = "./llama3_hf_roundtrip_checkpoint" + + +def main() -> None: + """Run the vLLM benchmark sweep.""" + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + add_common_args(parser, default_model=DEFAULT_MODEL) + config = parse_config(parser.parse_args()) + + print(f"Loading model: {config.model}") + engine = LLM(model=config.model, runner="generate", dtype="bfloat16") + + # vLLM needs a tokenizer to build prompts -- reuse the one bundled with the checkpoint. + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(config.model) + tokenizer.pad_token = tokenizer.eos_token + + results = [] + grid = list(itertools.product(config.batch_sizes, config.prompt_lens, config.output_lens)) + + for batch_size, prompt_len, output_len in grid: + label = f"batch={batch_size} prompt={prompt_len} output={output_len}" + print(f"\n[{label}]") + + prompts, _ = build_prompts(tokenizer, batch_size, prompt_len) + + def _generate(max_tokens: int) -> None: + engine.generate(prompts, SamplingParams(max_tokens=max_tokens, temperature=0)) + + for _ in range(config.warmup): + _generate(output_len) + + ttft_s = median_timing(lambda: _generate(1), config.repeats) + e2e_s = median_timing(lambda: _generate(output_len), config.repeats) + + result = compute_metrics(e2e_s, ttft_s, batch_size, output_len) + result.prompt_len = prompt_len + results.append(result) + print( + f" e2e={result.e2e_ms:.1f}ms ttft={result.ttft_ms:.1f}ms " + f"tpot={result.tpot_ms:.2f}ms throughput={result.throughput_tok_s:.1f} tok/s" + ) + + print("\n" + "=" * 60) + print_results(results) + if config.csv_path: + write_csv(results, config.csv_path) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/convert.py b/bionemo-recipes/recipes/vllm_inference/llama3/convert.py new file mode 100644 index 0000000000..f4fd870e8b --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/convert.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion utilities between HuggingFace Llama3 and TransformerEngine formats.""" + +import inspect + +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +import state +from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM + + +mapping = { + "model.embed_tokens.weight": "model.embed_tokens.weight", + "model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight", + "model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight", + "model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight", + "model.norm.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", +} + +# Reverse mapping from TE to HF format by reversing the original mapping +reverse_mapping = {v: k for k, v in mapping.items()} + + +def convert_llama_hf_to_te(model_hf: LlamaForCausalLM, **config_kwargs) -> NVLlamaForCausalLM: + """Convert a Hugging Face model to a Transformer Engine model. + + Args: + model_hf (nn.Module): The Hugging Face model. + **config_kwargs: Additional configuration kwargs to be passed to NVLlamaConfig. + + Returns: + nn.Module: The Transformer Engine model. + """ + te_config = NVLlamaConfig(**model_hf.config.to_dict(), **config_kwargs) + with torch.device("meta"): + model_te = NVLlamaForCausalLM(te_config) + + if model_hf.config.tie_word_embeddings: + state_dict_ignored_entries = ["lm_head.weight"] + else: + state_dict_ignored_entries = [] + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + state.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="model.layers.*.self_attention.layernorm_qkv.weight", + fn=state.TransformFns.merge_qkv, + ), + state.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="model.layers.*.layernorm_mlp.fc1_weight", + fn=state.TransformFns.merge_fc1, + ), + ], + state_dict_ignored_entries=state_dict_ignored_entries, + ) + + output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone() + + return output_model + + +def convert_llama_te_to_hf(model_te: NVLlamaForCausalLM, **config_kwargs) -> LlamaForCausalLM: + """Convert a Transformer Engine model to a Hugging Face model. + + Args: + model_te (nn.Module): The Transformer Engine model. + **config_kwargs: Additional configuration kwargs to be passed to LlamaConfig. + + Returns: + nn.Module: The Hugging Face model. + """ + # Filter out keys from model_te.config that are not valid LlamaConfig attributes + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(LlamaConfig.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = LlamaConfig(**filtered_config, **config_kwargs) + + with torch.device("meta"): + model_hf = LlamaForCausalLM(hf_config) + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [ + state.state_transform( + source_key="model.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + fn=state.TransformFns.split_qkv, + ), + state.state_transform( + source_key="model.layers.*.layernorm_mlp.fc1_weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=state.TransformFns.split_fc1, + ), + ], + state_dict_ignored_entries=model_hf._tied_weights_keys, + ) + + output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone() + output_model.tie_weights() + + return output_model diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/export_llama3.py b/bionemo-recipes/recipes/vllm_inference/llama3/export_llama3.py new file mode 100644 index 0000000000..a812b1e923 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/export_llama3.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a TransformerEngine Llama-3 checkpoint to vLLM-ready HuggingFace format. + +The llama3_native_te recipe produces NVLlamaForCausalLM checkpoints (TE format). +This script converts those checkpoints to standard LlamaForCausalLM so they can +be served by vLLM, SGLang, or loaded by plain transformers without +trust_remote_code. + +Usage with an existing TE checkpoint (from training): + + python export_llama3.py --te-checkpoint /path/to/recipe/final_model + +Demo mode (no training needed -- creates a TE checkpoint from HuggingFace): + + python export_llama3.py +""" + +import argparse +import json +import shutil +from pathlib import Path + +from convert import convert_llama_hf_to_te, convert_llama_te_to_hf +from modeling_llama_te import AUTO_MAP, NVLlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer + + +THIS_DIR = Path(__file__).resolve().parent +HF_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +HF_CHECKPOINT_DIR = THIS_DIR / "llama3_hf_roundtrip_checkpoint" + + +def create_te_checkpoint(output_dir: Path) -> Path: + """Create a TE checkpoint by converting a pretrained HF model. + + Follows the same workflow as export.py: convert weights, patch config.json + with auto_map, and copy the modeling file so the checkpoint is self-contained + and loadable with ``trust_remote_code=True``. + """ + print(f" Loading pretrained model: {HF_MODEL_ID}") + model_hf = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID) + + print(" Converting HF -> TE") + model_te = convert_llama_hf_to_te(model_hf) + del model_hf + + te_dir = output_dir / "te_checkpoint" + te_dir.mkdir(parents=True, exist_ok=True) + model_te.save_pretrained(te_dir) + del model_te + + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) + tokenizer.save_pretrained(te_dir) + + config_path = te_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + config["auto_map"] = AUTO_MAP + with open(config_path, "w") as f: + json.dump(config, f, indent=2, sort_keys=True) + + shutil.copy(THIS_DIR / "modeling_llama_te.py", te_dir / "modeling_llama_te.py") + + print(f" TE checkpoint saved to {te_dir}") + return te_dir + + +def convert_te_to_vllm(te_checkpoint: Path, output_dir: Path) -> Path: + """Convert a TE checkpoint to standard HF format for vLLM serving.""" + print(f" Loading TE checkpoint: {te_checkpoint}") + model_te = NVLlamaForCausalLM.from_pretrained(te_checkpoint) + + # convert_llama_te_to_hf creates the target LlamaForCausalLM on the meta + # device in float32. The TE checkpoint may store weights in bfloat16 + # (typical for training), so we cast to float32 to match. + model_te = model_te.float() + + print(" Converting TE -> HF") + model_hf = convert_llama_te_to_hf(model_te) + del model_te + + output_dir.mkdir(parents=True, exist_ok=True) + model_hf.save_pretrained(output_dir) + del model_hf + + print(f" vLLM-ready checkpoint saved to {output_dir}") + return output_dir + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--te-checkpoint", + type=Path, + default=None, + help="Path to an NVLlamaForCausalLM checkpoint (e.g. from llama3_native_te recipe). " + "If omitted, a TE checkpoint is created from HuggingFace for demo purposes.", + ) + parser.add_argument( + "--output", type=Path, default=HF_CHECKPOINT_DIR, help="Output directory for the vLLM checkpoint." + ) + parser.add_argument( + "--tokenizer", + type=str, + default=HF_MODEL_ID, + help="Tokenizer to bundle with the checkpoint (HF model ID or local path).", + ) + args = parser.parse_args() + + # Phase 1: Obtain a TE checkpoint + if args.te_checkpoint is not None: + te_path = args.te_checkpoint + print(f"[1/2] Using existing TE checkpoint: {te_path}") + else: + print("[1/2] No TE checkpoint provided -- creating one from HuggingFace (demo mode)") + te_path = create_te_checkpoint(args.output.parent) + + # Phase 2: Convert TE -> HF for vLLM + print("[2/2] Converting TE checkpoint to vLLM-ready HF format") + convert_te_to_vllm(te_path, args.output) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + tokenizer.save_pretrained(args.output) + + print(f"\nDone. Serve with: vllm serve {args.output}") diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/install_vllm.sh b/bionemo-recipes/recipes/vllm_inference/llama3/install_vllm.sh new file mode 100755 index 0000000000..a761046837 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/install_vllm.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARCH="${1:-$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}.{cc[1]}')")}" +MAX_JOBS="${MAX_JOBS:-8}" +export UV_BREAK_SYSTEM_PACKAGES=1 + +echo "Building vLLM for CUDA arch: $ARCH (MAX_JOBS=$MAX_JOBS)" + +cd /workspace +if [ ! -d vllm ]; then + git clone --branch v0.15.1 --depth 1 https://github.com/vllm-project/vllm.git +fi +cd vllm +python use_existing_torch.py +TORCH_CUDA_ARCH_LIST="$ARCH" MAX_JOBS="$MAX_JOBS" \ + uv pip install -r requirements/build.txt --system +TORCH_CUDA_ARCH_LIST="$ARCH" MAX_JOBS="$MAX_JOBS" \ + uv pip install --no-build-isolation -e . --system +pip install --upgrade "transformers[torch]" + +echo "vLLM installed for arch $ARCH" diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/llama3_transformers_inference_example.py b/bionemo-recipes/recipes/vllm_inference/llama3/llama3_transformers_inference_example.py new file mode 100644 index 0000000000..12faf9aeeb --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/llama3_transformers_inference_example.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quick smoke-test: load the round-tripped checkpoint and generate text.""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +CHECKPOINT = "./llama3_hf_roundtrip_checkpoint" + +model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) +model.to("cuda") + +tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) +tokenizer.pad_token = tokenizer.eos_token + +inputs = tokenizer("The quick brown fox", return_tensors="pt") +inputs = {k: v.to("cuda") for k, v in inputs.items()} + +with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + +print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/llama3_vllm_inference_example.py b/bionemo-recipes/recipes/vllm_inference/llama3/llama3_vllm_inference_example.py new file mode 100644 index 0000000000..f47845fc8c --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/llama3_vllm_inference_example.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quick smoke-test: load the round-tripped checkpoint in vLLM and generate text.""" + +from vllm import LLM, SamplingParams + + +if __name__ == "__main__": + engine = LLM( + model="./llama3_hf_roundtrip_checkpoint", + runner="generate", + dtype="bfloat16", + ) + prompts = ["The quick brown fox"] + sampling_params = SamplingParams(max_tokens=16) + outputs = engine.generate(prompts, sampling_params) + print(outputs[0].outputs[0].text) diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/modeling_llama_te.py b/bionemo-recipes/recipes/vllm_inference/llama3/modeling_llama_te.py new file mode 100644 index 0000000000..033eb5ebe3 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/modeling_llama_te.py @@ -0,0 +1,481 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TransformerEngine-optimized Llama model.""" + +import warnings +from collections import OrderedDict +from typing import ClassVar, Unpack + +import torch +import torch.nn as nn +import transformer_engine.pytorch +import transformers +from transformer_engine.pytorch.attention import InferenceParams +from transformer_engine.pytorch.attention.inference import PagedKVCacheManager +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import LlamaConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +AUTO_MAP = { + "AutoConfig": "modeling_llama_te.NVLlamaConfig", + "AutoModel": "modeling_llama_te.NVLlamaModel", + "AutoModelForCausalLM": "modeling_llama_te.NVLlamaForCausalLM", + "AutoModelForSequenceClassification": "modeling_llama_te.NVLlamaForSequenceClassification", + "AutoModelForQuestionAnswering": "modeling_llama_te.NVLlamaForQuestionAnswering", + "AutoModelForTokenClassification": "modeling_llama_te.NVLlamaForTokenClassification", +} + + +class NVLlamaConfig(LlamaConfig): + """NVLlama configuration.""" + + # Attention input format: + # "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + + +class NVLlamaPreTrainedModel(PreTrainedModel): + """Base class for NVLlama models.""" + + config_class = NVLlamaConfig + base_model_prefix = "model" + _no_split_modules = ("TransformerLayer",) + _skip_keys_device_placement = ("past_key_values",) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + + self.model.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling this method for TE modules, since the default _init_weights will assume + # any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking `LayerNormLinear` + # and `LayerNormMLP` modules that use `weight` for the linear layer and `layer_norm_weight` for the layer + # norm. + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with + standard PyTorch/HuggingFace model loading. These are filtered out to ensure + checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVLlamaModel(NVLlamaPreTrainedModel): + """Llama3 model implemented in Transformer Engine.""" + + def __init__(self, config: LlamaConfig): + """Initialize the NVLlama model.""" + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + activation="swiglu", + attn_input_format=config.attn_input_format, + self_attn_mask_type=config.self_attn_mask_type, + num_gqa_groups=config.num_key_value_heads, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original + # LlamaRotaryEmbedding. + self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: InferenceParams | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass for the NVLlama model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + use_cache (bool): Whether to use cache. + **kwargs: Additional keyword arguments. + + Returns: + BaseModelOutputWithPast: The output of the model. + """ + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # TE-specific input handling. + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + # Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we + # dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output. + # This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused + # attention backend, but it should be faster for the flash attention backend. + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend) + attention_mask = ~attention_mask[:, None, None, :].bool() + + if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching. + # In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to + # compute the lengths of each sequence in the batch. + lengths = ( + attention_mask.sum(dim=1).tolist() + if attention_mask.shape == input_ids.shape + else [1] * input_ids.shape[0] + ) + past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) + + # Ensure that rotary embeddings are computed with at a higher precision + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + if te_rope_emb.dtype == torch.float32: + warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad + # these with the same _pad_input call as below if we wanted them returned in BSHD format. + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output. + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin): + """Llama3 model with causal language head.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config): + """Initialize the NVLlamaForCausalLM model.""" + super().__init__(config) + self.model = NVLlamaModel(config) + self.vocab_size = config.vocab_size + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + use_cache: bool | None = None, + cache_position: torch.Tensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + """Forward pass for the NVLlamaForCausalLM model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + labels (torch.Tensor): The labels. + shift_labels (torch.Tensor): Labels that have already been shifted by the dataloader, to be used instead of + labels for the loss function. For context parallelism, it is more reliable to shift the labels before + splitting the batch into shards. + use_cache (bool): Whether to use cache. + cache_position (torch.Tensor): The cache position. + logits_to_keep (int | torch.Tensor): Whether to keep only the last logits to reduce the memory footprint of + the model during generation. + **kwargs: Additional keyword arguments. + + Returns: + CausalLMOutputWithPast: The output of the model. + """ + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: # With THD inputs, batch and sequence dimensions are collapsed in the first dimension. + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NVLlamaForSequenceClassification( + transformers.modeling_layers.GenericForSequenceClassification, NVLlamaPreTrainedModel +): + """Llama3 model with sequence classification head.""" + + +class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestionAnswering, NVLlamaPreTrainedModel): + """Llama3 model with question answering head.""" + + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +class NVLlamaForTokenClassification( + transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel +): + """Llama3 model with token classification head.""" + + +torch._dynamo.config.capture_scalar_outputs = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to a BSHD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to a THD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: # Likely in generation mode with kv-caching + return ( + hidden_states.squeeze(1), # hidden_states + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), # indices + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), # cu_seqlens + 1, # max_seqlen + 1, # seqused + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +class HFInferenceParams(InferenceParams): + """Extension of the InferenceParams class to support HF generate() and beam search.""" + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Return the current cached sequence length. + + Required by HuggingFace transformers generate() to determine how many + tokens have already been cached. + """ + if not self.sequences: + return 0 + return max(self.sequences.values()) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache based on the beam indices.""" + if isinstance(self.cache_manager, PagedKVCacheManager): + raise NotImplementedError("Beam search is not supported for paged cache manager.") + for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items(): + updated_key_cache = key_cache.index_select(0, beam_idx) + updated_value_cache = value_cache.index_select(0, beam_idx) + self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/requirements.txt b/bionemo-recipes/recipes/vllm_inference/llama3/requirements.txt new file mode 100644 index 0000000000..ec6a547cb8 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/requirements.txt @@ -0,0 +1,5 @@ +lm-eval # For testing +torch +torchao!=0.14.0 +transformer_engine[pytorch] +transformers diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/state.py b/bionemo-recipes/recipes/vllm_inference/llama3/state.py new file mode 100644 index 0000000000..bda08c4d79 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/state.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion utilities adapted from nemo.lightning.io.state. + +This module provides the transform system used by convert.py to map state dicts between model formats: + +- ``mapping``: A dict of simple key renames (source_key -> target_key). Each source key is copied directly + to the corresponding target key with no modification to the tensor values. + +- ``transforms``: A list of ``StateDictTransform`` objects for multi-key merges and splits. These handle + cases where multiple source keys must be combined into one target key (e.g., merging Q/K/V into fused QKV), + or one source key must be split into multiple target keys. + + Important: When ``source_key`` is a tuple (many-to-one merge), the transform function's parameter names + are used to map each source key to a function argument. This means ``*args`` style parameters do not work; + each parameter must be explicitly named (e.g., ``def fn(q, k, v)`` not ``def fn(*args)``). +""" + +import inspect +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger(__name__) + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + """Transform Data class Definition.""" + + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +class _ModelState: + """Helper class for used for to modify state dict of a source model during model conversion.""" + + def __init__(self, state_dict, config=None): + self._state_dict = state_dict + self.config = config + + def state_dict(self): + # pylint: disable=C0115,C0116 + return self._state_dict + + def to(self, dtype): + # pylint: disable=C0115,C0116 + for k, v in self._state_dict.items(): + if v.dtype != dtype: + logger.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + +@torch.no_grad +def apply_transforms( + source: Union[nn.Module, _ModelState], + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, + state_dict_ignored_entries: Optional[List] = None, + cast_dtype: Optional[torch.dtype] = None, +) -> TargetModuleT: + """Transform the state dictionary of a source module to match the structure of a target module's state dictionary. + + This function renames keys according to a provided mapping and modifies values using a list + of transformation functions. Each transformation function typically is decorated + with `io.state_transform`. + + Args: + source (nn.Module): The source module from which parameters and buffers are taken. + target (TargetModuleT): The target module to which parameters and buffers are adapted. + mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary + is mapped to a corresponding key in the target state dictionary. + transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions + that modify the `TransformCTX` object. If None, no transformations beyond key renaming + are applied. Defaults to None. + state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases + where multiple entries in model's state_dict point to one entry in model's named_parameter. + E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, + `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight + in T5 Huggingface implementation.). In these cases, ignore redundant entries. + cast_dtype: case the output state dict to a certain precision. + + Returns: + TargetModuleT: The modified target module with its state dictionary adjusted according to + the specified mappings and transformations. + + Raises: + ValueError: If there's a mismatch in shape between corresponding source and target parameters + or buffers. + RuntimeError: If the target state dictionary contains keys that are not present in the source + state dictionary after all transformations. + + Examples: + >>> source_module = nn.Linear(10, 5) + >>> target_module = nn.Linear(10, 5) + >>> mapping = {'weight': 'weights', 'bias': 'biases'} + @io.state_transform( + source_key="weight", + target_key="weights" + ) + def scale_weights(ctx): + ctx.target_state['weights'] = ctx.source_state['weight'] * 2 + return ctx + >>> transformed_target = apply_transforms( + ... source_module, target_module, mapping, [scale_weights] + ... ) + >>> print(transformed_target.state_dict()['weights']) + + See Also: + - `TransformCTX`: For more details on the context object used in transformations. + - `StateDictTransform`: For creating complex transformations. + + Note: + This function is particularly useful when adapting models from different frameworks or + when consolidating models with different architectural changes. + """ + if transforms is None: + transforms = [] + if state_dict_ignored_entries is None: + state_dict_ignored_entries = [] + + # Track dtypes to make sure they weren't modified during conversion. + target_orig_dtypes = extract_dtypes(target.named_parameters()) + + target_state = target.state_dict() + ctx = TransformCTX( + source=source, + source_state=source.state_dict(), + target=target, + target_state=target_state, + ) + + for key, val in mapping.items(): + logger.debug(f"Mapping {key} -> {val}") + ctx = StateDictTransform(key, val)(ctx) + + for transform in transforms: + logger.debug(f"Transforming {transform.source_key} -> {transform.target_key}") + ctx = transform(ctx) + + _params: Dict[str, nn.Parameter] = {} + for name, param in target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError( + f"Shape mismatch for parameter {name}: target shape {param.shape} vs " + f"converted source shape {target_param.shape}" + ) + + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in target model but is in source model.") + + for key, val in _params.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_parameter(_key, val) + + _buffers = {} + for name, buffer in target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_buffer(_key, val) + + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) + keys = [key for key in keys if key not in state_dict_ignored_entries] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.") + + if hasattr(target, "tie_weights"): + target.tie_weights() + + meta_tensor_keys = [] + for name, param in target.named_parameters(): + if param.is_meta: + meta_tensor_keys.append(name) + + assert not meta_tensor_keys, ( + f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." + f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" + ) + + if cast_dtype: + logger.info(f"Casting model to {cast_dtype}...") + target.to(cast_dtype) + logger.info(f"Casting model to {cast_dtype} complete.") + else: + target_new_dtypes = extract_dtypes(target.named_parameters()) + for key in target_orig_dtypes.keys(): + if key in target_new_dtypes: # For tied weights, these parameters may disappear. + assert target_orig_dtypes[key] == target_new_dtypes[key], ( + f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}" + ) + + return target + + +def _default_transform(inp): + return inp + + +class StateDictTransform(Generic[F]): + """A transformation class for state dictionaries. + + Allows for flexible key matching and transformation of values between source and target state dictionaries. + + Attributes: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + transform: A callable that performs the transformation on matched keys' values. + + Examples: + >>> def example_transform(ctx, *args): + ... return sum(args) + >>> transform = StateDictTransform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", + ... transform=example_transform + ... ) + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + """Initialize the StateDictTransform.""" + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + """Perform the transformation on the given context.""" + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + np.set_printoptions(threshold=10) + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + matched = False + if isinstance(source_key, (dict, tuple)): + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): + # Wrap in a list if it's a single layer (ie non-expert) + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] # noqa: PLW2901 + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) + logger.debug(f"Matched (transform)! {layer_names_group=}") + matched = True + else: + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size == 1 and target_matches == np.array(None): + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + _matches = [_match_keys(target_keys, key) for key in target_key] + target_matches = np.stack(_matches, axis=-1) + + # Determine if we are dealing with multiple source matches or multiple target matches + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + try: + source_match = source_matches[target_index] + except IndexError as e: + logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}") + raise e + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + logger.debug(f"Matched (multi source)! {target_match=} {source_match=}") + matched = True + else: + for source_index, source_match in np.ndenumerate(source_matches): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = dict(zip(fn_params, source_values)) + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + logger.debug(f"Matched (single source)! {target_match=} {source_match=}") + matched = True + if not matched: + logger.warning(f"No matches found for source key: {source_key=} {target_key=}") + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + """Perform transform and check if the given args valid.""" + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ["self", "ctx"]]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if "ctx" in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + escaped_pattern = "" + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == "**": + escaped_pattern += r"(.+)" # Match any characters including dots + wildcard_positions.append("**") + i += 2 + elif pattern[i] == "*": + escaped_pattern += r"([^.]+)" # Match any characters except dots + wildcard_positions.append("*") + i += 1 + else: + if pattern[i] == ".": + escaped_pattern += r"\." # Escape the dot + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + # Sort the wildcard matches to maintain consistent ordering + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + # Determine the shape of the output array based on the unique matches for each wildcard + shape = [len(matches) for matches in wildcard_matches] + + if len(wildcard_matches) == 0: + # If there is no wildcard matches, assuming it is a single match + shape = [1] + # Initialize an empty array with the determined shape + output_array = np.empty(shape, dtype=object) + + # Populate the array with the keys, now that we have the correct shape and ordering + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + # Convert match groups to indices based on their position in wildcard_matches + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key # Place the key in the array based on the indices + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """Create a StateDictTransform instance with specified source and target keys, and a transformation function. + + Args: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + fn: An optional callable that performs the transformation on matched keys' values. If not + provided, the decorator can be used to wrap a function definition. + + Returns: + ------- + A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that + takes a function and returns a StateDictTransform instance. + + Examples: + -------- + >>> @state_transform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" + ... ) + ... def sum_transform(ctx, *args): + ... return sum(args) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + + return wrapper(fn) + + +class TransformFns: + """A collection of common functions used in state dict transformation.""" + + @staticmethod + def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): + """Split interleave-concatenated qkv to q, k, v. + + Example: export layer linear_qkv to HF {q|k|v}_proj + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + hidden_size = linear_qkv.size(-1) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + @staticmethod + def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Split interleave-concatenated qkv bias to separate q, k, v bias. + + Example: export layer linear_qkv bias to HF {q|k|v}_proj bias + """ + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + @staticmethod + def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor): + """Merge naively concatenated q, k, v to interleave-concatenated qkv. + + Example: import HF qkv to layer linear_qkv + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0) + return TransformFns.merge_qkv(ctx, q, k, v) + + @staticmethod + def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge q, k, v to interleave-concatenated qkv. + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + @staticmethod + def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF qkv bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + qb, kb, vb = qkv_bias.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + return TransformFns.merge_qkv_bias(ctx, qb, kb, vb) + + @staticmethod + def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): + """Merge q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF {q|k|v}_proj bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + + new_q_tensor_shape = (head_num, head_size) + new_kv_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_tensor_shape) + kb = kb.view(*new_kv_tensor_shape) + vb = vb.view(*new_kv_tensor_shape) + + qkv_bias = torch.empty((0, head_size)).type_as(qb) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) + return qkv_bias + + @staticmethod + def merge_fc1(gate: torch.Tensor, up: torch.Tensor): + """Merge gate and up proj into concatenated fc1. + + Example: import HF {gate|up}_proj to layer linear_fc1 + """ + return torch.cat((gate, up), dim=0) + + @staticmethod + def split_fc1(linear_fc1: torch.Tensor): + """Split concatenated fc1 to gate and up proj. + + Example: export layer linear_fc1 to HF {gate|up}_proj + """ + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + return gate_proj, up_proj + + @staticmethod + def duplicate2(param: torch.Tensor): + """Duplicate the source parameter to two target parameters. + + Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A + """ + return param, param + + @staticmethod + def duplicate3(param: torch.Tensor): + """Duplicate the source parameter to three target parameters. + + Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A + """ + return param, param, param + + @staticmethod + def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): + """Prune the embedding size to vocab size. + + Example: export embedding/output layer to HF with non-padded vocab size + """ + megatron_config = ctx.target.config + return embedding[: megatron_config.vocab_size, :] + + +def extract_dtypes(ckpt): + """Extract dtype from the input iterator. + + ckpt can be module.named_parameters or module.state_dict().items() + """ + dtypes = {} + for key, val in ckpt: + if hasattr(val, "dtype"): + dtypes[key] = val.dtype + elif hasattr(val, "data") and hasattr(val.data, "dtype"): + # if it's ShardedTensor populated with data. + dtypes[key] = val.data.dtype + return dtypes diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/tests/conftest.py b/bionemo-recipes/recipes/vllm_inference/llama3/tests/conftest.py new file mode 100644 index 0000000000..69652d6457 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/tests/conftest.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + + +sys.path.append(Path(__file__).parent.parent.as_posix()) diff --git a/bionemo-recipes/recipes/vllm_inference/llama3/tests/test_vllm.py b/bionemo-recipes/recipes/vllm_inference/llama3/tests/test_vllm.py new file mode 100644 index 0000000000..1545e9361b --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/llama3/tests/test_vllm.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden-value tests for the Llama-3 train-then-serve workflow. + +Validates that a round-tripped (HF -> TE -> HF) checkpoint served via vLLM +produces the same outputs as the original HuggingFace model. + +Three tests: +- **test_greedy_text_match**: greedy-decoded text must be identical. +- **test_logprob_similarity**: per-token log-probabilities must be close. +- **test_top_token_overlap**: top-K most likely tokens must overlap at every step. +""" + +import numpy as np +import pytest +import torch +from torch.nn import functional as f +from transformers import AutoModelForCausalLM, AutoTokenizer + + +try: + from vllm import LLM, SamplingParams + + _VLLM_AVAILABLE = True +except ImportError: + _VLLM_AVAILABLE = False + +from export_llama3 import HF_MODEL_ID, convert_te_to_vllm, create_te_checkpoint + + +PROMPTS = [ + "The quick brown fox", + "In a hole in the ground there lived", +] +MAX_TOKENS = 16 +TOP_K = 10 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def vllm_checkpoint(tmp_path_factory): + """Create the round-tripped HF checkpoint (HF -> TE -> HF) once per module.""" + base_dir = tmp_path_factory.mktemp("llama3_export") + te_path = create_te_checkpoint(base_dir) + + hf_path = base_dir / "hf_roundtrip" + convert_te_to_vllm(te_path, hf_path) + + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) + tokenizer.save_pretrained(hf_path) + + return str(hf_path) + + +@pytest.fixture(scope="module") +def hf_reference_outputs(): + """Run HF reference model: greedy text + per-token log-probs for each prompt.""" + model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.bfloat16).to("cuda").eval() + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) + tokenizer.pad_token = tokenizer.eos_token + + results = [] + with torch.no_grad(): + for prompt in PROMPTS: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + output_ids = model.generate(**inputs, max_new_tokens=MAX_TOKENS, do_sample=False, use_cache=False) + text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generated_ids = output_ids[0, inputs["input_ids"].shape[1] :] + log_probs = [] + top_k_ids = [] + current_ids = inputs["input_ids"] + for token_id in generated_ids: + outputs = model(current_ids) + logits = outputs.logits[0, -1, :] + lp = f.log_softmax(logits.float(), dim=-1) + log_probs.append(lp[token_id].cpu().item()) + top_k_ids.append(torch.topk(logits, TOP_K).indices.cpu().tolist()) + current_ids = torch.cat([current_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=1) + + results.append( + { + "text": text, + "log_probs": np.array(log_probs), + "token_ids": generated_ids.cpu().tolist(), + "top_k_ids": top_k_ids, + } + ) + + del model, tokenizer + torch.cuda.empty_cache() + return results + + +@pytest.fixture(scope="module") +def vllm_outputs(vllm_checkpoint): + """Run vLLM on the exported checkpoint: greedy text + per-token log-probs + top-K.""" + engine = LLM(model=vllm_checkpoint, runner="generate", dtype="bfloat16") + params = SamplingParams(max_tokens=MAX_TOKENS, temperature=0, logprobs=TOP_K) + raw_outputs = engine.generate(PROMPTS, params) + + results = [] + for prompt, output in zip(PROMPTS, raw_outputs): + text = prompt + output.outputs[0].text + token_ids = list(output.outputs[0].token_ids) + log_probs = [] + top_k_ids = [] + for tid, step in zip(token_ids, output.outputs[0].logprobs): + log_probs.append(step[tid].logprob) + top_k_ids.append(sorted(step.keys(), key=lambda t: step[t].logprob, reverse=True)) + results.append( + { + "text": text, + "log_probs": np.array(log_probs), + "token_ids": token_ids, + "top_k_ids": top_k_ids, + } + ) + + del engine + torch.cuda.empty_cache() + return results + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="vllm not installed") +def test_greedy_text_match(hf_reference_outputs, vllm_outputs): + """Greedy-decoded text from the round-tripped checkpoint must match the reference.""" + for i, prompt in enumerate(PROMPTS): + hf_text = hf_reference_outputs[i]["text"] + vllm_text = vllm_outputs[i]["text"] + assert hf_text == vllm_text, f"Prompt {i} ({prompt!r}):\n HF: {hf_text!r}\n vLLM: {vllm_text!r}" + + +@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="vllm not installed") +def test_logprob_similarity(hf_reference_outputs, vllm_outputs): + """Per-token log-probabilities must be close between HF reference and vLLM.""" + atol = 0.1 + + for i, prompt in enumerate(PROMPTS): + hf_lp = hf_reference_outputs[i]["log_probs"] + vllm_lp = vllm_outputs[i]["log_probs"] + hf_ids = hf_reference_outputs[i]["token_ids"] + vllm_ids = vllm_outputs[i]["token_ids"] + + n = min(len(hf_lp), len(vllm_lp)) + assert n > 0, f"Prompt {i}: no tokens generated" + + assert hf_ids[:n] == vllm_ids[:n], ( + f"Prompt {i} ({prompt!r}): token ID mismatch\n HF: {hf_ids[:n]}\n vLLM: {vllm_ids[:n]}" + ) + + max_diff = float(np.abs(hf_lp[:n] - vllm_lp[:n]).max()) + mean_diff = float(np.abs(hf_lp[:n] - vllm_lp[:n]).mean()) + assert max_diff < atol, ( + f"Prompt {i} ({prompt!r}): log-prob max |diff| = {max_diff:.6f} exceeds atol={atol}\n" + f" HF log-probs: {hf_lp[:n]}\n" + f" vLLM log-probs: {vllm_lp[:n]}" + ) + print(f" Prompt {i}: max |diff| = {max_diff:.6f}, mean |diff| = {mean_diff:.6f}") + + +@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="vllm not installed") +def test_top_token_overlap(hf_reference_outputs, vllm_outputs): + """Top-K most likely tokens must overlap at every generation step. + + Unlike atol-based log-prob checks, this is naturally robust to bfloat16 + numerical noise: small logit perturbations only affect the ranking at + tie boundaries, so the top-K set is stable. + """ + min_overlap = 0.9 + + for i, prompt in enumerate(PROMPTS): + hf_topk = hf_reference_outputs[i]["top_k_ids"] + vllm_topk = vllm_outputs[i]["top_k_ids"] + n = min(len(hf_topk), len(vllm_topk)) + + for step in range(n): + hf_set = set(hf_topk[step][:TOP_K]) + vllm_set = set(vllm_topk[step][:TOP_K]) + overlap = len(hf_set & vllm_set) / TOP_K + assert overlap >= min_overlap, ( + f"Prompt {i} ({prompt!r}), step {step}: " + f"top-{TOP_K} overlap = {overlap:.0%} < {min_overlap:.0%}\n" + f" HF top-{TOP_K}: {sorted(hf_set)}\n" + f" vLLM top-{TOP_K}: {sorted(vllm_set)}" + ) + + overlaps = [] + for step in range(n): + hf_set = set(hf_topk[step][:TOP_K]) + vllm_set = set(vllm_topk[step][:TOP_K]) + overlaps.append(len(hf_set & vllm_set) / TOP_K) + print(f" Prompt {i}: mean top-{TOP_K} overlap = {np.mean(overlaps):.0%}, min = {np.min(overlaps):.0%}") diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 8c952432df..f801430e6a 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -46,13 +46,21 @@ "bionemo-recipes/models/amplify/src/amplify/state.py", "bionemo-recipes/models/llama3/state.py", "bionemo-recipes/models/mixtral/state.py", + "bionemo-recipes/recipes/vllm_inference/llama3/state.py", ], "bionemo-recipes/models/llama3/modeling_llama_te.py": [ "bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py", + "bionemo-recipes/recipes/vllm_inference/llama3/modeling_llama_te.py", ], "bionemo-recipes/models/llama3/nucleotide_fast_tokenizer": [ "bionemo-recipes/recipes/llama3_native_te/tokenizers/nucleotide_fast_tokenizer", ], + "bionemo-recipes/models/llama3/convert.py": [ + "bionemo-recipes/recipes/vllm_inference/llama3/convert.py", + ], + "bionemo-recipes/models/llama3/requirements.txt": [ + "bionemo-recipes/recipes/vllm_inference/llama3/requirements.txt", + ], # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common",