From 3b08beda172f3972792010573cf6c3dff962a4a0 Mon Sep 17 00:00:00 2001 From: gramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 08:38:20 -0700 Subject: [PATCH 1/8] [#12332][feat] AutoDeploy: SuperV3 MTP Support (#12326) Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../model_registry/configs/super_v3_mtp.yaml | 16 + .../auto_deploy/model_registry/models.yaml | 2 + .../custom_ops/attention_interface.py | 73 +++ .../mamba/flashinfer_backend_mamba.py | 13 +- .../custom_ops/mamba/mamba_backend_common.py | 112 ++++- .../mamba/triton_backend_causal_conv.py | 82 +++- .../custom_ops/mamba/triton_backend_mamba.py | 114 ++++- tensorrt_llm/_torch/auto_deploy/llm_args.py | 36 +- .../models/custom/modeling_eagle.py | 414 ++++++++++++++---- .../models/custom/modeling_nemotron_h.py | 130 +++++- .../_torch/auto_deploy/models/eagle.py | 102 +++-- .../_torch/auto_deploy/shim/ad_executor.py | 9 +- .../_torch/auto_deploy/shim/interface.py | 157 +++++-- .../transform/library/collectives.py | 7 + .../transform/library/hidden_states.py | 20 +- .../auto_deploy/transform/library/sharding.py | 12 + tensorrt_llm/llmapi/llm_args.py | 2 +- .../defs/accuracy/references/gsm8k.yaml | 2 + .../defs/accuracy/test_llm_api_autodeploy.py | 99 ++++- .../examples/test_ad_speculative_decoding.py | 218 ++++++++- .../test_lists/qa/llm_function_core.txt | 1 + .../qa/llm_function_core_sanity.txt | 1 + .../test_lists/test-db/l0_dgx_b200.yml | 1 + .../test_lists/test-db/l0_h100.yml | 1 + .../_utils_test/_model_test_utils.py | 24 + .../mamba/test_flashinfer_mamba_cached_op.py | 1 + .../mamba/test_triton_mamba_cached_op.py | 2 + .../custom_ops/test_resource_handlers.py | 63 ++- .../test_triton_causal_conv_cached_op.py | 4 + .../singlegpu/models/test_eagle.py | 8 +- .../shim/test_cached_sequence_interface.py | 72 +++ .../smoke/test_ad_speculative_decoding.py | 203 ++++++++- 32 files changed, 1713 insertions(+), 288 deletions(-) create mode 100644 examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml diff --git a/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml new file mode 100644 index 000000000000..053d2598fab8 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml @@ -0,0 +1,16 @@ +# Config for SuperV3 with MTP speculative decoding (requires triton SSM/conv backends). +# TODO: Replace with sharding and footprint-improving transforms in the style of super_v3.yaml. +compile_backend: torch-simple +attn_backend: flashinfer +max_batch_size: 128 +max_seq_len: 16384 +max_num_tokens: 16384 +speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 7 + mtp_eagle_one_model: true +transforms: + insert_cached_ssm_attention: + backend: triton_ssm + insert_cached_causal_conv: + backend: triton_causal_conv diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index afb3df48eaee..6e65ebee22df 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -201,6 +201,8 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_8.yaml', 'multimodal.yaml', 'llama4_maverick_lite.yaml'] - name: nvidia/NVIDIA-Nemotron-3-Super-120B-BF16-BF16KV-010726 yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml','super_v3.yaml'] +- name: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16 + yaml_extra: ['dashboard_default.yaml', 'world_size_4.yaml', 'super_v3_mtp.yaml'] - name: zai-org/GLM-4.7-Flash yaml_extra: ['glm-4.7-flash.yaml'] - name: Nanbeige/Nanbeige4.1-3B diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 5e3f3ca8faa1..748ce06ba22b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -1686,6 +1686,79 @@ def state_shape(self) -> Tuple[int, int]: return (self.conv_dim, self.d_conv - 1) +class SpecSSMResourceHandler(StateResourceHandler): + """Intermediate SSM state cache descriptor for speculative decoding. + + Acts as a type marker conveying the per-layer SSM shape to the cache interface. + The actual buffer shape (including cache_steps = max_draft_len + 1) is determined + by the MambaHybridCacheManager using spec_config, not by this handler. + + Inherits from StateResourceHandler (not SSMResourceHandler) so that + isinstance(h, SSMResourceHandler) returns False for spec handlers, eliminating + the need for exclusion guards throughout the codebase. + """ + + def __init__( + self, + num_heads: int, + head_dim: int, + d_state: int, + dtype: torch.dtype, + ) -> None: + self.num_heads = num_heads + self.head_dim = head_dim + self.d_state = d_state + super().__init__(dtype=dtype) + + @property + def state_shape(self) -> Tuple[int, int, int]: + return (self.num_heads, self.head_dim, self.d_state) + + @classmethod + def from_base(cls, base: Optional["SSMResourceHandler"]) -> Optional["SpecSSMResourceHandler"]: + """Create a spec handler from a base SSM handler, or return None.""" + if base is None: + return None + return cls( + num_heads=base.num_heads, head_dim=base.head_dim, d_state=base.d_state, dtype=base.dtype + ) + + +class SpecCausalConvResourceHandler(StateResourceHandler): + """Intermediate conv state cache descriptor for speculative decoding. + + Acts as a type marker conveying the per-layer conv shape to the cache interface. + The actual buffer shape (including cache_steps = max_draft_len + 1) is determined + by the MambaHybridCacheManager using spec_config, not by this handler. + + Inherits from StateResourceHandler (not CausalConvResourceHandler) so that + isinstance(h, CausalConvResourceHandler) returns False for spec handlers. + """ + + def __init__( + self, + conv_dim: int, + d_conv: int, + dtype: torch.dtype, + ) -> None: + self.conv_dim = conv_dim + self.d_conv = d_conv + super().__init__(dtype=dtype) + + @property + def state_shape(self) -> Tuple[int, int]: + return (self.conv_dim, self.d_conv - 1) + + @classmethod + def from_base( + cls, base: Optional["CausalConvResourceHandler"] + ) -> Optional["SpecCausalConvResourceHandler"]: + """Create a spec handler from a base conv handler, or return None.""" + if base is None: + return None + return cls(conv_dim=base.conv_dim, d_conv=base.d_conv, dtype=base.dtype) + + class UnpagedResourceHandler(ResourceHandler): """Handler for per-token unpaged resources (e.g., unpaged KV caches). diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py index c3e1a00c91c3..60acdbd43056 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py @@ -60,9 +60,9 @@ def _flashinfer_cached_ssm( ) ssm_state_size = B.shape[3] batch_info = BatchInfo(batch_info_host) - num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() - num_seq = num_prefill + num_decode - num_total_tokens = num_prefill_tokens + num_decode + num_prefill, _, num_decode = batch_info.get_num_sequences() + num_prefill_tokens, _, num_decode_tokens = batch_info.get_num_tokens() + num_total_tokens = num_prefill_tokens + num_decode_tokens if out is not None: preallocated_ssm_out = out.view(bs, num_heads, head_dim) else: @@ -72,7 +72,7 @@ def _flashinfer_cached_ssm( device=hidden_states.device, ) - num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill( + _run_ssm_prefill( hs_flat, B_flat, C_flat, @@ -94,7 +94,6 @@ def _flashinfer_cached_ssm( preallocated_ssm_out[:num_prefill_tokens].unsqueeze(0), ) - num_decode = num_total_tokens - num_prefill_tokens decode_inputs = _prepare_ssm_decode_inputs( hs_flat, B_flat, @@ -106,8 +105,8 @@ def _flashinfer_cached_ssm( slot_idx, num_prefill, num_prefill_tokens, - num_seq, - num_total_tokens, + num_decode, + num_decode_tokens, num_heads, head_dim, ssm_state_size, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py index cbbb5f8231d0..37245cebc193 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -51,7 +51,8 @@ def _mamba_ssm_prepare_metadata( """ device = cu_seqlen.device batch_info = BatchInfo(batch_info_host) - num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() + + num_prefill, _, _ = batch_info.get_num_sequences() if num_prefill > 0: chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( @@ -132,14 +133,13 @@ def _run_ssm_prefill( time_step_limit: List[float], chunk_size: int, out: Optional[torch.Tensor] = None, -) -> Tuple[Optional[torch.Tensor], int, int, int, int]: +): batch_info = BatchInfo(batch_info_host) - num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() - num_seq = num_prefill + num_decode - num_total_tokens = num_prefill_tokens + num_decode + num_prefill, _, _ = batch_info.get_num_sequences() + num_prefill_tokens, _, _ = batch_info.get_num_tokens() if num_prefill <= 0: - return num_prefill, num_prefill_tokens, num_total_tokens, num_seq + return hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] @@ -186,7 +186,6 @@ def _run_ssm_prefill( ssm_state_cache.index_copy_( 0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype) ) - return num_prefill, num_prefill_tokens, num_total_tokens, num_seq def _prepare_ssm_decode_inputs( @@ -198,10 +197,10 @@ def _prepare_ssm_decode_inputs( D: torch.Tensor, dt_bias: torch.Tensor, slot_idx: torch.Tensor, - num_prefill: int, - num_prefill_tokens: int, - num_seq: int, - num_total_tokens: int, + decode_seq_start: int, + decode_token_start: int, + num_decode: int, + num_decode_tokens: int, num_heads: int, head_dim: int, ssm_state_size: int, @@ -217,22 +216,89 @@ def _prepare_ssm_decode_inputs( torch.Tensor, ] ]: - num_decode = num_total_tokens - num_prefill_tokens - if num_decode <= 0: + grouped_inputs = _prepare_ssm_grouped_state_update_inputs( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + slot_idx, + seq_start=decode_seq_start, + token_start=decode_token_start, + num_seq=num_decode, + num_tokens=num_decode_tokens, + num_heads=num_heads, + head_dim=head_dim, + ssm_state_size=ssm_state_size, + ) + if grouped_inputs is None: return None + ( + slot_idx_decode, + x_decode_g, + B_decode_g, + C_decode_g, + dt_hp_g, + A_full, + D_full, + dt_bias_hp, + ) = grouped_inputs + + # Reshape from [num_decode, 1, ...] to [num_decode, ...] + x_decode = x_decode_g.reshape(num_decode, num_heads, head_dim) + B_decode = B_decode_g.reshape(num_decode, B_flat.shape[1], ssm_state_size) + C_decode = C_decode_g.reshape(num_decode, C_flat.shape[1], ssm_state_size) + dt_hp = dt_hp_g.reshape(num_decode, num_heads, head_dim) - slot_idx_decode = slot_idx[num_prefill:num_seq] - x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D] - B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] - C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N] - dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H] + return slot_idx_decode, x_decode, B_decode, C_decode, dt_hp, dt_bias_hp, A_full, D_full - dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) - dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) + +def _prepare_ssm_grouped_state_update_inputs( + hs_flat: torch.Tensor, + B_flat: torch.Tensor, + C_flat: torch.Tensor, + dt_flat: torch.Tensor, + A: torch.Tensor, + D: torch.Tensor, + dt_bias: torch.Tensor, + slot_idx: torch.Tensor, + seq_start: int, + token_start: int, + num_seq: int, + num_tokens: int, + num_heads: int, + head_dim: int, + ssm_state_size: int, +) -> Optional[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +]: + if num_seq <= 0 or num_tokens <= 0: + return None + seq_len = num_tokens // num_seq + + seq_end = seq_start + num_seq + token_end = token_start + num_tokens + slot_idx_slice = slot_idx[seq_start:seq_end] + x_slice = hs_flat[token_start:token_end].view(num_seq, seq_len, num_heads, head_dim) + B_slice = B_flat[token_start:token_end].view(num_seq, seq_len, B_flat.shape[1], ssm_state_size) + C_slice = C_flat[token_start:token_end].view(num_seq, seq_len, C_flat.shape[1], ssm_state_size) + dt_slice = dt_flat[token_start:token_end].view(num_seq, seq_len, num_heads) + dt_hp_slice = dt_slice[..., None].expand(num_seq, seq_len, num_heads, head_dim) A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) D_full = D[..., None].expand(num_heads, head_dim) - - return slot_idx_decode, x_decode, B_decode, C_decode, dt_hp, dt_bias_hp, A_full, D_full + dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) + return slot_idx_slice, x_slice, B_slice, C_slice, dt_hp_slice, A_full, D_full, dt_bias_hp class BaseBackendSSM(AttentionDescriptor): diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py index ca877a2789c3..e90245e566ab 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_causal_conv.py @@ -33,11 +33,19 @@ causal_conv1d_update, ) -from ..attention_interface import AttentionRegistry, BatchInfo, MHACallable +from ..attention_interface import ( + AttentionRegistry, + BatchInfo, + MHACallable, + SpecCausalConvResourceHandler, +) from .causal_conv_common import BaseCausalConvDescriptor -@torch.library.custom_op("auto_deploy::triton_cached_causal_conv1d", mutates_args={"input"}) +@torch.library.custom_op( + "auto_deploy::triton_cached_causal_conv1d", + mutates_args=("input", "conv_state_cache", "intermediate_conv_state_cache"), +) def _triton_cached_causal_conv1d( # INPUTS (dense but may be flattened across sequences) input: torch.Tensor, # [b, s, c_in] @@ -53,6 +61,9 @@ def _triton_cached_causal_conv1d( # # CACHES conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] + intermediate_conv_state_cache: Optional[ + torch.Tensor + ], # [spec_state_size, max_draft_len+1, c_in, k-1] # CONSTANTS stride: int, padding: int, @@ -73,9 +84,10 @@ def _triton_cached_causal_conv1d( b, s = input.shape[:2] batch_info = BatchInfo(batch_info_host) - num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() - num_seq = num_prefill + num_decode - num_total_tokens = num_prefill_tokens + num_decode + num_prefill, num_extend, num_decode = batch_info.get_num_sequences() + num_prefill_tokens, num_extend_tokens, num_decode_tokens = batch_info.get_num_tokens() + num_seq = num_prefill + num_extend + num_decode + num_total_tokens = num_prefill_tokens + num_extend_tokens + num_decode_tokens # Flatten tokens bs = b * s @@ -113,9 +125,52 @@ def _triton_cached_causal_conv1d( # Scatter outputs back to input buffer inp_flat[:num_prefill_tokens] = y_varlen.transpose(0, 1) + # EXTEND: use the update kernel so extend tokens write the intermediate state cache. + if num_extend > 0: + # num_extend_tokens == num_extend * (max_draft_len + 1) for static draft lengths + # (dynamic lengths not supported) + tokens_per_extend = num_extend_tokens // num_extend + if intermediate_conv_state_cache.size(1) < tokens_per_extend: + raise RuntimeError( + "triton_cached_causal_conv1d received an intermediate_conv_state_cache " + "that is too small for the extend branch" + ) + + slot_idx_extend = slot_idx[num_prefill : num_prefill + num_extend].to(torch.int32) + + # The intermediate state cache will be stored in these indices and read by the mamba_cache_manager, + # which expects them in the indices arange(num_extend). They are not used across requests, so we + # do not need consistent slot indices. + intermediate_state_indices = torch.arange( + num_extend, dtype=torch.int32, device=slot_idx_extend.device + ) + + x_extend = ( + inp_flat[num_prefill_tokens : num_prefill_tokens + num_extend_tokens] + .view(num_extend, tokens_per_extend, -1) + .transpose(1, 2) + ) + y_extend = causal_conv1d_update( + x_extend, + conv_state_cache, + w2d, + bias, + activation=activation, + cache_seqlens=None, + conv_state_indices=slot_idx_extend, + intermediate_conv_window=intermediate_conv_state_cache, + intermediate_state_indices=intermediate_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + inp_flat[num_prefill_tokens : num_prefill_tokens + num_extend_tokens] = y_extend.transpose( + 1, 2 + ).view(-1, inp_flat.shape[1]) + # DECODE: batch update for single-token sequences if num_decode > 0: - x_decode = inp_flat[num_prefill_tokens:num_total_tokens] # [num_decode, C_in] + x_decode = inp_flat[ + num_prefill_tokens + num_extend_tokens : num_total_tokens + ] # [num_decode_tokens, C_in] # Note: Triton causal_conv1d_update returns a new tensor (not in-place like CUDA version) # so we need to capture the output and write it back @@ -126,10 +181,10 @@ def _triton_cached_causal_conv1d( bias, activation=activation, cache_seqlens=None, - conv_state_indices=slot_idx[num_prefill:num_seq].to(torch.int32), + conv_state_indices=slot_idx[num_prefill + num_extend : num_seq].to(torch.int32), pad_slot_id=PAD_SLOT_ID, ) - inp_flat[num_prefill_tokens:num_total_tokens] = y_decode + inp_flat[num_prefill_tokens + num_extend_tokens : num_total_tokens] = y_decode # Zero padding positions beyond valid tokens (for piecewise CUDA graph) if num_total_tokens < bs: @@ -152,6 +207,9 @@ def _triton_cached_causal_conv1d_fake( # # CACHES conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] + intermediate_conv_state_cache: Optional[ + torch.Tensor + ], # [spec_state_size, max_draft_len+1, c_in, k-1] # CONSTANTS stride: int, padding: int, @@ -176,6 +234,14 @@ class TritonBackendCausalConv(BaseCausalConvDescriptor): Overrides get_standard_metadata_args to include seq_len (used directly by Triton kernel). """ + @classmethod + def get_cache_initializers(cls, source_attn_node, cache_config): + ret = super().get_cache_initializers(source_attn_node, cache_config) + ret["intermediate_conv_state_cache"] = SpecCausalConvResourceHandler.from_base( + ret["conv_state_cache"] + ) + return ret + @classmethod def get_standard_metadata_args(cls) -> List[str]: return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 27ea9e55d38a..75a9c7010263 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,16 +19,20 @@ from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update -from ..attention_interface import AttentionRegistry, BatchInfo, MHACallable +from ..attention_interface import AttentionRegistry, BatchInfo, MHACallable, SpecSSMResourceHandler from .mamba_backend_common import ( BaseBackendSSM, _flatten_ssm_inputs, _prepare_ssm_decode_inputs, + _prepare_ssm_grouped_state_update_inputs, _run_ssm_prefill, ) -@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args=("ssm_state_cache",)) +@torch.library.custom_op( + "auto_deploy::triton_cached_ssm", + mutates_args=("ssm_state_cache", "intermediate_ssm_state_cache"), +) def _triton_cached_ssm( # INPUTS (dense but may be flattened across sequences) hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] @@ -50,6 +54,9 @@ def _triton_cached_ssm( seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + intermediate_ssm_state_cache: Optional[ + torch.Tensor + ], # [spec_state_size, max_draft_len+1, num_heads, head_dim, d_state] # CONSTANTS time_step_limit: List[float], chunk_size: int, @@ -60,10 +67,9 @@ def _triton_cached_ssm( ) ssm_state_size = B.shape[3] batch_info = BatchInfo(batch_info_host) - num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() - num_seq = num_prefill + num_decode - num_total_tokens = num_prefill_tokens + num_decode - + num_prefill, num_extend, num_decode = batch_info.get_num_sequences() + num_prefill_tokens, num_extend_tokens, num_decode_tokens = batch_info.get_num_tokens() + num_total_tokens = num_prefill_tokens + num_extend_tokens + num_decode_tokens if out is not None: preallocated_ssm_out = out.view(bs, num_heads, head_dim) else: @@ -73,7 +79,11 @@ def _triton_cached_ssm( device=hidden_states.device, ) - num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill( + preallocated_ssm_out_e = preallocated_ssm_out[ + num_prefill_tokens : num_prefill_tokens + num_extend_tokens + ] + + _run_ssm_prefill( hs_flat, B_flat, C_flat, @@ -95,7 +105,72 @@ def _triton_cached_ssm( preallocated_ssm_out[:num_prefill_tokens].unsqueeze(0), ) - num_decode = num_total_tokens - num_prefill_tokens + # EXTEND: use the update kernel so extend tokens write the intermediate state cache. + extend_inputs = _prepare_ssm_grouped_state_update_inputs( + hs_flat, + B_flat, + C_flat, + dt_flat, + A, + D, + dt_bias, + slot_idx, + seq_start=num_prefill, + token_start=num_prefill_tokens, + num_seq=num_extend, + num_tokens=num_extend_tokens, + num_heads=num_heads, + head_dim=head_dim, + ssm_state_size=ssm_state_size, + ) + if extend_inputs is not None: + tokens_per_extend = num_extend_tokens // num_extend + if intermediate_ssm_state_cache.size(1) < tokens_per_extend: + raise RuntimeError( + "triton_cached_ssm received an intermediate_ssm_state_cache " + "that is too small for the extend branch" + ) + + ( + slot_idx_extend, + x_extend, + B_extend, + C_extend, + dt_extend, + A_full, + D_full, + dt_bias_hp, + ) = extend_inputs + + # The intermediate state cache will be stored in these indices and read by the mamba_cache_manager, + # which expects them in the indices arange(num_extend). They are not used across requests, so we + # do not need consistent slot indices. + intermediate_state_indices = torch.arange( + num_extend, dtype=torch.int32, device=slot_idx_extend.device + ) + preallocated_ssm_out_e = preallocated_ssm_out_e.view( + num_extend, tokens_per_extend, num_heads, head_dim + ) + selective_state_update( + ssm_state_cache, + x_extend, + dt_extend, + A_full, + B_extend, + C_extend, + D=D_full, + z=None, + dt_bias=dt_bias_hp, + dt_softplus=True, + state_batch_indices=slot_idx_extend, + out=preallocated_ssm_out_e, + disable_state_update=True, + intermediate_states_buffer=intermediate_ssm_state_cache, + cache_steps=tokens_per_extend, + intermediate_state_indices=intermediate_state_indices, + ) + + # DECODE decode_inputs = _prepare_ssm_decode_inputs( hs_flat, B_flat, @@ -105,10 +180,10 @@ def _triton_cached_ssm( D, dt_bias, slot_idx, - num_prefill, - num_prefill_tokens, - num_seq, - num_total_tokens, + num_prefill + num_extend, + num_prefill_tokens + num_extend_tokens, + num_decode, + num_decode_tokens, num_heads, head_dim, ssm_state_size, @@ -137,7 +212,7 @@ def _triton_cached_ssm( dt_bias=dt_bias_hp, dt_softplus=True, state_batch_indices=slot_idx_decode, - out=preallocated_ssm_out[num_prefill_tokens:num_total_tokens], + out=preallocated_ssm_out[num_prefill_tokens + num_extend_tokens : num_total_tokens], ) if out is not None: @@ -172,6 +247,9 @@ def _triton_cached_ssm_fake( seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + intermediate_ssm_state_cache: Optional[ + torch.Tensor + ], # [spec_state_size, max_draft_len+1, num_heads, head_dim, d_state] # CONSTANTS time_step_limit: List[float], chunk_size: int, @@ -192,3 +270,11 @@ class TritonBackendSSM(BaseBackendSSM): @classmethod def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.triton_cached_ssm.default + + @classmethod + def get_cache_initializers(cls, source_attn_node, cache_config): + ret = super().get_cache_initializers(source_attn_node, cache_config) + ret["intermediate_ssm_state_cache"] = SpecSSMResourceHandler.from_base( + ret["ssm_state_cache"] + ) + return ret diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 58847c64dbeb..99e97aaf9d07 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -11,6 +11,7 @@ from ...llmapi.llm_args import ( BuildConfig, EagleDecodingConfig, + MTPDecodingConfig, SamplerType, TorchLlmArgs, _ParallelConfig, @@ -115,20 +116,39 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A @model_validator(mode="after") def setup_hidden_state_capture(self): - if self.speculative_config is None or not isinstance( - self.speculative_config, EagleDecodingConfig - ): + spec_config = self.speculative_config + if spec_config is None: + return self + + if isinstance(spec_config, MTPDecodingConfig): + if not spec_config.mtp_eagle_one_model: + return self + if spec_config.use_mtp_vanilla: + raise ValueError("mtp_eagle_one_model and use_mtp_vanilla cannot both be enabled") + if spec_config.max_draft_len is None: + raise ValueError( + "MTPDecodingConfig.max_draft_len must not be None when mtp_eagle_one_model is " + "enabled. Ensure num_nextn_predict_layers is set in the model config." + ) + capture_layers = {-1} + self.model_factory = "eagle_one_model" + elif isinstance(spec_config, EagleDecodingConfig): + if spec_config.max_draft_len is None: + raise ValueError( + "EagleDecodingConfig.max_draft_len must not be None. " + "Provide a positive integer for max_draft_len." + ) + capture_layers = spec_config.eagle3_layers_to_capture + if spec_config.eagle3_one_model: + self.model_factory = "eagle_one_model" + else: return self self.transforms["detect_hidden_states_for_capture"]["enabled"] = True self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = ( - self.speculative_config.eagle3_layers_to_capture + capture_layers ) - # Use the eagle_one_model factory for one-model Eagle speculative decoding - if self.speculative_config.eagle3_one_model: - self.model_factory = "eagle_one_model" - return self @model_validator(mode="after") diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py index 40e61d85cec7..28f553dc1486 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py @@ -13,16 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Eagle3 model implementation for AutoDeploy. +"""Eagle model implementation for AutoDeploy. -Eagle3 is a speculative decoding draft model that predicts next tokens based on +Eagle is a speculative decoding draft model that predicts next tokens based on hidden states from a target model (e.g., Llama-3.1-8B-Instruct). -This file contains model definitions used for executing Eagle3 speculative decoding in AutoDeploy. +This file contains: +- Generic Eagle infrastructure (EagleModel, EagleDrafterForCausalLM, EagleWrapper) +- Llama-specific Eagle layer implementation (LlamaEagleLayer) +- Layer dispatch functions for model-specific layer construction + +Model-specific layers for other architectures (e.g., NemotronH) are defined in their +respective model files and registered via get_eagle_layers(). """ from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from types import SimpleNamespace +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -30,9 +37,58 @@ from transformers.activations import ACT2FN from transformers.utils import ModelOutput +from ....pyexecutor.mamba_cache_manager import MambaHybridCacheManager from ...shim.interface import CachedSequenceInterface from ...utils._config import deep_merge_dicts from ...utils.logger import ad_logger +from .modeling_nemotron_h import build_nemotron_eagle_layers + +# ============================================================================= +# Layer Dispatch Functions +# ============================================================================= + + +def get_eagle_layers(config, model_type: str) -> Union[nn.ModuleList, nn.Module]: + """Build Eagle layers for the given model type. + + This function dispatches to model-specific layer builders based on model_type. + Each builder returns layers that implement the unified forward signature: + forward(hidden_states, inputs_embeds, position_ids) -> Tensor + + For backward compatibility with checkpoints: + - Single layer: returns layer directly (not wrapped in ModuleList) + - Multiple layers: returns nn.ModuleList of layer instances + + Args: + config: Model configuration (e.g., EagleConfig for Llama) + model_type: The base model type (e.g., "llama", "nemotron_h") + + Returns: + nn.ModuleList of layers for the Eagle model, or single layer if there is only one layer + """ + layers: list[nn.Module] + match model_type: + case "llama": + layers = build_llama_eagle_layers(config) + case "nemotron_h": + layers = build_nemotron_eagle_layers(config) + case _: + raise ValueError( + f"Model type '{model_type}' not supported for Eagle drafter. " + f"Supported types: llama, nemotron_h" + ) + + if len(layers) == 1: + return layers[0] + return nn.ModuleList(layers) + + +def build_llama_eagle_layers(config) -> list[nn.Module]: + """Build Llama-style Eagle decoder layers. + + Each layer handles RoPE internally, making the EagleModel fully model-agnostic. + """ + return [LlamaEagleLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] class EagleConfig(PretrainedConfig): @@ -43,16 +99,47 @@ class EagleConfig(PretrainedConfig): Args: config: Base config for the draft model from its config.json. - model_type: The base model type (e.g., "llama") used to look up defaults. + model_type: The base model type (e.g., "llama", "nemotron_h") used to look up defaults. """ # Map model_type -> default Eagle config values + # Includes _checkpoint_conversion_mapping for model-specific weight key transformations _drafter_defaults: Dict[str, Dict[str, Any]] = { "llama": { "load_embedding_from_target": True, "load_lm_head_from_target": False, "num_capture_layers": 3, + "normalize_target_hidden_state": False, + # Whether the final norm (pre-lm_head) is handled inside the layers. + # If False, the wrapper applies self.norm after the layers. + # If True, layers have their own final_layernorm and wrapper skips self.norm. + "layers_handle_final_norm": False, + # Llama Eagle checkpoint: fc.*, midlayer.* -> model.fc.*, model.layers.* + "_checkpoint_conversion_mapping": { + "^(?!lm_head|norm)": "model.", + "midlayer": "layers", + }, }, + "nemotron_h": { + "load_embedding_from_target": True, + "load_lm_head_from_target": True, + "num_capture_layers": 1, + "normalize_target_hidden_state": True, + "mtp_hybrid_override_pattern": "*E", + # NemotronH MTP layers have final_layernorm on the last layer, + # so the wrapper should NOT apply an additional norm. + "layers_handle_final_norm": True, + # NemotronH MTP checkpoint: mtp.* -> model.* + "_checkpoint_conversion_mapping": { + r"^mtp\.": "model.", + }, + }, + } + # Some custom HF config classes expose backward-compatibility fields as properties instead of + # storing them directly in __dict__. Those values do not survive config.to_dict(), so carry + # them over explicitly before rebuilding a generic EagleConfig. + _preserved_config_attrs: Dict[str, tuple[str, ...]] = { + "nemotron_h": ("mtp_hybrid_override_pattern",), } def __init__(self, config: PretrainedConfig, model_type: str): @@ -64,6 +151,9 @@ def __init__(self, config: PretrainedConfig, model_type: str): defaults = self._drafter_defaults[model_type] config_dict = config.to_dict() + for key in self._preserved_config_attrs.get(model_type, ()): + if key not in config_dict and hasattr(config, key): + config_dict[key] = getattr(config, key) # Log when config overrides a default for key, value in defaults.items(): @@ -201,9 +291,16 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + dtype = config.torch_dtype + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias, dtype=dtype + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias, dtype=dtype + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=config.mlp_bias, dtype=dtype + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -224,6 +321,7 @@ def __init__(self, config, layer_idx: int): self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.is_causal = True + dtype = config.torch_dtype # Note: Eagle3Attention expects 2 * hidden_size input, which is the concatenation of the hidden states # and the input embeddings. @@ -232,21 +330,25 @@ def __init__(self, config, layer_idx: int): 2 * config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, + dtype=dtype, ) self.k_proj = nn.Linear( 2 * config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, + dtype=dtype, ) self.v_proj = nn.Linear( 2 * config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, + dtype=dtype, ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, + dtype=dtype, ) def forward( @@ -288,38 +390,80 @@ def forward( return attn_output -class Eagle3DecoderLayer(nn.Module): - """Eagle decoder layer with modified attention and hidden state normalization.""" +# ============================================================================= +# Llama-Specific Eagle Layer +# ============================================================================= + + +class LlamaEagleLayer(nn.Module): + """Eagle decoder layer for Llama-family models. + + Architecture: + - Normalize embeds and hidden states, concatenate to 2*hidden_size + - Self-attention with RoPE (computed internally from position_ids) + - Add residual + - Normalize, gated MLP (SwiGLU), add residual + """ def __init__(self, config, layer_idx: int = 0): super().__init__() + self.config = config + self.layer_idx = layer_idx self.dtype = config.torch_dtype - self.self_attn = Eagle3Attention(config, layer_idx=layer_idx) + + # Normalization layers self.hidden_norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Attention (expects 2*hidden_size input from concat) + self.self_attn = Eagle3Attention(config, layer_idx=layer_idx) + + # MLP (gated SwiGLU style) self.mlp = EagleMLP(config) + # RoPE + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.rotary_emb = LlamaRotaryEmbedding( + config=config, dim=self.head_dim, device=torch.device("cuda") + ) + def forward( self, hidden_states: torch.Tensor, - embeds: torch.Tensor, - position_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - residual = hidden_states - hidden_states = self.hidden_norm(hidden_states) + inputs_embeds: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + """Forward pass with unified interface. - embeds = self.input_layernorm(embeds) + Args: + hidden_states: Hidden states from target model [batch, seq, hidden_size] + inputs_embeds: Token embeddings [batch, seq, hidden_size] + position_ids: Position IDs for RoPE [batch, seq] + Returns: + Updated hidden states [batch, seq, hidden_size] + """ + # Compute RoPE internally + cos, sin = self.rotary_emb(hidden_states, position_ids) + position_embeddings = (cos, sin) + + # Normalize and concatenate embeds + hidden states + residual = hidden_states + hidden_states = self.hidden_norm(hidden_states) + embeds = self.input_layernorm(inputs_embeds) hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self-attention with RoPE hidden_states = self.self_attn( hidden_states=hidden_states, - position_embeddings=position_embeds, + position_embeddings=position_embeddings, ) - hidden_states = residual + hidden_states + # MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) @@ -328,12 +472,23 @@ def forward( return hidden_states -class Eagle3Model(nn.Module): - """Core Eagle model architecture.""" +class EagleModel(nn.Module): + """Generic Eagle model architecture. - def __init__(self, config): - super().__init__() + This model is model-agnostic - it accepts layers from the factory and passes + position_ids through to them. Layers handle model-specific logic (e.g., RoPE) + internally. + Args: + config: Model configuration + layers: nn.ModuleList of layers (for multi-layer) or single nn.Module (for single layer). + Each layer implements the unified forward signature: + forward(hidden_states, inputs_embeds, position_ids) -> Tensor + """ + + def __init__(self, config, layers: Union[nn.ModuleList, nn.Module]): + super().__init__() + self.config = config self.dtype = config.torch_dtype load_embedding_from_target = getattr(config, "load_embedding_from_target", False) @@ -343,68 +498,66 @@ def __init__(self, config): else nn.Embedding(config.vocab_size, config.hidden_size) ) - if config.draft_vocab_size is not None and config.draft_vocab_size != config.vocab_size: - # Vocab mappings for draft <-> target token conversion - # Needed to convert draft outputs to target inputs for Eagle3. - # Since we reuse the target model's embedding in the drafter, we need - # to do this conversion after every draft iteration. + # Vocab mapping for draft -> target token conversion + draft_vocab_size = getattr(config, "draft_vocab_size", None) or config.vocab_size + if draft_vocab_size != config.vocab_size: self.d2t = nn.Parameter( - torch.empty((config.draft_vocab_size,), dtype=torch.int32), + torch.empty((draft_vocab_size,), dtype=torch.int32), requires_grad=False, ) - # Hidden size compression for target hidden states. - # Assumption: No feedforward fusion needed if we have just one capture layer (valid for MTPEagle) + # Hidden size compression for target hidden states (multi-layer capture) + num_capture_layers = getattr(config, "num_capture_layers", 1) self.fc = ( nn.Linear( - config.hidden_size * config.num_capture_layers, + config.hidden_size * num_capture_layers, config.hidden_size, bias=getattr(config, "bias", False), dtype=self.dtype, ) - if config.num_capture_layers > 1 + if num_capture_layers > 1 else None ) - self.head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) + # Layers (injected by factory - model-specific) + # Can be ModuleList (multi-layer) or single Module (single layer) for checkpoint compat + # No rotary_emb here - layers handle RoPE internally if needed + self.layers = layers - self.rotary_emb = LlamaRotaryEmbedding( - config=config, dim=self.head_dim, device=torch.device("cuda") - ) - - if config.num_hidden_layers > 1: - self.midlayer = nn.ModuleList( - [Eagle3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] - ) - else: - self.midlayer = Eagle3DecoderLayer(config, layer_idx=0) - - self.num_hidden_layers = config.num_hidden_layers - - # Assumption: The hidden states are already fused if necessary def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.LongTensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - cos, sin = self.rotary_emb(hidden_states, position_ids) - position_embeds = (cos, sin) + """Forward pass through the Eagle model. + + Args: + inputs_embeds: Token embeddings [batch, seq, hidden_size] + position_ids: Position IDs [batch, seq] - passed to layers + hidden_states: Hidden states from target model [batch, seq, hidden_size] - if self.num_hidden_layers > 1: - for layer in self.midlayer: + Returns: + Updated hidden states [batch, seq, hidden_size] + """ + # Pass position_ids through to layers - they decide what to do with it + # (e.g., Llama layers compute RoPE, NemotronH layers ignore it) + if isinstance(self.layers, nn.ModuleList): + for layer in self.layers: hidden_states = layer( hidden_states=hidden_states, - embeds=inputs_embeds, - position_embeds=position_embeds, + inputs_embeds=inputs_embeds, + position_ids=position_ids, ) - else: - hidden_states = self.midlayer( + elif isinstance(self.layers, nn.Module): + hidden_states = self.layers( hidden_states=hidden_states, - embeds=inputs_embeds, - position_embeds=position_embeds, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + ) + else: + raise TypeError( + f"Expected self.layers to be nn.ModuleList or nn.Module, got {type(self.layers).__name__}" ) return hidden_states @@ -417,38 +570,59 @@ class Eagle3DraftOutput(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None -class Eagle3DrafterForCausalLM(PreTrainedModel): +class EagleDrafterForCausalLM(PreTrainedModel): """HuggingFace-compatible wrapper for EagleModel. This wrapper makes EagleModel compatible with AutoDeploy's model loading - and inference pipeline. + and inference pipeline. It accepts layers from the factory to enable + model-specific layer implementations. + + Args: + config: Model configuration (should be EagleConfig with model-type specific defaults) + layers: Layers to use in EagleModel. Can be nn.ModuleList (multi-layer) or a single + nn.Module (single-layer). If None, builds based on model_type. """ base_model_prefix = "model" supports_gradient_checkpointing = False - _no_split_modules = ["Eagle3DecoderLayer"] - - # Checkpoint conversion mapping: Eagle checkpoints have keys like "fc.weight" - # but the wrapper model expects "model.fc.weight" (due to self.model = Eagle3Model). - # This mapping tells the factory to add "model." prefix when loading weights. - # Used by AutoModelForCausalLMFactory._remap_param_names_load_hook() - - _checkpoint_conversion_mapping = { - "^(?!lm_head|norm)": "model.", # Prepend "model." to all keys EXCEPT lm_head and norm - } + _no_split_modules = ["LlamaEagleLayer", "NemotronHEagleLayer"] - def __init__(self, config): + def __init__(self, config, layers: Optional[Union[nn.ModuleList, nn.Module]] = None): super().__init__(config) + # Read checkpoint conversion mapping from config (set by EagleConfig based on model_type) + self._checkpoint_conversion_mapping = getattr( + config, "_checkpoint_conversion_mapping", None + ) + self.load_embedding_from_target = getattr(config, "load_embedding_from_target", False) self.load_lm_head_from_target = getattr(config, "load_lm_head_from_target", False) - - self.model = Eagle3Model(config) - self.norm = EagleRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Whether layers handle the final norm (pre-lm_head) internally. + # If True, layers have their own final_layernorm and we skip self.norm in forward. + # If False (default), we apply self.norm after the layers. + self._layers_handle_final_norm = getattr(config, "layers_handle_final_norm", False) + + # If layers not provided, build based on model_type + if layers is None: + layers = get_eagle_layers(config, config.model_type) + + self.model = EagleModel(config, layers) + + # Only create norm if layers don't handle final normalization internally. + if not self._layers_handle_final_norm: + # Use fallback chain for eps: rms_norm_eps (Llama) -> layer_norm_epsilon (NemotronH) -> default + norm_eps = getattr(config, "rms_norm_eps", getattr(config, "layer_norm_epsilon", 1e-6)) + self.norm = EagleRMSNorm(config.hidden_size, eps=norm_eps) + else: + self.norm = None + # draft_vocab_size defaults to vocab_size if not specified + draft_vocab_size = getattr(config, "draft_vocab_size", None) or config.vocab_size self.lm_head = ( None if self.load_lm_head_from_target - else nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False) + else nn.Linear( + config.hidden_size, draft_vocab_size, bias=False, dtype=config.torch_dtype + ) ) eagle_config = getattr(config, "eagle_config", {}) @@ -457,24 +631,24 @@ def __init__(self, config): def forward( self, inputs_embeds: torch.LongTensor, - position_ids: Optional[torch.LongTensor] = None, + position_ids: torch.LongTensor, **kwargs, ) -> Eagle3DraftOutput: - """ - Kwargs: - hidden_states: Hidden states from the target model. Required. + """Forward pass for Eagle drafter. + + Args: + inputs_embeds: Input token embeddings [batch, seq, hidden_size] + position_ids: Position IDs [batch, seq]. Required. + **kwargs: Must contain 'hidden_states' from the target model. + + Returns: + Eagle3DraftOutput with norm_hidden_state and last_hidden_state. Raises: - ValueError: If hidden_states is not provided in kwargs. + ValueError: If hidden_states or position_ids is not provided. """ - batch_size, seq_len, _ = inputs_embeds.shape - device = inputs_embeds.device - - # Generate position_ids if not provided if position_ids is None: - position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - position_ids = position_ids.expand(batch_size, -1) - + raise ValueError("position_ids must be provided.") hidden_states = kwargs.get("hidden_states") if hidden_states is None: raise ValueError("hidden_states must be provided.") @@ -483,7 +657,13 @@ def forward( inputs_embeds=inputs_embeds, position_ids=position_ids, hidden_states=hidden_states ) - norm_hidden_state = self.norm(hidden_states) + # Apply final norm only if layers don't handle it internally. + # For Llama: layers don't normalize, so we apply self.norm here. + # For NemotronH: layers have final_layernorm, so hidden_states are already normalized. + if self.norm is not None: + norm_hidden_state = self.norm(hidden_states) + else: + norm_hidden_state = hidden_states # already normalized by layer last_hidden_state = norm_hidden_state if self._return_hidden_post_norm else hidden_states @@ -497,7 +677,7 @@ def get_input_embeddings(self): return self.model.embed_tokens else: raise NotImplementedError( - "Eagle3DrafterForCausalLM does not have an input embedding layer." + "EagleDrafterForCausalLM does not have an input embedding layer." ) def get_output_embeddings(self): @@ -505,7 +685,7 @@ def get_output_embeddings(self): return self.lm_head else: raise NotImplementedError( - "Eagle3DrafterForCausalLM does not have an output embedding layer." + "EagleDrafterForCausalLM does not have an output embedding layer." ) @@ -543,6 +723,7 @@ class EagleWrapperConfig: max_draft_len: int load_embedding_from_target: bool load_lm_head_from_target: bool + normalize_target_hidden_state: bool = False class EagleWrapper(nn.Module): @@ -563,12 +744,13 @@ def __init__(self, config: EagleWrapperConfig, target_model: nn.Module, draft_mo self.max_draft_len = config.max_draft_len self.load_embedding_from_target = config.load_embedding_from_target self.load_lm_head_from_target = config.load_lm_head_from_target + self.normalize_target_hidden_state = config.normalize_target_hidden_state @property def _draft_inner_model(self): """Get the inner model submodule of the draft model. - Before export: self.draft_model.model (Eagle3Model inside Eagle3DrafterForCausalLM). + Before export: self.draft_model.model (EagleModel inside EagleDrafterForCausalLM). After export: self.draft_model.model (preserved by DraftModelExportInfo.post_process). """ return self.draft_model.model @@ -578,6 +760,21 @@ def _draft_dtype(self): """Get the dtype of the draft model (works before and after export).""" return getattr(self._draft_inner_model, "dtype", None) or torch.bfloat16 + def normalize_target_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the target model's final normalization to hidden states. + + MTP hidden states are captured at the residual add (pre-norm_f), but the + MTP head expects post-norm_f input. The target model must expose + get_final_normalization() for this to work. + """ + norm_fn = getattr(self.target_model, "get_final_normalization", None) + if norm_fn is None: + raise RuntimeError( + "MTP requires the target model to expose get_final_normalization(), " + f"but {type(self.target_model).__name__} does not implement it." + ) + return norm_fn()(hidden_states) + def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor: """Apply the fc layer that fuses hidden states from multiple target layers.""" hidden_states = hidden_states.to(self._draft_dtype) @@ -745,14 +942,28 @@ def _forward_with_kv_cache(self, csi: CachedSequenceInterface): # NOTE: we assume gather_context_logits is False so that gathering here works! target_logits = csi.info.maybe_gather_and_squeeze(out.logits) + # ---- Phase 2: Collect hidden states ---- # TODO: investigate root cause — without this sync the hidden_states_cache buffers # read by _collect_hidden_states can contain stale data, dropping the spec-dec # acceptance rate from ~31% to ~7%. torch.cuda.synchronize() - # ---- Phase 2: Collect hidden states from cache buffers ---- + # TODO: For MTP, a cleaner approach would return hidden states as a second output + # from the target model (final_norm_hidden_states). However, this causes NCCL hangs + # at TP>1 because the export/sharding pipeline doesn't handle multiple graph outputs. + # For now, both MTP and Eagle3 use the detect_hidden_states_for_capture graph transform. hidden_states = self._collect_hidden_states(csi.named_args, num_total_tokens) - hidden_states = self.apply_eagle3_fc(hidden_states) + if self.normalize_target_hidden_state: + # MTP: hidden states are captured at the residual add (pre-normalization). + # Apply the target model's final normalization to match the PyTorch backend + # which passes normalized hidden_states to MTPEagleWorker. + hidden_states = self.normalize_target_hidden_states(hidden_states) + # Cast to draft model dtype (e.g. target may be FP8, draft BF16). + hidden_states = hidden_states.to(self._draft_dtype) + else: + # Eagle3: compress hidden states from multiple captured layers via fc. + # apply_eagle3_fc also handles the target->draft dtype cast. + hidden_states = self.apply_eagle3_fc(hidden_states) # ---- Phase 3: Sample ---- # check dtype/device @@ -807,6 +1018,19 @@ def _forward_with_kv_cache(self, csi: CachedSequenceInterface): if num_extend > 0: new_tokens_lens[num_prefill:] = new_tokens_lens_extend + # MTP state promotion: commit accepted intermediate mamba states to base state + # immediately after verification, before cache offset computation and draft loop. + # Must happen inside model forward (not in ad_executor) for correct timing — + # update_mamba_states reads .num_seqs and .num_contexts from attn_metadata. + kv_cache_manager = csi.kv_cache_manager + if num_extend > 0 and isinstance(kv_cache_manager, MambaHybridCacheManager): + if kv_cache_manager.is_speculative(): + _ctx = SimpleNamespace(num_seqs=num_sequences, num_contexts=num_prefill) + kv_cache_manager.update_mamba_states( + attn_metadata=_ctx, + num_accepted_tokens=new_tokens_lens, + ) + # compute the cache and position offset based on the number of new tokens compared to the # maximum draft length. NOTE: cache is currently at the position corresponding to the last # draft token. Hence the following constraint is true: diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py index f87af7e93080..bb06b20c89fc 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. """Slimmed down PyTorch NemotronH model implementation. @@ -616,6 +616,9 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) + def get_final_normalization(self): + return self.backbone.norm_f + def get_output_embeddings(self): return self.lm_head @@ -638,3 +641,128 @@ def forward( AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM) + + +# ============================================================================= +# Eagle Layer Builder for NemotronH MTP (Multi-Token Prediction) +# ============================================================================= + + +class NemotronHEagleLayer(nn.Module): + """Eagle layer for NemotronH models. + + NemotronH does not use RoPE, so position_ids is accepted but ignored. + The layer implements the MTP (Multi-Token Prediction) architecture: + - First layer fuses embeds + hidden_states via start projections (enorm, hnorm, eh_proj) + - All layers have pre-norm residual block with mixer (Attention or MoE) + - Last layer applies final_layernorm + + Supported layers are * (Attention with start projections) and E (MoE with final_layernorm) + """ + + def __init__( + self, + config, + layer_idx: int, + layer_type: str, + has_start_projections: bool, + has_end_norm: bool, + ): + super().__init__() + eps = getattr(config, "layer_norm_epsilon") + if eps is None: + raise ValueError("layer_norm_epsilon is not set in the config") + self.residual_in_fp32 = config.residual_in_fp32 + self.has_start_projections = has_start_projections + self.has_end_norm = has_end_norm + + # Start projections (only on first layer) + # These fuse embeds + hidden_states: eh_proj(cat(enorm(embeds), hnorm(hidden))) + if has_start_projections: + self.enorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + self.hnorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + + # Pre-layer norm + self.norm = NemotronHRMSNorm(config.hidden_size, eps=eps) + + # Mixer based on layer type + if layer_type == "*": + self.mixer = NemotronHAttention(config, layer_idx=layer_idx) + elif layer_type == "E": + self.mixer = NemotronHMOE(config, layer_idx=layer_idx) + else: + raise ValueError( + f"Unsupported MTP layer type in NemotronHEagleLayer. Only * and E are currently supported." + f"Layer type: {layer_type}" + ) + + # Final norm (only on last layer) + if has_end_norm: + self.final_layernorm = NemotronHRMSNorm(config.hidden_size, eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + """Forward pass with unified Eagle interface. + + Args: + hidden_states: Hidden states from target model [batch, seq, hidden_size] + inputs_embeds: Token embeddings [batch, seq, hidden_size] + position_ids: Position IDs (ignored - NemotronH doesn't use RoPE) + + Returns: + Updated hidden states [batch, seq, hidden_size] + """ + # Note: position_ids is ignored - NemotronH doesn't use RoPE + + # Fuse embeddings and hidden states (first layer only) + if self.has_start_projections: + e_normed = self.enorm(inputs_embeds) + h_normed = self.hnorm(hidden_states) + hidden_states = self.eh_proj(torch.cat([e_normed, h_normed], dim=-1)) + + # Standard pre-norm residual block + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.mixer(hidden_states) + hidden_states = residual + hidden_states + + # Final layer norm (last layer only) + if self.has_end_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def build_nemotron_eagle_layers(config) -> list[nn.Module]: + """Build NemotronH MTP layers for Eagle drafter. + + This function is called by get_eagle_layers() in modeling_eagle.py. + + Args: + config: Model configuration with NemotronH-specific parameters + (mtp_hybrid_override_pattern, n_routed_experts, etc.) + + Returns: + List of NemotronHEagleLayer instances + """ + pattern = getattr(config, "mtp_hybrid_override_pattern", None) + if pattern is None: + raise ValueError("mtp_hybrid_override_pattern is not set in the config") + + return [ + NemotronHEagleLayer( + config, + layer_idx=i, + layer_type=char, + has_start_projections=(i == 0), + has_end_norm=(i == len(pattern) - 1), + ) + for i, char in enumerate(pattern) + ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/eagle.py b/tensorrt_llm/_torch/auto_deploy/models/eagle.py index b27abaf7f69e..9dda8de7da42 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/eagle.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,10 +32,11 @@ from torch.export import Dim from torch.fx import GraphModule +from ....llmapi.llm_args import MTPDecodingConfig from ..utils.logger import ad_logger from .custom.modeling_eagle import ( - Eagle3DrafterForCausalLM, EagleConfig, + EagleDrafterForCausalLM, EagleWrapper, EagleWrapperConfig, ) @@ -47,41 +48,27 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory): """Factory for building Eagle drafter models. - This factory handles the mapping from base model types (e.g., "llama") to - their corresponding Eagle drafter model implementations. It overrides - _build_model() to directly construct the appropriate drafter class based - on the checkpoint's model_type. + The drafter builds its own model-specific layers internally based on + config.model_type, allowing it to work with different base models + (Llama, NemotronH, etc.) without the factory needing to know the details. The checkpoint config is expected to have the base model's model_type (e.g., "llama") along with Eagle-specific fields like draft_vocab_size. """ - _drafter_classes: Dict[str, type] = { - "llama": Eagle3DrafterForCausalLM, - } - def _build_model(self, device: DeviceLikeType) -> nn.Module: model_config, unused_kwargs = self._get_model_config() - # Select the appropriate drafter class and config based on the base model type + # Get model type for config model_type = model_config.model_type - if model_type not in self._drafter_classes: - raise ValueError( - f"Unsupported model_type '{model_type}' for Eagle drafter. " - f"Supported types: {list(self._drafter_classes.keys())}" - ) - drafter_cls = self._drafter_classes[model_type] - ad_logger.info( - f"EagleDrafterFactory: model_type='{model_type}' -> drafter_cls={drafter_cls.__name__}" - ) + ad_logger.info(f"EagleDrafterFactory: building drafter for model_type='{model_type}'") # Convert base config to EagleConfig, preserving existing values # and applying model-specific defaults based on model_type model_config = EagleConfig(model_config, model_type) - # Build the model (same pattern as parent's _build_model) with (init_empty_weights if device == "meta" else nullcontext)(): - model = drafter_cls._from_config(model_config, **unused_kwargs) + model = EagleDrafterForCausalLM._from_config(model_config, **unused_kwargs) if device == "meta": # post-init must be called explicitly for HF models with init_empty_weights @@ -128,9 +115,6 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): """Preserve embedding (always) and optionally lm_head on the exported GraphModule.""" # --- Embedding: always needed (target embeds input_ids for both target and draft) --- embed_tokens = sub_mod.get_input_embeddings() - sub_gm.get_input_embeddings = types.MethodType( - sub_mod.get_input_embeddings.__func__, sub_gm - ) # Find the submodule path for the embedding for embed_name, subsubmod in sub_mod.named_modules(): if subsubmod is embed_tokens: @@ -138,6 +122,9 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): else: raise RuntimeError("Could not find embedding module in target model.") sub_gm.set_submodule(embed_name, embed_tokens) + sub_gm.get_input_embeddings = types.MethodType( + lambda self, _n=embed_name: self.get_submodule(_n), sub_gm + ) # Add impure node to prevent GC n_embed = sub_gm.graph.get_attr(f"{embed_name}.weight") sub_gm.graph.call_function( @@ -147,20 +134,37 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): # --- lm_head: only if draft model loads it from target --- if self.load_lm_head_from_target: lm_head = sub_mod.get_output_embeddings() - sub_gm.get_output_embeddings = types.MethodType( - sub_mod.get_output_embeddings.__func__, sub_gm - ) for lm_head_name, subsubmod in sub_mod.named_modules(): if subsubmod is lm_head: break else: raise RuntimeError("Could not find lm_head module in target model.") sub_gm.set_submodule(lm_head_name, lm_head) + sub_gm.get_output_embeddings = types.MethodType( + lambda self, _n=lm_head_name: self.get_submodule(_n), sub_gm + ) n_lm_head = sub_gm.graph.get_attr(f"{lm_head_name}.weight") sub_gm.graph.call_function( torch._assert, args=(n_lm_head, "Avoid lm_head getting deleted from graph.") ) + # --- Final normalization: only if target model exposes it (e.g., NemotronH for MTP) --- + if hasattr(sub_mod, "get_final_normalization"): + norm_module = sub_mod.get_final_normalization() + for norm_name, subsubmod in sub_mod.named_modules(): + if subsubmod is norm_module: + break + else: + raise RuntimeError("Could not find final normalization module in target model.") + sub_gm.set_submodule(norm_name, norm_module) + sub_gm.get_final_normalization = types.MethodType( + lambda self, _n=norm_name: self.get_submodule(_n), sub_gm + ) + n_norm = sub_gm.graph.get_attr(f"{norm_name}.weight") + sub_gm.graph.call_function( + torch._assert, args=(n_norm, "Avoid final norm getting deleted from graph.") + ) + class DraftModelExportInfo(SubModuleExportInfo): """Export info for the draft model inside EagleWrapper.""" @@ -189,22 +193,34 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): # --- Embedding (only if draft model has its own) --- if not self.load_embedding_from_target: - sub_gm.set_submodule("model.embed_tokens", inner_model.embed_tokens) + embed_tokens = sub_mod.get_input_embeddings() + for embed_name, subsubmod in sub_mod.named_modules(): + if subsubmod is embed_tokens: + break + else: + raise RuntimeError("Could not find embedding module in draft model.") + sub_gm.set_submodule(embed_name, embed_tokens) sub_gm.get_input_embeddings = types.MethodType( - sub_mod.get_input_embeddings.__func__, sub_gm + lambda self, _n=embed_name: self.get_submodule(_n), sub_gm ) - n_embed = sub_gm.graph.get_attr("model.embed_tokens.weight") + n_embed = sub_gm.graph.get_attr(f"{embed_name}.weight") sub_gm.graph.call_function( torch._assert, args=(n_embed, "Avoid draft embedding getting deleted.") ) # --- lm_head (only if draft model has its own) --- if not self.load_lm_head_from_target: - sub_gm.set_submodule("lm_head", sub_mod.lm_head) + lm_head = sub_mod.get_output_embeddings() + for lm_head_name, subsubmod in sub_mod.named_modules(): + if subsubmod is lm_head: + break + else: + raise RuntimeError("Could not find lm_head module in draft model.") + sub_gm.set_submodule(lm_head_name, lm_head) sub_gm.get_output_embeddings = types.MethodType( - sub_mod.get_output_embeddings.__func__, sub_gm + lambda self, _n=lm_head_name: self.get_submodule(_n), sub_gm ) - n_lm_head = sub_gm.graph.get_attr("lm_head.weight") + n_lm_head = sub_gm.graph.get_attr(f"{lm_head_name}.weight") sub_gm.graph.call_function( torch._assert, args=(n_lm_head, "Avoid draft lm_head getting deleted.") ) @@ -267,6 +283,13 @@ def __init__( raise ValueError("speculative_config is required for EagleOneModelFactory.") self.speculative_config = speculative_config + # For MTP, derive Eagle-pipeline fields from MTP-specific fields. + if isinstance(speculative_config, MTPDecodingConfig): + draft_model_path = speculative_config.speculative_model or model + else: + draft_model_path = speculative_config.speculative_model + if draft_model_path is None: + raise ValueError("speculative_config.speculative_model must be set.") # Create target factory (AutoModelForCausalLM) self.target_factory = AutoModelForCausalLMFactory( @@ -279,9 +302,6 @@ def __init__( ) # Create draft factory (EagleDrafter) - draft_model_path = speculative_config.speculative_model - if draft_model_path is None: - raise ValueError("speculative_config.speculative_model must be set.") self.draft_factory = EagleDrafterFactory( model=str(draft_model_path), model_kwargs=speculative_model_kwargs, @@ -303,6 +323,9 @@ def _build_model(self, device: str) -> nn.Module: max_draft_len=self.speculative_config.max_draft_len, load_embedding_from_target=getattr(draft_config, "load_embedding_from_target", True), load_lm_head_from_target=getattr(draft_config, "load_lm_head_from_target", True), + normalize_target_hidden_state=getattr( + draft_config, "normalize_target_hidden_state", False + ), ) return EagleWrapper( @@ -338,6 +361,11 @@ def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: def get_sharding_config(self) -> Dict[str, Any]: return self.target_factory.get_sharding_config() + # TODO(govind): It's possible that draft models have different quant configs than target models. + # We need to address this possibility. + def get_quant_config(self) -> Dict[str, Any]: + return self.target_factory.get_quant_config() + def get_cache_config_updates(self) -> Dict[str, Any]: return self.target_factory.get_cache_config_updates() diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 5868bfc4f2bc..58092d6571f9 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -1079,7 +1079,10 @@ def instantiate_sampler( spec_config = ad_config.speculative_config # One-model spec dec: model performs sampling internally, returns pre-computed tokens - if spec_config is not None and spec_config.spec_dec_mode.is_eagle3_one_model(): + if spec_config is not None and ( + spec_config.spec_dec_mode.is_eagle3_one_model() + or spec_config.spec_dec_mode.is_mtp_eagle_one_model() + ): sampler_args = TorchSampler.Args( max_seq_len=ad_config.max_seq_len, max_draft_len=max_draft_len, @@ -1171,10 +1174,11 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3() or spec_config.spec_dec_mode.is_eagle3_one_model() + or spec_config.spec_dec_mode.is_mtp_eagle_one_model() ): raise ValueError( "Currently, AutoDeploy only supports speculative decoding in " - "draft_target, eagle3, or eagle3_one_model mode." + "draft_target, eagle3, eagle3_one_model, or mtp_eagle_one_model mode." ) if spec_config is not None and ad_config.guided_decoding_backend is not None: @@ -1187,6 +1191,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer if ( spec_config is not None and not spec_config.spec_dec_mode.is_eagle3_one_model() + and not spec_config.spec_dec_mode.is_mtp_eagle_one_model() and (spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3()) ): draft_model_engine = create_draft_model_engine_maybe( diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 3e66e4c09a9d..8c60753aaaba 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -19,6 +19,8 @@ ResourceHandler, ResourceHandlerDict, SequenceInfo, + SpecCausalConvResourceHandler, + SpecSSMResourceHandler, SSMResourceHandler, StateResourceHandler, ) @@ -272,23 +274,34 @@ def _identify_managed_kv_resources( def _identify_managed_state_resources( self, - ) -> Tuple[Optional[SSMResourceHandler], list, Optional[CausalConvResourceHandler], list]: + ) -> Tuple[ + Optional[SSMResourceHandler], + list, + list, + Optional[CausalConvResourceHandler], + list, + list, + ]: """Identify SSM and Conv resources compatible with MambaHybridCacheManager. - Finds reference handlers for SSM and Conv resources, checks the n_groups constraint, - and collects all compatible resources for each type. + Finds reference handlers for SSM and Conv resources, checks the n_groups + constraint, and collects all compatible resources for each type. Returns: - Tuple of (ssm_ref, ssm_managed, conv_ref, conv_managed) where: + Tuple of (ssm_ref, ssm_managed, ssm_spec, conv_ref, conv_managed, conv_spec) where: - ssm_ref: Reference SSM handler or None - - ssm_managed: List of (name, handler) tuples for compatible SSM resources - - conv_ref: Reference Conv handler or None (may be None if constraint fails) - - conv_managed: List of (name, handler) tuples for compatible Conv resources + - ssm_managed: List of (name, handler) tuples for compatible base SSM resources + - ssm_spec: List of (name, handler) tuples for compatible speculative SSM resources. + This is only nonempty when speculative decoding is enabled. + - conv_ref: Reference Conv handler or None (may be None if the n_groups constraint fails) + - conv_managed: List of (name, handler) tuples for compatible base Conv resources + - conv_spec: List of (name, handler) tuples for compatible speculative Conv resources. + This is only nonempty when speculative decoding is enabled. """ ssm_ref: Optional[SSMResourceHandler] = None conv_ref: Optional[CausalConvResourceHandler] = None - # Find reference handlers for each state resource type + # Find the first base (non-spec) handler of each type as reference. for handler in self._resource_lookup.values(): if isinstance(handler, SSMResourceHandler) and ssm_ref is None: ssm_ref = handler @@ -306,11 +319,44 @@ def _identify_managed_state_resources( ) conv_ref = None # Don't manage Conv via cache manager - # Collect compatible resources for each managed type (using __eq__ for comparison) - ssm_managed = [(n, h) for n, h in self._resource_lookup.items() if ssm_ref == h] - conv_managed = [(n, h) for n, h in self._resource_lookup.items() if conv_ref == h] + # Collect resources compatible with the reference handlers for each managed type, + # using handler equality to match shape and dtype. + ssm_managed = [ + (name, handler) + for name, handler in self._resource_lookup.items() + if isinstance(handler, SSMResourceHandler) and handler == ssm_ref + ] + ssm_spec = [ + (name, handler) + for name, handler in self._resource_lookup.items() + if isinstance(handler, SpecSSMResourceHandler) + and handler == SpecSSMResourceHandler.from_base(ssm_ref) + ] + conv_managed = [ + (name, handler) + for name, handler in self._resource_lookup.items() + if isinstance(handler, CausalConvResourceHandler) and handler == conv_ref + ] + conv_spec = [ + (name, handler) + for name, handler in self._resource_lookup.items() + if isinstance(handler, SpecCausalConvResourceHandler) + and handler == SpecCausalConvResourceHandler.from_base(conv_ref) + ] + + # When speculative decoding is enabled, the backend must supply matching spec buffers. + # When it is not enabled, spec buffers may still be registered by the backend (e.g. + # triton_ssm always registers intermediate_ssm_state_cache) but will not be bound. + if self._spec_config is not None: + assert len(ssm_spec) == len(ssm_managed), ( + f"Mismatched SSM spec layer count: expected {len(ssm_managed)}, got {len(ssm_spec)}" + ) + assert len(conv_spec) == len(conv_managed), ( + f"Mismatched Conv spec layer count: expected {len(conv_managed)}, " + f"got {len(conv_spec)}" + ) - return ssm_ref, ssm_managed, conv_ref, conv_managed + return ssm_ref, ssm_managed, ssm_spec, conv_ref, conv_managed, conv_spec def _prepare_kv_cache_config( self, @@ -435,25 +481,30 @@ def _create_and_assign_state_views( kv_cache_kwargs: Dict, ssm_ref: Optional[SSMResourceHandler], ssm_managed: list, + ssm_spec: list, conv_ref: Optional[CausalConvResourceHandler], conv_managed: list, + conv_spec: list, ) -> Tuple[MambaHybridCacheManager, int]: """Create MambaHybridCacheManager and assign views for state resources. Creates the hybrid cache manager with mamba parameters derived from the reference - handlers, then retrieves and assigns buffer views for all managed SSM and Conv resources. + handlers, then retrieves and assigns buffer views for all managed SSM and Conv resources, + as well as speculative resources if they exist. Args: kv_cache_kwargs: Base kwargs for cache manager (will be extended with mamba params). ssm_ref: Reference SSM handler or None. - ssm_managed: List of (name, handler) tuples for SSM resources. + ssm_managed: List of base SSM resources. + ssm_spec: List of speculative SSM resources. conv_ref: Reference Conv handler or None. - conv_managed: List of (name, handler) tuples for Conv resources. + conv_managed: List of base Conv resources. + conv_spec: List of speculative Conv resources. Returns: Tuple of (manager, num_managed_mamba_layers). """ - # Derive Mamba parameters from reference handlers + # Mamba state params can be derived from reference handlers and number of managed (non-speculative) resources. mamba_params = self._get_mamba_state_params( ssm_ref, len(ssm_managed), conv_ref, len(conv_managed) ) @@ -465,18 +516,38 @@ def _create_and_assign_state_views( **kv_cache_kwargs, ) - # Retrieve and assign views for Mamba-managed resources (up to num_managed_mamba_layers) + # Retrieve and assign views for Mamba-managed resources (up to num_managed_mamba_layers). for layer_idx in range(num_managed_mamba_layers): if ssm_managed: + ssm_name = ssm_managed[layer_idx][0] ssm_view = manager.get_ssm_states(layer_idx) - assert ssm_view.is_contiguous(), f"Non-contiguous state {ssm_managed[layer_idx][0]}" - self._caches[ssm_managed[layer_idx][0]] = ssm_view + assert ssm_view.is_contiguous(), f"Non-contiguous state {ssm_name}" + self._caches[ssm_name] = ssm_view + if ssm_spec and self._spec_config is not None: + spec_ssm_name = ssm_spec[layer_idx][0] + spec_view = manager.get_intermediate_ssm_states(layer_idx) + if spec_view is None: + raise RuntimeError( + f"Intermediate SSM state binding returned no view for {spec_ssm_name}. " + "Are we using a backend that supports speculative decoding?" + ) + assert spec_view.is_contiguous(), f"Non-contiguous state {spec_ssm_name}" + self._caches[spec_ssm_name] = spec_view if conv_managed: + conv_name = conv_managed[layer_idx][0] conv_view = manager.get_conv_states(layer_idx) - assert conv_view.is_contiguous(), ( - f"Non-contiguous state {conv_managed[layer_idx][0]}" - ) - self._caches[conv_managed[layer_idx][0]] = conv_view + assert conv_view.is_contiguous(), f"Non-contiguous state {conv_name}" + self._caches[conv_name] = conv_view + if conv_spec and self._spec_config is not None: + spec_conv_name = conv_spec[layer_idx][0] + spec_view = manager.get_intermediate_conv_states(layer_idx) + if spec_view is None: + raise RuntimeError( + f"Intermediate conv state binding returned no view for {spec_conv_name}. " + "Are we using a backend that supports speculative decoding?" + ) + assert spec_view.is_contiguous(), f"Non-contiguous state {spec_conv_name}" + self._caches[spec_conv_name] = spec_view return manager, num_managed_mamba_layers @@ -541,7 +612,9 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: """ # 1. Identify managed resources kv_ref, kv_managed = self._identify_managed_kv_resources() - ssm_ref, ssm_managed, conv_ref, conv_managed = self._identify_managed_state_resources() + ssm_ref, ssm_managed, ssm_spec, conv_ref, conv_managed, conv_spec = ( + self._identify_managed_state_resources() + ) # 2. Prepare configuration kv_cache_config = self._prepare_kv_cache_config(max_tokens, kv_managed) @@ -553,7 +626,13 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: # NOTE: +1 for cuda graph padding kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots self._kv_cache_manager, _ = self._create_and_assign_state_views( - kv_cache_kwargs, ssm_ref, ssm_managed, conv_ref, conv_managed + kv_cache_kwargs, + ssm_ref, + ssm_managed, + ssm_spec, + conv_ref, + conv_managed, + conv_spec, ) else: # No typed state resources - use pure KVCacheManager @@ -588,15 +667,33 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: num_state_total = sum( 1 for h in self._resource_lookup.values() if isinstance(h, StateResourceHandler) ) - num_ssm_total = sum( + num_ssm_base_total = sum( 1 for h in self._resource_lookup.values() if isinstance(h, SSMResourceHandler) ) - num_conv_total = sum( + num_ssm_spec_total = sum( + 1 for h in self._resource_lookup.values() if isinstance(h, SpecSSMResourceHandler) + ) + num_ssm_total = num_ssm_base_total + num_ssm_spec_total + num_conv_base_total = sum( 1 for h in self._resource_lookup.values() if isinstance(h, CausalConvResourceHandler) ) + num_conv_spec_total = sum( + 1 + for h in self._resource_lookup.values() + if isinstance(h, SpecCausalConvResourceHandler) + ) + num_conv_total = num_conv_base_total + num_conv_spec_total num_state_other = num_state_total - num_ssm_total - num_conv_total - total_managed = len(kv_managed) + len(ssm_managed) + len(conv_managed) + # Count individual cache buffers owned by the cache manager. + # Spec buffers are only cache-manager-owned when spec decoding is enabled. + ssm_managed_count = len(ssm_managed) + ( + len(ssm_spec) if self._spec_config is not None else 0 + ) + conv_managed_count = len(conv_managed) + ( + len(conv_spec) if self._spec_config is not None else 0 + ) + total_managed = len(kv_managed) + ssm_managed_count + conv_managed_count paged_total = sum(1 for h in self._resource_lookup.values() if h.is_paged) kv_total = sum( @@ -613,9 +710,9 @@ def _create_kv_cache_manager(self, max_tokens: Optional[int] = None) -> Dict: "kv_managed": len(kv_managed), "paged_other": paged_other, "ssm_total": num_ssm_total, - "ssm_managed": len(ssm_managed), + "ssm_managed": ssm_managed_count, "conv_total": num_conv_total, - "conv_managed": len(conv_managed), + "conv_managed": conv_managed_count, "state_other": num_state_other, "other": other_total, "max_tokens": max_tokens_final, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py index b3c6380bada6..122a5464d424 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -105,6 +105,13 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: + # Collectives fusion depends on sharding (reads _sharding_transform_container). + # Draft models are not sharded, so skip them. + if getattr(gm, "is_draft", False): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + patterns = ADPatternMatcherPass() # Dummy shapes for tracing diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py index c2f843c96e34..8c92d9e6f5a9 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -174,13 +174,17 @@ def _apply( num_hidden_layers = len(residual_add_nodes) self.config.set_default_eagle3_layers_to_capture(num_hidden_layers) - residual_add_nodes = { - k: v for k, v in residual_add_nodes.items() if k in self.config.eagle3_layers_to_capture - } + layers_to_capture = self.config.eagle3_layers_to_capture.copy() + if -1 in layers_to_capture: + num_hidden_layers = len(residual_add_nodes) + layers_to_capture.remove(-1) + layers_to_capture.add(num_hidden_layers - 1) + + residual_add_nodes = {k: v for k, v in residual_add_nodes.items() if k in layers_to_capture} - assert residual_add_nodes.keys() == self.config.eagle3_layers_to_capture, ( + assert residual_add_nodes.keys() == layers_to_capture, ( f"Unable to find residual add nodes for layers. " - f"Expected: {self.config.eagle3_layers_to_capture}, Found: {residual_add_nodes.keys()}" + f"Expected: {layers_to_capture}, Found: {residual_add_nodes.keys()}" ) # Replace residual add nodes with special placeholder nodes @@ -237,12 +241,12 @@ def get_source_attention_op(cls) -> OpOverloadPacket: return torch.ops.auto_deploy.residual_add_for_capture @classmethod - def get_cached_attention_op(cls) -> MHACallable: + def get_cached_attention_op(cls, spec_config=None) -> MHACallable: return torch.ops.auto_deploy.cached_residual_add @classmethod def get_cache_initializers( - cls, source_attn_node: Node, cache_config: KvCacheConfig + cls, source_attn_node: Node, cache_config: KvCacheConfig, spec_config=None ) -> ResourceHandlerDict: hidden_size = source_attn_node.meta["val"].shape[-1] hidden_type = source_attn_node.meta["val"].dtype diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index dd64ba8556e1..fff763b437a9 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1062,6 +1062,12 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: + # Draft models are not sharded — they run unsharded inside EagleWrapper. + if getattr(gm, "is_draft", False): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + local_rank, world_size = shared_config.local_rank, shared_config.world_size assert isinstance(gm, GraphModule), "Expecting GraphModule" config = self.config @@ -1180,6 +1186,12 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: + # Draft models are not sharded — they run unsharded inside EagleWrapper. + if getattr(gm, "is_draft", False): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + # create a node dict for faster lookup node_dict = {n.name: n for n in gm.graph.nodes} diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index f21913259005..c5f6776eb2fd 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1462,7 +1462,7 @@ def log_two_model_deprecation_warning(self): return self def supports_backend(self, backend: str) -> bool: - return backend == "pytorch" + return backend in ("pytorch", "_autodeploy") @functools.cached_property def num_capture_layers(self) -> int: diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 82fbdfdeb0f7..ca6ace88a5d6 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -392,6 +392,8 @@ nvidia/Nemotron-Super-V3: mtp_enabled: true num_nextn_predict_layers: 3 accuracy: 80.85 + - spec_dec_algo: MTP + accuracy: 92.70 nvidia/Nemotron-3-Nano: - accuracy: 69.37 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 89c45318c76e..3772159f6434 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -22,7 +22,7 @@ from test_common.llm_data import hf_id_to_local_model_dir, llm_models_root from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM -from tensorrt_llm.llmapi import Eagle3DecodingConfig +from tensorrt_llm.llmapi import Eagle3DecodingConfig, MTPDecodingConfig from tensorrt_llm.quantization import QuantAlgo from tensorrt_llm.sampling_params import SamplingParams @@ -112,6 +112,32 @@ def print_memory_usage(label: str): print(f"{'=' * 60}\n") +def _check_acceptance_rate_stats(stats, min_acceptance_rate: float) -> None: + total_drafted = 0 + total_accepted = 0 + num_spec_iterations = 0 + + for stat in stats: + spec_stats = stat.get("specDecodingStats", {}) + num_draft = spec_stats.get("numDraftTokens", 0) + num_accepted = spec_stats.get("numAcceptedTokens", 0) + if num_draft <= 0: + continue + + num_spec_iterations += 1 + total_drafted += num_draft + total_accepted += num_accepted + + accept_rate = total_accepted / total_drafted if total_drafted > 0 else 0.0 + print("Spec dec acceptance rate: " + f"{accept_rate:.2%} ({total_accepted}/{total_drafted} tokens across " + f"{num_spec_iterations} speculative iterations)") + + assert accept_rate >= min_acceptance_rate, ( + f"Acceptance rate {accept_rate:.2%} below threshold {min_acceptance_rate:.0%}" + ) + + def low_memory_overrides(config, max_batch_size=32, free_gpu_memory_fraction=0.4, @@ -199,6 +225,10 @@ def get_default_sampling_params(self): n=beam_width, use_beam_search=beam_width > 1) + def check_acceptance_rate(self, llm, min_acceptance_rate: float): + """Check speculative decoding acceptance rate for the current run.""" + _check_acceptance_rate_stats(llm.get_stats(), min_acceptance_rate) + @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @@ -279,24 +309,7 @@ def check_acceptance_rate(self, llm, min_acceptance_rate: float): llm: The LLM instance with enable_iter_perf_stats=True. min_acceptance_rate: Minimum acceptance rate threshold (default 7%). """ - stats = llm.get_stats(timeout=2) - total_drafted = 0 - total_accepted = 0 - - for stat in stats: - spec_stats = stat.get("specDecodingStats", {}) - num_draft = spec_stats.get("numDraftTokens", 0) - num_accepted = spec_stats.get("numAcceptedTokens", 0) - if num_draft > 0: - total_drafted += num_draft - total_accepted += num_accepted - - accept_rate = total_accepted / total_drafted if total_drafted > 0 else 0.0 - print(f"Spec dec acceptance rate: {accept_rate:.2%} " - f"({total_accepted}/{total_drafted} tokens)") - assert accept_rate >= min_acceptance_rate, ( - f"Acceptance rate {accept_rate:.2%} below threshold {min_acceptance_rate:.0%}" - ) + _check_acceptance_rate_stats(llm.get_stats(), min_acceptance_rate) @pytest.mark.skip_less_device_memory(32000) def test_eagle3_one_model(self): @@ -526,6 +539,10 @@ def get_default_sampling_params(self): n=beam_width, use_beam_search=beam_width > 1) + def check_acceptance_rate(self, llm, min_acceptance_rate: float): + """Check speculative decoding acceptance rate for the current run.""" + _check_acceptance_rate_stats(llm.get_stats(), min_acceptance_rate) + @pytest.mark.skip_less_device_memory(180000) @pytest.mark.parametrize("attn_backend", ["flashinfer", "trtllm"]) @pytest.mark.parametrize("enable_attention_dp", [False, True], @@ -569,6 +586,50 @@ def test_accuracy(self, model_id, world_size, enable_attention_dp, print_memory_usage("after evaluation") + @pytest.mark.skip_less_device_memory(180000) + @pytest.mark.parametrize("world_size", [4, 8]) + def test_mtp(self, world_size): + if get_device_count() < world_size: + pytest.skip(f"Not enough devices for world_size={world_size}") + + model_path = self.MODEL_PATHS["bf16"] + kwargs = {} + low_memory_overrides(kwargs) + kwargs["compile_backend"] = "torch-simple" + kwargs["attn_backend"] = "flashinfer" + kwargs["speculative_config"] = MTPDecodingConfig( + num_nextn_predict_layers=6, + mtp_eagle_one_model=True, + speculative_model=model_path, + ) + kwargs["transforms"] = { + "insert_cached_ssm_attention": { + "backend": "triton_ssm" + }, + "insert_cached_causal_conv": { + "backend": "triton_causal_conv" + }, + } + + print( + f"SuperV3 MTP params: world_size={world_size}, model_path={model_path}" + ) + print(f"kwargs: {kwargs}") + + print_memory_usage("test start") + with AutoDeployLLM( + model=model_path, + tokenizer=model_path, + world_size=world_size, + enable_iter_perf_stats=True, + **kwargs, + ) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + self.check_acceptance_rate(llm, min_acceptance_rate=0.45) + + print_memory_usage("after evaluation") + class TestGLM4Flash(LlmapiAccuracyTestHarness): """Accuracy regression tests for GLM-4.7-Flash variants""" diff --git a/tests/integration/defs/examples/test_ad_speculative_decoding.py b/tests/integration/defs/examples/test_ad_speculative_decoding.py index eeb82368c29c..360ef5b4e495 100644 --- a/tests/integration/defs/examples/test_ad_speculative_decoding.py +++ b/tests/integration/defs/examples/test_ad_speculative_decoding.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import re from dataclasses import dataclass @@ -24,6 +25,7 @@ import torch.nn as nn from build_and_run_ad import ExperimentConfig, main from defs.conftest import llm_models_root +from test_common.llm_data import hf_id_to_local_model_dir from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.masking_utils import create_causal_mask from transformers.modeling_outputs import BaseModelOutputWithPast @@ -33,6 +35,7 @@ from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.llm import LLM from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( + EagleDrafterForCausalLM, EagleWrapper, EagleWrapperConfig, ) @@ -427,7 +430,7 @@ def test_eagle_model_with_weights(): builds the Eagle drafter model based on the checkpoint's model_type: 1. Factory creates config via AutoConfig.from_pretrained - 2. Factory selects Eagle3DrafterForCausalLM based on model_type="llama" + 2. Factory selects EagleDrafterForCausalLM based on model_type="llama" 3. Factory creates model via _from_config 4. Factory loads weights via load_or_random_init -> _load_checkpoint @@ -440,15 +443,6 @@ def test_eagle_model_with_weights(): _, _, eagle_model_path = get_model_paths() eagle_path = Path(eagle_model_path) - if not eagle_path.exists(): - pytest.skip(f"Eagle model not found at {eagle_model_path}") - - # Check for weights - bin_path = eagle_path / "pytorch_model.bin" - safetensors_path = eagle_path / "model.safetensors" - if not bin_path.exists() and not safetensors_path.exists(): - pytest.skip(f"Weights not found at {eagle_model_path}") - # 1. Setup Device device = "cuda" if torch.cuda.is_available() else "cpu" @@ -464,8 +458,8 @@ def test_eagle_model_with_weights(): # Factory flow: # build_model() -> prefetch_checkpoint() -> _build_model() # _build_model() -> _get_model_config() (gets base LlamaConfig) - # _build_model() -> selects Eagle3DrafterForCausalLM for model_type="llama" - # _build_model() -> Eagle3DrafterForCausalLM._from_config(config) + # _build_model() -> selects EagleDrafterForCausalLM for model_type="llama" + # _build_model() -> EagleDrafterForCausalLM._from_config(config) print("Building model via factory.build_model('meta')...") model = factory.build_model("meta") print(f"Model type: {type(model).__name__}") @@ -1049,9 +1043,6 @@ def test_eagle_wrapper_forward(batch_size: int): base_model_path, _, eagle_model_path = get_model_paths() eagle_path = Path(eagle_model_path) - if not eagle_path.exists(): - pytest.skip("Eagle model not found (model missing)") - # Configuration capture_layers = {1, 15, 28} # Layers to capture for Eagle3 num_capture_layers = len(capture_layers) @@ -1308,3 +1299,198 @@ def test_eagle_wrapper_forward(batch_size: int): f" Eagle: {eagle_generated.tolist()}" ) print(f"✓ First {num_tokens_to_check} generated tokens match for all batches!") + + +def _load_valid_safetensors_index(index_path: Path): + """Load a safetensors index JSON, skipping invalid and Git-LFS pointer files.""" + if not index_path.exists(): + return None + + try: + index_text = index_path.read_text(encoding="utf-8") + except OSError: + return None + + if index_text.lstrip().startswith("version https://git-lfs.github.com/spec/v1"): + return None + + try: + index = json.loads(index_text) + except json.JSONDecodeError: + return None + + if not isinstance(index, dict): + return None + + weight_map = index.get("weight_map") + if not isinstance(weight_map, dict): + return None + + return index + + +def _analyze_mtp_weight_loading(model_path: Path, model): + """Analyze weight loading for MTP models with safetensors index. + + MTP checkpoints use multiple safetensors files with an index. This function + loads the checkpoint keys from the index and applies the model's + _checkpoint_conversion_mapping to determine which keys will be loaded. + + Args: + model_path: Path to the MTP model directory + model: The instantiated model + + Returns: + Tuple of (loaded_keys, missing_keys, unexpected_keys) + """ + # Load checkpoint keys from safetensors index + index_path = model_path / "model.safetensors.index.json" + index = _load_valid_safetensors_index(index_path) + if index is None: + raise ValueError( + "Expected a valid safetensors index JSON. " + f"Path was missing, malformed, or a Git-LFS pointer: {index_path}" + ) + + # Get MTP-specific checkpoint keys (those starting with "mtp.") + checkpoint_keys_original = [k for k in index["weight_map"].keys() if k.startswith("mtp.")] + if not checkpoint_keys_original: + raise ValueError(f"No mtp.* keys found in safetensors index: {index_path}") + + # Apply _checkpoint_conversion_mapping (same logic as hf.py _remap_param_names_load_hook) + conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) + checkpoint_keys_remapped = [] + + for key in checkpoint_keys_original: + new_key = key + if conversion_mapping: + for pattern, replacement in conversion_mapping.items(): + new_key = re.sub(pattern, replacement, new_key) + checkpoint_keys_remapped.append(new_key) + + # Get model's expected keys + model_keys = set(model.state_dict().keys()) + checkpoint_keys = set(checkpoint_keys_remapped) + + # Calculate differences + loaded_keys = checkpoint_keys & model_keys + missing_in_checkpoint = model_keys - checkpoint_keys + unexpected_in_checkpoint = checkpoint_keys - model_keys + + return loaded_keys, missing_in_checkpoint, unexpected_in_checkpoint + + +def test_nemotron_mtp_model_with_weights(): + """Test NemotronH MTP model weight loading using EagleDrafterFactory. + + This test verifies that: + 1. EagleDrafterFactory can create EagleDrafterForCausalLM with NemotronH layers + 2. Weights are correctly loaded with the mtp.* -> model.* key mapping + 3. All expected model parameters are loaded from the MTP checkpoint + + The MTP model uses a checkpoint that contains both backbone.* and mtp.* keys. + Only mtp.* keys are loaded. Shared parameters (embed_tokens, lm_head) are NOT + created in the model (load_embedding_from_target=True, load_lm_head_from_target=True), + so they don't appear as missing keys. They are shared from the target model at runtime. + """ + print("\n" + "=" * 80) + print("Test: NemotronH MTP model weight loading (via EagleDrafterFactory)") + print("=" * 80) + + mtp_model_name = "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" + + mtp_model_path = hf_id_to_local_model_dir(mtp_model_name) + mtp_path = Path(mtp_model_path) + index_path = mtp_path / "model.safetensors.index.json" + + # Check for a valid index JSON and verify it has mtp.* keys. + index = _load_valid_safetensors_index(index_path) + assert index is not None, ( + "Expected a valid safetensors index JSON. " + f"Path was missing, malformed, or a Git-LFS pointer: {index_path}" + ) + mtp_source_keys = {k for k in index["weight_map"].keys() if k.startswith("mtp.")} + assert mtp_source_keys, f"Expected at least one mtp.* key in {index_path}" + + # Setup device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Create factory - use EagleDrafterFactory for NemotronH MTP + print("Creating EagleDrafterFactory...") + factory = EagleDrafterFactory( + model=mtp_model_path, + skip_loading_weights=False, + ) + + # Build model using factory + print("Building model via factory.build_model('meta')...") + model = factory.build_model("meta") + print(f"Model type: {type(model).__name__}") + print(f"Model config type: {type(model.config).__name__}") + + # Verify model type is EagleDrafterForCausalLM + assert isinstance(model, EagleDrafterForCausalLM), ( + f"Expected EagleDrafterForCausalLM, got {type(model).__name__}" + ) + + # Analyze weight loading + print("\n--- Weight Loading Analysis ---") + loaded_keys, missing_keys, unexpected_keys = _analyze_mtp_weight_loading(mtp_path, model) + + print(f"Total model parameters: {len(loaded_keys) + len(missing_keys)}") + print(f"Total MTP checkpoint keys: {len(loaded_keys) + len(unexpected_keys)}") + print(f"✅ Weights to be loaded: {len(loaded_keys)}") + print(f"⚠️ Missing in checkpoint (should be 0): {len(missing_keys)}") + print(f"⚠️ Unexpected in checkpoint (should be 0): {len(unexpected_keys)}") + + if missing_keys: + print("\nMissing keys (expected - shared from target model):") + for key in sorted(missing_keys): + if "embed_tokens" in key: + print(f" - {key} (shared embedding from target)") + elif "lm_head" in key: + print(f" - {key} (shared lm_head from target)") + else: + print(f" - {key}") + + if unexpected_keys: + print("\nUnexpected keys (should not happen):") + for key in sorted(unexpected_keys): + print(f" - {key}") + + print("--- End Weight Analysis ---\n") + + # Verify expected missing and unexpected keys + # MTP checkpoint does NOT contain embed_tokens or lm_head, but that's OK because: + # - embed_tokens: shared from target model (load_embedding_from_target=True → model doesn't create it) + # - lm_head: shared from target model (load_lm_head_from_target=True → model doesn't create it) + # Since neither parameter is created in the model, they don't appear as missing keys. + # Note: For NemotronH, layers_handle_final_norm=True, so the wrapper doesn't create self.norm. + # The final norm is inside the layers (final_layernorm), which IS in the checkpoint. + expected_missing_keys = set() # All model params are loaded; shared params aren't created + expected_unexpected_keys = set() # All checkpoint keys should be used + + assert missing_keys == expected_missing_keys, ( + f"Unexpected missing keys.\n" + f"Expected: {expected_missing_keys}\n" + f"Got: {missing_keys}\n" + f"Extra missing: {missing_keys - expected_missing_keys}\n" + f"Not missing (but expected): {expected_missing_keys - missing_keys}" + ) + + assert unexpected_keys == expected_unexpected_keys, ( + f"Unexpected keys in checkpoint.\n" + f"Expected: {expected_unexpected_keys}\n" + f"Got: {unexpected_keys}\n" + f"Extra unexpected: {unexpected_keys - expected_unexpected_keys}" + ) + + print("✅ Weight loading analysis matches expected missing/unexpected keys!") + + # Load weights using factory + print("Loading weights via factory.load_or_random_init()...") + factory.load_or_random_init(model, device, disable_preload=True) + print("Weights loaded successfully via factory interface!") + + model.eval() + print("✅ NemotronH MTP model created and weights loaded successfully!") diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index d74631bc958f..f1c7849beb8c 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -562,3 +562,4 @@ llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging # AutoDeploy text generation accuracy tests accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B_Instruct_Eagle3::test_eagle3_one_model +accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[4] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index a1b678e54614..1378ce4efa40 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -235,3 +235,4 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx # AutoDeploy text generation accuracy tests accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B_Instruct_Eagle3::test_eagle3_one_model +accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[4] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 244f2307d5db..7603334c9344 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -324,6 +324,7 @@ l0_dgx_b200: - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[fp8-4-attn_dp_on-trtllm] - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[nvfp4-4-attn_dp_on-trtllm] - accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True] + - accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[4] # ------------- AutoDeploy Perf Sanity --------------- - perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws4_1k1k] TIMEOUT (120) - condition: diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 17d9f96e5f7e..8d87773b9675 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -487,4 +487,5 @@ l0_h100: - examples/test_ad_speculative_decoding.py::test_eagle_model_with_weights - examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[1] - examples/test_ad_speculative_decoding.py::test_eagle_wrapper_forward[2] + - examples/test_ad_speculative_decoding.py::test_nemotron_mtp_model_with_weights - examples/test_ad_export_onnx.py::test_ad_export_onnx[Qwen/Qwen2.5-3B-Instruct-/tmp/test_ad_export_onnx_qwen2.5-3b-36] diff --git a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py index c7346e79e063..e0a57da550dd 100644 --- a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py @@ -552,6 +552,30 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "num_hidden_layers": 8, }, }, + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": { + "model_kwargs": { + "num_hidden_layers": 1, + "layers_block_type": ["mamba"], + "hidden_size": 32, + "intermediate_size": 64, + "mamba_num_heads": 4, + "mamba_head_dim": 40, + "n_groups": 2, + "ssm_state_size": 32, + "conv_kernel": 4, + # MoE dimensions (used by the MTP/Eagle drafter's "E" layer) + "n_routed_experts": 4, + "n_shared_experts": 1, + "num_experts_per_tok": 2, + "moe_intermediate_size": 64, + "moe_shared_expert_intermediate_size": 64, + "moe_latent_size": 16, + "n_group": 1, + "topk_group": 1, + "num_attention_heads": 4, + "num_key_value_heads": 2, + }, + }, "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": { "model_kwargs": { "hidden_size": 64, diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py index 7a254b984c63..e449789c8aac 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py @@ -64,6 +64,7 @@ def test_flashinfer_decode_matches_triton(mamba_env): None, # seq_idx_prefill # CACHES ssm_state_cache_triton, + None, # CONSTANTS time_step_limit, chunk_size, diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py index 5e77ee236cd3..0562488c6372 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py @@ -103,6 +103,7 @@ def test_triton_generate_only_with_slot_mapping(mamba_env): None, # seq_idx_prefill # CACHES ssm_state_cache_triton, + None, # CONSTANTS time_step_limit, chunk_size, @@ -203,6 +204,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): seq_idx_prefill, # CACHES ssm_state_cache_triton, + None, # CONSTANTS time_step_limit, chunk_size, diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py index ec2d4e23d506..eff235a7ad0e 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py @@ -12,9 +12,13 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( AttentionDescriptor, + CausalConvResourceHandler, KVPagedResourceHandler, ResourceHandler, SequenceInfo, + SpecCausalConvResourceHandler, + SpecSSMResourceHandler, + SSMResourceHandler, StateResourceHandler, UnpagedResourceHandler, ) @@ -274,8 +278,6 @@ def test_kv_paged_handler_eq_different_head_dim_or_dtype(): def test_ssm_handler_eq_same_params(): """Verify SSMResourceHandler __eq__ for same parameters.""" - from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SSMResourceHandler - h1 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16) h2 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16) @@ -284,8 +286,6 @@ def test_ssm_handler_eq_same_params(): def test_ssm_handler_eq_different_params(): """Verify SSMResourceHandler __eq__ returns False for different parameters.""" - from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SSMResourceHandler - h1 = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16) h2 = SSMResourceHandler( num_heads=4, head_dim=64, d_state=16, dtype=torch.bfloat16 @@ -306,10 +306,6 @@ def test_ssm_handler_eq_different_params(): def test_conv_handler_eq_same_params(): """Verify CausalConvResourceHandler __eq__ for same parameters.""" - from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( - CausalConvResourceHandler, - ) - h1 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32) h2 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32) @@ -318,10 +314,6 @@ def test_conv_handler_eq_same_params(): def test_conv_handler_eq_different_params(): """Verify CausalConvResourceHandler __eq__ returns False for different parameters.""" - from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( - CausalConvResourceHandler, - ) - h1 = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32) h2 = CausalConvResourceHandler(conv_dim=512, d_conv=4, dtype=torch.float32) # diff conv_dim h3 = CausalConvResourceHandler(conv_dim=256, d_conv=5, dtype=torch.float32) # diff d_conv @@ -330,3 +322,50 @@ def test_conv_handler_eq_different_params(): assert h1 != h2 assert h1 != h3 assert h1 != h4 + + +def test_spec_ssm_handler_from_base(): + """Verify SpecSSMResourceHandler mirrors base SSM dims and remains a distinct type.""" + base = SSMResourceHandler(num_heads=8, head_dim=64, d_state=16, dtype=torch.bfloat16) + spec = SpecSSMResourceHandler.from_base(base) + matching_spec = SpecSSMResourceHandler(8, 64, 16, dtype=torch.bfloat16) + + assert base.state_shape == (8, 64, 16) + assert spec.state_shape == (8, 64, 16) + assert spec.num_heads == base.num_heads + assert spec.head_dim == base.head_dim + assert spec.d_state == base.d_state + assert spec.dtype == base.dtype + + assert isinstance(spec, SpecSSMResourceHandler) + assert not isinstance(spec, SSMResourceHandler) + assert base != spec + assert spec == matching_spec + + +def test_spec_ssm_handler_from_base_none(): + """Verify SpecSSMResourceHandler.from_base(None) returns None.""" + assert SpecSSMResourceHandler.from_base(None) is None + + +def test_spec_conv_handler_from_base(): + """Verify SpecCausalConvResourceHandler mirrors base conv dims and remains a distinct type.""" + base = CausalConvResourceHandler(conv_dim=256, d_conv=4, dtype=torch.float32) + spec = SpecCausalConvResourceHandler.from_base(base) + matching_spec = SpecCausalConvResourceHandler(256, 4, dtype=torch.float32) + + assert base.state_shape == (256, 3) + assert spec.state_shape == (256, 3) + assert spec.conv_dim == base.conv_dim + assert spec.d_conv == base.d_conv + assert spec.dtype == base.dtype + + assert isinstance(spec, SpecCausalConvResourceHandler) + assert not isinstance(spec, CausalConvResourceHandler) + assert base != spec + assert spec == matching_spec + + +def test_spec_conv_handler_from_base_none(): + """Verify SpecCausalConvResourceHandler.from_base(None) returns None.""" + assert SpecCausalConvResourceHandler.from_base(None) is None diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py index ad75cd05f676..c5932aefbe81 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_triton_causal_conv_cached_op.py @@ -102,6 +102,7 @@ def test_generate_only_triton_vs_cuda(conv_env): slot_idx, use_initial_states, conv_state_cache_triton, + None, s, p, d, @@ -189,6 +190,7 @@ def test_context_flattened_triton_vs_cuda(conv_env): slot_idx, use_initial_states, conv_state_cache_triton, + None, s, p, d, @@ -279,6 +281,7 @@ def test_mixed_prefill_decode_triton_vs_cuda(conv_env): slot_idx, use_initial_states, conv_state_cache_triton, + None, s, p, d, @@ -369,6 +372,7 @@ def test_larger_batch_triton_vs_cuda(conv_env): slot_idx, use_initial_states, conv_state_cache_triton, + None, s, p, d, diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_eagle.py b/tests/unittest/auto_deploy/singlegpu/models/test_eagle.py index 85d2d2d2f4c4..dca162dc5b9b 100644 --- a/tests/unittest/auto_deploy/singlegpu/models/test_eagle.py +++ b/tests/unittest/auto_deploy/singlegpu/models/test_eagle.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,9 +23,9 @@ from build_and_run_ad import ExperimentConfig, main from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( - Eagle3DrafterForCausalLM, Eagle3DraftOutput, EagleConfig, + EagleDrafterForCausalLM, ) from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry @@ -58,7 +58,7 @@ class MockEagleConfig(EagleConfig): } -class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): +class MockEagle3ModelForCausalLM(EagleDrafterForCausalLM): """Test wrapper that provides random hidden states for standalone Eagle testing. In production speculative decoding, real hidden states come from the target model. @@ -186,7 +186,7 @@ def test_eagle_model_torch_export(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 - # Create model via EagleDrafterFactory (creates Eagle3DrafterForCausalLM) + # Create model via EagleDrafterFactory (creates EagleDrafterForCausalLM) factory = EagleDrafterFactory(model=str(eagle_path), skip_loading_weights=True) model = factory.build_model(device) config = model.config diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py index 11bcd9e6849e..6a5f8d1500c6 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py @@ -13,6 +13,8 @@ CausalConvResourceHandler, KVPagedResourceHandler, SequenceInfo, + SpecCausalConvResourceHandler, + SpecSSMResourceHandler, SSMResourceHandler, StateResourceHandler, UnpagedResourceHandler, @@ -20,6 +22,7 @@ from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.llmapi import DraftTargetDecodingConfig from tensorrt_llm.llmapi.llm_args import KvCacheConfig # ============================================================================= @@ -781,6 +784,75 @@ def test_typed_handlers_inherit_from_state_resource_handler(): assert isinstance(conv_handler, StateResourceHandler) +def test_intermediate_state_resources_bind_via_managed_state_path(paged_kv_cache_config): + """Test that speculative Mamba state handlers bind through the managed cache path.""" + spec_config = DraftTargetDecodingConfig( + max_draft_len=2, + speculative_model="dummy-model", + ) + interface = CachedSequenceInterface( + max_seq_len=128, + max_batch_size=4, + device="cuda", + kv_cache_config=paged_kv_cache_config, + spec_config=spec_config, + ) + + num_heads = 4 + head_dim = 64 + d_state = 16 + conv_dim = head_dim * num_heads + 2 * 2 * d_state + resource_names = [] + + for i in range(2): + resource_names.append( + interface.add_resource( + f"ssm_state_{i}", + SSMResourceHandler( + num_heads=num_heads, + head_dim=head_dim, + d_state=d_state, + dtype=torch.bfloat16, + ), + ) + ) + resource_names.append( + interface.add_resource( + f"intermediate_ssm_state_{i}", + SpecSSMResourceHandler( + num_heads=num_heads, + head_dim=head_dim, + d_state=d_state, + dtype=torch.bfloat16, + ), + ) + ) + resource_names.append( + interface.add_resource( + f"conv_state_{i}", + CausalConvResourceHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32), + ) + ) + resource_names.append( + interface.add_resource( + f"intermediate_conv_state_{i}", + SpecCausalConvResourceHandler( + conv_dim=conv_dim, + d_conv=4, + dtype=torch.float32, + ), + ) + ) + + interface.initialize_resources() + + for resource_name in resource_names: + cache = interface._caches[resource_name] + assert cache is not None + assert cache.is_contiguous() + assert resource_name not in interface._unmanaged_resources + + def test_multiple_ssm_resources_contiguous_views(paged_kv_cache_config): """Test that multiple SSM resources get contiguous views from MambaHybridCacheManager.""" interface = CachedSequenceInterface( diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py index 864576b8e3ba..c833b4b680e7 100644 --- a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,12 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest +import torch from _model_test_utils import get_small_model_config from build_and_run_ad import ExperimentConfig, main -from test_common.llm_data import with_mocked_hf_download_for_single_gpu +from test_common.llm_data import hf_id_to_local_model_dir, with_mocked_hf_download_for_single_gpu + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.eagle import EagleOneModelFactory +from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig +from tensorrt_llm._torch.auto_deploy.transform.library.hidden_states import ( + DetectHiddenStatesForCapture, +) +from tensorrt_llm._torch.speculative import get_num_extra_kv_tokens +from tensorrt_llm.llmapi import ( + DraftTargetDecodingConfig, + Eagle3DecodingConfig, + KvCacheConfig, + MTPDecodingConfig, +) -from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig + +def get_extra_seq_len_for_kv_cache(llm_args) -> int: + """Mirror the current extra-KV sizing logic used by the runtime.""" + extra = 0 + spec_config = llm_args.speculative_config + if not llm_args.disable_overlap_scheduler: + extra += 1 + if spec_config is not None: + extra += spec_config.tokens_per_gen_step - 1 + + if spec_config is not None: + extra += spec_config.tokens_per_gen_step - 1 + extra += get_num_extra_kv_tokens(spec_config) + + return extra @pytest.mark.skip( @@ -83,3 +113,170 @@ def test_ad_speculative_decoding_smoke(use_hf_speculative_model: bool): assert len(generated_text) > 0, "Generated text should not be empty" print("Speculative decoding smoke test passed!") + + +def test_super_mtp_smoke(): + """Test one-model MTP/Eagle runtime with a tiny Nemotron SuperV3 target.""" + test_prompt = "What is the capital of France?" + model_hub_id = "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16" + model_path = hf_id_to_local_model_dir(model_hub_id) + + experiment_config = get_small_model_config( + model_hub_id, + transforms={ + "insert_cached_causal_conv": {"backend": "triton_causal_conv"}, + "insert_cached_ssm_attention": {"backend": "triton_ssm"}, + }, + ) + experiment_config["args"]["model"] = model_path + experiment_config["args"]["runtime"] = "trtllm" + experiment_config["args"]["world_size"] = 1 + experiment_config["args"]["speculative_config"] = MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + speculative_model=model_path, + ) + # Shrink the Eagle/MTP drafter model to match the target's reduced dimensions. + experiment_config["args"]["speculative_model_kwargs"] = experiment_config["args"][ + "model_kwargs" + ] + experiment_config["args"]["disable_overlap_scheduler"] = True + experiment_config["args"]["compile_backend"] = "torch-simple" + experiment_config["args"]["max_num_tokens"] = 256 + experiment_config["prompt"]["batch_size"] = 1 + experiment_config["prompt"]["queries"] = test_prompt + + cfg = ExperimentConfig(**experiment_config) + cfg.prompt.sp_kwargs = { + "max_tokens": 64, + "top_k": None, + "temperature": 0.0, + "seed": 42, + } + + results = main(cfg) + + prompts_and_outputs = results["prompts_and_outputs"] + assert len(prompts_and_outputs) == 1 + prompt, _generated_text = prompts_and_outputs[0] + assert prompt == test_prompt + + +def test_kv_cache_extra_seq_len_for_spec_dec(): + """Test that get_extra_seq_len_for_kv_cache computes correct extra capacity.""" + from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + + # Case 1: No spec config, no overlap + args_no_spec = LlmArgs( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + disable_overlap_scheduler=True, + ) + assert get_extra_seq_len_for_kv_cache(args_no_spec) == 0 + + # Case 2: No spec config, with overlap + args_overlap = LlmArgs( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + disable_overlap_scheduler=False, + ) + assert get_extra_seq_len_for_kv_cache(args_overlap) == 1 # overlap adds +1 + + # Case 3: Eagle3 one-model, overlap disabled + spec_config = Eagle3DecodingConfig( + max_draft_len=3, + speculative_model="some/model", + eagle3_one_model=True, + ) + args_eagle = LlmArgs( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + speculative_config=spec_config, + disable_overlap_scheduler=True, + ) + extra = get_extra_seq_len_for_kv_cache(args_eagle) + # Should include max_total_draft_tokens + get_num_extra_kv_tokens (max_draft_len - 1) + assert extra > 0 + assert extra == spec_config.max_total_draft_tokens + (spec_config.max_draft_len - 1) + + # Case 4: Eagle3 one-model, overlap enabled + args_eagle_overlap = LlmArgs( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + speculative_config=spec_config, + disable_overlap_scheduler=False, + ) + extra_overlap = get_extra_seq_len_for_kv_cache(args_eagle_overlap) + # Should be more than without overlap + assert extra_overlap > extra + + +def test_mtp_autodeploy_uses_eagle_one_model_capture(): + from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + + model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + args = LlmArgs( + model=model, + speculative_config=MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + ), + ) + + assert isinstance(args.speculative_config, MTPDecodingConfig) + assert args.model_factory == "eagle_one_model" + assert args.transforms["detect_hidden_states_for_capture"]["enabled"] is True + assert args.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] == {-1} + + +def test_detect_hidden_states_capture_last_layer_for_mtp_eagle_one_model(): + from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs + + config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct") + + args = LlmArgs( + **config["args"], + speculative_config=MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + speculative_model=config["args"]["model"], + ), + ) + + factory = args.create_factory() + assert isinstance(factory, EagleOneModelFactory) + + model = factory.target_factory.build_model("meta") + input_ids = torch.ones((1, 8), dtype=torch.int64) + position_ids = torch.arange(8, dtype=torch.int64).unsqueeze(0) + gm = torch_export_to_gm( + model, + args=(input_ids, position_ids), + ) + + transform = DetectHiddenStatesForCapture( + config=TransformConfig( + stage="pattern_matcher", + eagle3_layers_to_capture={-1}, + ) + ) + + original_residual_nodes = transform.collect_residual_add_nodes(gm) + assert original_residual_nodes + last_layer = max(original_residual_nodes) + last_layer_residual = original_residual_nodes[last_layer] + expected_arg_names = tuple( + arg.name if isinstance(arg, torch.fx.Node) else arg for arg in last_layer_residual.args + ) + + gm, info = transform._apply(gm, None, None, None) + + capture_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" + and node.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + + assert info.num_matches == 1 + assert len(capture_nodes) == 1 + capture_arg_names = tuple( + arg.name if isinstance(arg, torch.fx.Node) else arg for arg in capture_nodes[0].args + ) + assert capture_arg_names == expected_arg_names From b9ba730a631300f3fc5ce49b35916d864afa6b0b Mon Sep 17 00:00:00 2001 From: Michal Guzek Date: Thu, 2 Apr 2026 09:23:21 -0700 Subject: [PATCH 2/8] [TRTLLM-11163][feat] Introduce a fast path (token IDs + MM) for VLMs instead of de-tokenizing already encoded prompt (#11708) Signed-off-by: Michal Guzek --- .../_torch/models/modeling_llava_next.py | 197 +++++++++++++----- tensorrt_llm/inputs/registry.py | 138 +++++++++++- tensorrt_llm/llmapi/llm.py | 34 ++- .../test_lists/test-db/l0_l40s.yml | 1 + .../modeling/test_modeling_llava_next.py | 36 +++- 5 files changed, 341 insertions(+), 65 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index a2c0e540ebf5..92068bd09449 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -20,7 +20,7 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder, BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, - MultimodalPlaceholderPlacement, TextPrompt, + MultimodalPlaceholderPlacement, TextPrompt, TokensPrompt, register_input_processor, support_multimodal_disaggregated) from ...logger import logger @@ -86,6 +86,102 @@ def processor(self) -> AutoProcessor: def dtype(self) -> torch.dtype: return self._dtype + def get_text_with_mm_placeholders(self, mm_counts: Dict[str, int]) -> str: + """ + Return minimal placeholder text for the given multimodal item counts, + so that the HF processor can be called with (dummy_text, mm_data) without error. + Used when processing tokenized prompt + MM data. + + Args: + mm_counts (Dict[str, int]): A mapping of each multimodal modality name (e.g., 'image', 'video') + to the count of items for that modality that need corresponding placeholders in the dummy text. + + Returns: + str: A minimal placeholder string containing the correct number and type of multimodal placeholders, + suitable for passing along with mm_data to the Hugging Face processor. + """ + num_images = mm_counts.get("image", 0) + processor = self.processor + image_token = processor.image_token + return image_token * num_images + + def _expand_image_placeholders_in_token_ids( + self, + prompt_token_ids: List[int], + num_mm_tokens_per_placeholder: List[int], + ) -> Tuple[List[int], List[int], List[int]]: + """ + Shared logic (called by expand_prompt_token_ids_for_mm and get_prompt_token_ids): + replace each image placeholder token in prompt_token_ids + with placeholder_id repeated num_mm_tokens_per_placeholder[i] times. + + Returns: + expanded_ids (List[int]): The new prompt token IDs with each image placeholder replaced by the correct number of MM tokens. + mm_token_lengths (List[int]): Number of MM tokens inserted for each placeholder, in order. + mm_token_offsets (List[int]): Offset (position) in the expanded sequence where each MM token group (for each placeholder) begins. + """ + image_token_id = self.config.image_token_index + placeholder_id = self.vocab_size + 1 + + expanded: List[int] = [] + mm_token_lengths: List[int] = [] + mm_token_offsets: List[int] = [] + image_idx = 0 + for tok in prompt_token_ids: + if tok == image_token_id: + if image_idx >= len(num_mm_tokens_per_placeholder): + raise ValueError( + "More image placeholder tokens in prompt than " + "num_mm_tokens_per_placeholder entries: " + f"found {image_idx + 1} placeholders, " + f"num_mm_tokens_per_placeholder has {len(num_mm_tokens_per_placeholder)} entries." + ) + n = num_mm_tokens_per_placeholder[image_idx] + mm_token_offsets.append(len(expanded)) + expanded.extend([placeholder_id] * n) + mm_token_lengths.append(n) + image_idx += 1 + else: + expanded.append(tok) + + if image_idx != len(num_mm_tokens_per_placeholder): + raise ValueError( + f"Expected {len(num_mm_tokens_per_placeholder)} image placeholders, " + f"found {image_idx}. Ensure the prompt contains the model image " + f"placeholder (token id {image_token_id}).") + return expanded, mm_token_lengths, mm_token_offsets + + def expand_prompt_token_ids_for_mm( + self, + prompt_token_ids: List[int], + num_mm_tokens_per_placeholder: List[int], + hf_processor_mm_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[int]: + """ + Expands MM placeholder tokens in `prompt_token_ids` so that each single placeholder + is replaced by the corresponding number of multimodal feature tokens. + + This is used when processing a tokenized prompt plus multimodal data, without calling the full + HuggingFace processor. + + Subclasses that require the HuggingFace processor or feature extractor (for example, + to determine image-size-dependent token counts) can use `hf_processor_mm_kwargs` if provided. + + Args: + prompt_token_ids (List[int]): The input prompt token IDs with image placeholder tokens. + num_mm_tokens_per_placeholder (List[int]): For each MM placeholder in prompt_token_ids, + specifies the number of MM feature tokens to expand/repeat for that placeholder. + hf_processor_mm_kwargs (Optional[Dict[str, Any]]): Optional dictionary of arguments + to pass to the HuggingFace processor, if needed for token expansion. + + Returns: + List[int]: The prompt token IDs where each MM placeholder token has been + replaced/expanded with the appropriate number of MM feature tokens. + """ + expanded, _, _ = self._expand_image_placeholders_in_token_ids( + prompt_token_ids, num_mm_tokens_per_placeholder) + return expanded + def _postprocess( self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor, List[torch.Tensor]] @@ -169,14 +265,16 @@ def _postprocess( return fused_input_ids, mm_features def get_prompt_token_ids( - self, inputs: TextPrompt, + self, inputs: Union[TextPrompt, TokensPrompt], mm_handles: List[Dict[str, Any]]) -> Tuple[List[int], List[int], List[int]]: """ Build input token ids with multimodal placeholders expanded to the number of MM tokens. + Uses an already tokenized prompt or tokenizes the txt prompt first. + Args: - inputs: Text prompt input container. Must contain a non-empty prompt string. + inputs: Inputs containing an already tokenized text prompt or a text prompt string. mm_handles: List of multimodal embedding handles. Returns: @@ -187,8 +285,13 @@ def get_prompt_token_ids( """ # TODO: Move this function to the base input processor class when extending for more models text_prompt = inputs.get("prompt") - if not text_prompt: - raise ValueError("Text prompt is required but not provided") + prompt_token_ids = inputs.get("prompt_token_ids") + if text_prompt: + prompt_token_ids = self.tokenizer( + text_prompt, return_tensors="pt").input_ids[0].tolist() + elif not prompt_token_ids: + raise ValueError( + "Text prompt or token IDs are required but neither is provided") if not isinstance(mm_handles, list): raise ValueError("mm_handles must be a list") @@ -200,49 +303,27 @@ def get_prompt_token_ids( raise RuntimeError( f"Multimodal embedding {i} hidden size {hidden_size} must match model hidden size {expected_hidden_size}" ) - input_ids = self.tokenizer(text_prompt, - return_tensors="pt").input_ids[0] - - vocab_size = self.config.text_config.vocab_size - image_token_index = self.config.image_token_index - - image_mask = input_ids == image_token_index - image_positions = torch.where(image_mask)[0] - num_images = len(image_positions) - assert num_images == len( - mm_handles), "Number of images must match number of mm_handles" - total_mm_tokens = sum(mm_handle["tensor_size"][0] - for mm_handle in mm_handles) - final_length = len(input_ids) - num_images + total_mm_tokens - # Create output tensor - expanded_ids = torch.empty(final_length, dtype=input_ids.dtype) - placeholder_id = vocab_size + 1 - - # Fill the expanded sequence - write_pos = 0 - image_cnt = 0 - mm_token_length = [] - mm_token_offsets = [] - for read_pos in range(len(input_ids)): - if input_ids[read_pos] == image_token_index: - # Replace with placeholder id - mm_token_num = mm_handles[image_cnt]["tensor_size"][0] - expanded_ids[write_pos:write_pos + mm_token_num] = \ - placeholder_id - mm_token_offsets.append(write_pos) - mm_token_length.append(mm_token_num) - write_pos += mm_token_num - image_cnt += 1 - else: - # Copy text token as-is - expanded_ids[write_pos] = input_ids[read_pos] - write_pos += 1 - assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}" - assert mm_token_length[-1] + mm_token_offsets[ - -1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})" - return expanded_ids.to( - torch.int32).tolist(), mm_token_length, mm_token_offsets + num_mm_tokens_per_image = [h["tensor_size"][0] for h in mm_handles] + expanded_ids, mm_token_length, mm_token_offsets = ( + self._expand_image_placeholders_in_token_ids( + prompt_token_ids, num_mm_tokens_per_image)) + + # Final assertions to check the correctness of the expanded ids. + final_length = len(expanded_ids) + expected_final_length = (len(prompt_token_ids) - + len(num_mm_tokens_per_image) + + sum(num_mm_tokens_per_image)) + assert final_length == expected_final_length, ( + f"Write position mismatch: {final_length} != {expected_final_length}" + ) + if mm_token_length: + assert mm_token_length[-1] + mm_token_offsets[-1] <= final_length, ( + f"mm_token_length[-1] + mm_token_offsets[-1] " + f"({mm_token_length[-1] + mm_token_offsets[-1]}) should be less " + f"than or equal to final_length ({final_length})") + + return expanded_ids, mm_token_length, mm_token_offsets def attach_multimodal_embeddings( self, inputs: TextPrompt, @@ -255,17 +336,13 @@ def attach_multimodal_embeddings( It replaces/expands image placeholders in the text with appropriate tokens and prepares the embeddings for model forward pass. Args: - inputs: Text prompt containing image placeholders + inputs: Text prompt containing image placeholders, or prompt_token_ids (list of int) multimodal_embedding: Dictionary containing pre-processed image embedding data Returns: Tuple of (token_ids, extra_processed_inputs) where: - token_ids: List of processed token IDs with image placeholders - extra_processed_inputs: Optional dictionary containing multimodal embeddings """ - text_prompt = inputs.get("prompt") - if not text_prompt: - raise ValueError("Text prompt is required but not provided") - if not isinstance(multimodal_embedding, dict): raise ValueError("multimodal_embedding must be a dictionary") @@ -274,9 +351,19 @@ def attach_multimodal_embeddings( "Only image modality is supported for external multimodal embedding" ) - input_ids = self.tokenizer(text_prompt, - return_tensors="pt").input_ids[0] - mm_features = multimodal_embedding['image'] + prompt_token_ids = inputs.get("prompt_token_ids") + if prompt_token_ids is not None: + # Token IDs already provided (e.g. tokenized+MM path): use directly, skip tokenization. + input_ids = torch.tensor(prompt_token_ids, dtype=torch.long) + else: + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError( + "Either 'prompt' (text) or 'prompt_token_ids' is required") + input_ids = self.tokenizer(text_prompt, + return_tensors="pt").input_ids[0] + + mm_features = multimodal_embedding["image"] fused_input_ids, mm_features = self._postprocess(input_ids, mm_features) multimodal_data = {} multimodal_data["multimodal_embedding"] = mm_features diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 7d2f1bcb2404..5cab205f5385 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -126,6 +126,12 @@ class BaseMultimodalInputProcessor(ABC): This class provides default implementations that work with most AutoProcessor-based models. Specific processors can override these methods if they need custom logic. + + Optional tokenized+MM fast path: to support prompt_token_ids + multi_modal_data + without detokenizing, implement get_text_with_mm_placeholders(mm_counts) and + expand_prompt_token_ids_for_mm(prompt_token_ids, num_mm_tokens_per_placeholder, ...). + If these are not implemented, the pipeline detokenizes the text prompt first and then + processes the multimodal inputs. """ def __init__(self, @@ -648,6 +654,55 @@ def create_input_processor( return DefaultInputProcessor(None, None, tokenizer) +def _mm_data_to_counts(mm_data: Dict[str, Any]) -> Dict[str, int]: + """Normalize multimodal data to per-key counts (each value as list length).""" + mm_items = { + k: (v if isinstance(v, list) else [v]) + for k, v in mm_data.items() + } + return {k: len(v) for k, v in mm_items.items()} + + +def _process_multimodal_with_dummy_placeholders( + input_processor: BaseMultimodalInputProcessor, + mm_data: Dict[str, Any], + mm_counts: Dict[str, int], + mm_processor_kwargs: Optional[Dict[str, Any]], + sampling_params: SamplingParams, +) -> ExtraProcessedInputs: + """Run input_processor with dummy text placeholders for multi-modal slots; return extra processed inputs.""" + dummy_text = input_processor.get_text_with_mm_placeholders(mm_counts) + dummy_inputs = TextPrompt( + prompt=dummy_text, + multi_modal_data=mm_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + # input_processor runs the HF processor / vision encoder on mm_data (e.g. images). + # extra_processed_inputs contains the processed MM data keyed under "multimodal_data"; + # it is reused later with the real token IDs so we do not run the vision encoder again. + _, extra_processed_inputs = input_processor(dummy_inputs, sampling_params) + if extra_processed_inputs is None: + return {} + return extra_processed_inputs + + +def _get_single_mm_token_lengths( + mm_data: Dict[str, Any], + input_processor: BaseMultimodalInputProcessor, +) -> Optional[List[int]]: + """Get the single set of MM token lengths (first value from find_mm_token_lengths). Returns None if empty.""" + num_mm_tokens_by_key = find_mm_token_lengths(mm_data, input_processor) + if not num_mm_tokens_by_key: + return None + # find_mm_token_lengths returns Dict[modality, List[int]], e.g. {"image": [2928, 2928]}. + # We need the list of per-item lengths (for find_mm_token_positions), We take the first modality's + # list; multi-modality is not yet supported (see TODO in multimodal_hashing_process). + num_mm_tokens = next(iter(num_mm_tokens_by_key.values())) + if len(num_mm_tokens) <= 0: + return None + return num_mm_tokens + + def create_input_processor_with_hash( input_processor: BaseMultimodalInputProcessor, hash_lib=default_hasher, @@ -663,11 +718,64 @@ def create_input_processor_with_hash( A wrapped processor that modifies prompts before processing. """ - def multimodal_hashing_process( + def tokenized_multimodal_process( inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ - Process the multinmodal hashing for media tokens if possible. + Process prompt_token_ids and multi_modal_data without detokenizing. + + Runs the input processor with dummy text placeholders for multi-modal slots, + then replaces placeholder token IDs with the actual feature token IDs and + delegates to multimodal_hashing_process. + + Args: + inputs: TextPrompt with "prompt_token_ids" and "multi_modal_data" (and optional "mm_processor_kwargs"). + sampling_params: Sampling parameters for the input processor. + + Returns: + (prompt_token_ids, extra_processed_inputs) from multimodal_hashing_process. + ([], None) if multi-modal token lengths cannot be determined. + """ + prompt_token_ids = inputs["prompt_token_ids"] + mm_data = inputs["multi_modal_data"] + mm_counts = _mm_data_to_counts(mm_data) + extra_processed_inputs = _process_multimodal_with_dummy_placeholders( + input_processor, + mm_data, + mm_counts, + inputs.get("mm_processor_kwargs"), + sampling_params, + ) + num_mm_tokens = _get_single_mm_token_lengths(mm_data, input_processor) + if num_mm_tokens is None: + raise ValueError( + "tokenized_multimodal_process: find_mm_token_lengths returned " + "no token lengths for the provided multi_modal_data.") + + expanded_ids = input_processor.expand_prompt_token_ids_for_mm( + prompt_token_ids, + num_mm_tokens, + hf_processor_mm_kwargs=inputs.get("mm_processor_kwargs")) + return multimodal_hashing_process( + inputs, + sampling_params, + precomputed_token_ids=expanded_ids, + precomputed_extra=extra_processed_inputs, + ) + + def multimodal_hashing_process( + inputs: TextPrompt, + sampling_params: SamplingParams, + *, + precomputed_token_ids: Optional[List[int]] = None, + precomputed_extra: Optional[ExtraProcessedInputs] = None, + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Process multimodal hashing for media tokens if possible. + + precomputed_token_ids and precomputed_extra must be provided together or + both be None. When both are provided (tokenized+MM path), skips the + input_processor call and uses them; when both are None, calls input_processor. Supports optional user-provided UUIDs via 'multi_modal_uuids' in inputs. When a UUID is provided for a multimodal item, it will be used as the @@ -680,8 +788,17 @@ def multimodal_hashing_process( mm_uuids = inputs.get('multi_modal_uuids', None) mm_hashes, mm_uuid_list = apply_mm_hashes(mm_data, mm_uuids, hash_lib) - prompt_token_ids, extra_processed_inputs = input_processor( - inputs, sampling_params) + + if precomputed_token_ids is not None and precomputed_extra is not None: + prompt_token_ids = precomputed_token_ids + extra_processed_inputs = precomputed_extra + elif precomputed_token_ids is None and precomputed_extra is None: + prompt_token_ids, extra_processed_inputs = input_processor( + inputs, sampling_params) + else: + raise ValueError( + "precomputed_token_ids and precomputed_extra must be provided " + "together or both be None; got one without the other.") num_mm_tokens = find_mm_token_lengths(mm_data, input_processor) # TODO: here we assume there is only one modality for now @@ -724,6 +841,19 @@ def multimodal_hashing_process( def input_processor_wrapper( inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + # Tokenized prompt + multi_modal_data + if (inputs.get("prompt_token_ids") is not None + and inputs.get("multi_modal_data") is not None + and inputs.get("prompt") is None): + if hasattr(input_processor, + "get_text_with_mm_placeholders") and hasattr( + input_processor, "expand_prompt_token_ids_for_mm"): + try: + return tokenized_multimodal_process(inputs, sampling_params) + except Exception as e: + logger.warning(f"Tokenized+MM path failed: {e}") + raise + try_multimodal_hashing = False # only used for first time use_multimodal_hashing = False # used for subsequent calls modalities = list(set(inputs['multi_modal_data'].keys()) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index caf3adeec83a..4f207cf4403f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -503,11 +503,24 @@ def _preprocess( """ inputs = prompt_inputs(inputs) + # A fast path for token IDs & MM data is available for a VLM if the input processor has the following methods. + # TODO: Once all the VLMs support the fast path, remove this flag and modify the remaining logic accordingly. + use_token_ids_for_mm_placeholders = ( + hasattr(self.input_processor, "get_text_with_mm_placeholders") + and hasattr(self.input_processor, "expand_prompt_token_ids_for_mm")) + + # This IF branch is applicable, whenever: + # - multimodal data is present (whether through embeddings or as preprocessed data), AND + # - token IDs are present, AND + # - two methods defining the placeholder token IDs expansion logic are not available. if not inputs.get("prompt") and inputs.get("prompt_token_ids") and ( inputs.get("multi_modal_data") or inputs.get("multi_modal_embeddings")) and not isinstance( - self.input_processor, DefaultInputProcessor): - # VLMs need to process/tokenize the prompt in their own way + self.input_processor, DefaultInputProcessor + ) and not use_token_ids_for_mm_placeholders: + # VLMs need to process/tokenize the prompt in their own way, + # if they don't have the fast path for token IDs & MM data implemented yet. + # TODO: Once all the VLMs support the fast path, we can remove this detokenization step entirely. prompt = self.tokenizer.decode(inputs['prompt_token_ids']) inputs = TextPrompt( prompt=prompt, @@ -529,6 +542,8 @@ def _preprocess( multimodal_params = None prompt = None + # This branch is applicable for Encode --> Prefill handoff scenario, + # in E/P/D/ and E/PD settings. Prefill worker executes this code path. if is_mm_disagg: if not getattr(self.input_processor, "support_mm_disagg", False): raise ValueError( @@ -561,7 +576,11 @@ def _preprocess( multimodal_input=multimodal_input, multimodal_data=multimodal_data, ) - elif "prompt_token_ids" in inputs: + # This condition is to ensure that this branch is not hit for models that expand + # placeholder token IDs with MM data. + elif ("prompt_token_ids" in inputs + and inputs.get("multi_modal_data") is None + and inputs.get("multi_modal_embeddings") is None): prompt_token_ids = inputs['prompt_token_ids'] query_token_ids = inputs.get("query_token_ids", None) multimodal_data = {} @@ -581,7 +600,11 @@ def _preprocess( if multimodal_data: multimodal_params = MultimodalParams( multimodal_data=multimodal_data) - elif "prompt" in inputs: + # This is the fast path for token IDs & MM data, as well as the slow path for text prompt and/or MM data, + # for both encode or aggregated workers. + elif "prompt" in inputs or ("prompt_token_ids" in inputs and + (("multi_modal_data" in inputs + or "multi_modal_embeddings" in inputs))): if 'multi_modal_data' in inputs: # TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash) # to handle/add multimodal hashes, positions, and lengths. Now we only support image modality. @@ -603,7 +626,8 @@ def _preprocess( with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor( inputs, sampling_params) - prompt = inputs['prompt'] + prompt = inputs.get( + "prompt") # This is the text prompt, if present. if extra_processed_inputs is not None: query_token_ids = extra_processed_inputs.get('query_token_ids') # Create unified MultimodalParams diff --git a/tests/integration/test_lists/test-db/l0_l40s.yml b/tests/integration/test_lists/test-db/l0_l40s.yml index a80e53c000b7..0bf5f82f7a26 100644 --- a/tests/integration/test_lists/test-db/l0_l40s.yml +++ b/tests/integration/test_lists/test-db/l0_l40s.yml @@ -21,6 +21,7 @@ l0_l40s: - unittest/_torch/modeling -k "modeling_nemotron_nano_v2_vl" - unittest/_torch/modeling -k "modeling_phi4mm" - unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all + - unittest/_torch/modeling/test_modeling_llava_next.py::test_llava_next_expand_prompt_token_ids_for_mm - unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all - unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all - unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all diff --git a/tests/unittest/_torch/modeling/test_modeling_llava_next.py b/tests/unittest/_torch/modeling/test_modeling_llava_next.py index 129f5cbd7a30..6c8fdc994911 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llava_next.py +++ b/tests/unittest/_torch/modeling/test_modeling_llava_next.py @@ -1,15 +1,18 @@ import os from dataclasses import dataclass +from pathlib import Path from typing import List +import pytest from test_modeling_multimodal import MultimodalScenario, TestModelingMultimodal, llm_models_root -from transformers import LlavaNextConfig +from transformers import AutoTokenizer, LlavaNextConfig from transformers import LlavaNextForConditionalGeneration as HFLlavaNextForConditionalGeneration from tensorrt_llm._torch.models.checkpoints.hf.llava_next_weight_mapper import ( LlavaNextHfWeightMapper, ) from tensorrt_llm._torch.models.modeling_llava_next import LlavaNextModel +from tensorrt_llm.inputs import create_input_processor LLAVA_NEXT_7B_CONFIG = { "architectures": ["LlavaNextForConditionalGeneration"], @@ -96,3 +99,34 @@ def get_scenarios(self) -> List[TestLlavaNextScenario]: ), ] return scenarios + + +def test_llava_next_expand_prompt_token_ids_for_mm(): + """Test LlavaNextInputProcessor.expand_prompt_token_ids_for_mm replaces image placeholders correctly.""" + model_path = LLAVA_NEXT_7B_CONFIG["_name_or_path"] + if not Path(model_path).exists(): + pytest.skip(f"LLaVA-Next model not found at {model_path} (set LLM_MODELS_ROOT)") + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + input_processor = create_input_processor(model_path, tokenizer=tokenizer) + + image_token_id = LLAVA_NEXT_7B_CONFIG["image_token_index"] + vocab_size = LLAVA_NEXT_7B_CONFIG["vocab_size"] + placeholder_id = vocab_size + 1 + + # prompt_token_ids: two image placeholders with text tokens in between + prompt_token_ids = [1, 2, image_token_id, 3, image_token_id, 4] + num_mm_tokens_per_placeholder = [10, 20] + + expanded = input_processor.expand_prompt_token_ids_for_mm( + prompt_token_ids, num_mm_tokens_per_placeholder + ) + + # Expected: [1, 2] + 10 * placeholder_id + [3] + 20 * placeholder_id + [4] + expected_len = 2 + 10 + 1 + 20 + 1 + assert len(expanded) == expected_len + assert expanded[:2] == [1, 2] + assert expanded[2:12] == [placeholder_id] * 10 + assert expanded[12] == 3 + assert expanded[13:33] == [placeholder_id] * 20 + assert expanded[33] == 4 From bba6ca42a48ed4c6f85797dfdda2dacbde106ed1 Mon Sep 17 00:00:00 2001 From: Ludwig Schneider Date: Thu, 2 Apr 2026 11:34:23 -0500 Subject: [PATCH 3/8] [TRTLLM-11237][fix] [fix] Synchronize NCCL memory allocation error handling (#12125) Signed-off-by: Ludwig Schneider --- cpp/include/tensorrt_llm/common/cudaUtils.h | 13 ++ cpp/tensorrt_llm/common/ncclUtils.cpp | 122 ++++++++++++++++-- cpp/tensorrt_llm/common/ncclUtils.h | 16 +++ cpp/tensorrt_llm/thop/allreduceOp.cpp | 1 - .../_torch/custom_ops/torch_custom_ops.py | 48 +++++-- tensorrt_llm/_torch/distributed/ops.py | 20 +++ .../_torch/pyexecutor/model_loader.py | 19 +++ 7 files changed, 213 insertions(+), 26 deletions(-) diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h index cd58a7abb5d9..c8cc04ac36d6 100644 --- a/cpp/include/tensorrt_llm/common/cudaUtils.h +++ b/cpp/include/tensorrt_llm/common/cudaUtils.h @@ -1424,3 +1424,16 @@ TRTLLM_NAMESPACE_END { \ tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ } while (0) + +// Warn-only variant: log a warning on failure but do not throw or abort. +// Use for cleanup/secondary operations where a CUDA error is non-fatal (e.g. free on an error path). +#define TLLM_CUDA_CHECK_WARN(stat) \ + do \ + { \ + cudaError_t const _tllm_cuda_warn_err = (stat); \ + if (TLLM_UNLIKELY(_tllm_cuda_warn_err != cudaSuccess)) \ + { \ + TLLM_LOG_WARNING( \ + "CUDA error in %s (%s:%d): %s", #stat, __FILE__, __LINE__, cudaGetErrorString(_tllm_cuda_warn_err)); \ + } \ + } while (0) diff --git a/cpp/tensorrt_llm/common/ncclUtils.cpp b/cpp/tensorrt_llm/common/ncclUtils.cpp index 036cd9801bb6..6b3d4209f7e9 100644 --- a/cpp/tensorrt_llm/common/ncclUtils.cpp +++ b/cpp/tensorrt_llm/common/ncclUtils.cpp @@ -24,6 +24,69 @@ #include #include +namespace +{ + +// RAII guard for cudaMalloc — frees the pointer on destruction, logging a warning on failure. +struct CudaMallocGuard +{ + void* ptr{nullptr}; + + explicit CudaMallocGuard(void* p) noexcept + : ptr(p) + { + } + + ~CudaMallocGuard() + { + if (ptr) + { + TLLM_CUDA_CHECK_WARN(cudaFree(ptr)); + } + } + + void* release() noexcept + { + void* p = ptr; + ptr = nullptr; + return p; + } + + CudaMallocGuard(CudaMallocGuard const&) = delete; + CudaMallocGuard& operator=(CudaMallocGuard const&) = delete; +}; + +// RAII guard for ncclMemAlloc — frees the pointer on destruction, logging a warning on failure. +struct NcclMemGuard +{ + void* ptr{nullptr}; + + explicit NcclMemGuard(void* p) noexcept + : ptr(p) + { + } + + ~NcclMemGuard() + { + if (ptr) + { + TLLM_NCCL_CHECK_WARN(ncclMemFree(ptr)); + } + } + + void* release() noexcept + { + void* p = ptr; + ptr = nullptr; + return p; + } + + NcclMemGuard(NcclMemGuard const&) = delete; + NcclMemGuard& operator=(NcclMemGuard const&) = delete; +}; + +} // namespace + namespace tensorrt_llm::common::nccl_util { @@ -403,28 +466,59 @@ bool NCCLWindowAllocator::isCommValid(ncclComm_t comm) const noexcept NCCLWindowBuffer NCCLWindowAllocator::allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle) { - NCCLWindowBuffer buffer; - buffer.handle = handle; - - // Allocate device memory using ncclMemAlloc - ncclResult_t allocResult = ncclMemAlloc(&buffer.ptr, size); - if (allocResult != ncclSuccess) + // Step 1: Allocate symmetric memory (per-rank, non-collective — can fail asymmetrically). + void* ncclPtr = nullptr; + TLLM_NCCL_CHECK_WARN(ncclMemAlloc(&ncclPtr, size)); + int const localAllocOk = (ncclPtr != nullptr) ? 1 : 0; + NcclMemGuard ncclGuard{ncclPtr}; // frees ncclPtr on any early return or exception + + // Step 2: ncclCommWindowRegister is collective — if any rank skips it, all other ranks hang. + // Synchronize the per-rank alloc status using a small cudaMalloc flag (not ncclMemAlloc, so + // OOM on symmetric memory does not prevent us from allocating the flag). + int* rankSyncFlag = nullptr; + TLLM_CUDA_CHECK(cudaMalloc(&rankSyncFlag, sizeof(int))); + CudaMallocGuard flagGuard{rankSyncFlag}; // frees rankSyncFlag on any early return or exception + + // Step 3: Populate flag, reduce with min across ranks (0 if any rank failed), then read back. + // H2D failure is non-fatal: warn and continue — device flag may be stale but the allreduce + // must still be reached by all ranks. allreduce and D2H failures are catastrophic (throw). + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TLLM_CUDA_CHECK_WARN(cudaMemcpy(rankSyncFlag, &localAllocOk, sizeof(int), cudaMemcpyHostToDevice)); + TLLM_NCCL_CHECK(ncclAllReduce(rankSyncFlag, rankSyncFlag, 1, ncclInt32, ncclMin, comm, stream)); + TLLM_CUDA_CHECK_WARN(cudaStreamSynchronize(stream)); + + int allAllocOk = 0; + TLLM_CUDA_CHECK(cudaMemcpy(&allAllocOk, rankSyncFlag, sizeof(int), cudaMemcpyDeviceToHost)); + // flagGuard frees rankSyncFlag here at end of its scope + + if (!allAllocOk) { - TLLM_THROW("ncclMemAlloc failed with error: %d", allocResult); + if (localAllocOk) + { + TLLM_LOG_WARNING( + "[NCCLUtil] ncclMemAlloc failed on at least one other rank; " + "freeing local allocation (size=%zu) and aborting window registration on all ranks.", + size); + } + return NCCLWindowBuffer{}; // ncclGuard frees ncclPtr } - buffer.size = size; - // Register the buffer with NCCL as a window - ncclResult_t regResult = ncclCommWindowRegister(comm, buffer.ptr, size, &buffer.window, NCCL_WIN_COLL_SYMMETRIC); + // Step 4: Register with NCCL as a window (collective — all ranks must reach this call). + // Failure here is non-fatal: warn and fall back to regular allreduce. + // ncclGuard frees ncclPtr on return. + ncclWindow_t window = nullptr; + ncclResult_t const regResult = ncclCommWindowRegister(comm, ncclPtr, size, &window, NCCL_WIN_COLL_SYMMETRIC); + TLLM_NCCL_CHECK_WARN(regResult); if (regResult != ncclSuccess) { - ncclMemFree(buffer.ptr); - TLLM_THROW("ncclCommWindowRegister failed with error: %d", regResult); + return NCCLWindowBuffer{}; } + // Step 5: Success — transfer ownership to the returned buffer. + ncclGuard.release(); + NCCLWindowBuffer buffer{ncclPtr, handle, size, window}; TLLM_LOG_TRACE("[NCCLUtil] Allocated and registered NCCL window buffer: handle=%d, ptr=%p, size=%zu, window=%p", - handle, buffer.ptr, size, static_cast(buffer.window)); - + handle, buffer.ptr, buffer.size, static_cast(buffer.window)); return buffer; } diff --git a/cpp/tensorrt_llm/common/ncclUtils.h b/cpp/tensorrt_llm/common/ncclUtils.h index 4ffa73efc23f..f4699d71c334 100644 --- a/cpp/tensorrt_llm/common/ncclUtils.h +++ b/cpp/tensorrt_llm/common/ncclUtils.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h" #if ENABLE_MULTI_DEVICE #include @@ -43,6 +44,21 @@ #if ENABLE_MULTI_DEVICE +// TLLM_NCCL_CHECK (throw on failure) is provided by multiDeviceUtils.h. + +// Warn-only variant: log a warning on NCCL failure but do not throw or abort. +// Use for cleanup/secondary operations where an NCCL error is non-fatal (e.g. ncclMemFree on an error path). +#define TLLM_NCCL_CHECK_WARN(cmd) \ + do \ + { \ + ncclResult_t const _tllm_nccl_warn_r = (cmd); \ + if (TLLM_UNLIKELY(_tllm_nccl_warn_r != ncclSuccess)) \ + { \ + TLLM_LOG_WARNING( \ + "NCCL error in %s (%s:%d): %s", #cmd, __FILE__, __LINE__, ncclGetErrorString(_tllm_nccl_warn_r)); \ + } \ + } while (0) + TRTLLM_NAMESPACE_BEGIN namespace common::nccl_util diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 9ebf2b385a2d..e5d875fb52e5 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1427,7 +1427,6 @@ class AllreduceOp bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size) { // If messageSize is greater than maxWorkspaceSize or topology is unsuitable, use NCCL fallback. - // TODO: Use NCCL_SYMMETRIC once the memory allocation issue is resolved. if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported) { return true; diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 9ae4589e32ab..d14dca370c10 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1711,6 +1711,10 @@ def _( class AllReduceRunner(TunableRunner): _prealloc_lock: ClassVar[threading.Lock] = threading.Lock() _prealloc_done: ClassVar[set] = set() + # Set from AllReduce.__init__ via extra_attrs when the model is built. + _prealloc_max_num_tokens: ClassVar[Optional[int]] = None + _prealloc_hidden_size: ClassVar[Optional[int]] = None + _prealloc_dtype: ClassVar[Optional[torch.dtype]] = None tuning_config = TuningConfig( dynamic_tensor_specs=(DynamicTensorSpec( 0, 0, get_last_power_of_2_num_tokens_buckets(8192), @@ -1743,7 +1747,10 @@ def unique_id(self): def _maybe_preallocate_buffers(cls, input_tensor: torch.Tensor, group: List[int], - do_preparation: bool = False) -> None: + do_preparation: bool = False, + max_num_tokens: Optional[int] = None, + hidden_size: Optional[int] = None, + dtype: Optional[torch.dtype] = None) -> None: if not do_preparation: return if not hasattr(torch.ops.trtllm, "preallocate_nccl_window_buffer"): @@ -1758,7 +1765,22 @@ def _maybe_preallocate_buffers(cls, # If capture status can't be queried, avoid prealloc to be safe. return - num_tokens = int(input_tensor.size(0)) + # If max_num_tokens and hidden_size are provided, pre-allocate at 2x + # the model-configured size to give the NCCL window allocator extra + # headroom beyond the nominal max shape. dtype comes from the model + # spec; fall back to the actual input tensor's properties when any + # value is missing. + # The dummy tensor is created here, after the stream-capture guard, + # so it is never allocated inside a CUDA graph context. + if max_num_tokens is not None and hidden_size is not None: + prealloc_input = torch.empty( + [2 * max_num_tokens, hidden_size], + dtype=dtype if dtype is not None else input_tensor.dtype, + device=input_tensor.device) + else: + prealloc_input = input_tensor + + num_tokens = int(prealloc_input.size(0)) if num_tokens <= 0: return group_key = tuple(group) @@ -1771,7 +1793,6 @@ def _maybe_preallocate_buffers(cls, logger.debug( "[tunable_allreduce] Pre-allocating NCCL window buffers: " "tokens=%d group=%s", num_tokens, list(group)) - prealloc_input = input_tensor torch.ops.trtllm.preallocate_nccl_window_buffer(prealloc_input, group, 2) @@ -1816,16 +1837,21 @@ def forward( OptimizationProfile(), **kwargs) if AllReduceStrategy.NCCL_SYMMETRIC.value in valid_tactics: - self._maybe_preallocate_buffers(input, - self.group, - do_preparation=True) + self._maybe_preallocate_buffers( + input, + self.group, + do_preparation=True, + max_num_tokens=AllReduceRunner._prealloc_max_num_tokens, + hidden_size=AllReduceRunner._prealloc_hidden_size, + dtype=AllReduceRunner._prealloc_dtype, + ) return input if tactic == -1: - # tactic == -1 means the autotuner cache missed for this shape, - # so we fall back to plain NCCL instead of NCCL_SYMMETRIC. - # NCCL_SYMMETRIC requires ncclMemAlloc which can fail asymmetrically - # across ranks under OOM, causing a deadlock at ncclCommWindowRegister. - tactic = AllReduceStrategy.NCCL.value + # tactic == -1 means the autotuner cache missed for this shape; + # fall back to NCCL_SYMMETRIC. Asymmetric ncclMemAlloc failures are + # handled by a cross-rank barrier in NCCLWindowAllocator, which + # falls back to plain NCCL if allocation fails on any rank. + tactic = AllReduceStrategy.NCCL_SYMMETRIC.value return torch.ops.trtllm.allreduce( input, diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index c91fd64d5c50..dd71bb00ec9e 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -10,6 +10,7 @@ from tensorrt_llm._mnnvl_utils import HelixCpMnnvlMemory, MnnvlMemory from tensorrt_llm._torch.distributed.symm_mem_allreduce import \ SymmetricMemoryAllReduce +from tensorrt_llm._torch.utils import get_model_extra_attrs from tensorrt_llm._utils import mpi_comm, mpi_disabled from tensorrt_llm.bindings import internal as _tllm_internal from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer @@ -691,6 +692,25 @@ def __init__(self, self._disable_mpi = mpi_disabled() self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce + + # Propagate model-level prealloc config to AllReduceRunner once per + # process. extra_attrs is only active during model __init__, so we + # read it here and stash the values as class-level attributes that + # AllReduceRunner.forward can use during the autotuner warm-up phase. + extra_attrs = get_model_extra_attrs() + if extra_attrs: + from tensorrt_llm._torch.custom_ops.torch_custom_ops import \ + AllReduceRunner + max_num_tokens = extra_attrs.get('allreduce_max_num_tokens') + hidden_size = extra_attrs.get('allreduce_hidden_size') + if max_num_tokens is not None: + AllReduceRunner._prealloc_max_num_tokens = max_num_tokens + if hidden_size is not None: + AllReduceRunner._prealloc_hidden_size = hidden_size + prealloc_dtype = extra_attrs.get('allreduce_dtype') + if prealloc_dtype is not None: + AllReduceRunner._prealloc_dtype = prealloc_dtype + if self.mapping.tp_size > 1 and not self.mapping.enable_attention_dp: # Initialize Symmetric Memory AllReduce if needed (before workspace allocation) if self.strategy == AllReduceStrategy.SYMM_MEM: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 9293e96f1a90..b3c57a29f75a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -497,6 +497,25 @@ def _load_and_validate_config( # Store nvfp4 config in extra_attrs for Linear layer access config.extra_attrs[ 'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends + # Store allreduce pre-allocation config for AllReduce module access. + # Use get_text_config() so VLM wrapper configs (e.g. KimiK2VLConfig, + # KimiK25Config) that store the text config under .text_config are + # handled transparently. For flat configs get_text_config() returns + # self, so this is safe for all config types. Still guard with + # try/except for configs that lack hidden_size entirely. + try: + config.extra_attrs[ + 'allreduce_max_num_tokens'] = config.max_num_tokens + config.extra_attrs[ + 'allreduce_hidden_size'] = config.pretrained_config.get_text_config( + ).hidden_size + config.extra_attrs[ + 'allreduce_dtype'] = config.pretrained_config.torch_dtype + except AttributeError as e: + logger.warning( + f"Could not read allreduce pre-allocation config from " + f"{type(config.pretrained_config).__name__}: {e}. " + f"AllReduce pre-allocation will be skipped.") validate_and_set_kv_cache_quant(config, self.llm_args.kv_cache_config.dtype) From d0ac8fc0d43e7d5b3d80b7b7fd8e20dfe91f9e73 Mon Sep 17 00:00:00 2001 From: Stefan Niebler <82932102+stnie@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:54:52 +0200 Subject: [PATCH 4/8] [https://nvbugs/6008710][fix] Adjust prompt logprobs to use the correct prompt token id (#12499) Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 12 ++++- .../_torch/sampler/test_logits_logprobs.py | 50 +++++++++++++++++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 929d95ba3786..094109eabc4e 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -855,8 +855,16 @@ def _compute_pytorch_prompt_logprobs( prompt=cached, generation=None ) # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler. context_logits = response.result.context_logits - assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested." - prompt_token_ids = generation_result._generation_request.prompt_token_ids + assert context_logits is not None, "context_logits must not be None when prompt_logprobs is requested." + result = response.result.get_result() + assert result is not None, "result must not be None when prompt_logprobs is requested." + # Single element list + first_generation_token = result.output_token_ids[0][:1] + assert first_generation_token, "first generation token must not be empty when prompt_logprobs is requested." + # Pass prompt_token_ids with an offset of 1 for correct mapping to the context logits + prompt_token_ids = generation_result._generation_request.prompt_token_ids[ + 1:] + first_generation_token + logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None, context_logits, None, None, prompt_token_ids) diff --git a/tests/unittest/_torch/sampler/test_logits_logprobs.py b/tests/unittest/_torch/sampler/test_logits_logprobs.py index f409e0f1c18d..fabce63f1627 100644 --- a/tests/unittest/_torch/sampler/test_logits_logprobs.py +++ b/tests/unittest/_torch/sampler/test_logits_logprobs.py @@ -303,7 +303,7 @@ def test_sampled_token_always_in_prompt_logprobs(logprobs_k: int, simple_llm: LL print(f"Prompt token IDs: {output.prompt_token_ids}") logprobs = output.outputs[0].prompt_logprobs - token_ids = output.prompt_token_ids + token_ids = output.prompt_token_ids[1:] + output.outputs[0].token_ids[:1] assert len(logprobs) == len(token_ids), ( f"Expected {len(token_ids)} logprob entries, got {len(logprobs)}" @@ -372,6 +372,7 @@ def check_logprobs( logprobs: TokenLogprobs, logits_cuda: torch.Tensor, case_str: str, + logprobs_offset: int = 0, ): """Checks if the provided logprobs match the logprobs calculated from the logits""" expected_logprobs = torch.nn.functional.log_softmax(logits_cuda, dim=-1).to(device="cpu") @@ -385,7 +386,7 @@ def check_logprobs( processed_ranks_and_logprobs: dict[int, float] = {} for token_id, logprob_obj in token_logprobs.items(): # the sampled token may have any rank > 0 - if token_id != tokens[generation_idx]: + if token_id != tokens[generation_idx + logprobs_offset]: # All other tokens should have a rank <= num_logprobs assert logprob_obj.rank <= num_logprobs, ( f"{case_str} logprob rank is greater than {num_logprobs}" @@ -427,9 +428,10 @@ def check_logprobs( generation_logprobs, generation_logits, "generation", + logprobs_offset=0, ) if prompt_logprobs_k is not None: - context_tokens = output.prompt_token_ids + context_tokens = output.prompt_token_ids + output.outputs[0].token_ids[:1] context_logprobs = output.outputs[0].prompt_logprobs context_logits = output.context_logits.to(device="cuda") check_logprobs( @@ -438,7 +440,49 @@ def check_logprobs( context_logprobs, context_logits, "context", + logprobs_offset=1, # Prompt logprobs are offset by 1 relative to the prompt token ids ) + # The last context logprob dict and the first generation logprob dict should agree on + # the top-n entries (n = min(prompt_logprobs_k, logprobs_k)) and the sampled token's logprob. + if prompt_logprobs_k is not None and logprobs_k is not None: + last_context_logprob = context_logprobs[-1] + first_generation_logprob = generation_logprobs[0] + less_prompt_logprobs = prompt_logprobs_k <= logprobs_k + expected = last_context_logprob if less_prompt_logprobs else first_generation_logprob + compare = first_generation_logprob if less_prompt_logprobs else last_context_logprob + sampled_token_id = generation_tokens[0] + assert sampled_token_id in last_context_logprob, ( + f"Sampled token {sampled_token_id} is not a valid key in the last entry " + f"of the context logprob dict: {list(last_context_logprob.keys())}" + ) + assert sampled_token_id in first_generation_logprob, ( + f"Sampled token {sampled_token_id} is not a valid key in the first entry " + f"of the generation logprob dict: {list(first_generation_logprob.keys())}" + ) + torch.testing.assert_close( + last_context_logprob[sampled_token_id].logprob, + first_generation_logprob[sampled_token_id].logprob, + msg=( + f"logprob {last_context_logprob[sampled_token_id].logprob} in the last " + f"entry of the context logprob dict does not match the corresponding " + f"logprob {first_generation_logprob[sampled_token_id].logprob} in the " + f"first entry of the generation logprob dict for token {sampled_token_id}" + ), + ) + for token_id, logprob_obj in expected.items(): + assert token_id in compare, ( + f"Token {token_id} is not a valid key in the other dict: {list(compare.keys())}" + ) + expected_logprob = logprob_obj.logprob + compare_logprob = compare[token_id].logprob + torch.testing.assert_close( + expected_logprob, + compare_logprob, + msg=( + f"logprob {expected_logprob} does not match the corresponding " + f"logprob {compare_logprob} in the other dict for token {token_id}" + ), + ) @pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"]) From 2f07e56dfc0557a15a6521fe3b29e5d79f53702f Mon Sep 17 00:00:00 2001 From: yuanjingx87 <197832395+yuanjingx87@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:01:26 -0700 Subject: [PATCH 5/8] [None][infra] Bump tornado and black in container (#12600) Signed-off-by: Yuanjing Xue <197832395+yuanjingx87@users.noreply.github.com> --- constraints.txt | 10 ++++------ docker/Dockerfile.multi | 5 +++-- jenkins/current_image_tags.properties | 8 ++++---- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/constraints.txt b/constraints.txt index 25ae35bdecfc..7dd0d6747765 100644 --- a/constraints.txt +++ b/constraints.txt @@ -1,10 +1,8 @@ # These vulnerabilities were inherited from the base image (pytorch:25.12-py3) and should be removed when the base image # is updated. -# WAR against https://github.com/advisories/GHSA-38jv-5279-wg99 -urllib3>=2.6.3 # WAR against https://github.com/advisories/GHSA-8rrh-rw8j-w5fx wheel>=0.46.2 -# WAR against https://github.com/advisories/GHSA-7gcm-g887-7qv7 -protobuf>=6.33.5 -# WAR against https://github.com/advisories/GHSA-6mq8-rvhq-8wgg -aiohttp>=3.13.3 +# WAR against https://github.com/advisories/GHSA-qjxf-f2mg-c6mc +tornado>=6.5.5 +# WAR against https://github.com/advisories/GHSA-3936-cmfr-pm3m +black>=26.3.1 diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index a03fbdbd757e..96b75cd96a43 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -52,9 +52,10 @@ RUN --mount=type=bind,source=docker/common,target=/opt/docker/common \ # Install constraints after install.sh so cleanup() doesn't delete the file mid-RUN COPY constraints.txt /tmp/constraints.txt RUN --mount=type=cache,target=/root/.cache/pip \ + # WAR: uninstall dependencies that has vulnerability + pip3 uninstall -y tornado black nbconvert || true && \ pip3 install --ignore-installed --no-cache-dir -r /tmp/constraints.txt && \ - rm /tmp/constraints.txt && \ - pip3 uninstall -y nbconvert || true + rm /tmp/constraints.txt # Install UCX, NIXL, etcd # TODO: Combine these into the main install.sh script diff --git a/jenkins/current_image_tags.properties b/jenkins/current_image_tags.properties index 98829d692b8f..8d751664640b 100644 --- a/jenkins/current_image_tags.properties +++ b/jenkins/current_image_tags.properties @@ -13,7 +13,7 @@ # images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead. IMAGE_NAME=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm -LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-x86_64-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202603241450-12102 -LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-aarch64-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202603241450-12102 -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py310-trt10.15.1.29-skip-tritondevel-202603241450-12102 -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py312-trt10.15.1.29-skip-tritondevel-202603241450-12102 +LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-x86_64-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604011104-12600 +LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-26.02-py3-sbsa-ubuntu24.04-trt10.15.1.29-skip-tritondevel-202604011104-12600 +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py310-trt10.15.1.29-skip-tritondevel-202604011104-12600 +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-13.1.1-devel-rocky8-x86_64-rocky8-py312-trt10.15.1.29-skip-tritondevel-202604011104-12600 From 11c40bb3303ec9ba7d69641c0914979310f7d924 Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:09:22 -0700 Subject: [PATCH 6/8] [TRTLLM-11043][feat] Add global pool support for suffix automaton speculative decoding (#12130) Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> --- .../suffixAutomaton/suffixAutomaton.h | 80 +++- .../suffixAutomaton/suffixAutomatonKernels.cu | 207 +++++++++ .../suffixAutomaton/suffixAutomatonKernels.h | 176 +------ .../suffixAutomaton/suffixAutomatonParams.h | 292 ++++++++++++ .../nanobind/suffixAutomaton/bindings.cpp | 99 ++-- .../_torch/pyexecutor/model_engine.py | 6 +- tensorrt_llm/_torch/speculative/eagle3.py | 4 +- tensorrt_llm/_torch/speculative/mtp.py | 4 +- tensorrt_llm/_torch/speculative/pard.py | 4 +- .../_torch/speculative/sa_enhancer.py | 85 +++- tensorrt_llm/_torch/speculative/sa_worker.py | 32 +- .../_torch/speculative/suffix_automaton.py | 291 ++++++++++-- tensorrt_llm/_torch/speculative/utils.py | 21 +- tensorrt_llm/llmapi/__init__.py | 9 +- tensorrt_llm/llmapi/llm_args.py | 110 +++-- .../defs/accuracy/test_llm_api_pytorch.py | 91 +++- .../test_lists/qa/llm_function_core.txt | 3 + .../speculative/test_suffix_automaton.py | 436 ++++++++++++++++++ tests/unittest/_torch/speculative/test_sa.py | 202 ++++---- 19 files changed, 1707 insertions(+), 445 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h index 4be4630253a5..3d0f93aa7df7 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomaton.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -375,6 +375,84 @@ struct SuffixAutomaton return SAOptional(); } + /** + * @brief Find the longest suffix of an external token sequence that appears + * as a substring in this SA's text, then return its continuation position. + * + * Uses the standard longest-common-substring algorithm: process suffix tokens + * in forward order through the SA, following suffix links on mismatch. + * + * Time complexity: O(suffixLen) amortized. Each token either advances the + * match or triggers suffix link fallbacks. Since matchedLen increases at most + * suffixLen times and never goes below 0, total suffix link hops is bounded + * by suffixLen. + * + * @param suffix Pointer to the suffix tokens (forward order: oldest to newest) + * @param suffixLen Number of tokens in the suffix + * @return Optional LookupResult with continuation position and match length + */ + SA_CUDA_CALLABLE SAOptional lookupWithSuffix(Token const* suffix, int suffixLen) const + { + if (mStates.empty() || suffixLen <= 0) + { + return SAOptional(); + } + + NodeIndex state = NodeIndex(0); + int matchedLen = 0; + + for (int i = 0; i < suffixLen; i++) + { + Token token = suffix[i]; + + while (state != NodeIndex(0) && mStates.at(state, token) == nullptr) + { + state = *mStates.at(state).link; + matchedLen = mStates.at(state).len; + } + + NodeIndex const* nextPtr = mStates.at(state, token); + if (nextPtr != nullptr) + { + state = *nextPtr; + matchedLen++; + } + } + + if (matchedLen == 0 || state == NodeIndex(0)) + { + return SAOptional(); + } + + while (state != NodeIndex(0)) + { + auto& nodeData = mStates.at(state); + SAOptional posOpt = nodeData.pos; + + if (posOpt.hasValue()) + { + TextIndex pos = *posOpt; + if (+pos + 1 < +mTokens.size()) + { + LookupResult result; + result.pos = TextIndex(+pos + 1); + result.len = matchedLen; + return SAOptional(result); + } + } + + auto linkOpt = nodeData.link; + if (!linkOpt.hasValue()) + { + break; + } + state = *linkOpt; + matchedLen = mStates.at(state).len; + } + + return SAOptional(); + } + SA_CUDA_CALLABLE void getDraftTokens(Token::ValueType* buf, int bufLen, TextIndex startPos) const { int availableLen = +mTokens.size() - +startPos; diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu index de8f4a4afcee..6a114b88e62b 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.cu @@ -177,6 +177,213 @@ void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& pa params.acceptedTokensIn, params.acceptedLensIn); } +// ===================================================================== +// Global search kernels (cross-request pattern sharing) +// ===================================================================== + +// Kernel 1: Extend all SAs with accepted tokens. +// Separate kernel ensures all mutations complete before cross-SA reads. +__global__ void suffixAutomatonGlobalExtendKernel(int batchSize, int draftLength, int maxSlots, size_t stateSize, + void* slotsMemory, int const* batchIndices, int const* acceptedTokensIn, int const* acceptedLensIn) +{ + int reqIdx = blockIdx.x; + if (reqIdx >= batchSize) + { + return; + } + + int ownSlotIdx = batchIndices[reqIdx]; + assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots); + uint8_t* slotMemory = static_cast(slotsMemory) + static_cast(ownSlotIdx) * stateSize; + SuffixAutomaton* ownSlot = reinterpret_cast(slotMemory); + + int numNewTokens = acceptedLensIn[reqIdx]; + assert(numNewTokens >= 0 && numNewTokens <= draftLength + 1); + + for (int j = 0; j < numNewTokens; j++) + { + ownSlot->extend(Token(acceptedTokensIn[reqIdx * (draftLength + 1) + j])); + } +} + +// Per-thread match result for shared-memory parallel reduction +struct SlotMatch +{ + int matchLen; + int continuationLen; + int isOwnSlot; + int slotIdx; + TextIndex pos; +}; + +// kMaxGlobalSuffixLen is defined in suffixAutomatonParams.h. +// With maxNgramSize == -1, longer sequences are silently truncated to that limit. + +// Kernel 2: Search all active SAs in parallel, reduce to best match per request. +// All SAs are read-only (const) — launched after the extend kernel on the same stream. +__global__ void suffixAutomatonGlobalSearchKernel(int batchSize, int draftLength, int maxNgramSize, int maxSlots, + size_t stateSize, void const* slotsMemory, int const* batchIndices, int const* activeSlotMask, int* matchLenOut, + int* matchSlotOut, int* draftTokensOut) +{ + extern __shared__ SlotMatch sharedMatches[]; + + int reqIdx = blockIdx.x; + int slotIdx = threadIdx.x; + + if (reqIdx >= batchSize) + { + return; + } + + int ownSlotIdx = batchIndices[reqIdx]; + assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots); + + // Step 1: Extract suffix from own SA into shared memory + __shared__ Token sharedSuffix[kMaxGlobalSuffixLen]; + __shared__ int suffixLen; + + if (slotIdx == 0) + { + uint8_t const* slotMem = static_cast(slotsMemory) + static_cast(ownSlotIdx) * stateSize; + SuffixAutomaton const* ownSlot = reinterpret_cast(slotMem); + + int maxSuffixLen = (maxNgramSize > 0) ? maxNgramSize : kMaxGlobalSuffixLen; + int textLen = +ownSlot->mTokens.size(); + suffixLen = (maxSuffixLen < textLen) ? maxSuffixLen : textLen; + + for (int i = 0; i < suffixLen; i++) + { + sharedSuffix[i] = ownSlot->mTokens.at(TextIndex(textLen - suffixLen + i)); + } + } + __syncthreads(); + + // Step 2: Each thread searches one slot + SlotMatch myMatch = {0, 0, 0, -1, TextIndex(0)}; + + if (slotIdx < maxSlots && activeSlotMask[slotIdx]) + { + uint8_t const* slotMem = static_cast(slotsMemory) + static_cast(slotIdx) * stateSize; + SuffixAutomaton const* slot = reinterpret_cast(slotMem); + + auto result = slot->lookupWithSuffix(sharedSuffix, suffixLen); + if (result.hasValue()) + { + myMatch.matchLen = result->len; + myMatch.continuationLen = +slot->mTokens.size() - +result->pos; + myMatch.isOwnSlot = (slotIdx == ownSlotIdx) ? 1 : 0; + myMatch.slotIdx = slotIdx; + myMatch.pos = result->pos; + } + } + + sharedMatches[slotIdx] = myMatch; + __syncthreads(); + + // Step 3: Parallel reduction — three-level comparison: + // 1. Prefer longer match (higher matchLen) + // 2. Among equal matchLen, prefer own slot + // 3. Among equal matchLen and same locality, prefer longer continuation + // Requires blockDim.x to be a power of 2 (guaranteed by nextPowerOf2 in the host launcher). + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) + { + if (slotIdx < stride) + { + auto& current = sharedMatches[slotIdx]; + auto& candidate = sharedMatches[slotIdx + stride]; + bool replace = false; + if (candidate.matchLen > current.matchLen) + { + replace = true; + } + else if (candidate.matchLen == current.matchLen && candidate.matchLen > 0) + { + if (candidate.isOwnSlot > current.isOwnSlot) + { + replace = true; + } + else if (candidate.isOwnSlot == current.isOwnSlot + && candidate.continuationLen > current.continuationLen) + { + replace = true; + } + } + if (replace) + { + current = candidate; + } + } + __syncthreads(); + } + + // Step 4: Thread 0 writes output + if (slotIdx == 0) + { + SlotMatch best = sharedMatches[0]; + + if (best.matchLen > 0 && best.slotIdx >= 0) + { + matchLenOut[reqIdx] = best.matchLen; + matchSlotOut[reqIdx] = best.slotIdx; + + uint8_t const* slotMem + = static_cast(slotsMemory) + static_cast(best.slotIdx) * stateSize; + SuffixAutomaton const* slot = reinterpret_cast(slotMem); + slot->getDraftTokens(&draftTokensOut[reqIdx * draftLength], draftLength, best.pos); + } + else + { + matchLenOut[reqIdx] = 0; + matchSlotOut[reqIdx] = -1; + } + } +} + +namespace +{ + +int nextPowerOf2(int v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return (v < 1) ? 1 : v; +} + +} // anonymous namespace + +void invokeSuffixAutomatonGlobalSearch(SuffixAutomatonGlobalSearchParams const& params, cudaStream_t stream) +{ + params.checkParams(); + + int batchSize = params.batchSize; + int maxSlots = params.maxSlots; + if (batchSize > maxSlots) + { + batchSize = maxSlots; + } + + size_t stateSize = getSuffixAutomatonStateSize(params.maxSeqLen); + + // Kernel 1: Extend all SAs (1 thread per block, 1 block per request) + suffixAutomatonGlobalExtendKernel<<>>(batchSize, params.draftLength, maxSlots, stateSize, + params.slots, params.batchIndices, params.acceptedTokensIn, params.acceptedLensIn); + + // Kernel 2: Global search + reduce (N threads per block, 1 block per request) + int threadsPerBlock = nextPowerOf2(maxSlots); + threadsPerBlock = (threadsPerBlock < 1024) ? threadsPerBlock : 1024; + + size_t sharedMemSize = static_cast(threadsPerBlock) * sizeof(SlotMatch); + + suffixAutomatonGlobalSearchKernel<<>>(batchSize, + params.draftLength, params.maxNgramSize, maxSlots, stateSize, params.slots, params.batchIndices, + params.activeSlotMask, params.matchLenOut, params.matchSlotOut, params.draftTokensOut); +} + size_t getSuffixAutomatonStateSize(size_t maxSeqLen) { return SuffixAutomaton::getRequiredMemorySize(maxSeqLen); diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h index bd7beea3409a..a03a82b534a9 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonKernels.h @@ -20,175 +20,13 @@ #pragma once +// Full SA class definition — needed by .cu files that operate on SuffixAutomaton +// objects. This header redefines cudaStream_t when __CUDACC__ is not defined +// (via saCudaCallable.h), so only include this header from CUDA translation units. #include "suffixAutomaton.h" -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/config.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include -#include - -TRTLLM_NAMESPACE_BEGIN - -namespace kernels::speculative_decoding::suffix_automaton -{ - -//! \brief Parameters for the suffix automaton extend kernel -struct SuffixAutomatonExtendParams -{ - //! Number of sequences in the batch - int batchSize{0}; - - //! Number of draft tokens to generate per sequence - int draftLength{0}; - - //! Maximum number of slots in the workspace - int maxSlots{0}; - - //! Maximum sequence length (runtime configurable) - int maxSeqLen{0}; - - //! Pointer to the suffix automaton workspace on GPU (raw bytes) - void* slots{nullptr}; - - //! Batch indices mapping external batch idx to workspace slot [batchSize] - int const* batchIndices{nullptr}; - - //! Output: match lengths for each sequence [batchSize] - int* matchLenOut{nullptr}; - - //! Output: draft tokens for each sequence [batchSize, draftLength] - int* draftTokensOut{nullptr}; - - //! Input: accepted tokens for each sequence [batchSize, draftLength + 1] - int const* acceptedTokensIn{nullptr}; - - //! Input: number of accepted tokens for each sequence [batchSize] - int const* acceptedLensIn{nullptr}; - - void checkParams() const - { - TLLM_CHECK(batchSize > 0); - TLLM_CHECK(draftLength > 0); - TLLM_CHECK(maxSlots > 0); - TLLM_CHECK(maxSeqLen > 0); - TLLM_CHECK(slots != nullptr); - TLLM_CHECK(batchIndices != nullptr); - TLLM_CHECK(matchLenOut != nullptr); - TLLM_CHECK(draftTokensOut != nullptr); - TLLM_CHECK(acceptedTokensIn != nullptr); - TLLM_CHECK(acceptedLensIn != nullptr); - } -}; - -//! \brief Invokes the suffix automaton extend kernel -//! -//! This kernel updates the suffix automaton states for each sequence in the batch -//! with the newly accepted tokens, then performs a lookup to find the longest -//! suffix match and generates draft tokens based on that match. -//! -//! \param params The parameters for the kernel -//! \param stream The CUDA stream to run the kernel on -void invokeSuffixAutomatonExtend(SuffixAutomatonExtendParams const& params, cudaStream_t stream); - -//! \brief Parameters for the suffix automaton extend kernel with ngram support -struct SuffixAutomatonExtendNgramParams -{ - //! Number of sequences in the batch - int batchSize{0}; - - //! Number of draft tokens to generate per sequence - int draftLength{0}; - //! Maximum ngram size for matching, or -1 for longest match mode - int maxNgramSize{-1}; +// Param structs and function declarations — shared with nanobind bindings. +#include "suffixAutomatonParams.h" - //! Maximum number of slots in the workspace - int maxSlots{0}; - - //! Maximum sequence length (runtime configurable) - int maxSeqLen{0}; - - //! Pointer to the suffix automaton workspace on GPU (raw bytes) - void* slots{nullptr}; - - //! Batch indices mapping external batch idx to workspace slot [batchSize] - int const* batchIndices{nullptr}; - - //! Output: match lengths for each sequence [batchSize] - int* matchLenOut{nullptr}; - - //! Output: draft tokens for each sequence [batchSize, draftLength] - int* draftTokensOut{nullptr}; - - //! Input: accepted tokens for each sequence [batchSize, draftLength + 1] - int const* acceptedTokensIn{nullptr}; - - //! Input: number of accepted tokens for each sequence [batchSize] - int const* acceptedLensIn{nullptr}; - - void checkParams() const - { - TLLM_CHECK(batchSize > 0); - TLLM_CHECK(draftLength > 0); - TLLM_CHECK(maxSlots > 0); - TLLM_CHECK(maxSeqLen > 0); - TLLM_CHECK(slots != nullptr); - TLLM_CHECK(batchIndices != nullptr); - TLLM_CHECK(matchLenOut != nullptr); - TLLM_CHECK(draftTokensOut != nullptr); - TLLM_CHECK(acceptedTokensIn != nullptr); - TLLM_CHECK(acceptedLensIn != nullptr); - } -}; - -//! \brief Invokes the suffix automaton extend kernel with ngram support -//! -//! This kernel updates the suffix automaton states for each sequence in the batch -//! with the newly accepted tokens, then performs lookup based on maxNgramSize: -//! - If maxNgramSize == -1: finds the longest suffix match -//! - If maxNgramSize > 0: tries ngram sizes from maxNgramSize down to 1 until a match is found -//! -//! This kernel is CUDA graph compatible. -//! -//! \param params The parameters for the kernel -//! \param stream The CUDA stream to run the kernel on -void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& params, cudaStream_t stream); - -//! \brief Get the size in bytes of a single SuffixAutomaton state for a given max sequence length -//! \param maxSeqLen Maximum sequence length -//! \return Size in bytes -size_t getSuffixAutomatonStateSize(size_t maxSeqLen); - -//! \brief Initialize a SuffixAutomaton at the given memory location -//! \param memory Pointer to allocated memory for the SuffixAutomaton -//! \param maxSeqLen Maximum sequence length -void initAutomaton(void* memory, size_t maxSeqLen); - -//! \brief Build a suffix automaton by extending with the given tokens -//! \param sa Pointer to an initialized SuffixAutomaton -//! \param tokens Array of token IDs -//! \param numTokens Number of tokens in the array -void buildAutomatonFromTokens(SuffixAutomaton* sa, int const* tokens, int numTokens); - -//! \brief Relocate a SuffixAutomaton's internal pointers for GPU copy -//! -//! WARNING: This function MUTATES the SuffixAutomaton in-place, making it SINGLE-USE. -//! It rebases internal pointers from oldBase to newBase directly in the host buffer. -//! After this call, the host buffer's internal pointer graph is relative to newBase -//! (the GPU destination), so the host buffer is effectively corrupted for any subsequent -//! use (e.g., copying to a different GPU slot). If a caller invokes this twice with -//! the same host buffer but a different newBase, the second operation will produce -//! incorrectly relocated pointers. -//! -//! To copy the same automaton state to multiple GPU slots, callers must rebuild the -//! automaton from scratch via initAutomaton() + buildAutomatonFromTokens() for each -//! destination. -//! -//! \param sa Pointer to the SuffixAutomaton (mutated in-place) -//! \param oldBase The current base address (host address) -//! \param newBase The target base address (GPU address) -void relocateAutomaton(SuffixAutomaton* sa, void* oldBase, void* newBase); - -} // namespace kernels::speculative_decoding::suffix_automaton - -TRTLLM_NAMESPACE_END +#include "tensorrt_llm/common/cudaUtils.h" +#include diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h new file mode 100644 index 000000000000..abd46df7bbe5 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h @@ -0,0 +1,292 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Adapted from Baseten's sa_spec library (Apache-2.0) + * https://github.com/basetenlabs/sa_spec + */ + +// Lightweight header containing only param structs and function declarations. +// Safe to include from non-CUDA translation units (e.g. nanobind bindings) +// because it does NOT include suffixAutomaton.h, which redefines cudaStream_t +// to int via saCudaCallable.h when __CUDACC__ is not defined. + +#pragma once + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/config.h" +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::speculative_decoding::suffix_automaton +{ + +// Forward declaration — full definition lives in suffixAutomaton.h (CUDA-only). +struct SuffixAutomaton; + +// Max suffix tokens loaded into shared memory per block in the global search kernel. +// Caps the suffix length used for cross-SA matching to bound shared-memory usage. +static constexpr int kMaxGlobalSuffixLen = 64; + +// ===================================================================== +// Param structs +// ===================================================================== + +//! \brief Parameters for the suffix automaton extend kernel +struct SuffixAutomatonExtendParams +{ + //! Number of sequences in the batch + int batchSize{0}; + + //! Number of draft tokens to generate per sequence + int draftLength{0}; + + //! Maximum number of slots in the workspace + int maxSlots{0}; + + //! Maximum sequence length (runtime configurable) + int maxSeqLen{0}; + + //! Pointer to the suffix automaton workspace on GPU (raw bytes) + void* slots{nullptr}; + + //! Batch indices mapping external batch idx to workspace slot [batchSize] + int const* batchIndices{nullptr}; + + //! Output: match lengths for each sequence [batchSize] + int* matchLenOut{nullptr}; + + //! Output: draft tokens for each sequence [batchSize, draftLength] + int* draftTokensOut{nullptr}; + + //! Input: accepted tokens for each sequence [batchSize, draftLength + 1] + int const* acceptedTokensIn{nullptr}; + + //! Input: number of accepted tokens for each sequence [batchSize] + int const* acceptedLensIn{nullptr}; + + void checkParams() const + { + TLLM_CHECK(batchSize > 0); + TLLM_CHECK(draftLength > 0); + TLLM_CHECK(maxSlots > 0); + TLLM_CHECK(maxSeqLen > 0); + TLLM_CHECK(slots != nullptr); + TLLM_CHECK(batchIndices != nullptr); + TLLM_CHECK(matchLenOut != nullptr); + TLLM_CHECK(draftTokensOut != nullptr); + TLLM_CHECK(acceptedTokensIn != nullptr); + TLLM_CHECK(acceptedLensIn != nullptr); + } +}; + +//! \brief Parameters for the suffix automaton extend kernel with ngram support +struct SuffixAutomatonExtendNgramParams +{ + //! Number of sequences in the batch + int batchSize{0}; + + //! Number of draft tokens to generate per sequence + int draftLength{0}; + + //! Maximum ngram size for matching, or -1 for longest match mode + int maxNgramSize{-1}; + + //! Maximum number of slots in the workspace + int maxSlots{0}; + + //! Maximum sequence length (runtime configurable) + int maxSeqLen{0}; + + //! Pointer to the suffix automaton workspace on GPU (raw bytes) + void* slots{nullptr}; + + //! Batch indices mapping external batch idx to workspace slot [batchSize] + int const* batchIndices{nullptr}; + + //! Output: match lengths for each sequence [batchSize] + int* matchLenOut{nullptr}; + + //! Output: draft tokens for each sequence [batchSize, draftLength] + int* draftTokensOut{nullptr}; + + //! Input: accepted tokens for each sequence [batchSize, draftLength + 1] + int const* acceptedTokensIn{nullptr}; + + //! Input: number of accepted tokens for each sequence [batchSize] + int const* acceptedLensIn{nullptr}; + + void checkParams() const + { + TLLM_CHECK(batchSize > 0); + TLLM_CHECK(draftLength > 0); + TLLM_CHECK(maxSlots > 0); + TLLM_CHECK(maxSeqLen > 0); + TLLM_CHECK(slots != nullptr); + TLLM_CHECK(batchIndices != nullptr); + TLLM_CHECK(matchLenOut != nullptr); + TLLM_CHECK(draftTokensOut != nullptr); + TLLM_CHECK(acceptedTokensIn != nullptr); + TLLM_CHECK(acceptedLensIn != nullptr); + } +}; + +//! \brief Parameters for the global search kernel (cross-request pattern sharing) +//! +//! Limitation: maxSlots must be <= 1024 because the search kernel maps one +//! CUDA thread per slot within a single block. The suffix matched per request +//! is capped at 64 tokens (kMaxGlobalSuffixLen) due to shared-memory storage. +struct SuffixAutomatonGlobalSearchParams +{ + //! Number of sequences in the batch + int batchSize{0}; + + //! Number of draft tokens to generate per sequence + int draftLength{0}; + + //! Maximum ngram size for matching, or -1 for longest match mode + int maxNgramSize{-1}; + + //! Maximum number of slots in the workspace (must be <= 1024) + int maxSlots{0}; + + //! Maximum sequence length (runtime configurable) + int maxSeqLen{0}; + + //! Pointer to the suffix automaton workspace on GPU (raw bytes) + void* slots{nullptr}; + + //! Batch indices mapping external batch idx to workspace slot [batchSize] + int const* batchIndices{nullptr}; + + //! Active slot mask: 1=active, 0=inactive [maxSlots] + int const* activeSlotMask{nullptr}; + + //! Output: match lengths for each sequence [batchSize] + int* matchLenOut{nullptr}; + + //! Output: source slot index for the best match [batchSize] + int* matchSlotOut{nullptr}; + + //! Output: draft tokens for each sequence [batchSize, draftLength] + int* draftTokensOut{nullptr}; + + //! Input: accepted tokens for each sequence [batchSize, draftLength + 1] + int const* acceptedTokensIn{nullptr}; + + //! Input: number of accepted tokens for each sequence [batchSize] + int const* acceptedLensIn{nullptr}; + + void checkParams() const + { + TLLM_CHECK(batchSize > 0); + TLLM_CHECK(draftLength > 0); + TLLM_CHECK(maxSlots > 0); + TLLM_CHECK_WITH_INFO( + maxSlots <= 1024, "Global search kernel supports at most 1024 slots (one CUDA thread per slot)"); + TLLM_CHECK_WITH_INFO(maxNgramSize == -1 || (maxNgramSize >= 1 && maxNgramSize <= kMaxGlobalSuffixLen), + "maxNgramSize must be -1 (longest match) or in [1, %d], got %d", kMaxGlobalSuffixLen, maxNgramSize); + TLLM_CHECK(maxSeqLen > 0); + TLLM_CHECK(slots != nullptr); + TLLM_CHECK(batchIndices != nullptr); + TLLM_CHECK(activeSlotMask != nullptr); + TLLM_CHECK(matchLenOut != nullptr); + TLLM_CHECK(matchSlotOut != nullptr); + TLLM_CHECK(draftTokensOut != nullptr); + TLLM_CHECK(acceptedTokensIn != nullptr); + TLLM_CHECK(acceptedLensIn != nullptr); + } +}; + +// ===================================================================== +// Function declarations +// ===================================================================== + +//! \brief Invokes the suffix automaton extend kernel +//! +//! This kernel updates the suffix automaton states for each sequence in the batch +//! with the newly accepted tokens, then performs a lookup to find the longest +//! suffix match and generates draft tokens based on that match. +//! +//! \param params The parameters for the kernel +//! \param stream The CUDA stream to run the kernel on +void invokeSuffixAutomatonExtend(SuffixAutomatonExtendParams const& params, cudaStream_t stream); + +//! \brief Invokes the suffix automaton extend kernel with ngram support +//! +//! This kernel updates the suffix automaton states for each sequence in the batch +//! with the newly accepted tokens, then performs lookup based on maxNgramSize: +//! - If maxNgramSize == -1: finds the longest suffix match +//! - If maxNgramSize > 0: tries ngram sizes from maxNgramSize down to 1 until a match is found +//! +//! This kernel is CUDA graph compatible. +//! +//! \param params The parameters for the kernel +//! \param stream The CUDA stream to run the kernel on +void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& params, cudaStream_t stream); + +//! \brief Invokes the global search kernel for cross-request pattern sharing +//! +//! This launches two kernels on the same stream: +//! 1. Extend kernel: updates each request's SA with accepted tokens +//! 2. Search kernel: each request searches all active SA states in parallel +//! and reduces to the best match (longest match -> own slot -> longest continuation) +//! +//! The kernel launch boundary acts as a device-wide barrier, ensuring all extends +//! complete before any search begins. CUDA graph compatible. +//! +//! \param params The parameters for the kernel +//! \param stream The CUDA stream to run the kernel on +void invokeSuffixAutomatonGlobalSearch(SuffixAutomatonGlobalSearchParams const& params, cudaStream_t stream); + +//! \brief Get the size in bytes of a single SuffixAutomaton state for a given max sequence length +//! \param maxSeqLen Maximum sequence length +//! \return Size in bytes +size_t getSuffixAutomatonStateSize(size_t maxSeqLen); + +//! \brief Initialize a SuffixAutomaton at the given memory location +//! \param memory Pointer to allocated memory for the SuffixAutomaton +//! \param maxSeqLen Maximum sequence length +void initAutomaton(void* memory, size_t maxSeqLen); + +//! \brief Build a suffix automaton by extending with the given tokens +//! \param sa Pointer to an initialized SuffixAutomaton +//! \param tokens Array of token IDs +//! \param numTokens Number of tokens in the array +void buildAutomatonFromTokens(SuffixAutomaton* sa, int const* tokens, int numTokens); + +//! \brief Relocate a SuffixAutomaton's internal pointers for GPU copy +//! +//! WARNING: This function MUTATES the SuffixAutomaton in-place, making it SINGLE-USE. +//! It rebases internal pointers from oldBase to newBase directly in the host buffer. +//! After this call, the host buffer's internal pointer graph is relative to newBase +//! (the GPU destination), so the host buffer is effectively corrupted for any subsequent +//! use (e.g., copying to a different GPU slot). If a caller invokes this twice with +//! the same host buffer but a different newBase, the second operation will produce +//! incorrectly relocated pointers. +//! +//! To copy the same automaton state to multiple GPU slots, callers must rebuild the +//! automaton from scratch via initAutomaton() + buildAutomatonFromTokens() for each +//! destination. +//! +//! \param sa Pointer to the SuffixAutomaton (mutated in-place) +//! \param oldBase The current base address (host address) +//! \param newBase The target base address (GPU address) +void relocateAutomaton(SuffixAutomaton* sa, void* oldBase, void* newBase); + +} // namespace kernels::speculative_decoding::suffix_automaton + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp b/cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp index 595fce72db0a..80de395656c6 100644 --- a/cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/suffixAutomaton/bindings.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,56 +26,13 @@ #include #include -// Forward declarations - we don't include the suffix automaton headers directly -// because they contain macros that redefine cudaStream_t to int for non-CUDA compilers. -// These functions are implemented in the CUDA-compiled suffixAutomatonKernels.cu -// Note: Must use the _v1 inline namespace to match the TRTLLM_NAMESPACE_BEGIN macro -namespace tensorrt_llm::_v1::kernels::speculative_decoding::suffix_automaton -{ -// Forward declaration of the opaque SuffixAutomaton type -struct SuffixAutomaton; +// Include only the params header (structs + function declarations). +// We cannot include the full suffixAutomatonKernels.h because it transitively +// includes suffixAutomaton.h → saCudaCallable.h, which redefines cudaStream_t +// to int when __CUDACC__ is not defined (this file is compiled as C++, not CUDA). +#include "tensorrt_llm/kernels/speculativeDecoding/suffixAutomaton/suffixAutomatonParams.h" -struct SuffixAutomatonExtendParams -{ - int batchSize{0}; - int draftLength{0}; - int maxSlots{0}; - int maxSeqLen{0}; - void* slots{nullptr}; - int const* batchIndices{nullptr}; - int* matchLenOut{nullptr}; - int* draftTokensOut{nullptr}; - int const* acceptedTokensIn{nullptr}; - int const* acceptedLensIn{nullptr}; -}; - -void invokeSuffixAutomatonExtend(SuffixAutomatonExtendParams const& params, cudaStream_t stream); - -struct SuffixAutomatonExtendNgramParams -{ - int batchSize{0}; - int draftLength{0}; - int maxNgramSize{-1}; - int maxSlots{0}; - int maxSeqLen{0}; - void* slots{nullptr}; - int const* batchIndices{nullptr}; - int* matchLenOut{nullptr}; - int* draftTokensOut{nullptr}; - int const* acceptedTokensIn{nullptr}; - int const* acceptedLensIn{nullptr}; -}; - -void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& params, cudaStream_t stream); -size_t getSuffixAutomatonStateSize(size_t maxSeqLen); - -// Functions for building automatons - these are implemented in the .cu file -void initAutomaton(void* memory, size_t maxSeqLen); -void buildAutomatonFromTokens(SuffixAutomaton* sa, int const* tokens, int numTokens); -void relocateAutomaton(SuffixAutomaton* sa, void* oldBase, void* newBase); -} // namespace tensorrt_llm::_v1::kernels::speculative_decoding::suffix_automaton - -namespace sa = tensorrt_llm::_v1::kernels::speculative_decoding::suffix_automaton; +namespace sa = tensorrt_llm::kernels::speculative_decoding::suffix_automaton; namespace tensorrt_llm::nanobind::suffix_automaton { @@ -165,6 +122,48 @@ void initBindings(nb::module_& m) "If max_ngram_size == -1, uses longest match. " "If max_ngram_size > 0, tries ngram sizes from max down to 1."); + // Export the global search function (cross-request pattern sharing) + m.def( + "invoke_global_search", + [](int batchSize, int draftLength, int maxNgramSize, int maxSlots, int maxSeqLen, at::Tensor slots, + at::Tensor batchIndices, at::Tensor activeSlotMask, at::Tensor matchLenOut, at::Tensor matchSlotOut, + at::Tensor draftTokensOut, at::Tensor acceptedTokensIn, at::Tensor acceptedLensIn) + { + TORCH_CHECK(slots.is_cuda(), "slots must be a CUDA tensor"); + TORCH_CHECK(batchIndices.is_cuda(), "batchIndices must be a CUDA tensor"); + TORCH_CHECK(activeSlotMask.is_cuda(), "activeSlotMask must be a CUDA tensor"); + TORCH_CHECK(matchLenOut.is_cuda(), "matchLenOut must be a CUDA tensor"); + TORCH_CHECK(matchSlotOut.is_cuda(), "matchSlotOut must be a CUDA tensor"); + TORCH_CHECK(draftTokensOut.is_cuda(), "draftTokensOut must be a CUDA tensor"); + TORCH_CHECK(acceptedTokensIn.is_cuda(), "acceptedTokensIn must be a CUDA tensor"); + TORCH_CHECK(acceptedLensIn.is_cuda(), "acceptedLensIn must be a CUDA tensor"); + TORCH_CHECK(maxSeqLen > 0, "maxSeqLen must be positive"); + + sa::SuffixAutomatonGlobalSearchParams params; + params.batchSize = batchSize; + params.draftLength = draftLength; + params.maxNgramSize = maxNgramSize; + params.maxSlots = maxSlots; + params.maxSeqLen = maxSeqLen; + params.slots = slots.data_ptr(); + params.batchIndices = batchIndices.data_ptr(); + params.activeSlotMask = static_cast(activeSlotMask.data_ptr()); + params.matchLenOut = matchLenOut.data_ptr(); + params.matchSlotOut = matchSlotOut.data_ptr(); + params.draftTokensOut = draftTokensOut.data_ptr(); + params.acceptedTokensIn = static_cast(acceptedTokensIn.data_ptr()); + params.acceptedLensIn = static_cast(acceptedLensIn.data_ptr()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + sa::invokeSuffixAutomatonGlobalSearch(params, stream); + }, + nb::arg("batch_size"), nb::arg("draft_length"), nb::arg("max_ngram_size"), nb::arg("max_slots"), + nb::arg("max_seq_len"), nb::arg("slots"), nb::arg("batch_indices"), nb::arg("active_slot_mask"), + nb::arg("match_len_out"), nb::arg("match_slot_out"), nb::arg("draft_tokens_out"), nb::arg("accepted_tokens_in"), + nb::arg("accepted_lens_in"), + "Invoke global search across all active SA states. " + "Launches extend + search kernels on the same stream for cross-request pattern sharing."); + // Helper function to allocate workspace for suffix automaton states m.def( "allocate_workspace", diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 88ec88a5115e..7c2ce12a7f4b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -3617,9 +3617,9 @@ def _prepare_inputs( f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.") # Initialize SA state for new requests (MTP+SA, EAGLE3+SA, PARD+SA, etc.) - use_sa_spec = (self.spec_config is not None - and getattr(self.spec_config, 'use_sa_spec', False)) - if use_sa_spec and resource_manager is not None and self.mapping.is_last_pp_rank( + has_sa_enhancer = (self.spec_config is not None and getattr( + self.spec_config, 'sa_config', None) is not None) + if has_sa_enhancer and resource_manager is not None and self.mapping.is_last_pp_rank( ): from tensorrt_llm._torch.speculative.suffix_automaton import \ SuffixAutomatonManager diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 41411b533e5f..0211f3e7c989 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -397,8 +397,8 @@ def __init__(self, self.spec_config = spec_config self.mapping = mapping self.sa_enhancer: Optional[SADraftEnhancer] = None - if getattr(spec_config, 'use_sa_spec', False): - self.sa_enhancer = SADraftEnhancer(spec_config.sa_spec_threshold) + if getattr(spec_config, 'sa_config', None) is not None: + self.sa_enhancer = SADraftEnhancer(spec_config.sa_config.threshold) @property def max_draft_len(self) -> int: diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index cda368e972cc..8baf2da76615 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -274,8 +274,8 @@ def __init__(self, self.model_config = model_config self.is_thop = False self.sa_enhancer: Optional[SADraftEnhancer] = None - if spec_config.use_sa_spec: - self.sa_enhancer = SADraftEnhancer(spec_config.sa_spec_threshold) + if spec_config.sa_config is not None: + self.sa_enhancer = SADraftEnhancer(spec_config.sa_config.threshold) @property def max_draft_len(self) -> int: diff --git a/tensorrt_llm/_torch/speculative/pard.py b/tensorrt_llm/_torch/speculative/pard.py index f4da0e6c9d26..628de48022d7 100644 --- a/tensorrt_llm/_torch/speculative/pard.py +++ b/tensorrt_llm/_torch/speculative/pard.py @@ -89,8 +89,8 @@ def __init__( self.spec_config = spec_config self.mapping = mapping self.sa_enhancer: Optional[SADraftEnhancer] = None - if getattr(spec_config, "use_sa_spec", False): - self.sa_enhancer = SADraftEnhancer(spec_config.sa_spec_threshold) + if getattr(spec_config, "sa_config", None) is not None: + self.sa_enhancer = SADraftEnhancer(spec_config.sa_config.threshold) logger.info( f"PARDWorker initialized with use_separate_draft_kv_cache={use_separate_draft_kv_cache}" ) diff --git a/tensorrt_llm/_torch/speculative/sa_enhancer.py b/tensorrt_llm/_torch/speculative/sa_enhancer.py index dec2c4bbe107..1a2d73b1c227 100644 --- a/tensorrt_llm/_torch/speculative/sa_enhancer.py +++ b/tensorrt_llm/_torch/speculative/sa_enhancer.py @@ -33,18 +33,44 @@ class SADraftEnhancer: override draft tokens) so that any worker (MTP, EAGLE3, PARD, etc.) can opt into SA enhancement. + The SA extend+search kernels are launched on a dedicated CUDA side-stream + so they overlap with the compute-heavy draft model forward passes on the + main stream. Results are synchronized lazily — only when + ``maybe_override_all_draft_tokens`` is called after the draft loop. + Usage: - 1. Construct once during worker ``__init__`` when ``use_sa_spec`` is True. + 1. Construct once during worker ``__init__`` when ``sa_config`` is set. 2. Call ``extend_and_prepare`` after ``sample_and_accept_draft_tokens``. 3. Call ``maybe_override_all_draft_tokens`` once after all draft layers have finished, so that neural draft layers never see SA tokens. """ - def __init__(self, sa_spec_threshold: int): - self.sa_spec_threshold = sa_spec_threshold + def __init__(self, threshold: int): + self.threshold = threshold self.sa_match_len: Optional[torch.Tensor] = None self.sa_draft_tokens: Optional[torch.Tensor] = None self.sa_spec_index: int = 0 + self._sa_stream: Optional[torch.cuda.Stream] = None + self._sa_event: Optional[torch.cuda.Event] = None + self._num_gens: int = 0 + + def _ensure_stream(self) -> None: + if self._sa_stream is None: + self._sa_stream = torch.cuda.Stream() + self._sa_event = torch.cuda.Event() + + def _ensure_buffers(self, num_gens: int, max_draft_len: int) -> None: + """Pre-allocate / reuse GPU buffers for SA match results.""" + if self.sa_match_len is None or self.sa_match_len.shape[0] < num_gens: + self.sa_match_len = torch.zeros((num_gens,), dtype=torch.int32, device="cuda") + if ( + self.sa_draft_tokens is None + or self.sa_draft_tokens.shape[0] < num_gens + or self.sa_draft_tokens.shape[1] < max_draft_len + ): + self.sa_draft_tokens = torch.zeros( + (num_gens, max_draft_len), dtype=torch.int32, device="cuda" + ) def extend_and_prepare( self, @@ -59,7 +85,8 @@ def extend_and_prepare( """Extend SA states with accepted tokens and prepare override buffers. Must be called after ``sample_and_accept_draft_tokens`` and before the - draft generation loop. + draft generation loop. The SA kernels are launched on a side-stream so + the caller can immediately proceed with the draft model forward passes. Args: sa_manager: The SuffixAutomatonManager instance. @@ -71,10 +98,8 @@ def extend_and_prepare( num_contexts: Number of context requests in the batch. max_draft_len: Number of draft positions to produce. """ - self.sa_match_len = torch.zeros((num_gens,), dtype=torch.int32, device="cuda") - self.sa_draft_tokens = torch.zeros( - (num_gens, max_draft_len), dtype=torch.int32, device="cuda" - ) + self._ensure_buffers(num_gens, max_draft_len) + self._num_gens = num_gens self.sa_spec_index = 0 if num_gens > 0: @@ -86,14 +111,30 @@ def extend_and_prepare( # slice + .contiguous() below compacts memory so the stride # matches the kernel expectation. gen_accepted = accepted_tokens[num_contexts:, : max_draft_len + 1].contiguous() - match_len, draft_tokens_sa = sa_manager.extend( - gen_request_ids, - gen_accepted, - num_accepted_tokens[num_contexts:], - max_draft_len, - ) - self.sa_match_len.copy_(match_len) - self.sa_draft_tokens.copy_(draft_tokens_sa) + + self._ensure_stream() + main_stream = torch.cuda.current_stream() + self._sa_stream.wait_stream(main_stream) + + with torch.cuda.stream(self._sa_stream): + if sa_manager.enable_global_pool: + match_len, draft_tokens_sa = sa_manager.extend_global( + gen_request_ids, + gen_accepted, + num_accepted_tokens[num_contexts:], + max_draft_len, + ) + else: + match_len, draft_tokens_sa = sa_manager.extend( + gen_request_ids, + gen_accepted, + num_accepted_tokens[num_contexts:], + max_draft_len, + ) + self.sa_match_len[:num_gens].copy_(match_len) + self.sa_draft_tokens[:num_gens, :max_draft_len].copy_(draft_tokens_sa) + + self._sa_event.record(self._sa_stream) def maybe_override_all_draft_tokens( self, @@ -110,11 +151,13 @@ def maybe_override_all_draft_tokens( Returns: The (potentially overridden) draft tokens tensor. """ - if self.sa_match_len is not None and self.sa_match_len.shape[0] > 0: + if self.sa_match_len is not None and self._num_gens > 0: + if self._sa_event is not None: + torch.cuda.current_stream().wait_event(self._sa_event) + + n = self._num_gens K = draft_tokens.shape[1] - mask = ( - (self.sa_match_len >= self.sa_spec_threshold).unsqueeze(1).expand_as(draft_tokens) - ) - draft_tokens = torch.where(mask, self.sa_draft_tokens[:, :K], draft_tokens) + mask = (self.sa_match_len[:n] >= self.threshold).unsqueeze(1).expand_as(draft_tokens) + draft_tokens = torch.where(mask, self.sa_draft_tokens[:n, :K], draft_tokens) return draft_tokens diff --git a/tensorrt_llm/_torch/speculative/sa_worker.py b/tensorrt_llm/_torch/speculative/sa_worker.py index 7d7c20231214..f014423686f5 100644 --- a/tensorrt_llm/_torch/speculative/sa_worker.py +++ b/tensorrt_llm/_torch/speculative/sa_worker.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -281,15 +281,27 @@ def _generate_draft_tokens( # No SA manager available, throw error raise ValueError("No SA manager available") - # extend_ngram is CUDA graph compatible - # It extends SA states with accepted tokens and performs pattern matching - match_len, draft_tokens = sa_manager.extend_ngram( - request_ids, - accepted_tokens, - num_accepted_tokens, - max_draft_len, - max_ngram_size=self._max_matching_ngram_size, - ) + if sa_manager.enable_global_pool: + match_len, draft_tokens = sa_manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=self._max_matching_ngram_size, + ) + else: + match_len, draft_tokens = sa_manager.extend_ngram( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=self._max_matching_ngram_size, + ) + + # Gate draft tokens by match_len: zero out rows where no match was found, + # so stale or uninitialized draft tokens are never used (CUDA graph safe). + mask = (match_len > 0).unsqueeze(1).to(draft_tokens.dtype) + draft_tokens = draft_tokens * mask return draft_tokens # [batch_size, max_draft_len] GPU tensor diff --git a/tensorrt_llm/_torch/speculative/suffix_automaton.py b/tensorrt_llm/_torch/speculative/suffix_automaton.py index 1c2ab0673476..1eb7be0bcea5 100644 --- a/tensorrt_llm/_torch/speculative/suffix_automaton.py +++ b/tensorrt_llm/_torch/speculative/suffix_automaton.py @@ -18,6 +18,7 @@ """ import logging +from collections import OrderedDict from dataclasses import dataclass from typing import Dict, List, Optional, Set @@ -33,6 +34,9 @@ logger = logging.getLogger(__name__) +_DEFAULT_POOL_SIZE = 64 + + @dataclass class SAConfig: """Configuration for suffix automaton speculative decoding.""" @@ -46,6 +50,27 @@ class SAConfig: # Minimum match length to use SA draft tokens threshold: int = 4 + # Enable global pool search across all active SA states + enable_global_pool: bool = False + + # Explicit global pool size (None = use default heuristic) + global_pool_size: Optional[int] = None + + @property + def effective_pool_size(self) -> int: + """Actual number of SA slots to allocate. + + When global pool is enabled but no explicit size is given, + defaults to max(_DEFAULT_POOL_SIZE, max_slots) — a fixed-size + pool independent of batch size, floored to at least max_slots + so there's always room for the full batch. + """ + if self.global_pool_size is not None: + return self.global_pool_size + if self.enable_global_pool: + return max(_DEFAULT_POOL_SIZE, self.max_slots) + return self.max_slots + class SuffixAutomatonManager(BaseResourceManager): """ @@ -76,36 +101,61 @@ def __init__( "Please ensure the native bindings are properly built." ) + from tensorrt_llm.llmapi.llm_args import SADecodingConfig, SAEnhancerConfig + # SA configuration - sa_config = ( - config - if isinstance(config, SAConfig) - else SAConfig( + if isinstance(config, SAConfig): + sa_config = config + elif isinstance(config, SAEnhancerConfig): + sa_config = SAConfig( max_seq_len=max_seq_len, max_slots=max_num_requests, - threshold=getattr(config, "sa_spec_threshold", 4), + threshold=config.threshold, + enable_global_pool=config.enable_global_pool, + ) + elif isinstance(config, SADecodingConfig): + sa_config = SAConfig( + max_seq_len=max_seq_len, + max_slots=max_num_requests, + enable_global_pool=config.enable_global_pool, + global_pool_size=config.global_pool_size, + ) + else: + raise TypeError( + f"SuffixAutomatonManager received unsupported config type " + f"{type(config).__name__}. Expected SAConfig, SAEnhancerConfig," + f" or SADecodingConfig." ) - ) self.config = sa_config self.max_num_requests = max_num_requests self.max_seq_len = sa_config.max_seq_len + self.enable_global_pool = sa_config.enable_global_pool + + # Pool sizing: effective_pool_size returns max_num_requests when + # global pool is off, or max(64, max_num_requests) / explicit + # value when on. All slot-indexed sizing uses pool_size. + self.pool_size = sa_config.effective_pool_size + if self.pool_size < max_num_requests: + raise ValueError( + f"global_pool_size ({self.pool_size}) must be >= " + f"max_batch_size ({max_num_requests})" + ) # Calculate per-state size based on max_seq_len self.state_size = _sa_native.get_state_size(self.max_seq_len) - # Log memory usage - total_memory_mb = max_num_requests * self.state_size / 1024 / 1024 logger.info( - f"Allocating {max_num_requests} SA slots with max_seq_len={self.max_seq_len} " - f"({self.state_size / 1024 / 1024:.1f} MB/slot, {total_memory_mb:.1f} MB total)" + f"SA pool: {self.pool_size} slots " + f"({self.pool_size - max_num_requests} retained capacity, " + f"{self.pool_size * self.state_size / 1024 / 1024:.1f} MB total)" ) # Request ID -> slot index mapping self._request_to_slot: Dict[int, int] = {} - # Free slots for reuse - self._free_slots: List[int] = list(range(max_num_requests)) + # Free slots now range over the full pool + self._free_slots: List[int] = list(range(self.pool_size)) # Host-side SA states as pinned memory tensors self._host_states_native: Dict[int, torch.Tensor] = {} @@ -123,14 +173,34 @@ def __init__( self._gpu_draft_tokens: Optional[torch.Tensor] = None self._gpu_batch_indices: Optional[torch.Tensor] = None + # Global pool buffers (allocated lazily in _ensure_workspace) + self._gpu_active_slot_mask: Optional[torch.Tensor] = None + self._gpu_match_slot: Optional[torch.Tensor] = None + self._pending_mask_updates: Dict[int, int] = {} + # Track which requests have been initialized (for prepare_resources) self._initialized_requests: Set[int] = set() - # Reserved slot for CUDA graph dummy requests — shared by all dummies - # so they never consume slots from the real pool. - self._dummy_slot_index: int = max_num_requests + # Dummy slot lives right after the pool — always use pool_size, + # not max_num_requests, so dummies never collide with pool slots. + self._dummy_slot_index: int = self.pool_size self._dummy_request_ids: Set[int] = set() + # Pre-allocated CPU staging buffers for prepare() to avoid + # repeated pinned-memory allocation every round. + self._cpu_batch_indices: Optional[torch.Tensor] = None + self._cpu_nondummy_mask: Optional[torch.Tensor] = None + + # Retained slots: completed requests whose SA states remain in the + # pool for cross-request search. OrderedDict preserves insertion + # (completion) order for FIFO eviction. + # key: slot index + # value: original request_id (for debugging/logging) + self._retained_slots: OrderedDict[int, int] = OrderedDict() + + # Track which slots have active (in-flight) requests. + self._active_slots: Set[int] = set() + def _ensure_workspace(self, max_draft_len: int): """Ensure GPU workspace is allocated with sufficient capacity. @@ -141,7 +211,7 @@ def _ensure_workspace(self, max_draft_len: int): ValueError: If called with max_draft_len larger than previously allocated. """ if not self._workspace_allocated: - # First allocation - create all buffers + # Batch-indexed buffers: sized to max_num_requests (max batch size) self._gpu_match_len = torch.zeros( (self.max_num_requests,), dtype=torch.int32, device="cuda" ) @@ -159,11 +229,18 @@ def _ensure_workspace(self, max_draft_len: int): (self.max_num_requests,), dtype=torch.int32, device="cuda" ) - # Allocate one extra slot beyond max_num_requests for the shared - # CUDA graph dummy (slot index = max_num_requests). - self._gpu_slots = _sa_native.allocate_workspace( - self.max_num_requests + 1, self.max_seq_len - ) + # Slot-indexed buffers: pool_size + 1 (pool slots + dummy slot). + # When global pool is off, pool_size == max_num_requests. + self._gpu_slots = _sa_native.allocate_workspace(self.pool_size + 1, self.max_seq_len) + + # Global pool buffers + if self.enable_global_pool: + self._gpu_active_slot_mask = torch.zeros( + (self.pool_size + 1,), dtype=torch.int32, device="cuda" + ) + self._gpu_match_slot = torch.zeros( + (self.max_num_requests,), dtype=torch.int32, device="cuda" + ) self._allocated_max_draft_len = max_draft_len self._workspace_allocated = True @@ -180,6 +257,26 @@ def _ensure_workspace(self, max_draft_len: int): # --- Core SA operations --- + def _allocate_slot(self) -> int: + """Allocate a slot, evicting the oldest retained request if needed.""" + if self._free_slots: + return self._free_slots.pop() + + if self._retained_slots: + slot, old_rid = self._retained_slots.popitem(last=False) # FIFO + if self._gpu_slots is not None: + _sa_native.clear_slot(self._gpu_slots, slot, self.max_seq_len) + self._pending_mask_updates[slot] = 0 + logger.debug( + f"Evicted retained SA slot {slot} (request {old_rid}) to make room for new request" + ) + return slot + + raise RuntimeError( + f"No free or retained slots available. pool_size={self.pool_size}, " + f"active={len(self._active_slots)}" + ) + def add_request(self, request_id: int, context_tokens: List[int]): """ Add a new request and build its initial suffix automaton from context. @@ -196,12 +293,9 @@ def add_request(self, request_id: int, context_tokens: List[int]): self._pending_copies.add(request_id) return - if not self._free_slots: - raise RuntimeError("No free slots available for new request") - - # Allocate a slot - slot = self._free_slots.pop() + slot = self._allocate_slot() self._request_to_slot[request_id] = slot + self._active_slots.add(slot) # Build SA state on host using native code self._host_states_native[request_id] = _sa_native.build_automaton_host( @@ -209,26 +303,44 @@ def add_request(self, request_id: int, context_tokens: List[int]): ) self._pending_copies.add(request_id) + if self.enable_global_pool: + self._pending_mask_updates[slot] = 1 + def remove_request(self, request_id: int): - """Remove a request and free its resources.""" + """Remove a request, retaining its SA state for cross-request search + when global pool is enabled and the pool has retention capacity.""" if request_id not in self._request_to_slot: return slot = self._request_to_slot.pop(request_id) if request_id in self._dummy_request_ids: - # Dummy slot is reserved; never return it to the free pool. self._dummy_request_ids.discard(request_id) + # Dummy slot is reserved; never retain or free. + return + + self._active_slots.discard(slot) + + # If the GPU copy was never flushed, the slot contains stale data — + # skip retention and free immediately to avoid searching garbage. + stale = request_id in self._pending_copies + + if self.enable_global_pool and self.pool_size > self.max_num_requests and not stale: + # Retain: keep SA state alive for cross-request search. + # Active mask stays ON — the slot is still searchable. + self._retained_slots[slot] = request_id else: + # Free immediately + if self.enable_global_pool: + self._pending_mask_updates[slot] = 0 self._free_slots.append(slot) + if self._gpu_slots is not None: + _sa_native.clear_slot(self._gpu_slots, slot, self.max_seq_len) self._host_states_native.pop(request_id, None) self._pending_copies.discard(request_id) self._initialized_requests.discard(request_id) - if self._gpu_slots is not None: - _sa_native.clear_slot(self._gpu_slots, slot, self.max_seq_len) - def prepare(self, request_ids: List[int], max_draft_len: int): """ Prepare batch indices for the upcoming extend() call. @@ -254,28 +366,43 @@ def prepare(self, request_ids: List[int], max_draft_len: int): ) self._pending_copies.clear() + # Flush deferred active-slot-mask updates. add_request/remove_request + # queue updates because the GPU tensor may not exist yet at that point; + # here _ensure_workspace has already run so the tensor is available. + if self._pending_mask_updates and self._gpu_active_slot_mask is not None: + for slot, value in self._pending_mask_updates.items(): + self._gpu_active_slot_mask[slot] = value + self._pending_mask_updates.clear() + # Map each request ID to its slot. Unknown IDs (e.g. CUDA graph # warmup dummies that skipped the context phase) are routed to the # reserved dummy slot so the kernel still runs on valid memory. + num_requests = len(request_ids) slots = [self._request_to_slot.get(rid, self._dummy_slot_index) for rid in request_ids] - batch_indices = torch.tensor( - slots, - dtype=torch.int32, - pin_memory=prefer_pinned(), - ) - # Build a non-dummy mask (1 = real, 0 = dummy) on CPU, then copy to - # the pre-allocated GPU buffer. extend() will use this mask via a - # simple element-wise multiply which is CUDA-graph-safe. - nondummy_mask = torch.tensor( - [0 if s == self._dummy_slot_index else 1 for s in slots], - dtype=torch.int32, - pin_memory=prefer_pinned(), - ) - num_requests = len(request_ids) + # Reuse pre-allocated pinned CPU buffers to avoid costly + # pinned-memory allocation every round. + if self._cpu_batch_indices is None or self._cpu_batch_indices.shape[0] < num_requests: + buf_size = self.max_num_requests + self._cpu_batch_indices = torch.zeros( + buf_size, dtype=torch.int32, pin_memory=prefer_pinned() + ) + self._cpu_nondummy_mask = torch.zeros( + buf_size, dtype=torch.int32, pin_memory=prefer_pinned() + ) + + batch_indices = self._cpu_batch_indices[:num_requests] + nondummy_mask = self._cpu_nondummy_mask[:num_requests] + for i, s in enumerate(slots): + batch_indices[i] = s + nondummy_mask[i] = 0 if s == self._dummy_slot_index else 1 + self._gpu_batch_indices[:num_requests].copy_(batch_indices, non_blocking=True) self._gpu_nondummy_mask[:num_requests].copy_(nondummy_mask, non_blocking=True) - torch.cuda.synchronize() + # Stream-ordered: the non_blocking copies above are on the current + # stream, so any kernel launched on the same stream afterwards will + # see the updated values. No device-wide sync needed. + torch.cuda.current_stream().synchronize() def extend( self, @@ -326,7 +453,7 @@ def extend( _sa_native.invoke_extend( batch_size, max_draft_len, - self.max_num_requests + 1, + self.pool_size + 1, self.max_seq_len, self._gpu_slots, self._gpu_batch_indices[:batch_size], @@ -386,11 +513,70 @@ def extend_ngram( batch_size, max_draft_len, max_ngram_size, - self.max_num_requests + 1, + self.pool_size + 1, + self.max_seq_len, + self._gpu_slots, + self._gpu_batch_indices[:batch_size], + match_len, + draft_tokens, + accepted_tokens, + num_accepted_tokens, + ) + + return match_len, draft_tokens + + def extend_global( + self, + request_ids: List[int], + accepted_tokens: torch.Tensor, + num_accepted_tokens: torch.Tensor, + max_draft_len: int, + max_ngram_size: int = -1, + ) -> tuple: + """ + Extend SA states and search across all active SAs for the best match. + + Each request's SA is extended with accepted tokens, then all active + SA states are searched in parallel to find the longest match across + the pool. CUDA graph compatible (two kernel launches on same stream). + + Args: + request_ids: List of request IDs in the batch + accepted_tokens: [batch_size, max_draft_len + 1] accepted token tensor + num_accepted_tokens: [batch_size] number of accepted tokens per request + max_draft_len: Maximum draft length + max_ngram_size: Max ngram size for suffix extraction (-1 = full) + + Returns: + Tuple of (match_len, draft_tokens) tensors + """ + self._ensure_workspace(max_draft_len) + + batch_size = len(request_ids) + + match_len = self._gpu_match_len[:batch_size] + match_slot = self._gpu_match_slot[:batch_size] + draft_tokens = self._gpu_draft_tokens[:batch_size, :max_draft_len] + + if accepted_tokens.dtype != torch.int32: + accepted_tokens = accepted_tokens.to(torch.int32) + if num_accepted_tokens.dtype != torch.int32: + num_accepted_tokens = num_accepted_tokens.to(torch.int32) + + # Zero out dummy entries (see extend() for rationale). + num_accepted_tokens = num_accepted_tokens * self._gpu_nondummy_mask[:batch_size] + + _sa_native.invoke_global_search( + batch_size, + max_draft_len, + max_ngram_size, + self.pool_size + 1, self.max_seq_len, self._gpu_slots, self._gpu_batch_indices[:batch_size], + self._gpu_active_slot_mask, match_len, + match_slot, draft_tokens, accepted_tokens, num_accepted_tokens, @@ -421,7 +607,7 @@ def add_dummy_requests(self, request_ids: List[int]): """Add dummy requests for CUDA graph padding. Dummy requests are mapped to a single reserved slot - (index = max_num_requests) that lives outside the real slot pool. + (index = pool_size) that lives outside the real slot pool. This prevents CUDA graph padding from exhausting slots that real requests need. @@ -441,8 +627,10 @@ def shutdown(self): torch.cuda.synchronize() self._request_to_slot.clear() - self._free_slots = list(range(self.max_num_requests)) + self._free_slots = list(range(self.pool_size)) self._dummy_request_ids.clear() + self._retained_slots.clear() + self._active_slots.clear() self._host_states_native.clear() self._pending_copies.clear() @@ -452,6 +640,9 @@ def shutdown(self): self._gpu_match_len = None self._gpu_draft_tokens = None self._gpu_batch_indices = None + self._gpu_active_slot_mask = None + self._gpu_match_slot = None + self._pending_mask_updates.clear() self._workspace_allocated = False self._allocated_max_draft_len = 0 diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index d834dbd665df..33f377063994 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -155,8 +155,9 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): spec_dec_mode = spec_config.spec_dec_mode if spec_dec_mode.is_mtp_eagle_one_model(): sa_manager = None - if getattr(spec_config, 'use_sa_spec', False): - sa_manager = SuffixAutomatonManager(spec_config, max_num_requests, + sa_cfg = getattr(spec_config, 'sa_config', None) + if sa_cfg is not None: + sa_manager = SuffixAutomatonManager(sa_cfg, max_num_requests, max_seq_len) if spec_config.use_relaxed_acceptance_for_thinking or sa_manager is not None: return MTPHiddenStatesManager( @@ -170,8 +171,9 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): return None if spec_dec_mode.is_mtp_one_model(): sa_manager = None - if getattr(spec_config, 'use_sa_spec', False): - sa_manager = SuffixAutomatonManager(spec_config, max_num_requests, + sa_cfg = getattr(spec_config, 'sa_config', None) + if sa_cfg is not None: + sa_manager = SuffixAutomatonManager(sa_cfg, max_num_requests, max_seq_len) return MTPHiddenStatesManager( spec_config, @@ -182,8 +184,9 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): ) if spec_dec_mode.is_eagle3_one_model(): sa_manager = None - if getattr(spec_config, 'use_sa_spec', False): - sa_manager = SuffixAutomatonManager(spec_config, max_num_requests, + sa_cfg = getattr(spec_config, 'sa_config', None) + if sa_cfg is not None: + sa_manager = SuffixAutomatonManager(sa_cfg, max_num_requests, max_seq_len) if sa_manager is not None: return Eagle3ResourceManager( @@ -215,9 +218,9 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): max_num_tokens, ) if spec_dec_mode.is_pard(): - if getattr(spec_config, 'use_sa_spec', False): - return SuffixAutomatonManager(spec_config, max_num_requests, - max_seq_len) + sa_cfg = getattr(spec_config, 'sa_config', None) + if sa_cfg is not None: + return SuffixAutomatonManager(sa_cfg, max_num_requests, max_seq_len) return None if spec_dec_mode.is_ngram(): return NGramPoolManager(spec_config, max_num_requests) diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index eb711e48788c..2b5a63dd3d85 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -15,10 +15,10 @@ LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, - SADecodingConfig, SaveHiddenStatesDecodingConfig, - SchedulerConfig, SkipSoftmaxAttentionConfig, - TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, - UserProvidedDecodingConfig) + SADecodingConfig, SAEnhancerConfig, + SaveHiddenStatesDecodingConfig, SchedulerConfig, + SkipSoftmaxAttentionConfig, TorchCompileConfig, + TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mm_encoder import MultimodalEncoder @@ -63,6 +63,7 @@ 'NGramDecodingConfig', 'PARDDecodingConfig', 'SADecodingConfig', + 'SAEnhancerConfig', 'UserProvidedDecodingConfig', 'TorchCompileConfig', 'DraftTargetDecodingConfig', diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index c5f6776eb2fd..62006752ce10 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1180,19 +1180,37 @@ def is_linear_tree(self) -> bool: return False +class SAEnhancerConfig(StrictBaseModel): + """Configuration for the Suffix Automaton (SA) draft enhancer. + + Use this to combine SA pattern-matching drafting with another speculative + decoding method (Eagle3, MTP, PARD). When provided as ``sa_config`` on a + decoding config, SA drafting is enabled and may override neural draft + tokens when the suffix match length meets the *threshold*. + + For standalone SA speculative decoding (no neural drafter), use + :class:`SADecodingConfig` instead. + """ + + threshold: PositiveInt = Field( + default=4, + description="Minimum suffix match length required for the SA output " + "to override neural draft tokens.") + enable_global_pool: bool = Field( + default=False, + description="When True, each request searches all active SA states " + "for the longest match, not just its own. Improves acceptance rates " + "when requests share common patterns.") + + class Eagle3DecodingConfig(EagleDecodingConfig): decoding_type: Literal["Eagle3"] = "Eagle3" - # Suffix Automaton speculative decoding settings - use_sa_spec: Optional[bool] = Field( - default=False, + sa_config: Optional[SAEnhancerConfig] = Field( + default=None, status="beta", - description="Combine with Suffix Automaton Decoding") - sa_spec_threshold: PositiveInt = Field( - default=4, - description="The threshold for the Suffix Automaton Decoding. If the" - " length of the suffix match exceeds the threshold, use" - " the suffix automaton output for the next draft tokens.") + description="Optional Suffix Automaton configuration. When set, " + "combines SA drafting with Eagle3 speculative decoding.") class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): @@ -1329,17 +1347,37 @@ def supports_backend(self, backend: str) -> bool: class SADecodingConfig(DecodingBaseConfig): - """ - Configuration for Suffix Automaton (SA) speculative decoding (one-model design). + """Configuration for standalone Suffix Automaton (SA) speculative decoding. - Uses a GPU-native suffix automaton for pattern matching. Drafting runs inside - the target model forward; supports CUDA graph and overlap scheduler. + Uses a GPU-native suffix automaton for pattern matching. Drafting runs + inside the target model forward; supports CUDA graph and overlap scheduler. + + To combine SA with a neural drafter (Eagle3, MTP, PARD) instead of using + it standalone, pass :class:`SAEnhancerConfig` via ``sa_config``. """ decoding_type: Literal["SA"] = "SA" max_matching_ngram_size: int = Field( default=-1, description="Positive value (e.g., 3): fixed-size ngram matching. " "-1: longest possible match via suffix automaton. 0 is invalid.") + enable_global_pool: bool = Field( + default=False, + description="When True, each request searches all active SA states " + "for the longest match, not just its own. Improves acceptance rates " + "when requests share common patterns. " + "Limitations: at most 1024 concurrent slots; suffix matching is " + "capped at 64 tokens per request.") + + global_pool_size: Optional[PositiveInt] = Field( + default=None, + description="Number of SA slots in the global pool. " + "When None and enable_global_pool=True, defaults to " + "max(64, max_batch_size) — a fixed-size pool independent of batch size. " + "When set explicitly, must be >= max_batch_size. " + "Completed requests' SA states are retained in the pool for " + "cross-request search until the pool is full, at which point " + "the oldest completed request is evicted. " + "Only effective when enable_global_pool=True.") @model_validator(mode='after') def validate_sa_config(self): @@ -1347,8 +1385,20 @@ def validate_sa_config(self): raise ValueError( "max_matching_ngram_size must be > 0 (fixed ngram) or -1 (longest match). " "Got 0.") + if self.enable_global_pool and self.max_matching_ngram_size != -1 and not ( + 1 <= self.max_matching_ngram_size <= 64): + raise ValueError( + "max_matching_ngram_size must be -1 (longest match) or in [1, 64] " + "when enable_global_pool is True. " + f"Got {self.max_matching_ngram_size}.") if self.max_draft_len is None or self.max_draft_len <= 0: raise ValueError("max_draft_len must be > 0 for SA") + if self.global_pool_size is not None: + if self.global_pool_size < 1: + raise ValueError("global_pool_size must be >= 1") + if not self.enable_global_pool: + raise ValueError( + "global_pool_size requires enable_global_pool=True") self.max_total_draft_tokens = self.max_draft_len return self @@ -1415,16 +1465,11 @@ class MTPDecodingConfig(DecodingBaseConfig): "When using EAGLE-style MTP, use faster one-model implementation (drafter as submodule) vs two-model." ) - # Suffix Automaton speculative decoding settings - use_sa_spec: Optional[bool] = Field( - default=False, + sa_config: Optional[SAEnhancerConfig] = Field( + default=None, status="beta", - description="Combine with Suffix Automaton Decoding") - sa_spec_threshold: PositiveInt = Field( - default=4, - description="The threshold for the Suffix Automaton Decoding. If the" - " length of the suffix match exceeds the threshold, use" - " the suffix automaton output for the next draft tokens.") + description="Optional Suffix Automaton configuration. When set, " + "combines SA drafting with MTP speculative decoding.") # TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers` # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine. @@ -1504,16 +1549,11 @@ class PARDDecodingConfig(DecodingBaseConfig): decoding_type: Literal["PARD"] = "PARD" - # Suffix Automaton speculative decoding settings - use_sa_spec: Optional[bool] = Field( - default=False, + sa_config: Optional[SAEnhancerConfig] = Field( + default=None, status="beta", - description="Combine with Suffix Automaton Decoding") - sa_spec_threshold: PositiveInt = Field( - default=4, - description="The threshold for the Suffix Automaton Decoding. If the" - " length of the suffix match exceeds the threshold, use" - " the suffix automaton output for the next draft tokens.") + description="Optional Suffix Automaton configuration. When set, " + "combines SA drafting with PARD speculative decoding.") @model_validator(mode="after") def set_max_total_draft_tokens(self): @@ -3644,6 +3684,14 @@ def validate_speculative_config(self): if isinstance(self.speculative_config, PARDDecodingConfig): assert self.speculative_config.max_draft_len > 0, "PARD max_draft_len must be > 0" + if isinstance(self.speculative_config, SADecodingConfig): + pool_size = self.speculative_config.global_pool_size + if pool_size is not None and self.max_batch_size is not None: + if pool_size < self.max_batch_size: + raise ValueError( + f"global_pool_size ({pool_size}) must be >= " + f"max_batch_size ({self.max_batch_size})") + if isinstance(self.speculative_config, SaveHiddenStatesDecodingConfig): logger.warning( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fb85d4e81b53..d2cdf7e0cd03 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -59,7 +59,8 @@ def patched_start_mpi_pool(self): DeepSeekSparseAttentionConfig, Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, SADecodingConfig, SamplingParams, - SchedulerConfig, SkipSoftmaxAttentionConfig, TorchCompileConfig) + SchedulerConfig, SkipSoftmaxAttentionConfig, SAEnhancerConfig, + TorchCompileConfig) # isort: on from tensorrt_llm.quantization import QuantAlgo @@ -373,7 +374,35 @@ def test_eagle3_sa(self): spec_config = Eagle3DecodingConfig(max_draft_len=4, speculative_model=eagle_model_dir, eagle3_one_model=True, - use_sa_spec=True) + sa_config=SAEnhancerConfig()) + + with LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + + @skip_pre_hopper + def test_eagle3_sa_global_pool(self): + """Accuracy test for EAGLE3 One-Model + Suffix Automaton with global pool enabled.""" + max_batch_size = 32 + pytorch_config = dict( + max_batch_size=max_batch_size, + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=max_batch_size, + enable_padding=True), + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + + eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + + spec_config = Eagle3DecodingConfig( + max_draft_len=4, + speculative_model=eagle_model_dir, + eagle3_one_model=True, + sa_config=SAEnhancerConfig(enable_global_pool=True)) with LLM(model=target_model_dir, **pytorch_config, @@ -428,7 +457,35 @@ def test_pard_sa(self): spec_config = PARDDecodingConfig(max_draft_len=4, speculative_model=pard_model_dir, - use_sa_spec=True) + sa_config=SAEnhancerConfig()) + + with LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + + @skip_pre_hopper + @pytest.mark.skip(reason="PARD accuracy issue with batch size > 1") + def test_pard_sa_global_pool(self): + """Accuracy test for PARD + Suffix Automaton with global pool enabled.""" + max_batch_size = 32 + pytorch_config = dict( + max_batch_size=max_batch_size, + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=max_batch_size, + enable_padding=True), + ) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + + pard_model_dir = f"{llm_models_root()}/PARD-Llama-3.2-1B" + target_model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" + + spec_config = PARDDecodingConfig( + max_draft_len=4, + speculative_model=pard_model_dir, + sa_config=SAEnhancerConfig(enable_global_pool=True)) with LLM(model=target_model_dir, **pytorch_config, @@ -467,7 +524,8 @@ def test_ngram(self): task.evaluate(llm) @skip_pre_hopper - def test_suffix_automaton(self): + @parametrize_with_ids("enable_global_pool", [False, True]) + def test_suffix_automaton(self, enable_global_pool): max_bs = 16 pytorch_config = dict( @@ -482,6 +540,7 @@ def test_suffix_automaton(self): spec_config = SADecodingConfig( max_draft_len=4, max_matching_ngram_size=-1, # longest match via suffix automaton + enable_global_pool=enable_global_pool, ) with LLM(model=self.MODEL_PATH, @@ -1685,7 +1744,29 @@ def test_bfloat16_mtp_sa(self): cuda_graph_config=CudaGraphConfig(), ) mtp_config = MTPDecodingConfig(num_nextn_predict_layers=2, - use_sa_spec=True) + sa_config=SAEnhancerConfig()) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, + max_num_tokens=8192, + **pytorch_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, extra_acc_spec="use_sa_spec") + + @pytest.mark.skip_less_device_memory(60000) + def test_bfloat16_mtp_sa_global_pool(self): + """Accuracy test for MTP + Suffix Automaton with global pool enabled.""" + max_batch_size = 32 + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + pytorch_config = dict( + max_batch_size=max_batch_size, + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=max_batch_size, + enable_padding=True), + ) + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=2, + sa_config=SAEnhancerConfig(enable_global_pool=True)) with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config, max_num_tokens=8192, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index f1c7849beb8c..229a7f1e180e 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -9,6 +9,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_a accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3_sa_global_pool accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] @@ -37,6 +38,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard_sa_global_pool accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype @@ -97,6 +99,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_sched accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-enable_chunked_prefill=True] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_python_scheduler[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-enable_chunked_prefill=True] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_mtp_sa_global_pool accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_2_model_mtp accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] diff --git a/tests/torch/speculative/test_suffix_automaton.py b/tests/torch/speculative/test_suffix_automaton.py index 668da1c6c3fb..ee5a7f03b365 100644 --- a/tests/torch/speculative/test_suffix_automaton.py +++ b/tests/torch/speculative/test_suffix_automaton.py @@ -570,6 +570,425 @@ def test_extend_ngram_cuda_graph(self): manager.shutdown() +class TestExtendGlobal: + """Tests for extend_global() — cross-request pattern sharing.""" + + def test_extend_global_cross_request_match(self): + """Request B finds a pattern from Request A's context.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + # Request 0: context has [1, 2, 3, 4, 5] + manager.add_request(0, [1, 2, 3, 4, 5]) + # Request 1: context has [10, 20, 1, 2, 3] — ends with same [1, 2, 3] + manager.add_request(1, [10, 20, 1, 2, 3]) + + request_ids = [0, 1] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + # Extend with token 6 for req 0, token 4 for req 1 (so req 1 ends [.., 3, 4]) + accepted_tokens = torch.tensor( + [[6, 0, 0, 0, 0], [4, 0, 0, 0, 0]], + dtype=torch.int32, + device="cuda", + ) + num_accepted_tokens = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + print(f"global cross-request: match_len={match_len}, draft_tokens={draft_tokens}") + + # Request 1's SA has [10, 20, 1, 2, 3, 4]. lookupWithSuffix on + # Request 0's SA [1, 2, 3, 4, 5, 6] processes tokens: + # 10 → no match, 20 → no match, 1 → match(1), 2 → match(2), + # 3 → match(3), 4 → match(4) + # Match [1, 2, 3, 4] (len=4) from req 0 → continuation is [5, 6] + match_len_1 = match_len[1].item() + assert match_len_1 == 4, f"Request 1 should match [1,2,3,4] (len=4), got {match_len_1}" + draft_1 = draft_tokens[1, :max_draft_len].cpu().tolist() + assert draft_1[0] == 5, f"Expected continuation starting with 5, got {draft_1}" + assert draft_1[1] == 6, f"Expected second draft token 6, got {draft_1}" + + manager.shutdown() + + def test_extend_global_prefers_own_slot(self): + """When match lengths are equal, prefer the requesting SA's own slot.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + # Request 0: [1, 2, 3, 100, 1, 2] — has [1, 2] with continuation [3, 100, ...] + manager.add_request(0, [1, 2, 3, 100, 1, 2]) + # Request 1: [1, 2, 3, 200, 1, 2] — has [1, 2] with continuation [3, 200, ...] + manager.add_request(1, [1, 2, 3, 200, 1, 2]) + + request_ids = [0, 1] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + # Extend both with token 3 + accepted_tokens = torch.tensor( + [[3, 0, 0, 0, 0], [3, 0, 0, 0, 0]], + dtype=torch.int32, + device="cuda", + ) + num_accepted_tokens = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + print(f"global prefer-own: match_len={match_len}, draft_tokens={draft_tokens}") + + # Request 0 should prefer its own SA (continuation 100) over request 1's (200) + draft_0 = draft_tokens[0].cpu().tolist() + assert draft_0[0] == 100, f"Request 0 should use own slot continuation (100), got {draft_0}" + + # Request 1 should prefer its own SA (continuation 200) over request 0's (100) + draft_1 = draft_tokens[1].cpu().tolist() + assert draft_1[0] == 200, f"Request 1 should use own slot continuation (200), got {draft_1}" + + manager.shutdown() + + def test_extend_global_no_match(self): + """No match across any SA returns match_len=0.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + manager.add_request(0, [1, 2, 3]) + manager.add_request(1, [4, 5, 6]) + + request_ids = [0, 1] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + # Token 99 doesn't exist in any SA + accepted_tokens = torch.tensor( + [[99, 0, 0, 0, 0], [99, 0, 0, 0, 0]], + dtype=torch.int32, + device="cuda", + ) + num_accepted_tokens = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + print(f"global no-match: match_len={match_len}, draft_tokens={draft_tokens}") + + assert match_len[0].item() == 0, "Request 0 should have no match" + assert match_len[1].item() == 0, "Request 1 should have no match" + draft_0 = draft_tokens[0, :max_draft_len].cpu().tolist() + assert draft_0 == [0] * max_draft_len, f"Expected zeroed draft for request 0, got {draft_0}" + draft_1 = draft_tokens[1, :max_draft_len].cpu().tolist() + assert draft_1 == [0] * max_draft_len, f"Expected zeroed draft for request 1, got {draft_1}" + + manager.shutdown() + + def test_extend_global_active_slot_mask(self): + """Removed requests should not be searchable via the active slot mask.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + # Request 0 has pattern [1, 2, 3, 4, 5] + manager.add_request(0, [1, 2, 3, 4, 5]) + # Request 1 has [10, 20, 1, 2] + manager.add_request(1, [10, 20, 1, 2]) + + # Remove request 0 — its slot mask should be cleared + manager.remove_request(0) + + request_ids = [1] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + # Request 1 extends with token 3: suffix [1, 2, 3] should NOT match + # against removed request 0's SA + accepted_tokens = torch.tensor([[3, 0, 0, 0, 0]], dtype=torch.int32, device="cuda") + num_accepted_tokens = torch.tensor([1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + print(f"global mask test: match_len={match_len}") + + # Req 1's SA is [10, 20, 1, 2, 3] (all unique tokens). lookupWithSuffix + # on its own SA matches the entire sequence (len=5) but pos=4 has no + # continuation (pos+1 == mTokens.size()), so it returns empty. + # Req 0's slot is masked out, so it's never searched. + assert match_len[0].item() == 0, ( + f"Expected match_len=0 (no continuation in own SA, removed slot masked), " + f"got {match_len[0].item()}" + ) + + manager.shutdown() + + def test_extend_global_single_request(self): + """Global search with a single request behaves like local search.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + context_tokens = [0, 1, 2, 1, 2] + manager.add_request(0, context_tokens) + + request_ids = [0] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + accepted_tokens = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.int32, device="cuda") + num_accepted_tokens = torch.tensor([1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + print(f"global single request: match_len={match_len}, draft_tokens={draft_tokens}") + + # Same as local: [0, 1, 2, 1, 2, 1] → longest suffix [1, 2, 1] matches at pos 1-3 + match_len_val = match_len[0].item() + assert match_len_val == 3, f"Expected match_len=3, got {match_len_val}" + assert draft_tokens[0, 0].item() == 2, ( + f"Expected continuation 2, got {draft_tokens[0, 0].item()}" + ) + + manager.shutdown() + + def test_extend_global_cuda_graph(self): + """Test that extend_global works with CUDA graph capture.""" + config = SAConfig(max_seq_len=1024, max_slots=16, enable_global_pool=True) + manager = SuffixAutomatonManager(config, max_num_requests=16) + + manager.add_request(0, [1, 2, 3, 4, 5, 1, 2, 3]) + + request_ids = [0] + max_draft_len = 4 + manager.prepare(request_ids, max_draft_len) + + accepted_tokens = torch.tensor([[4, 0, 0, 0, 0]], dtype=torch.int32, device="cuda") + num_accepted_tokens = torch.tensor([1], dtype=torch.int32, device="cuda") + + # Warmup + for _ in range(3): + manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + # Reset state after warmup + manager.remove_request(0) + manager.add_request(0, [1, 2, 3, 4, 5, 1, 2, 3]) + manager.prepare(request_ids, max_draft_len) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len, + max_ngram_size=-1, + ) + + g.replay() + + print(f"global CUDA graph: match_len={match_len}, draft_tokens={draft_tokens}") + + match_len_val = match_len[0].item() + assert match_len_val >= 1, f"Expected match after CUDA graph replay, got {match_len_val}" + + manager.shutdown() + + +class TestRetainedPool: + """Tests for retained slot pool (completed requests stay searchable).""" + + def test_retained_slot_is_searchable(self): + """Completed request's SA stays searchable by active requests.""" + # pool_size=4 > max_num_requests=2 → retention capacity of 2 + config = SAConfig( + max_seq_len=1024, + max_slots=2, + enable_global_pool=True, + global_pool_size=4, + ) + manager = SuffixAutomatonManager(config, max_num_requests=2) + + # Request A: context has [1, 2, 3, 4, 5] + manager.add_request(0, [1, 2, 3, 4, 5]) + # Request B: context has [10, 20, 30] + manager.add_request(1, [10, 20, 30]) + + # Flush A's state to GPU + manager.prepare([0, 1], max_draft_len=4) + + # Complete request A — should be retained, not freed + manager.remove_request(0) + assert 0 not in manager._request_to_slot + assert len(manager._retained_slots) == 1 + + # Request C arrives, ends with [1, 2, 3] + manager.add_request(2, [50, 60, 1, 2, 3]) + + request_ids = [1, 2] + manager.prepare(request_ids, max_draft_len=4) + + # Extend request C with token 4 — should match retained A's [1,2,3,4] + accepted_tokens = torch.tensor( + [[99, 0, 0, 0, 0], [4, 0, 0, 0, 0]], + dtype=torch.int32, + device="cuda", + ) + num_accepted_tokens = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + + match_len, draft_tokens = manager.extend_global( + request_ids, + accepted_tokens, + num_accepted_tokens, + max_draft_len=4, + max_ngram_size=-1, + ) + + # Request C (index 1) should find a match from retained A + match_len_c = match_len[1].item() + assert match_len_c >= 3, ( + f"Request C should match retained A's pattern (len>=3), got {match_len_c}" + ) + draft_c = draft_tokens[1].cpu().tolist() + assert draft_c[0] == 5, f"Expected continuation token 5 from A, got {draft_c}" + + manager.shutdown() + + def test_eviction_fifo_order(self): + """Oldest retained slot is evicted first when pool is full.""" + # pool_size=4, max_batch=2 → 2 retained slot capacity + config = SAConfig( + max_seq_len=1024, + max_slots=2, + enable_global_pool=True, + global_pool_size=4, + ) + manager = SuffixAutomatonManager(config, max_num_requests=2) + # Initial: free=[0,1,2,3], active={}, retained={} + + manager.add_request(0, [1, 2, 3]) + manager.add_request(1, [4, 5, 6]) + manager.prepare([0, 1], max_draft_len=4) + # free=[0,1], active={2,3}, retained={} (slots allocated from end) + + # Complete A → retained + manager.remove_request(0) + assert len(manager._retained_slots) == 1 + assert len(manager._active_slots) == 1 + + # Complete B → retained + manager.remove_request(1) + assert len(manager._retained_slots) == 2 # A and B both retained + assert len(manager._active_slots) == 0 + + # Requests C and D fill both free slots + manager.add_request(2, [7, 8, 9]) + manager.add_request(3, [10, 11, 12]) + # free=[], active={slot_c, slot_d}, retained={slot_a: 0, slot_b: 1} + assert len(manager._free_slots) == 0 + assert len(manager._active_slots) == 2 + assert len(manager._retained_slots) == 2 + + # Request E arrives — pool full, must evict oldest retained (A) + manager.add_request(4, [13, 14, 15]) + assert len(manager._retained_slots) == 1 + retained_rids = list(manager._retained_slots.values()) + assert retained_rids == [1], f"Expected B (rid=1) retained, got {retained_rids}" + + manager.shutdown() + + def test_active_never_evicted(self): + """Active (in-flight) requests must never be evicted.""" + # pool_size=2 = max_batch=2 → 0 retained capacity → no retention + config = SAConfig( + max_seq_len=1024, + max_slots=2, + enable_global_pool=True, + global_pool_size=2, + ) + manager = SuffixAutomatonManager(config, max_num_requests=2) + + manager.add_request(0, [1, 2, 3]) + manager.add_request(1, [4, 5, 6]) + manager.prepare([0, 1], max_draft_len=4) + + # Complete A — pool_size == max_num_requests, so no retention + manager.remove_request(0) + assert len(manager._retained_slots) == 0 + assert len(manager._free_slots) == 1 + + manager.shutdown() + + def test_no_retention_when_global_pool_disabled(self): + """With global pool off, remove_request always frees immediately.""" + config = SAConfig( + max_seq_len=1024, + max_slots=4, + enable_global_pool=False, + ) + manager = SuffixAutomatonManager(config, max_num_requests=4) + + manager.add_request(0, [1, 2, 3]) + manager.prepare([0], max_draft_len=4) + manager.remove_request(0) + + assert len(manager._retained_slots) == 0 + assert len(manager._free_slots) == 4 + + manager.shutdown() + + def test_stale_request_not_retained(self): + """Request removed before GPU copy is flushed should not be retained.""" + config = SAConfig( + max_seq_len=1024, + max_slots=2, + enable_global_pool=True, + global_pool_size=4, + ) + manager = SuffixAutomatonManager(config, max_num_requests=2) + + # Add but don't prepare (GPU copy still pending) + manager.add_request(0, [1, 2, 3]) + assert 0 in manager._pending_copies + + # Remove before prepare — should NOT be retained (stale GPU data) + manager.remove_request(0) + assert len(manager._retained_slots) == 0 + assert len(manager._free_slots) == 4 # slot returned to free list + + manager.shutdown() + + class TestNativeKernel: """Tests for native kernel.""" @@ -614,10 +1033,27 @@ def test_native_kernel(self): test.test_extend_ngram_batch() test.test_extend_ngram_cuda_graph() + print("\n--- extend_global tests ---") + test = TestExtendGlobal() + test.test_extend_global_cross_request_match() + test.test_extend_global_prefers_own_slot() + test.test_extend_global_no_match() + test.test_extend_global_active_slot_mask() + test.test_extend_global_single_request() + test.test_extend_global_cuda_graph() + print("\n--- CUDA graph compatibility tests ---") test = TestCUDAGraphCompatibility() test.test_cuda_graph_capture() + print("\n--- Retained pool tests ---") + test = TestRetainedPool() + test.test_retained_slot_is_searchable() + test.test_eviction_fifo_order() + test.test_active_never_evicted() + test.test_no_retention_when_global_pool_disabled() + test.test_stale_request_not_retained() + print("\n" + "=" * 60) print("All tests passed!") print("=" * 60) diff --git a/tests/unittest/_torch/speculative/test_sa.py b/tests/unittest/_torch/speculative/test_sa.py index f8c1bb305efe..f077c7277eda 100644 --- a/tests/unittest/_torch/speculative/test_sa.py +++ b/tests/unittest/_torch/speculative/test_sa.py @@ -12,22 +12,6 @@ from utils.llm_data import llm_models_root -def get_perf_metrics(result): - """Extract performance metrics from result using built-in request_perf_metrics.""" - metrics = {} - if result.outputs and result.outputs[0].request_perf_metrics: - perf = result.outputs[0].request_perf_metrics - timing = perf.timing_metrics - # Convert timedelta to seconds - metrics['arrival_time'] = timing.arrival_time.total_seconds() - metrics['first_token_time'] = timing.first_token_time.total_seconds() - metrics['last_token_time'] = timing.last_token_time.total_seconds() - # Calculate TTFT and E2E latency - metrics['ttft'] = metrics['first_token_time'] - metrics['arrival_time'] - metrics['e2e'] = metrics['last_token_time'] - metrics['arrival_time'] - return metrics - - # Test parameter combinations: # - disable_overlap_scheduler: Controls scheduler mode (False=overlap enabled) # - use_cuda_graph: Whether to use CUDA graph capture @@ -52,12 +36,14 @@ def get_perf_metrics(result): @pytest.mark.high_cuda_memory def test_llama_sa(disable_overlap_scheduler: bool, use_cuda_graph: bool, attn_backend: str, max_matching_ngram_size: int): - """Test SA (Suffix Automaton) speculative decoding correctness and acceptance rate. + """Test SA (Suffix Automaton) speculative decoding acceptance rate. Verifies: - 1. Speculative decoding produces identical results to baseline - 2. SA drafting produces draft tokens that get accepted - 3. Multi-token acceptance occurs (acceptanceLength > 1) + 1. SA drafting produces draft tokens that get accepted + 2. Multi-token acceptance occurs (acceptanceLength > 1) + + Output correctness is validated by integration accuracy tests in + tests/integration/defs/accuracy/test_llm_api_pytorch.py. """ total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 20: @@ -97,18 +83,13 @@ def test_llama_sa(disable_overlap_scheduler: bool, use_cuda_graph: bool, "16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, " "34, 35,", ] - # Enable perf metrics collection via return_perf_metrics=True sampling_params = SamplingParams(max_tokens=64, ignore_eos=True, - temperature=0, - return_perf_metrics=True) + temperature=0) - # Run with speculative decoding llm_spec = LLM(**llm_common_config, speculative_config=spec_config) - results_spec = llm_spec.generate(prompts, sampling_params) - generated_text_spec = [result.outputs[0].text for result in results_spec] + llm_spec.generate(prompts, sampling_params) - # Get spec decoding stats before shutdown stats = llm_spec.get_stats(timeout=5) iterations_with_spec = [] for stat in stats: @@ -117,31 +98,9 @@ def test_llama_sa(disable_overlap_scheduler: bool, use_cuda_graph: bool, if spec_stats.get('numDraftTokens', 0) > 0: iterations_with_spec.append(spec_stats) - # Get perf metrics using built-in request_perf_metrics - spec_metrics = get_perf_metrics(results_spec[0]) if results_spec else {} - llm_spec.shutdown() - # Run reference without speculative decoding - llm_ref = LLM(**llm_common_config) - results_ref = llm_ref.generate(prompts, sampling_params) - generated_text_ref = [result.outputs[0].text for result in results_ref] - - # Get perf metrics for reference - ref_metrics = get_perf_metrics(results_ref[0]) if results_ref else {} - - llm_ref.shutdown() - - # Verify 1: Identical results (correctness) - for i, (text_spec, - text_ref) in enumerate(zip(generated_text_spec, - generated_text_ref)): - assert text_spec == text_ref, ( - f"Prompt {i}: Spec decode result differs from baseline.\n" - f"Spec: {text_spec}\nRef: {text_ref}") - print(f"Correctness verified: spec decode matches baseline") - - # Verify 2: Spec decoding stats show drafting occurred + # Verify 1: Spec decoding stats show drafting occurred assert len(iterations_with_spec) > 0, ( f"SA should have iterations with specDecodingStats. " f"Got {len(stats)} total stats but 0 with draft tokens.") @@ -164,7 +123,7 @@ def test_llama_sa(disable_overlap_scheduler: bool, use_cuda_graph: bool, f"SA should accept some draft tokens. " f"Got {total_accepted} accepted out of {total_draft} drafted") - # Verify 3: Multi-token acceptance (acceptanceLength > 1) + # Verify 2: Multi-token acceptance (acceptanceLength > 1) has_multi_token_acceptance = any(s['acceptanceLength'] > 1.0 for s in iterations_with_spec) print(f" Has multi-token acceptance: {has_multi_token_acceptance}") @@ -173,41 +132,6 @@ def test_llama_sa(disable_overlap_scheduler: bool, use_cuda_graph: bool, "Expected at least one iteration with acceptanceLength > 1 " "for repetitive pattern") - # Print performance comparison using built-in metrics - print("\n" + "=" * 70) - print("PERFORMANCE COMPARISON (using request_perf_metrics)") - print("=" * 70) - print( - f"Config: overlap_scheduler={'enabled' if not disable_overlap_scheduler else 'disabled'}, " - f"cuda_graph={'enabled' if use_cuda_graph else 'disabled'}") - print("-" * 70) - print(f"{'Metric':<30} {'Spec Decoding':<20} {'Reference':<20}") - print("-" * 70) - - # Print TTFT (Time to First Token) - ttft_spec = spec_metrics.get('ttft', None) - ttft_ref = ref_metrics.get('ttft', None) - ttft_spec_str = f"{ttft_spec*1000:.2f} ms" if ttft_spec else "N/A" - ttft_ref_str = f"{ttft_ref*1000:.2f} ms" if ttft_ref else "N/A" - print(f"{'TTFT':<30} {ttft_spec_str:<20} {ttft_ref_str:<20}") - - # Print E2E latency - e2e_spec = spec_metrics.get('e2e', None) - e2e_ref = ref_metrics.get('e2e', None) - e2e_spec_str = f"{e2e_spec*1000:.2f} ms" if e2e_spec else "N/A" - e2e_ref_str = f"{e2e_ref*1000:.2f} ms" if e2e_ref else "N/A" - print(f"{'E2E Latency':<30} {e2e_spec_str:<20} {e2e_ref_str:<20}") - - # Calculate and print speedup - if e2e_spec and e2e_ref and e2e_spec > 0: - speedup = e2e_ref / e2e_spec - print("-" * 70) - print(f"{'Speedup (E2E)':<30} {speedup:.2f}x") - print("=" * 70 + "\n") - - # Synchronize CUDA to catch any async memory errors before test completes. - # This ensures errors are attributed to this test rather than propagating - # to subsequent tests. torch.cuda.synchronize() @@ -231,5 +155,111 @@ def test_sa_config_invalid_zero(): ) +def test_sa_config_global_pool(): + """Test SADecodingConfig with enable_global_pool.""" + config = SADecodingConfig( + max_draft_len=4, + enable_global_pool=True, + ) + assert config.enable_global_pool is True + + config_off = SADecodingConfig( + max_draft_len=4, + enable_global_pool=False, + ) + assert config_off.enable_global_pool is False + + # Default should be False + config_default = SADecodingConfig(max_draft_len=4) + assert config_default.enable_global_pool is False + + +@pytest.mark.parametrize("disable_overlap_scheduler,use_cuda_graph", [ + [False, False], + [False, True], +]) +@pytest.mark.high_cuda_memory +def test_llama_sa_global_pool(disable_overlap_scheduler: bool, + use_cuda_graph: bool): + """Test SA speculative decoding with global pool enabled. + + Verifies that SA drafting with global pool produces draft tokens that + get accepted. Output correctness is validated by integration accuracy + tests in tests/integration/defs/accuracy/test_llm_api_pytorch.py. + """ + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 20: + pytest.skip("Not enough memory to load target model") + + print( + f"\nTest config: disable_overlap_scheduler={disable_overlap_scheduler}, " + f"use_cuda_graph={use_cuda_graph}, enable_global_pool=True") + + max_batch_size = 2 + max_draft_len = 4 + kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=8192) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1, 2]) if use_cuda_graph else None + + llm_common_config = dict( + model=llm_models_root() / "llama-3.1-model" / "Meta-Llama-3.1-8B", + backend='pytorch', + attn_backend='TRTLLM', + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + max_num_tokens=2048, + enable_iter_perf_stats=True, + ) + + spec_config = SADecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=-1, + enable_global_pool=True, + ) + + # Use two prompts with similar patterns so global pool can help + prompts = [ + "Count from 1 to 50: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, " + "14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,", + "Count from 1 to 50: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, " + "14, 15, 16, 17, 18, 19, 20, 21, 22, 23,", + ] + sampling_params = SamplingParams(max_tokens=64, + ignore_eos=True, + temperature=0) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + llm_spec.generate(prompts, sampling_params) + + stats = llm_spec.get_stats(timeout=5) + iterations_with_spec = [] + for stat in stats: + if 'specDecodingStats' in stat: + spec_stats = stat['specDecodingStats'] + if spec_stats.get('numDraftTokens', 0) > 0: + iterations_with_spec.append(spec_stats) + + llm_spec.shutdown() + + # Verify 1: Spec decoding stats show drafting occurred + assert len(iterations_with_spec) > 0, ( + "SA global pool should have iterations with specDecodingStats.") + + total_draft = sum(s['numDraftTokens'] for s in iterations_with_spec) + total_accepted = sum(s['numAcceptedTokens'] for s in iterations_with_spec) + + print(f"Global pool spec decoding stats:") + print(f" Iterations with drafting: {len(iterations_with_spec)}") + print(f" Total draft tokens: {total_draft}") + print(f" Total accepted tokens: {total_accepted}") + + assert total_draft > 0, "SA global pool should produce draft tokens" + assert total_accepted > 0, "SA global pool should accept some draft tokens" + + torch.cuda.synchronize() + + if __name__ == "__main__": unittest.main() From 4c97a03c79f994c8fc2ceae4adcef979fb738152 Mon Sep 17 00:00:00 2001 From: Shi Xiaowei <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 3 Apr 2026 03:42:04 +0800 Subject: [PATCH 7/8] [https://nvbugs/5979673][fix] improve NIXL agent import error diagnostics (#12446) --- requirements-dev.txt | 2 +- .../_torch/disaggregation/nixl/_agent_py.py | 47 ++++++++++++------- .../_torch/disaggregation/nixl/agent.py | 30 ++++++++---- .../test_agent_multi_backends.py | 37 +++++++++++++++ 4 files changed, 90 insertions(+), 26 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index c4141694c98b..a22e69c05225 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -39,7 +39,7 @@ opentelemetry-semantic-conventions-ai>=0.4.1 fuzzywuzzy==0.18.0 aiperf==0.6.0 nanobind>=2.9.0 -nixl==0.8.0 +nixl==0.9.0 cupti-python>=13.0,<13.2 nvidia-cuda-cupti>=13.0,<13.2 cxxfilt diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py index cd01f3024a9e..dcfa28210f81 100644 --- a/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py @@ -4,6 +4,7 @@ from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.logger import logger # Import base classes for type compatibility from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus @@ -36,9 +37,11 @@ def wait(self, timeout_ms=None): while status in (TransferState.PENDING, TransferState.PROCESSING): status = TransferState(self.agent.check_xfer_state(self.handle)) if status == TransferState.ERROR: - return False # Transfer failed + logger.error("NIXL transfer entered ERROR state (agent=%s).", self.agent.name) + return False if timeout is not None and (time.time() - start_time > timeout): - return False # Timeout + logger.warning("NIXL transfer wait timed out after %s ms.", timeout_ms) + return False time.sleep(sleep_time) sleep_time = min(sleep_time * 2, max_sleep_time) return status == TransferState.DONE @@ -61,23 +64,25 @@ def __init__(self, name: str, use_prog_thread: bool = True, num_threads: int = 1 ) self.agent = nixl_agent(name, agent_config) - def register_memory(self, descs: RegMemoryDescs): + def _get_validated_reg_descs(self, descs: RegMemoryDescs): if not descs.descs: raise ValueError("descs.descs must not be empty") - if isinstance(descs.descs[0], tuple): - assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}" + if isinstance(descs.descs[0], tuple) and len(descs.descs[0]) != 4: + raise ValueError( + f"Expected 4 elements per desc, got {len(descs.descs[0])}: {descs.descs[0]}" + ) reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) - assert reg_descs is not None, "Failed to get reg_descs" - self.agent.register_memory(reg_descs) + if reg_descs is None: + raise RuntimeError( + f"nixl get_reg_descs returned None for type={descs.type}, count={len(descs.descs)}" + ) + return reg_descs + + def register_memory(self, descs: RegMemoryDescs): + self.agent.register_memory(self._get_validated_reg_descs(descs)) def deregister_memory(self, descs: RegMemoryDescs): - if not descs.descs: - raise ValueError("descs.descs must not be empty") - if isinstance(descs.descs[0], tuple): - assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}" - reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) - assert reg_descs is not None, "Failed to get reg_descs" - self.agent.deregister_memory(reg_descs) + self.agent.deregister_memory(self._get_validated_reg_descs(descs)) def load_remote_agent(self, name: str, agent_desc: bytes): self.agent.add_remote_agent(agent_desc) @@ -97,9 +102,15 @@ def notify_sync_message(self, name: str, sync_message: str): @nvtx_range("NixlTransferAgent.submit_transfer_requests") def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type) + if src_xfer_descs is None: + raise RuntimeError( + f"nixl get_xfer_descs returned None for src type={request.src_descs.type}" + ) dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type) - assert src_xfer_descs is not None, "Failed to get src_xfer_descs" - assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs" + if dst_xfer_descs is None: + raise RuntimeError( + f"nixl get_xfer_descs returned None for dst type={request.dst_descs.type}" + ) sync_message = "" if request.sync_message is None else request.sync_message handle = self.agent.initialize_xfer( request.op, @@ -110,5 +121,7 @@ def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ) status = self.agent.transfer(handle) if status == "ERROR": - raise RuntimeError("NIXL transfer initialization failed.") + raise RuntimeError( + f"NIXL transfer failed: op={request.op}, remote={request.remote_name}" + ) return NixlTransferStatus(self.agent, handle) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/agent.py b/tensorrt_llm/_torch/disaggregation/nixl/agent.py index 5f6f3db1547d..8be3c7c5775e 100644 --- a/tensorrt_llm/_torch/disaggregation/nixl/agent.py +++ b/tensorrt_llm/_torch/disaggregation/nixl/agent.py @@ -13,31 +13,45 @@ """ -def _load_agent(module_name, required_attributes): +def _load_agent( + module_name: str, required_attributes: list[str] +) -> tuple[object, ImportError | None]: try: module = __import__(module_name, fromlist=required_attributes, level=0) if all(hasattr(module, attr) for attr in required_attributes): - return module + return module, None + missing = [a for a in required_attributes if not hasattr(module, a)] + err = ImportError(f"Module {module_name} is missing required attributes: {missing}") + logger.warning("%s", err) + return None, err except ImportError as e: - logger.info("Failed to import module: %s. Error: %s", module_name, str(e)) - return None + logger.warning("Failed to import module: %s. Error: %s", module_name, str(e)) + return None, e NixlTransferStatus, NixlTransferAgent = None, None if use_pure_python_transfer_agent(): - _py_agent = _load_agent( + _py_agent, _py_agent_err = _load_agent( module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py", required_attributes=["NixlTransferAgent", "NixlTransferStatus"], ) - assert _py_agent is not None, "Failed to load pure Python NIXL Transfer Agent." + if _py_agent is None: + raise ImportError( + "Failed to load pure Python NIXL Transfer Agent." + + (f" Caused by: {_py_agent_err}" if _py_agent_err else "") + ) NixlTransferStatus = _py_agent.NixlTransferStatus NixlTransferAgent = _py_agent.NixlTransferAgent else: - _cpp_agent = _load_agent( + _cpp_agent, _cpp_agent_err = _load_agent( module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp", required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"], ) - assert _cpp_agent is not None, "Failed to load C++ NIXL Transfer Agent bindings." + if _cpp_agent is None: + raise ImportError( + "Failed to load C++ NIXL Transfer Agent bindings." + + (f" Caused by: {_cpp_agent_err}" if _cpp_agent_err else "") + ) NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent diff --git a/tests/unittest/disaggregated/test_agent_multi_backends.py b/tests/unittest/disaggregated/test_agent_multi_backends.py index 0a95bad03bc3..5bf1da73d34b 100644 --- a/tests/unittest/disaggregated/test_agent_multi_backends.py +++ b/tests/unittest/disaggregated/test_agent_multi_backends.py @@ -4,6 +4,43 @@ import pytest +def test_load_agent_missing_module(): + """_load_agent returns (None, ImportError) for a non-existent module. + + Regression test: previously a missing nixl package caused an AssertionError + at module import time, making pytest exit with code 2 (collection failure) + instead of a clear ImportError with a descriptive message. + """ + from tensorrt_llm._torch.disaggregation.nixl.agent import _load_agent + + agent, err = _load_agent("_trtllm_nonexistent_module_xyz_", ["SomeClass"]) + assert agent is None + assert isinstance(err, ImportError), f"Expected ImportError, got {type(err)}: {err}" + assert "No module named" in str(err) or "_trtllm_nonexistent_module_xyz_" in str(err) + + +def test_load_agent_missing_attributes(): + """_load_agent returns (None, ImportError) and logs a warning when attributes are missing.""" + from tensorrt_llm._torch.disaggregation.nixl.agent import _load_agent + + # 'os' exists but has no NixlTransferAgent attribute + agent, err = _load_agent("os", ["NixlTransferAgent"]) + assert agent is None + assert isinstance(err, ImportError), f"Expected ImportError, got {type(err)}: {err}" + assert "NixlTransferAgent" in str(err) + + +def test_load_agent_success(): + """_load_agent returns (module, None) on success.""" + from tensorrt_llm._torch.disaggregation.nixl.agent import _load_agent + + agent, err = _load_agent("os", ["path", "getcwd"]) + assert agent is not None + assert err is None + assert hasattr(agent, "path") + assert hasattr(agent, "getcwd") + + @pytest.mark.parametrize("use_py_nixl", ["0", "1"]) def test_run_with_different_env(use_py_nixl): os.environ["TRTLLM_USE_PY_NIXL_KVCACHE"] = use_py_nixl From 7aa781861fc7e2c32a660a6aaa31b019fda39995 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:29:22 -0700 Subject: [PATCH 8/8] [None][feat] Add triton paged attention for AutoDeploy (#12642) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../custom_ops/attention/__init__.py | 3 +- .../attention/triton_paged_attention.py | 1185 +++++++++++++++++ .../defs/accuracy/test_llm_api_autodeploy.py | 8 +- .../test_lists/test-db/l0_b200.yml | 1 + .../test_lists/test-db/l0_h100.yml | 1 + .../attention/test_triton_paged_attention.py | 690 ++++++++++ 6 files changed, 1886 insertions(+), 2 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py create mode 100644 tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py index 438271839738..b615a7fb8ba3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py @@ -22,7 +22,7 @@ - trtllm_attention: TRT-LLM thop.attention-based optimized attention - triton_attention: Triton-based attention implementations - triton_attention_with_kv_cache: Triton attention with KV cache support -- triton_attention_with_paged_kv_cache: Triton attention with paged KV cache +- triton_paged_attention: Triton paged attention (two-stage flash-decode) with HND layout - onnx_attention: Placeholder ops for ONNX export of attention mechanisms """ @@ -34,5 +34,6 @@ "triton_attention", "triton_attention_with_kv_cache", "triton_attention_with_paged_kv_cache", + "triton_paged_attention", "onnx_attention", ] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py new file mode 100644 index 000000000000..b33e70405c2e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py @@ -0,0 +1,1185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton Paged Attention + +This module provides a Triton-based paged attention implementation with: +- Combined KV cache with HND layout: [num_blocks, 2, num_kv_heads, page_size, head_dim] +- Context/prefill kernel with causal masking +""" + +import math +from typing import List, Literal, Optional, Tuple + +import flashinfer +import torch +import triton +import triton.language as tl +from torch._ops import OpOverloadPacket +from torch._subclasses import FakeTensor +from torch.fx import Node + +from tensorrt_llm.llmapi.llm_args import KvCacheConfig + +from ...utils.logger import ad_logger +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + Constant, + KVPagedResourceHandler, + MHACallable, + PrepareMetadataCallable, + ResourceHandlerDict, +) + +KV_LAYOUT: Literal["HND", "NHD"] = "HND" + +# Cache SM count to avoid repeated get_device_properties calls +_NUM_SMS: Optional[int] = None + + +def _get_num_sms() -> int: + """Get the number of SMs on the current GPU (cached).""" + global _NUM_SMS + if _NUM_SMS is None: + _NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count + return _NUM_SMS + + +def _get_sm_scale(head_dim: int, scale: Optional[float]) -> float: + """Get softmax scale, computing default if not provided.""" + return scale if scale is not None else 1.0 / math.sqrt(head_dim) + + +@triton.jit +def _update_paged_kv_cache_kernel( + # Input K, V + k_ptr, + v_ptr, + # Metadata + batch_indices_ptr, + positions_ptr, + # KV cache + kv_cache_ptr, + # Page table + kv_indices_ptr, + kv_indptr_ptr, + # Constants + NUM_TOKENS: tl.constexpr, + N_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + # Strides for kv_cache: [num_blocks, 2, num_kv_heads, page_size, head_dim] + cache_stride_block: tl.constexpr, + cache_stride_kv: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_token: tl.constexpr, +): + """Update combined KV cache with new tokens.""" + token_id = tl.program_id(axis=0) + head_id = tl.program_id(axis=1) + + if token_id >= NUM_TOKENS: + return + + batch_idx = tl.load(batch_indices_ptr + token_id) + position = tl.load(positions_ptr + token_id) + + page_idx_in_seq = position // PAGE_SIZE + offset_in_page = position % PAGE_SIZE + + page_start = tl.load(kv_indptr_ptr + batch_idx) + physical_page = tl.load(kv_indices_ptr + page_start + page_idx_in_seq) + + head_offsets = tl.arange(0, HEAD_DIM) + kv_offset = token_id * N_KV_HEADS * HEAD_DIM + head_id * HEAD_DIM + head_offsets + + k = tl.load(k_ptr + kv_offset) + v = tl.load(v_ptr + kv_offset) + + # Compute cache offset (use int64 to avoid overflow when physical_page * stride > 2^31) + cache_base = ( + physical_page.to(tl.int64) * cache_stride_block + + head_id * cache_stride_head + + offset_in_page.to(tl.int64) * cache_stride_token + + head_offsets + ) + + tl.store(kv_cache_ptr + cache_base, k) + tl.store(kv_cache_ptr + cache_base + cache_stride_kv, v) + + +def update_paged_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, +) -> None: + """Update the combined paged KV cache with new K, V tensors.""" + num_tokens, n_kv_heads, head_dim = k.shape + page_size = kv_cache.shape[3] + + if num_tokens == 0: + return + + grid = (num_tokens, n_kv_heads) + _update_paged_kv_cache_kernel[grid]( + k, + v, + batch_indices, + positions, + kv_cache, + kv_indices, + kv_indptr, + NUM_TOKENS=num_tokens, + N_KV_HEADS=n_kv_heads, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + cache_stride_block=kv_cache.stride(0), + cache_stride_kv=kv_cache.stride(1), + cache_stride_head=kv_cache.stride(2), + cache_stride_token=kv_cache.stride(3), + ) + + +# ============================================================================= +# FLASH DECODE - HELPERS +# ============================================================================= + + +def _get_num_splits(max_seq_len: int, batch_size: int, n_kv_heads: int, page_size: int) -> int: + """Compute optimal number of KV splits for FlashDecoding. + + With GQA batching, the grid is (batch, n_kv_heads, num_splits). + We want enough blocks to saturate the GPU. + """ + if max_seq_len <= 0: + return 1 + + num_sms = _get_num_sms() + existing_parallelism = batch_size * n_kv_heads + + # Already enough parallelism + if existing_parallelism >= num_sms * 2: + return 1 + + # Target ~4 waves of thread blocks + target_blocks = num_sms * 4 + num_splits = max(1, (target_blocks + existing_parallelism - 1) // existing_parallelism) + + # Cap splits so each block has at least 2 pages of work. + # With fewer pages, the per-block overhead (Q load, accumulator init, + # partial_o/lse store, plus stage2 reduction cost) dominates the useful + # compute (page-loop iterations). 2 pages is a conservative lower bound + # to keep the overhead-to-work ratio acceptable. + max_pages = max_seq_len // page_size + max_splits = max(1, max_pages // 2) + num_splits = min(num_splits, max_splits) + + # Round to next power of 2 for Triton compile caching + if num_splits > 1: + num_splits = 2 ** math.ceil(math.log2(num_splits)) + + return min(num_splits, 128) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=2, num_stages=2), + triton.Config({}, num_warps=2, num_stages=3), + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["HEAD_DIM", "PAGE_SIZE", "HEAD_RATIO_PADDED"], +) +@triton.jit +def _flash_decode_stage1_kernel( + # Query input + q_ptr, + # KV cache (combined) + kv_cache_ptr, + # Page table + kv_indices_ptr, + kv_indptr_ptr, + kv_last_page_len_ptr, + # Intermediate outputs + partial_o_ptr, + partial_lse_ptr, + # Q strides: [batch, n_heads, head_dim] + q_stride_batch: tl.constexpr, + q_stride_head: tl.constexpr, + # Partial output strides: [batch, n_heads, num_splits, head_dim] + po_stride_batch: tl.constexpr, + po_stride_head: tl.constexpr, + po_stride_split: tl.constexpr, + # Partial LSE strides: [batch, n_heads, num_splits] + plse_stride_batch: tl.constexpr, + plse_stride_head: tl.constexpr, + plse_stride_split: tl.constexpr, + # Cache strides: [num_blocks, 2, n_kv_heads, page_size, head_dim] + cache_stride_block: tl.constexpr, + cache_stride_kv: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_token: tl.constexpr, + # Constants + SM_SCALE: tl.constexpr, + N_HEADS: tl.constexpr, + N_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_RATIO: tl.constexpr, + HEAD_RATIO_PADDED: tl.constexpr, + NUM_SPLITS: tl.constexpr, +): + """ + Key optimizations: + - Loads KV once for HEAD_RATIO Q heads + - Iterates by page for contiguous memory access + - Splits KV sequence across blocks for GPU utilization + """ + batch_id = tl.program_id(axis=0) + kv_head_id = tl.program_id(axis=1) + split_id = tl.program_id(axis=2) + + # Get sequence info from page table + kv_page_start = tl.load(kv_indptr_ptr + batch_id) + kv_page_end = tl.load(kv_indptr_ptr + batch_id + 1) + num_pages = kv_page_end - kv_page_start + last_page_len = tl.load(kv_last_page_len_ptr + batch_id) + + # Compute this split's page range (page-aligned splits) + pages_per_split = (num_pages + NUM_SPLITS - 1) // NUM_SPLITS + page_split_start = split_id * pages_per_split + page_split_end = tl.minimum(page_split_start + pages_per_split, num_pages) + + dhead_offsets = tl.arange(0, HEAD_DIM) + # Use padded range for Triton power-of-2 requirement; mask out-of-bounds heads + head_local = tl.arange(0, HEAD_RATIO_PADDED) + head_ids = kv_head_id * HEAD_RATIO + head_local + head_mask = head_local < HEAD_RATIO + + # Handle inactive splits (beyond the sequence's pages) + if page_split_start >= num_pages: + # Store zeros + -inf LSE for valid HEAD_RATIO Q heads only + po_offsets = ( + batch_id * po_stride_batch + + head_ids[:, None] * po_stride_head + + split_id * po_stride_split + + dhead_offsets[None, :] + ) + tl.store( + partial_o_ptr + po_offsets, + tl.zeros([HEAD_RATIO_PADDED, HEAD_DIM], dtype=tl.float32), + mask=head_mask[:, None], + ) + plse_offsets = ( + batch_id * plse_stride_batch + + head_ids * plse_stride_head + + split_id * plse_stride_split + ) + tl.store( + partial_lse_ptr + plse_offsets, + tl.zeros([HEAD_RATIO_PADDED], dtype=tl.float32) + float("-inf"), + mask=head_mask, + ) + return + + # Load Q for HEAD_RATIO heads sharing this KV head: [HEAD_RATIO_PADDED, HEAD_DIM] + # Padded rows get zeros, producing zero attention scores (harmless, never stored) + q_offsets = ( + batch_id * q_stride_batch + head_ids[:, None] * q_stride_head + dhead_offsets[None, :] + ) + q_all = tl.load(q_ptr + q_offsets, mask=head_mask[:, None], other=0.0) + + acc = tl.zeros([HEAD_RATIO_PADDED, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([HEAD_RATIO_PADDED], dtype=tl.float32) + float("-inf") + l_i = tl.zeros([HEAD_RATIO_PADDED], dtype=tl.float32) + + num_pages_this_split = page_split_end - page_split_start + for local_page_idx in range(num_pages_this_split): + page_idx = page_split_start + local_page_idx + physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) + + # Determine valid tokens in this page + is_last_page_of_seq = page_idx == (num_pages - 1) + valid_tokens = tl.where(is_last_page_of_seq, last_page_len, PAGE_SIZE) + + page_offsets = tl.arange(0, PAGE_SIZE) + page_mask = page_offsets < valid_tokens + + # Compute cache offset (use int64 to avoid overflow when physical_page * stride > 2^31) + cache_base = ( + physical_page.to(tl.int64) * cache_stride_block + + kv_head_id * cache_stride_head + + page_offsets[:, None] * cache_stride_token + + dhead_offsets[None, :] + ) + page_mask_2d = page_mask[:, None] + + k = tl.load( + kv_cache_ptr + cache_base, mask=page_mask_2d, other=0.0 + ) # [PAGE_SIZE, HEAD_DIM] + v = tl.load( + kv_cache_ptr + cache_base + cache_stride_kv, + mask=page_mask_2d, + other=0.0, + ) # [PAGE_SIZE, HEAD_DIM] + + # [HEAD_RATIO_PADDED, HEAD_DIM] @ [HEAD_DIM, PAGE_SIZE] -> [HEAD_RATIO_PADDED, PAGE_SIZE] + attn = tl.dot(q_all, tl.trans(k)) * SM_SCALE + attn = tl.where(page_mask[None, :], attn, float("-inf")) + + # Online softmax update (vectorized over HEAD_RATIO_PADDED) + m_ij = tl.max(attn, axis=1) # [HEAD_RATIO_PADDED] + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(attn - m_i_new[:, None]) # [HEAD_RATIO_PADDED, PAGE_SIZE] + + # [HEAD_RATIO_PADDED, PAGE_SIZE] @ [PAGE_SIZE, HEAD_DIM] -> [HEAD_RATIO_PADDED, HEAD_DIM] + acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_i_new + + # Finalize: normalize and compute LSE + l_i_safe = tl.where(l_i == 0.0, 1.0, l_i) + partial_o_val = acc / l_i_safe[:, None] # [HEAD_RATIO_PADDED, HEAD_DIM] + lse_val = m_i + tl.log(l_i_safe) # [HEAD_RATIO_PADDED] + + # Store results for valid HEAD_RATIO Q heads only (masked 2D store) + po_offsets = ( + batch_id * po_stride_batch + + head_ids[:, None] * po_stride_head + + split_id * po_stride_split + + dhead_offsets[None, :] + ) + tl.store(partial_o_ptr + po_offsets, partial_o_val, mask=head_mask[:, None]) + + plse_offsets = ( + batch_id * plse_stride_batch + head_ids * plse_stride_head + split_id * plse_stride_split + ) + tl.store(partial_lse_ptr + plse_offsets, lse_val, mask=head_mask) + + +@triton.jit +def _flash_decode_stage2_kernel( + # Partial results + partial_o_ptr, + partial_lse_ptr, + # Final output + o_ptr, + # Partial output strides: [batch, n_heads, num_splits, head_dim] + po_stride_batch: tl.constexpr, + po_stride_head: tl.constexpr, + po_stride_split: tl.constexpr, + # Partial LSE strides: [batch, n_heads, num_splits] + plse_stride_batch: tl.constexpr, + plse_stride_head: tl.constexpr, + plse_stride_split: tl.constexpr, + # Output strides: [batch, n_heads, head_dim] + o_stride_batch: tl.constexpr, + o_stride_head: tl.constexpr, + # Constants + HEAD_DIM: tl.constexpr, + NUM_SPLITS: tl.constexpr, +): + """ + Each program combines results from all splits for one (batch, head) pair. + """ + batch_id = tl.program_id(axis=0) + head_id = tl.program_id(axis=1) + + dhead_offsets = tl.arange(0, HEAD_DIM) + + # Find global maximum LSE across splits for numerical stability + global_max_lse = float("-inf") + for split_id in range(NUM_SPLITS): + plse_offset = ( + batch_id * plse_stride_batch + head_id * plse_stride_head + split_id * plse_stride_split + ) + lse = tl.load(partial_lse_ptr + plse_offset) + global_max_lse = tl.maximum(global_max_lse, lse) + + # Guard: if all splits had -inf LSE (empty sequence), output zeros + o_offset = batch_id * o_stride_batch + head_id * o_stride_head + dhead_offsets + if global_max_lse == float("-inf"): + tl.store(o_ptr + o_offset, tl.zeros([HEAD_DIM], dtype=tl.float32)) + return + + # Weighted combination: weight_i = exp(lse_i - global_max) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + total_weight = 0.0 + + for split_id in range(NUM_SPLITS): + plse_offset = ( + batch_id * plse_stride_batch + head_id * plse_stride_head + split_id * plse_stride_split + ) + lse = tl.load(partial_lse_ptr + plse_offset) + weight = tl.exp(lse - global_max_lse) + + po_base = batch_id * po_stride_batch + head_id * po_stride_head + split_id * po_stride_split + partial_o = tl.load(partial_o_ptr + po_base + dhead_offsets) + + acc += weight * partial_o + total_weight += weight + + # Normalize and store + total_weight = tl.where(total_weight == 0.0, 1.0, total_weight) + o = acc / total_weight + tl.store(o_ptr + o_offset, o) + + +def triton_paged_decode( + q: torch.Tensor, + kv_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + sm_scale: float, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Optimized paged decode with GQA batching + FlashDecoding + page-aligned iteration. + + Args: + q: Query tensor [batch_size, n_heads, head_dim] + kv_cache: Combined cache [num_blocks, 2, n_kv_heads, page_size, head_dim] + kv_indices: Physical page indices (flattened) + kv_indptr: Cumulative page counts [batch_size + 1] + kv_last_page_len: Valid tokens in last page [batch_size] + sm_scale: Softmax scale factor + out: Optional output tensor [batch_size, n_heads, head_dim] + + Returns: + Output tensor [batch_size, n_heads, head_dim] + """ + batch_size, n_heads, head_dim = q.shape + _, _, n_kv_heads, page_size, _ = kv_cache.shape + head_ratio = n_heads // n_kv_heads + head_ratio_padded = max(1, 2 ** math.ceil(math.log2(head_ratio))) if head_ratio > 1 else 1 + + max_pages = kv_indices.shape[0] + max_seq_len = max_pages * page_size + + output = out if out is not None else torch.empty_like(q) + + if batch_size == 0: + return output + + num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size) + + # Allocate intermediate buffers for split-K + partial_o = torch.empty( + batch_size, + n_heads, + num_splits, + head_dim, + dtype=torch.float32, + device=q.device, + ) + partial_lse = torch.empty( + batch_size, + n_heads, + num_splits, + dtype=torch.float32, + device=q.device, + ) + + # Stage 1: GQA-batched parallel KV processing + _flash_decode_stage1_kernel[(batch_size, n_kv_heads, num_splits)]( + q, + kv_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + partial_o, + partial_lse, + # Q strides + q.stride(0), + q.stride(1), + # Partial output strides + partial_o.stride(0), + partial_o.stride(1), + partial_o.stride(2), + # Partial LSE strides + partial_lse.stride(0), + partial_lse.stride(1), + partial_lse.stride(2), + # Cache strides + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + kv_cache.stride(3), + # Constants + SM_SCALE=sm_scale, + N_HEADS=n_heads, + N_KV_HEADS=n_kv_heads, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + HEAD_RATIO=head_ratio, + HEAD_RATIO_PADDED=head_ratio_padded, + NUM_SPLITS=num_splits, + ) + + # Stage 2: Combine partial results + _flash_decode_stage2_kernel[(batch_size, n_heads)]( + partial_o, + partial_lse, + output, + # Partial output strides + partial_o.stride(0), + partial_o.stride(1), + partial_o.stride(2), + # Partial LSE strides + partial_lse.stride(0), + partial_lse.stride(1), + partial_lse.stride(2), + # Output strides + output.stride(0), + output.stride(1), + # Constants + HEAD_DIM=head_dim, + NUM_SPLITS=num_splits, + ) + + return output + + +# ============================================================================= +# TRITON KERNELS - CONTEXT/PREFILL (page-aligned, causal skip, autotuned) +# ============================================================================= +@triton.autotune( + configs=[ + triton.Config({"Q_BLOCK": 64}, num_stages=2, num_warps=2), + triton.Config({"Q_BLOCK": 64}, num_stages=2, num_warps=4), + triton.Config({"Q_BLOCK": 64}, num_stages=4, num_warps=4), + triton.Config({"Q_BLOCK": 128}, num_stages=2, num_warps=4), + triton.Config({"Q_BLOCK": 128}, num_stages=2, num_warps=8), + triton.Config({"Q_BLOCK": 128}, num_stages=3, num_warps=8), + ], + key=["HEAD_DIM", "PAGE_SIZE"], +) +@triton.jit +def _paged_context_kernel( + # Inputs + q_ptr, + kv_cache_ptr, + # Metadata + qo_indptr_ptr, + kv_indptr_ptr, + kv_indices_ptr, + kv_last_page_len_ptr, + seq_len_with_cache_ptr, + # Output + o_ptr, + # Strides + q_stride_token: tl.constexpr, + q_stride_head: tl.constexpr, + o_stride_token: tl.constexpr, + o_stride_head: tl.constexpr, + cache_stride_block: tl.constexpr, + cache_stride_kv: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_token: tl.constexpr, + # Autotuned + Q_BLOCK: tl.constexpr, + # Constants + SM_SCALE: tl.constexpr, + N_HEADS: tl.constexpr, + N_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Context/prefill attention with paged KV cache, causal skip, and page-aligned iteration. + + Grid: (num_seq, n_heads, num_q_blocks) + + Optimizations: + - Page-aligned iteration: 1 scalar page table load per page, no div/mod, + contiguous KV memory access within each page. + - Causal skip: pages entirely beyond the last Q position are skipped, + saving ~50% of KV loads on average for causal attention. + - Autotuned Q_BLOCK, num_stages, num_warps for best tile/pipeline config. + """ + batch_id = tl.program_id(axis=0) + head_id = tl.program_id(axis=1) + q_block_id = tl.program_id(axis=2) + + HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS + kv_head_id = head_id // HEAD_RATIO + + q_start = tl.load(qo_indptr_ptr + batch_id) + q_end = tl.load(qo_indptr_ptr + batch_id + 1) + q_len = q_end - q_start + + kv_page_start = tl.load(kv_indptr_ptr + batch_id) + kv_page_end = tl.load(kv_indptr_ptr + batch_id + 1) + num_kv_pages = kv_page_end - kv_page_start + total_kv_len = tl.load(seq_len_with_cache_ptr + batch_id) + + cache_len = total_kv_len - q_len + + q_block_start = q_block_id * Q_BLOCK + q_offsets = q_block_start + tl.arange(0, Q_BLOCK) + q_mask = q_offsets < q_len + + if tl.sum(q_mask.to(tl.int32)) == 0: + return + + dhead_offsets = tl.arange(0, HEAD_DIM) + q_load_offsets = ( + (q_start + q_offsets[:, None]) * q_stride_token + + head_id * q_stride_head + + dhead_offsets[None, :] + ) + q_load_mask = q_mask[:, None] + q = tl.load(q_ptr + q_load_offsets, mask=q_load_mask, other=0.0) + + acc = tl.zeros([Q_BLOCK, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([Q_BLOCK], dtype=tl.float32) - float("inf") + l_i = tl.zeros([Q_BLOCK], dtype=tl.float32) + + page_offsets = tl.arange(0, PAGE_SIZE) + + # Two-phase page loop: + # Phase 1 (full pages): pages entirely before the first Q position need no causal mask. + # First Q position in KV coords = q_block_start + cache_len. + # A page ending at (page_idx+1)*PAGE_SIZE - 1 is fully attended if that's <= first Q pos. + # Phase 2 (boundary pages): remaining pages up to last Q position need causal masking. + first_q_kv_pos = q_block_start + cache_len + max_q_pos = q_block_start + Q_BLOCK - 1 + cache_len + + # Number of full pages (all tokens in these pages are attended by all Q tokens) + num_full_pages = first_q_kv_pos // PAGE_SIZE + + # Check if this is a full Q block (no q_mask needed) + is_full_q_block = (q_block_start + Q_BLOCK) <= q_len + + # Phase 1: Full pages — no causal mask, no validity mask + # Process one page at a time with a clean inner loop + kv_head_offset = kv_head_id * cache_stride_head + local_kv = page_offsets[:, None] * cache_stride_token + dhead_offsets[None, :] + + for page_idx in range(num_full_pages): + physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) + + # Use int64 to avoid overflow when physical_page * stride > 2^31 + page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset + k_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=kv_cache_ptr + page_base + cache_stride_kv, + shape=(PAGE_SIZE, HEAD_DIM), + strides=(cache_stride_token, 1), + offsets=(0, 0), + block_shape=(PAGE_SIZE, HEAD_DIM), + order=(1, 0), + ) + k = tl.load(k_block_ptr) + v = tl.load(v_block_ptr) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + + if not is_full_q_block: + qk = tl.where(q_mask[:, None], qk, float("-inf")) + + m_ij = tl.max(qk, axis=1) + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) + acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_i_new + + # Phase 2: Boundary pages — need causal mask and validity mask + # Pre-compute q_positions outside loop (invariant across pages) + q_positions_2d = q_offsets[:, None] + cache_len + + for page_idx in range(num_full_pages, num_kv_pages): + kv_base_pos = page_idx * PAGE_SIZE + + # Causal skip: if entire page is beyond last Q position, skip it. + if kv_base_pos <= max_q_pos: + physical_page = tl.load(kv_indices_ptr + kv_page_start + page_idx) + valid_tokens = tl.minimum(PAGE_SIZE, total_kv_len - kv_base_pos) + page_mask = page_offsets < valid_tokens + + # Use int64 to avoid overflow when physical_page * stride > 2^31 + page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset + page_mask_2d = page_mask[:, None] + k = tl.load(kv_cache_ptr + page_base + local_kv, mask=page_mask_2d, other=0.0) + v = tl.load( + kv_cache_ptr + page_base + local_kv + cache_stride_kv, + mask=page_mask_2d, + other=0.0, + ) + + qk = tl.dot(q, tl.trans(k)) * SM_SCALE + kv_positions = kv_base_pos + page_offsets[None, :] + causal_mask = q_positions_2d >= kv_positions + full_mask = q_mask[:, None] & causal_mask & page_mask[None, :] + qk = tl.where(full_mask, qk, float("-inf")) + + m_ij = tl.max(qk, axis=1) + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + p = tl.exp(qk - m_i_new[:, None]) + acc = tl.dot(p.to(v.dtype), v, acc=acc * alpha[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_i_new + + l_i = tl.where(l_i == 0.0, 1.0, l_i) + o = acc / l_i[:, None] + o_store_offsets = ( + (q_start + q_offsets[:, None]) * o_stride_token + + head_id * o_stride_head + + dhead_offsets[None, :] + ) + tl.store(o_ptr + o_store_offsets, o, mask=q_load_mask) + + +@triton.jit +def _fast_gather_sdpa_kernel( + kv_cache_ptr, + kv_indices_ptr, + out_k_ptr, + out_v_ptr, + # Strides + cache_stride_block: tl.constexpr, + cache_stride_kv: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_token: tl.constexpr, + out_stride_seq: tl.constexpr, + out_stride_head: tl.constexpr, + out_stride_token: tl.constexpr, + # Constants + MAX_PAGES: tl.constexpr, + N_KV_HEADS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + """Gather scattered pages into separate K, V buffers in SDPA layout. + + Grid: (total_pages, N_KV_HEADS) + Each program copies one page for one KV head into contiguous K and V + outputs shaped [num_seq, n_kv_heads, max_kv_len, head_dim]. + No precomputed mapping needed — seq_id and local_page computed from global index. + """ + page_global_idx = tl.program_id(0) + kv_head_id = tl.program_id(1) + + # Compute seq_id and local_page from global page index + seq_id = page_global_idx // MAX_PAGES + local_page = page_global_idx % MAX_PAGES + + physical_page = tl.load(kv_indices_ptr + page_global_idx) + + token_offsets = tl.arange(0, PAGE_SIZE) + head_offsets = tl.arange(0, HEAD_DIM) + + # Source: kv_cache[physical_page, 0/1, kv_head_id, :, :] + src_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_id * cache_stride_head + src_offsets = token_offsets[:, None] * cache_stride_token + head_offsets[None, :] + + k_data = tl.load(kv_cache_ptr + src_base + src_offsets) + v_data = tl.load(kv_cache_ptr + src_base + cache_stride_kv + src_offsets) + + # Destination: out_k/v[seq_id, kv_head_id, local_page*PAGE_SIZE + :, :] + local_token_start = local_page * PAGE_SIZE + dst_base = ( + seq_id * out_stride_seq + + kv_head_id * out_stride_head + + (local_token_start + token_offsets[:, None]) * out_stride_token + + head_offsets[None, :] + ) + + tl.store(out_k_ptr + dst_base, k_data) + tl.store(out_v_ptr + dst_base, v_data) + + +def triton_paged_context( + q: torch.Tensor, + kv_cache: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + kv_last_page_len: torch.Tensor, + seq_len_with_cache: torch.Tensor, + sm_scale: float, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Context/prefill attention with paged KV cache.""" + total_tokens, n_heads, head_dim = q.shape + _, _, n_kv_heads, page_size, _ = kv_cache.shape + num_seq = qo_indptr.shape[0] - 1 + + output = out if out is not None else torch.empty_like(q) + + if num_seq == 0 or total_tokens == 0: + return output + + # Compute max_q_len without GPU sync for single-sequence batches (most common + # in serving). For multi-sequence batches, we must use .item() because + # total_tokens // num_seq gives the average, not the max — variable-length + # sequences can produce wrong results (under-launched Q blocks or wrong SDPA reshape). + if num_seq == 1: + max_q_len = total_tokens + else: + q_lens = qo_indptr[1:] - qo_indptr[:-1] + max_q_len = int(q_lens.max().item()) + + # Adaptive dispatch: gather + cuDNN SDPA for seq>=512 (outperforms paged kernel), + # paged Triton kernel for shorter sequences where gather overhead dominates. + # Compute max_pages from max_q_len without GPU sync + # (assumes pure prefill where q_len == kv_len for each seq) + max_pages = (max_q_len + page_size - 1) // page_size + total_expected_pages = num_seq * max_pages + use_sdpa = ( + max_q_len >= 512 + and num_seq <= 64 + and max_pages > 0 + and kv_indices.shape[0] == total_expected_pages # all seqs same page count + ) + + if use_sdpa: + # Fast Triton gather: scattered pages → separate K, V in SDPA layout + # Single alloc for both K and V, single kernel to fill + max_kv_len = max_pages * page_size + kv_buf = torch.empty( + 2, + num_seq, + n_kv_heads, + max_kv_len, + head_dim, + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + k_sdpa = kv_buf[0] + v_sdpa = kv_buf[1] + _fast_gather_sdpa_kernel[(total_expected_pages, n_kv_heads)]( + kv_cache, + kv_indices, + k_sdpa, + v_sdpa, + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + kv_cache.stride(3), + k_sdpa.stride(0), + k_sdpa.stride(1), + k_sdpa.stride(2), + MAX_PAGES=max_pages, + N_KV_HEADS=n_kv_heads, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + ) + + # SDPA with GQA + o_sdpa = torch.nn.functional.scaled_dot_product_attention( + q.view(num_seq, max_q_len, n_heads, head_dim).transpose(1, 2), + k_sdpa, + v_sdpa, + scale=sm_scale, + is_causal=True, + enable_gqa=True, + ) + output.view(num_seq, max_q_len, n_heads, head_dim).copy_(o_sdpa.permute(0, 2, 1, 3)) + else: + # Use paged kernel (better for small workloads) + def grid_paged(meta): + q_block = meta["Q_BLOCK"] + num_q_blocks = (max_q_len + q_block - 1) // q_block + return (num_seq, n_heads, num_q_blocks) + + _paged_context_kernel[grid_paged]( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + seq_len_with_cache, + output, + q.stride(0), + q.stride(1), + output.stride(0), + output.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + kv_cache.stride(3), + SM_SCALE=sm_scale, + N_HEADS=n_heads, + N_KV_HEADS=n_kv_heads, + HEAD_DIM=head_dim, + PAGE_SIZE=page_size, + ) + + return output + + +@torch.library.custom_op("auto_deploy::triton_paged_prepare_metadata", mutates_args=()) +def prepare_triton_paged_metadata( + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + seq_len_with_cache: torch.Tensor, +) -> List[torch.Tensor]: + """Prepare metadata for Triton paged attention.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo + + batch_info = BatchInfo(batch_info_host) + num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() + num_seq = num_prefill + num_decode + num_tokens = num_prefill_tokens + num_decode + + qo_indptr = cu_seqlen[: num_seq + 1] + + batch_indices, positions = flashinfer.get_batch_indices_positions( + qo_indptr, seq_len_with_cache[:num_seq], num_tokens + ) + + return batch_indices, positions + + +@prepare_triton_paged_metadata.register_fake +def prepare_triton_paged_metadata_fake( + position_ids: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen: torch.Tensor, + seq_len_with_cache: torch.Tensor, +): + num_tokens = position_ids.shape[0] * position_ids.shape[1] + return ( + torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), + torch.empty(num_tokens, dtype=torch.int32, device=position_ids.device), + ) + + +@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache",)) +def triton_paged_mha_with_cache( + # Q, K, V + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + # STANDARD METADATA + batch_info_host: torch.Tensor, + cu_seqlen_host: torch.Tensor, + cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc: torch.Tensor, + last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + # EXTRA METADATA + triton_batch_indices: torch.Tensor, + triton_positions: torch.Tensor, + # CACHES - combined KV cache + kv_cache: torch.Tensor, + # CONSTANTS + scale: Optional[float], +) -> torch.Tensor: + """Triton paged attention with mixed batch support.""" + head_dim = kv_cache.shape[-1] + q_shape_og = q.shape + b, s = q_shape_og[:2] + + q = q.reshape(b * s, -1, head_dim).contiguous() + k = k.reshape(b * s, -1, head_dim).contiguous() + v = v.reshape(b * s, -1, head_dim).contiguous() + + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo + + batch_info = BatchInfo(batch_info_host) + num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() + num_seq = num_prefill + num_decode + num_total_tokens = num_prefill_tokens + num_decode + + sm_scale = _get_sm_scale(head_dim, scale) + + # Update KV cache with new tokens + update_paged_kv_cache( + k[:num_total_tokens], + v[:num_total_tokens], + triton_batch_indices[:num_total_tokens], + triton_positions[:num_total_tokens], + kv_cache, + cache_loc, + cu_num_pages[: num_seq + 1], + ) + + y = torch.empty_like(q) + + # Process prefill tokens if any + if num_prefill > 0: + cu_seqlen = cu_seqlen_host[: num_prefill + 1].to(q.device, non_blocking=True) + seq_len_with_cache = seq_len_with_cache_host[:num_prefill].to(q.device, non_blocking=True) + triton_paged_context( + q[:num_prefill_tokens], + kv_cache, + cu_seqlen, + cu_num_pages[: num_prefill + 1], + cache_loc, + last_page_len[:num_prefill], + seq_len_with_cache, + sm_scale, + out=y[:num_prefill_tokens], + ) + + # Process decode tokens if any + if num_decode > 0: + triton_paged_decode( + q[num_prefill_tokens:num_total_tokens], + kv_cache, + cache_loc, + cu_num_pages[num_prefill : num_seq + 1], + last_page_len[num_prefill:num_seq], + sm_scale, + out=y[num_prefill_tokens:num_total_tokens], + ) + + return y.view(q_shape_og) + + +@triton_paged_mha_with_cache.register_fake +def triton_paged_mha_with_cache_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + batch_info_host: torch.Tensor, + cu_seqlen_host: torch.Tensor, + cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc: torch.Tensor, + last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + triton_batch_indices: torch.Tensor, + triton_positions: torch.Tensor, + kv_cache: torch.Tensor, + scale: Optional[float], +) -> torch.Tensor: + return torch.empty_like(q.contiguous()) + + +@AttentionRegistry.register("triton_paged") +class TritonPagedAttention(AttentionDescriptor): + """Descriptor for Triton Paged Attention backend. + + Optimized with GQA head batching, FlashDecoding, and page-aligned iteration. + """ + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + return 3 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + return torch.ops.auto_deploy.torch_attention + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return torch.ops.auto_deploy.triton_paged_mha_with_cache.default + + @classmethod + def get_standard_metadata_args(cls) -> List[str]: + return [ + "batch_info_host", + "cu_seqlen_host", + "cu_num_pages", + "cu_num_pages_host", + "cache_loc", + "last_page_len", + "last_page_len_host", + "seq_len_with_cache_host", + ] + + @classmethod + def get_prepare_extra_metadata_info( + cls, any_source_attn_node: Node + ) -> Tuple[Optional[PrepareMetadataCallable], int, List[Constant]]: + return ( + torch.ops.auto_deploy.triton_paged_prepare_metadata.default, + 2, + [], + ) + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: KvCacheConfig + ) -> ResourceHandlerDict: + k_fake: FakeTensor = source_attn_node.args[1].meta["val"] + num_kv_heads = k_fake.shape[2] + head_dim = k_fake.shape[3] + + return { + "kv_cache": KVPagedResourceHandler( + num_kv_heads, + head_dim, + dtype=cls.resolve_cache_dtype(cache_config.dtype, k_fake.dtype), + kv_factor=2, + kv_layout=KV_LAYOUT, + ) + } + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + layout = source_attn_node.kwargs.get("layout", None) + if ( + layout is None + and len(source_attn_node.args) > 0 + and isinstance(source_attn_node.args[-1], str) + ): + layout = source_attn_node.args[-1] + if layout != "bsnd": + raise RuntimeError( + f"Expected torch_attention layout='bsnd' but got {layout!r} " + f"for node: {source_attn_node.format_node()}" + ) + + attn_mask, dropout_p, is_causal = extract_op_args( + source_attn_node, "attn_mask", "dropout_p", "is_causal" + ) + if attn_mask is not None or dropout_p != 0.0 or not is_causal: + ad_logger.debug( + "Unsupported attention arguments for " + f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" + ) + + if len(source_attn_node.args) > 6: + scale = source_attn_node.args[6] + else: + scale = source_attn_node.kwargs.get("scale", None) + + if not (isinstance(scale, float) or scale is None): + ad_logger.warning(f"Provided {scale=}, is not a float. Using default scale instead.") + scale = None + + return [scale] diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 3772159f6434..ce69ecba516c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -181,6 +181,11 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): "max_seq_len": 2048, "compile_backend": "torch-simple", }, + "triton_paged": { + "max_batch_size": 128, + "max_seq_len": 8192, + "compile_backend": "torch-cudagraph", + }, } def get_default_kwargs(self, @@ -232,7 +237,8 @@ def check_acceptance_rate(self, llm, min_acceptance_rate: float): @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) - @pytest.mark.parametrize("attn_backend", ["flashinfer", "trtllm", "torch"]) + @pytest.mark.parametrize("attn_backend", + ["flashinfer", "trtllm", "torch", "triton_paged"]) def test_auto_dtype(self, world_size, enable_chunked_prefill, attn_backend): kwargs = self.get_default_kwargs(enable_chunked_prefill, attn_backend) sampling_params = self.get_default_sampling_params() diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index ff8bd51e5ad5..8d1a900aa0c8 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -265,6 +265,7 @@ l0_b200: backend: autodeploy tests: - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[trtllm-False-1] + - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[triton_paged-False-1] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[torch-True-1] - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_auto_dtype[trtllm-False] - accuracy/test_llm_api_autodeploy.py::TestGLM4Flash::test_nvfp4[False] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 8d87773b9675..58750e43b3fe 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -470,6 +470,7 @@ l0_h100: - unittest/auto_deploy/singlegpu/utils - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[trtllm-False-1] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[trtllm-True-1] + - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[triton_paged-False-1] - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B_Instruct_Eagle3::test_eagle3_one_model - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[trtllm-triton_ssm-False] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[trtllm-flashinfer_ssm-False] diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py new file mode 100644 index 000000000000..5a14303e14a7 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py @@ -0,0 +1,690 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Triton Paged Attention. + +Tests the Triton paged attention kernels and compares against FlashInfer for correctness. +""" + +import math + +import pytest +import torch + +# Skip all tests if CUDA is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +def create_paged_kv_cache( + num_blocks: int, + page_size: int, + n_kv_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", +) -> torch.Tensor: + """Create an empty paged KV cache with HND layout. + + Shape: [num_blocks, 2, n_kv_heads, page_size, head_dim] + """ + return torch.zeros(num_blocks, 2, n_kv_heads, page_size, head_dim, dtype=dtype, device=device) + + +def create_page_table( + batch_size: int, + max_pages_per_seq: int, + num_blocks: int, + device: str = "cuda", +) -> tuple: + """Create page table metadata. + + Returns: + kv_indices: Flattened page indices + kv_indptr: Cumulative page counts [batch_size + 1] + kv_last_page_len: Valid tokens in last page [batch_size] + """ + # Assign sequential pages to each sequence + kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + all_indices = [] + + for i in range(batch_size): + num_pages = min(max_pages_per_seq, num_blocks - sum(kv_indptr[: i + 1].tolist())) + pages = list(range(int(kv_indptr[i].item()), int(kv_indptr[i].item()) + num_pages)) + all_indices.extend(pages) + kv_indptr[i + 1] = kv_indptr[i] + num_pages + + kv_indices = torch.tensor(all_indices, dtype=torch.int32, device=device) + kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) + + return kv_indices, kv_indptr, kv_last_page_len + + +class TestTritonPagedDecodeKernel: + """Tests for the FlashDecoding paged decode kernel (stage1 + stage2).""" + + @pytest.mark.parametrize("batch_size", [1, 4, 8]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [64, 256, 512]) + def test_decode_kernel_vs_pytorch_reference( + self, batch_size: int, n_heads: int, n_kv_heads: int, head_dim: int, seq_len: int + ): + """Test decode kernel against PyTorch SDPA reference.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + + # Create Q for decode (single token per sequence) + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + + # Create K, V for the full sequence + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + # Flatten for cache update + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + + # Create metadata for cache update + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + # Create page table + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + # Create and fill cache + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache, kv_indices, kv_indptr + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + # Run Triton kernel + output_triton = triton_paged_decode( + q, kv_cache, kv_indices, kv_indptr, kv_last_page_len, sm_scale + ) + + # Compute PyTorch reference + # Q: [B, n_heads, head_dim] -> [B, n_heads, 1, head_dim] + # K: [B, seq_len, n_kv_heads, head_dim] -> [B, n_kv_heads, seq_len, head_dim] + # V: [B, seq_len, n_kv_heads, head_dim] -> [B, n_kv_heads, seq_len, head_dim] + q_ref = q.unsqueeze(2) # [B, n_heads, 1, head_dim] + k_ref = k.transpose(1, 2) # [B, n_kv_heads, seq_len, head_dim] + v_ref = v.transpose(1, 2) # [B, n_kv_heads, seq_len, head_dim] + + # Handle GQA by expanding K, V + head_ratio = n_heads // n_kv_heads + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + output_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=sm_scale, is_causal=False + ) + output_ref = output_ref.squeeze(2) # [B, n_heads, head_dim] + + # Compare + torch.testing.assert_close(output_triton.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + +class TestTritonPagedContextKernel: + """Tests for the context/prefill kernel.""" + + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [32, 64, 128, 512]) + def test_context_kernel_vs_pytorch_reference( + self, batch_size: int, n_heads: int, n_kv_heads: int, head_dim: int, seq_len: int + ): + """Test context kernel against PyTorch SDPA reference.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context, + update_paged_kv_cache, + ) + + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 5 + total_tokens = batch_size * seq_len + + # Create inputs (flattened) + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + # Create metadata + qo_indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device="cuda" + )[: batch_size + 1] + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + seq_len_with_cache = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + # Create batch_indices and positions for cache update + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + # Create and fill cache + kv_cache = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + sm_scale = 1.0 / math.sqrt(head_dim) + + output = triton_paged_context( + q, + kv_cache, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + seq_len_with_cache, + sm_scale, + ) + + assert output.shape == q.shape + + # PyTorch SDPA reference (causal) + q_ref = q.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2) + k_ref = k.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + v_ref = v.view(batch_size, seq_len, n_kv_heads, head_dim).transpose(1, 2) + + head_ratio = n_heads // n_kv_heads + if head_ratio > 1: + k_ref = k_ref.repeat_interleave(head_ratio, dim=1) + v_ref = v_ref.repeat_interleave(head_ratio, dim=1) + + output_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=sm_scale, is_causal=True + ) + output_ref = output_ref.transpose(1, 2).reshape(total_tokens, n_heads, head_dim) + + torch.testing.assert_close(output.float(), output_ref.float(), rtol=1e-2, atol=1e-2) + + +class TestCacheUpdate: + """Tests for the KV cache update kernel.""" + + def test_cache_update_writes_correct_values(self): + """Test that cache update writes K, V to correct locations.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + update_paged_kv_cache, + ) + + batch_size = 2 + seq_len = 8 + n_kv_heads = 4 + head_dim = 32 + page_size = 4 + num_blocks = 10 + + # Create K, V with known values + k = torch.arange( + batch_size * seq_len * n_kv_heads * head_dim, dtype=torch.float16, device="cuda" + ).reshape(batch_size * seq_len, n_kv_heads, head_dim) + v = k + 1000 # Offset to distinguish K from V + + # Create metadata + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + kv_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device="cuda") + kv_indices = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device="cuda") + + # Create empty cache + kv_cache = torch.zeros( + num_blocks, 2, n_kv_heads, page_size, head_dim, dtype=torch.float16, device="cuda" + ) + + # Run update + update_paged_kv_cache(k, v, batch_indices, positions, kv_cache, kv_indices, kv_indptr) + + # Verify: check first token of sequence 0 + # Token 0 should be at page 0, offset 0 + expected_k = k[0] # First token's K + actual_k = kv_cache[0, 0, :, 0, :] # Page 0, K (idx 0), all heads, offset 0 + torch.testing.assert_close(actual_k, expected_k) + + expected_v = v[0] # First token's V + actual_v = kv_cache[0, 1, :, 0, :] # Page 0, V (idx 1), all heads, offset 0 + torch.testing.assert_close(actual_v, expected_v) + + +class TestTritonPagedMHAIntegration: + """Integration tests for triton_paged_mha_with_cache and prepare_triton_paged_metadata. + + These test the full integration layer including BatchInfo parsing, + metadata preparation, KV cache update, and mixed prefill/decode dispatch. + This would have caught the batch_info_host 12-element format change. + """ + + @staticmethod + def _make_batch_info( + num_prefill: int, + num_prefill_tokens: int, + num_decode: int, + max_context_length: int = 8192, + max_blocks_per_seq: int = 256, + max_batch_size: int = 8, + ) -> torch.Tensor: + """Create a 12-element batch_info_host tensor.""" + bi = torch.zeros(12, dtype=torch.int, pin_memory=True) + bi[0] = num_prefill + bi[1] = num_prefill_tokens + bi[2] = 0 # num_extend + bi[3] = 0 # num_extend_tokens + bi[4] = num_decode + bi[5] = num_decode # num_decode_tokens = num_decode (1 token each) + bi[6] = max_context_length + bi[7] = max_blocks_per_seq + bi[8] = 1 # block_offset_multiplier + bi[9] = max_batch_size + bi[10] = 0 # num_tokens_to_gather + bi[11] = 0 # gather_required + return bi + + def test_batch_info_12_element_format(self): + """Test that triton_paged_mha_with_cache handles 12-element batch_info_host. + + Regression test: batch_info_host changed from 3-element to 12-element tensor. + The old code did `num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()` + which would crash with ValueError: too many values to unpack. + """ + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_mha_with_cache, + ) + + n_heads, n_kv_heads, head_dim, page_size = 32, 8, 128, 16 + seq_len = 32 + num_pages = (seq_len + page_size - 1) // page_size + num_blocks = num_pages + 16 + + # Prefill-only batch: 1 sequence, 32 tokens + q = torch.randn(1, seq_len, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(1, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(1, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + batch_info_host = self._make_batch_info( + num_prefill=1, num_prefill_tokens=seq_len, num_decode=0 + ) + cu_seqlen_host = torch.tensor([0, seq_len], dtype=torch.int32) + cu_num_pages = torch.tensor([0, num_pages], dtype=torch.int32, device="cuda") + cu_num_pages_host = cu_num_pages.cpu() + cache_loc = torch.arange(num_pages, dtype=torch.int32, device="cuda") + last_page_len = torch.tensor( + [seq_len % page_size or page_size], dtype=torch.int32, device="cuda" + ) + last_page_len_host = last_page_len.cpu() + seq_len_with_cache_host = torch.tensor([seq_len], dtype=torch.int32) + kv_cache = torch.zeros( + num_blocks, 2, n_kv_heads, page_size, head_dim, dtype=torch.float16, device="cuda" + ) + + # Prepare metadata (this also uses batch_info_host) + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + prepare_triton_paged_metadata, + ) + + position_ids = torch.arange(seq_len, device="cuda") + batch_indices, positions = prepare_triton_paged_metadata( + position_ids, + batch_info_host, + cu_seqlen_host.to("cuda", non_blocking=True), + seq_len_with_cache_host.to("cuda", non_blocking=True), + ) + + # Run the full MHA with cache (should not crash) + output = triton_paged_mha_with_cache( + q, + k, + v, + batch_info_host, + cu_seqlen_host, + cu_num_pages, + cu_num_pages_host, + cache_loc, + last_page_len, + last_page_len_host, + seq_len_with_cache_host, + batch_indices, + positions, + kv_cache, + scale=None, + ) + + assert output.shape == q.shape + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + + def test_batch_info_with_extend_requests(self): + """Test that extend requests are absorbed into prefill counts.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo + + bi = torch.zeros(12, dtype=torch.int, pin_memory=True) + bi[0] = 1 # num_prefill + bi[1] = 32 # num_prefill_tokens + bi[2] = 2 # num_extend + bi[3] = 64 # num_extend_tokens + bi[4] = 3 # num_decode + + batch_info = BatchInfo(bi) + num_prefill, num_prefill_tokens, num_decode = batch_info.get_absorbed_info() + + # Extend should be absorbed into prefill + assert num_prefill == 3 # 1 + 2 + assert num_prefill_tokens == 96 # 32 + 64 + assert num_decode == 3 + + def test_prepare_metadata_with_12_element_batch_info(self): + """Test prepare_triton_paged_metadata with 12-element batch_info_host.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + prepare_triton_paged_metadata, + ) + + batch_info_host = self._make_batch_info(num_prefill=1, num_prefill_tokens=7, num_decode=0) + position_ids = torch.arange(7, device="cuda") + cu_seqlen = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + seq_len_with_cache = torch.tensor([7], dtype=torch.int32, device="cuda") + + # Should not raise ValueError + batch_indices, positions = prepare_triton_paged_metadata( + position_ids, batch_info_host, cu_seqlen, seq_len_with_cache + ) + + assert batch_indices.shape[0] == 7 + assert positions.shape[0] == 7 + assert (batch_indices == 0).all() + assert (positions == torch.arange(7, device="cuda")).all() + + +class TestFlashInferComparison: + """Tests comparing Triton implementation against FlashInfer.""" + + @pytest.mark.parametrize("batch_size", [1, 4, 8]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [64, 128, 512]) + def test_decode_vs_flashinfer( + self, batch_size: int, n_heads: int, n_kv_heads: int, head_dim: int, seq_len: int + ): + """Compare decode output against FlashInfer.""" + import flashinfer + + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_decode, + update_paged_kv_cache, + ) + + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 10 + + # Create shared K, V data + k = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + v = torch.randn( + batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" + ) + + # Query for decode + q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") + + # Page table metadata + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + # ===== Triton ===== + kv_cache_triton = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) + v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + update_paged_kv_cache( + k_flat, v_flat, batch_indices, positions, kv_cache_triton, kv_indices, kv_indptr + ) + output_triton = triton_paged_decode( + q, kv_cache_triton, kv_indices, kv_indptr, kv_last_page_len, sm_scale + ) + + # ===== FlashInfer ===== + kv_cache_fi = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + # Use FlashInfer's cache append + fi_batch_indices = batch_indices.clone() + fi_positions = positions.clone() + flashinfer.page.append_paged_kv_cache( + append_key=k_flat, + append_value=v_flat, + batch_indices=fi_batch_indices, + positions=fi_positions, + paged_kv_cache=kv_cache_fi, + kv_indices=kv_indices, + kv_indptr=kv_indptr, + kv_last_page_len=kv_last_page_len, + kv_layout="HND", + ) + + # Use FlashInfer decode + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda") + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace, "HND", use_tensor_cores=True + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + n_heads, + n_kv_heads, + head_dim, + page_size, + q_data_type=q.dtype, + kv_data_type=kv_cache_fi.dtype, + sm_scale=sm_scale, + ) + output_fi = wrapper.run(q, kv_cache_fi) + + # Compare + torch.testing.assert_close(output_triton.float(), output_fi.float(), rtol=1e-2, atol=1e-2) + + @pytest.mark.skipif( + not pytest.importorskip("flashinfer", reason="FlashInfer not installed"), + reason="FlashInfer not installed", + ) + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (32, 8)]) + @pytest.mark.parametrize("head_dim", [64, 128]) + @pytest.mark.parametrize("seq_len", [64, 128, 512]) + def test_prefill_vs_flashinfer( + self, batch_size: int, n_heads: int, n_kv_heads: int, head_dim: int, seq_len: int + ): + """Compare prefill output against FlashInfer.""" + import flashinfer + + from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( + triton_paged_context, + update_paged_kv_cache, + ) + + page_size = 16 + + num_pages_per_seq = (seq_len + page_size - 1) // page_size + num_blocks = batch_size * num_pages_per_seq + 10 + total_tokens = batch_size * seq_len + + # Create inputs + q = torch.randn(total_tokens, n_heads, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(total_tokens, n_kv_heads, head_dim, dtype=torch.float16, device="cuda") + + # Metadata + qo_indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device="cuda" + )[: batch_size + 1] + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_seq, + num_pages_per_seq, + dtype=torch.int32, + device="cuda", + )[: batch_size + 1] + kv_indices = torch.arange( + 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" + ) + last_token_in_page = seq_len % page_size + kv_last_page_len = torch.full( + (batch_size,), + last_token_in_page if last_token_in_page > 0 else page_size, + dtype=torch.int32, + device="cuda", + ) + seq_len_with_cache = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + batch_indices = torch.repeat_interleave( + torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len + ) + positions = torch.tile( + torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) + ) + + sm_scale = 1.0 / math.sqrt(head_dim) + + # ===== Triton ===== + kv_cache_triton = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + update_paged_kv_cache( + k, v, batch_indices, positions, kv_cache_triton, kv_indices, kv_indptr + ) + output_triton = triton_paged_context( + q, + kv_cache_triton, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + seq_len_with_cache, + sm_scale, + ) + + # ===== FlashInfer ===== + kv_cache_fi = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) + flashinfer.page.append_paged_kv_cache( + append_key=k, + append_value=v, + batch_indices=batch_indices.clone(), + positions=positions.clone(), + paged_kv_cache=kv_cache_fi, + kv_indices=kv_indices, + kv_indptr=kv_indptr, + kv_last_page_len=kv_last_page_len, + kv_layout="HND", + ) + + workspace = torch.empty(320 * 1024 * 1024, dtype=torch.uint8, device="cuda") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace, "HND") + wrapper.plan( + qo_indptr.cpu(), + kv_indptr.cpu(), + kv_indices, + kv_last_page_len.cpu(), + n_heads, + n_kv_heads, + head_dim, + page_size, + causal=True, + q_data_type=q.dtype, + kv_data_type=kv_cache_fi.dtype, + sm_scale=sm_scale, + seq_lens=seq_len_with_cache.cpu(), + ) + output_fi = wrapper.run(q, kv_cache_fi) + + # Compare + torch.testing.assert_close(output_triton.float(), output_fi.float(), rtol=1e-2, atol=1e-2)