diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index a92f69aba79..5afd1d36947 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1,5 +1,6 @@ import functools import math +import os import weakref from typing import List, Optional, Union, cast @@ -1125,29 +1126,6 @@ def yarn_get_mscale(scale=1, mscale=1): mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) q_scaling = 1.0 / (mscale * mscale) - if not self.is_dsa: - self.mha = create_attention( - config.attn_backend, - self.layer_idx, - self.num_heads_tp, - head_dim=self.qk_head_dim, - num_kv_heads=self.num_key_value_heads_tp, - pos_embd_params=pos_embd_params, - quant_config=quant_config, - q_scaling=q_scaling, - is_mla_enable=True, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - predicted_tokens_per_seq=self.predicted_tokens_per_seq, - skip_create_weights_in_init=config.skip_create_weights_in_init, - sparse_attention_config=config.sparse_attention_config, - ) - else: - self.mha = None - self.mqa = create_attention( config.attn_backend, self.layer_idx, @@ -1186,6 +1164,48 @@ def yarn_get_mscale(scale=1, mscale=1): is_neox=pos_embd_params.is_neox, ) + # Short-sequence MHA optimization for DSA models: + # For short prefill sequences, use MHA (kv_b_proj expansion + standard + # attention) instead of the absorption path, which has overhead from + # extra BMMs and larger head_dim (kv_lora_rank + qk_rope_head_dim). + # Only active when rope_fusion is True (DSA with TrtllmAttention). + _threshold_str = os.environ.get('TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD', + '0') + try: + self.short_seq_mha_threshold = int(_threshold_str) + except ValueError as err: + raise ValueError( + f"TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD must be an integer, " + f"got '{_threshold_str}'") from err + + # MHA attention backend: used by non-DSA (standard MLA) and optionally + # by DSA for the short-seq path (dense attention, no sparse config). + _short_seq_mha = (self.is_dsa and self.short_seq_mha_threshold > 0 + and not self.apply_rotary_emb) + if not self.is_dsa or _short_seq_mha: + self.mha = create_attention( + config.attn_backend, + self.layer_idx, + self.num_heads_tp, + head_dim=self.qk_head_dim, + num_kv_heads=self.num_key_value_heads_tp, + pos_embd_params=pos_embd_params, + quant_config=quant_config, + q_scaling=q_scaling, + is_mla_enable=True, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + predicted_tokens_per_seq=self.predicted_tokens_per_seq, + skip_create_weights_in_init=config.skip_create_weights_in_init, + sparse_attention_config=(None if _short_seq_mha else + config.sparse_attention_config), + ) + else: + self.mha = None + self.llama_4_scaling = False if hasattr(config.pretrained_config, 'llama_4_scaling'): self.llama_4_scaling = True @@ -1198,9 +1218,11 @@ def yarn_get_mscale(scale=1, mscale=1): self.create_weights() def create_weights(self): - # self.mha/mqa has no weights but has states that are related to quant_config, - # which could be modified after __init__ - if not self.is_dsa: + # self.mha/mqa has no weights but has states that are related to + # quant_config, which could be modified after __init__. + # self.mha is non-None for non-DSA models (standard MHA) and for DSA + # models when the short-seq MHA optimization is active. + if self.mha is not None: self.mha.update_quant_config(self.quant_config) self.mqa.update_quant_config(self.quant_config) @@ -1344,11 +1366,8 @@ def forward_impl(self, position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. - all_reduce_params (Optional[AllReduceParams]): The all reduce parameters. + output (torch.Tensor): Pre-allocated output tensor, written in-place. latent_cache_gen (Optional[torch.Tensor]): The latent cache used in generation. - - Returns: - torch.Tensor: The output tensor. """ # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts @@ -1450,11 +1469,9 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. - - Returns: - torch.Tensor: The output tensor. + output (torch.Tensor): Pre-allocated output tensor, written in-place. """ - assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode" + assert self.mqa is not None, "DSA is only supported in MQA mode" # split q, k, v into context and gen batches num_contexts = attn_metadata.num_contexts num_generations = attn_metadata.num_generations @@ -1484,14 +1501,29 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], # TODO: fuse wq_b + (indexer) wlq here q = self.q_b_proj(q) - # Indexer - topk_indices = self.indexer( - qr, - hidden_states, - attn_metadata, - position_ids, - indexer_k=indexer_k, # indexer K proj - ) + + # Check if the short-seq MHA path will handle context, in which case + # the indexer (topk_indices) is not needed for context tokens. + # The MHA path handles cached tokens via forward_context(), which + # dispatches to forward_context_with_cached_kv or + # forward_context_with_chunked_prefill as needed. + use_short_mha_for_ctx = (num_contexts > 0 + and self._should_use_short_mha( + attn_metadata, position_ids)) + + # Skip the indexer entirely when the short MHA path handles all + # context tokens and there are no generation tokens. + if use_short_mha_for_ctx and num_generations == 0: + topk_indices = None + else: + # Indexer + topk_indices = self.indexer( + qr, + hidden_states, + attn_metadata, + position_ids, + indexer_k=indexer_k, # indexer K proj + ) assert q.shape[ 0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}" @@ -1514,7 +1546,9 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], attn_metadata, output[:num_ctx_tokens, :], latent_cache_ctx, - topk_indices=topk_indices[:num_ctx_tokens, :], + topk_indices=topk_indices[:num_ctx_tokens, :] + if topk_indices is not None else None, + position_ids=position_ids, ) if num_generations > 0: @@ -1546,6 +1580,10 @@ def forward_context_default( output: torch.Tensor, latent_cache: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Dense MHA context path: expand KV via kv_b_proj and run attention. + + Used by non-DSA models and as the short-seq MHA fallback for DSA models. + """ kv = self.kv_b_proj(compressed_kv) k_nope, v = kv.split( [ @@ -1559,6 +1597,9 @@ def forward_context_default( maybe_compiled_copy_( k[..., :self.qk_nope_head_dim], k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim)) + # When rope_fusion=True (apply_rotary_emb=False), the rope portion + # of k is left uninitialized here; the fused attention kernel + # handles k_pe RoPE via latent_cache instead. if self.apply_rotary_emb: k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, self.qk_rope_head_dim) @@ -1577,6 +1618,23 @@ def forward_context_default( return attn_output + def _should_use_short_mha(self, attn_metadata: AttentionMetadata, + position_ids: Optional[torch.Tensor]) -> bool: + """Check if the short-seq MHA optimization should be used for context. + + Uses max_ctx_kv_len (max total KV length per context sequence, + including cached tokens) when available, to correctly account for + chunked context where the full attention span exceeds the threshold + even if the new token count is small. Falls back to num_ctx_tokens + (total new context tokens) when max_ctx_kv_len is not set. + """ + if not (self.short_seq_mha_threshold > 0 and not self.apply_rotary_emb + and self.mapping.cp_size == 1 and position_ids is not None): + return False + effective_len = getattr(attn_metadata, 'max_ctx_kv_len', + attn_metadata.num_ctx_tokens) + return effective_len <= self.short_seq_mha_threshold + def forward_context_dsa( self, q: torch.Tensor, @@ -1586,7 +1644,38 @@ def forward_context_dsa( output: torch.Tensor, latent_cache: Optional[torch.Tensor] = None, topk_indices: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Run context-phase attention for DSA models. + + Dispatches to the short-seq MHA path (forward_context) when the max + per-sequence KV length (including cached tokens) is within the + threshold, or falls through to the absorption/sparse MLA path + otherwise. forward_context() further dispatches to the appropriate + handler (forward_context_default, forward_context_with_cached_kv, or + forward_context_with_chunked_prefill) based on cached-KV state. + + Args: + q: Query tensor, shape [num_ctx_tokens, num_heads * qk_head_dim]. + compressed_kv: Latent KV, shape [num_ctx_tokens, kv_lora_rank]. + k_pe: RoPE key portion, shape [num_ctx_tokens, qk_rope_head_dim]. + attn_metadata: Attention metadata for the current batch. + output: Pre-allocated output tensor, written in-place. + latent_cache: Concatenated [compressed_kv, k_pe] for KV cache. + topk_indices: Sparse routing indices from the indexer (None when + the short-seq MHA path is used). + position_ids: Token position IDs (required for short-seq MHA). + """ + # Short-sequence MHA: bypass absorption path for short prefills, + # using kv_b_proj expansion + standard attention instead. + # See __init__ comment for rationale. topk_indices is not used + # because dense attention is faster than sparse routing at this scale. + # forward_context() handles cached tokens by dispatching to + # forward_context_with_cached_kv or forward_context_with_chunked_prefill. + if self._should_use_short_mha(attn_metadata, position_ids): + return self.forward_context(q, compressed_kv, k_pe, position_ids, + attn_metadata, output, latent_cache) + if get_sm_version() >= 100: return self.forward_absorption_context(q, compressed_kv, @@ -1929,10 +2018,13 @@ def forward_context( self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim, q.dtype, q.device) if trtllm_attention.is_chunked_prefill_for_mla_context( - attn_metadata): + attn_metadata) and get_sm_version() >= 100: return self.forward_context_with_chunked_prefill( q, compressed_kv, latent_cache, attn_metadata, output) - elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata): + elif trtllm_attention.has_cached_kv_for_mla_context( + attn_metadata + ) or trtllm_attention.is_chunked_prefill_for_mla_context( + attn_metadata): return self.forward_context_with_cached_kv( q, latent_cache, attn_metadata, output) return self.forward_context_default(q, compressed_kv, k_pe, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index abd257fb14b..5308f03f201 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -365,6 +365,10 @@ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5846166) perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con1_ctx1_pp8_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5846166) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] SKIP (https://nvbugs/5879614) +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=True] SKIP (https://nvbugs/5893116) +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5875522) +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5875522) +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5875522) accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463) unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460) diff --git a/tests/unittest/_torch/attention/sparse/test_short_seq_mha.py b/tests/unittest/_torch/attention/sparse/test_short_seq_mha.py new file mode 100644 index 00000000000..c0d5d8a7852 --- /dev/null +++ b/tests/unittest/_torch/attention/sparse/test_short_seq_mha.py @@ -0,0 +1,665 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test short-sequence MHA optimization path in MLA. + +Covers: pure prefill correctness (vs reference), threshold boundary condition, +threshold=0 disables MHA, above-threshold fallback to standard path, +A/B comparison vs absorption path, and chunked context with cached KV. +""" + +import math +import os +from dataclasses import dataclass +from types import SimpleNamespace +from typing import List, Tuple + +import pytest +import torch + +import tensorrt_llm +import tensorrt_llm.bindings +from tensorrt_llm._torch.attention_backend.interface import ( + AttentionRuntimeFeatures, + PositionalEmbeddingParams, + RopeParams, +) +from tensorrt_llm._torch.attention_backend.sparse.dsa import DSACacheManager +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.attention import MLA +from tensorrt_llm._utils import get_sm_version, str_dtype_to_binding, torch_dtype_to_str +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.functional import PositionEmbeddingType, RopeEmbeddingUtils +from tensorrt_llm.llmapi.llm_args import DeepSeekSparseAttentionConfig +from tensorrt_llm.mapping import Mapping + +# DSACacheManager creates background ThreadPoolExecutor threads. +pytestmark = pytest.mark.threadleak(enabled=False) + +# --------------------------------------------------------------------------- +# Model constants (DeepSeek V3-like) +# --------------------------------------------------------------------------- +NUM_HEADS = 128 +Q_LORA_RANK = 512 +KV_LORA_RANK = 512 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +V_HEAD_DIM = 128 +QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM +HIDDEN_SIZE = 2048 +MAX_POSITION_EMBEDDINGS = 4096 +TOKENS_PER_BLOCK = 64 +NUM_LAYERS = 1 +LAYER_IDX = 0 +TOPK_TOKENS = 2048 +NN_INIT_STD = 0.02 + +# --------------------------------------------------------------------------- +# Parametrized test specs +# --------------------------------------------------------------------------- +# (name, seq_lens, threshold_offset): threshold = sum(seq_lens) + offset. +# offset=0 tests the boundary condition (threshold == total tokens). +PREFILL_SPECS = [ + ("single", [16], 100), + ("multi_boundary", [8, 12, 16], 0), +] + +# (name, [(cached_tokens, new_tokens), ...], chunk_size or None) +# chunk_size=None -> cached-KV path; chunk_size=int -> chunked-prefill path. +CHUNKED_SPECS = [ + ("cached_kv", [(64, 32), (32, 16)], None), + ("chunked_prefill", [(64, 16), (64, 8)], 32), +] + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- +def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_embedding(x, cos_sin): + original_dtype = x.dtype + cos, sin = cos_sin.chunk(2, dim=-2) + cos = cos.squeeze(1) + sin = sin.squeeze(1) + x_interleaved = x.unflatten(-1, [-1, 2]).transpose(-2, -1).flatten(start_dim=-2) + cos_expanded = cos.view(cos.shape[0], *([1] * (x.ndim - 2)), cos.shape[-1]) + sin_expanded = sin.view(sin.shape[0], *([1] * (x.ndim - 2)), sin.shape[-1]) + x_rotated = (x_interleaved * cos_expanded) + (_rotate_half(x_interleaved) * sin_expanded) + return x_rotated.to(original_dtype) + + +def _reference_attention( + q, compressed_kv, k_pe, kv_b_proj_weight, rope_cos_sin, seq_lens, softmax_scale +): + """kv_b_proj expansion + RoPE + causal SDPA reference.""" + results = [] + offset = 0 + for slen in seq_lens: + qs = q[offset : offset + slen].view(-1, NUM_HEADS, QK_HEAD_DIM) + cs = rope_cos_sin[:slen] + + q_nope = qs[..., :QK_NOPE_HEAD_DIM] + q_pe = _apply_rotary_embedding(qs[..., QK_NOPE_HEAD_DIM:], cs) + q_full = torch.cat([q_nope, q_pe], dim=-1) + + k_pe_roped = _apply_rotary_embedding(k_pe[offset : offset + slen], cs) + kv = torch.nn.functional.linear(compressed_kv[offset : offset + slen], kv_b_proj_weight) + k_nope, v = kv.split([NUM_HEADS * QK_NOPE_HEAD_DIM, NUM_HEADS * V_HEAD_DIM], -1) + k_full = torch.cat( + [ + k_nope.view(-1, NUM_HEADS, QK_NOPE_HEAD_DIM), + k_pe_roped.view(-1, 1, QK_ROPE_HEAD_DIM).expand(-1, NUM_HEADS, -1), + ], + dim=-1, + ) + v_r = v.view(-1, NUM_HEADS, V_HEAD_DIM) + + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_full.unsqueeze(0).transpose(1, 2), + k_full.unsqueeze(0).transpose(1, 2), + v_r.unsqueeze(0).transpose(1, 2), + is_causal=True, + scale=softmax_scale, + ) + results.append(attn_out.transpose(1, 2).squeeze(0).reshape(slen, NUM_HEADS * V_HEAD_DIM)) + offset += slen + return torch.cat(results, dim=0) + + +# --------------------------------------------------------------------------- +# Setup helpers +# --------------------------------------------------------------------------- +@dataclass +class RopeConfig: + hidden_size: int + num_attention_heads: int + rope_scaling: dict + max_position_embeddings: int + rope_theta: float + qk_rope_head_dim: int + model_type: str + + +def _make_rope_config(): + return RopeConfig( + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + max_position_embeddings=MAX_POSITION_EMBEDDINGS, + rope_theta=10000.0, + qk_rope_head_dim=QK_ROPE_HEAD_DIM, + model_type="deepseek_v2", + ) + + +def _make_rope_cos_sin(rope_config, device): + return ( + torch.tensor( + RopeEmbeddingUtils.create_sinusoidal_positions_yarn( + rope_config.max_position_embeddings, + rope_config.qk_rope_head_dim, + rope_config.rope_theta, + rope_config.rope_scaling["factor"], + rope_config.rope_scaling["original_max_position_embeddings"], + rope_config.rope_scaling["beta_fast"], + rope_config.rope_scaling["beta_slow"], + rope_config.rope_scaling["mscale"], + rope_config.rope_scaling["mscale_all_dim"], + )[1], + dtype=torch.float32, + device=device, + ) + .reshape(rope_config.max_position_embeddings, -1, 2) + .transpose(-2, -1) + ) + + +def _compute_softmax_scale(rope_config): + def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + mscale_all_dim = rope_config.rope_scaling["mscale_all_dim"] + scaling_factor = rope_config.rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + q_scaling = 1.0 / (mscale * mscale) + return 1.0 / (math.sqrt(QK_HEAD_DIM) * q_scaling) + + +def _build_mla(rope_config, device, threshold): + """Build an MLA module with DSA config and the given threshold.""" + mapping = Mapping(world_size=1, tp_size=1, rank=0) + sparse_config = DeepSeekSparseAttentionConfig( + index_n_heads=64, + index_head_dim=128, + index_topk=TOPK_TOKENS, + ) + pretrained_config = SimpleNamespace(rms_norm_eps=1e-6) + model_config = ModelConfig( + mapping=mapping, + sparse_attention_config=sparse_config, + pretrained_config=pretrained_config, + ) + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(rope_config), + is_neox=False, + ) + + old_val = os.environ.get("TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD") + os.environ["TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD"] = str(threshold) + try: + mla = MLA( + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + num_key_value_heads=1, + qk_nope_head_dim=QK_NOPE_HEAD_DIM, + qk_rope_head_dim=QK_ROPE_HEAD_DIM, + v_head_dim=V_HEAD_DIM, + q_lora_rank=Q_LORA_RANK, + kv_lora_rank=KV_LORA_RANK, + predicted_tokens_per_seq=1, + max_position_embeddings=MAX_POSITION_EMBEDDINGS, + bias=False, + pos_embd_params=pos_embd_params, + layer_idx=LAYER_IDX, + dtype=torch.bfloat16, + config=model_config, + ).to(device) + finally: + if old_val is None: + os.environ.pop("TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD", None) + else: + os.environ["TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD"] = old_val + + # mla.mqa (DSATrtllmAttention) is not an nn.Module, so mla.to(device) + # does not move its children. Explicitly move the indexer. + if hasattr(mla, "mqa") and hasattr(mla.mqa, "indexer"): + mla.mqa.indexer.to(device) + + return mla, mapping, sparse_config, model_config + + +def _init_mla_weights(mla): + """Initialize MLA weights deterministically in the loaded layout. + + The loaded layout (as produced by modeling_deepseekv3.py) is: + [all_heads_k_nope, all_heads_v] along the output dimension. + """ + with torch.no_grad(): + dev = mla.kv_b_proj.weight.device + dt = mla.kv_b_proj.weight.dtype + + k_nope_weight = torch.empty(NUM_HEADS, QK_NOPE_HEAD_DIM, KV_LORA_RANK, dtype=dt, device=dev) + k_nope_weight.normal_(mean=0.0, std=NN_INIT_STD) + + v_weight = torch.empty(NUM_HEADS, V_HEAD_DIM, KV_LORA_RANK, dtype=dt, device=dev) + v_weight.normal_(mean=0.0, std=NN_INIT_STD) + + mla.kv_b_proj.weight.data = torch.cat( + [ + k_nope_weight.reshape(NUM_HEADS * QK_NOPE_HEAD_DIM, KV_LORA_RANK), + v_weight.reshape(NUM_HEADS * V_HEAD_DIM, KV_LORA_RANK), + ], + dim=0, + ) + mla.k_b_proj_trans.data = k_nope_weight.transpose(1, 2).contiguous() + mla.v_b_proj.data = v_weight.contiguous() + + mla.mqa.indexer.wq_b.weight.normal_(mean=0.0, std=NN_INIT_STD) + mla.mqa.indexer.wk.weight.normal_(mean=0.0, std=NN_INIT_STD) + mla.mqa.indexer.weights_proj.weight.normal_(mean=0.0, std=NN_INIT_STD) + + +def _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device): + """Build a DSACacheManager for the given batch.""" + kv_cache_manager = DSACacheManager( + KvCacheConfig(max_tokens=16384, enable_block_reuse=False), + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, + num_layers=NUM_LAYERS, + num_kv_heads=1, + head_dim=KV_LORA_RANK + QK_ROPE_HEAD_DIM, + tokens_per_block=TOKENS_PER_BLOCK, + max_seq_len=max(seq_lens), + max_batch_size=len(seq_lens), + mapping=mapping, + dtype=str_dtype_to_binding(torch_dtype_to_str(torch.bfloat16)), + sparse_attn_config=sparse_config, + model_config=model_config, + ) + for req_idx, seq_len in enumerate(seq_lens): + kv_cache_manager.add_dummy_requests( + request_ids=[req_idx], + token_nums=[seq_len], + is_gen=False, + prepare_resource=True, + ) + return kv_cache_manager + + +def _make_inputs(seq_lens, device, dtype=torch.bfloat16): + """Generate random MLA inputs for the given sequence lengths.""" + total = sum(seq_lens) + q = torch.randn(total, NUM_HEADS * QK_HEAD_DIM, dtype=dtype, device=device) + compressed_kv = torch.randn(total, KV_LORA_RANK, dtype=dtype, device=device) + k_pe = torch.randn(total, QK_ROPE_HEAD_DIM, dtype=dtype, device=device) + latent_cache = torch.cat([compressed_kv, k_pe], dim=-1) + position_ids = torch.cat([torch.arange(s, device=device, dtype=torch.int32) for s in seq_lens]) + return q, compressed_kv, k_pe, latent_cache, position_ids + + +def _make_metadata( + attn_cls, + seq_lens, + kv_cache_manager, + mapping, + sparse_config, + cached_per_seq=None, + runtime_features=None, +): + """Build and prepare attention metadata. + + When cached_per_seq is provided, enables the cached-KV context path. + """ + num_ctx = len(seq_lens) + kwargs = {} + if cached_per_seq is not None: + kwargs["enable_context_mla_with_cached_kv"] = True + if runtime_features is not None: + kwargs["runtime_features"] = runtime_features + metadata = attn_cls.Metadata( + seq_lens=torch.tensor(seq_lens, dtype=torch.int), + request_ids=list(range(num_ctx)), + max_num_requests=num_ctx, + num_contexts=num_ctx, + prompt_lens=seq_lens, + max_num_tokens=sum(seq_lens), + kv_cache_manager=kv_cache_manager, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=cached_per_seq or [0] * num_ctx, + ), + mapping=mapping, + sparse_attention_config=sparse_config, + **kwargs, + ) + metadata.prepare() + return metadata + + +def _run_forward( + mla, q, compressed_kv, k_pe, latent_cache, position_ids, metadata, topk_indices=None +): + """Run forward_context_dsa on cloned inputs and return the output tensor.""" + output = torch.empty(q.shape[0], NUM_HEADS * V_HEAD_DIM, dtype=q.dtype, device=q.device) + mla.forward_context_dsa( + q=q.clone(), + compressed_kv=compressed_kv.clone(), + k_pe=k_pe.clone(), + attn_metadata=metadata, + output=output, + latent_cache=latent_cache.clone(), + topk_indices=topk_indices, + position_ids=position_ids.clone(), + ) + return output + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +@pytest.mark.parametrize( + "name,seq_lens,threshold_offset", + PREFILL_SPECS, + ids=[s[0] for s in PREFILL_SPECS], +) +def test_forward_context_short_mha(name: str, seq_lens: List[int], threshold_offset: int): + """Short-seq MHA output vs standalone reference (kv_b_proj + RoPE + SDPA).""" + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + rope_config = _make_rope_config() + threshold = sum(seq_lens) + threshold_offset + mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold) + _init_mla_weights(mla) + + kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) + attn_cls = get_attention_backend("TRTLLM", sparse_config) + q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) + metadata = _make_metadata(attn_cls, seq_lens, kv_mgr, mapping, sparse_config) + + output = _run_forward(mla, q, compressed_kv, k_pe, latent_cache, position_ids, metadata) + + rope_cos_sin = _make_rope_cos_sin(rope_config, device) + softmax_scale = _compute_softmax_scale(rope_config) + ref = _reference_attention( + q, compressed_kv, k_pe, mla.kv_b_proj.weight.data, rope_cos_sin, seq_lens, softmax_scale + ) + + torch.testing.assert_close(output, ref, rtol=0.05, atol=0.05) + kv_mgr.shutdown() + + +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +def test_threshold_zero_disables_mha(): + """With threshold=0, the short MHA path is NOT active.""" + device = torch.device("cuda") + rope_config = _make_rope_config() + mla, _, _, _ = _build_mla(rope_config, device, threshold=0) + + assert mla.short_seq_mha_threshold == 0 + assert mla.mha is None + + +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +def test_standard_path_when_exceeds_threshold(): + """When total tokens > threshold, the standard absorption path is used.""" + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + rope_config = _make_rope_config() + seq_lens = [32] + mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold=16) + _init_mla_weights(mla) + + kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) + attn_cls = get_attention_backend("TRTLLM", sparse_config) + q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) + + total_tokens = sum(seq_lens) + hidden_states = torch.randn(total_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + qr = torch.randn(total_tokens, Q_LORA_RANK, dtype=torch.bfloat16, device=device) + + metadata = _make_metadata(attn_cls, seq_lens, kv_mgr, mapping, sparse_config) + topk_indices = mla.mqa.indexer( + qr, + hidden_states, + metadata, + position_ids, + indexer_k=mla.mqa.indexer.wk(hidden_states), + ) + + output = _run_forward( + mla, q, compressed_kv, k_pe, latent_cache, position_ids, metadata, topk_indices + ) + assert torch.isfinite(output).all() + kv_mgr.shutdown() + + +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +def test_agrees_with_absorption_path(): + """Short MHA and absorption paths produce numerically close results.""" + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + rope_config = _make_rope_config() + seq_lens = [16] + total_tokens = sum(seq_lens) + + mla_short, mapping, sparse_config, model_config = _build_mla( + rope_config, device, threshold=total_tokens + 100 + ) + mla_absorb, _, _, _ = _build_mla(rope_config, device, threshold=0) + + # Share identical weights between the two modules. + _init_mla_weights(mla_short) + with torch.no_grad(): + for (_, ps), (_, pa) in zip(mla_short.named_parameters(), mla_absorb.named_parameters()): + pa.data.copy_(ps.data) + for (_, ps), (_, pa) in zip( + mla_short.mqa.indexer.named_parameters(), + mla_absorb.mqa.indexer.named_parameters(), + ): + pa.data.copy_(ps.data) + + q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) + hidden_states = torch.randn(total_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + qr = torch.randn(total_tokens, Q_LORA_RANK, dtype=torch.bfloat16, device=device) + attn_cls = get_attention_backend("TRTLLM", sparse_config) + + def _run(mla_module): + kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) + meta = _make_metadata(attn_cls, seq_lens, kv_mgr, mapping, sparse_config) + use_short = ( + mla_module.short_seq_mha_threshold > 0 + and total_tokens <= mla_module.short_seq_mha_threshold + ) + topk = None + if not use_short: + topk = mla_module.mqa.indexer( + qr.clone(), + hidden_states.clone(), + meta, + position_ids.clone(), + indexer_k=mla_module.mqa.indexer.wk(hidden_states.clone()), + ) + out = _run_forward( + mla_module, q, compressed_kv, k_pe, latent_cache, position_ids, meta, topk + ) + kv_mgr.shutdown() + return out + + out_short = _run(mla_short) + out_absorb = _run(mla_absorb) + torch.testing.assert_close(out_short, out_absorb, rtol=0.08, atol=0.08) + + +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +@pytest.mark.parametrize( + "name,chunk_specs,chunk_size", + CHUNKED_SPECS, + ids=[s[0] for s in CHUNKED_SPECS], +) +def test_chunked_correctness(name: str, chunk_specs: List[Tuple[int, int]], chunk_size): + """Chunked context (cached KV or chunked prefill) matches single-pass prefill.""" + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + rope_config = _make_rope_config() + cached_per_seq = [c for c, _ in chunk_specs] + new_per_seq = [n for _, n in chunk_specs] + total_per_seq = [c + n for c, n in chunk_specs] + threshold = sum(total_per_seq) + 100 + + mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold) + _init_mla_weights(mla) + attn_cls = get_attention_backend("TRTLLM", sparse_config) + + q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(total_per_seq, device) + + # Build chunk index tensors. + c1_idx, c2_idx = [], [] + offset = 0 + for c, n in chunk_specs: + c1_idx.extend(range(offset, offset + c)) + c2_idx.extend(range(offset + c, offset + c + n)) + offset += c + n + c1_idx = torch.tensor(c1_idx, dtype=torch.long, device=device) + c2_idx = torch.tensor(c2_idx, dtype=torch.long, device=device) + + # Reference: single-pass full prefill. + kv_ref = _build_kv_cache_manager(mapping, sparse_config, model_config, total_per_seq, device) + meta_ref = _make_metadata(attn_cls, total_per_seq, kv_ref, mapping, sparse_config) + out_ref = _run_forward(mla, q, compressed_kv, k_pe, latent_cache, position_ids, meta_ref) + + # Chunk 1: pure prefill (populates KV cache). + kv_chunked = _build_kv_cache_manager( + mapping, sparse_config, model_config, total_per_seq, device + ) + meta_c1 = _make_metadata(attn_cls, cached_per_seq, kv_chunked, mapping, sparse_config) + pos_c1 = torch.cat([torch.arange(c, device=device, dtype=torch.int32) for c in cached_per_seq]) + _run_forward( + mla, q[c1_idx], compressed_kv[c1_idx], k_pe[c1_idx], latent_cache[c1_idx], pos_c1, meta_c1 + ) + + # Chunk 2: cached KV or chunked prefill path. + runtime_features = None + if chunk_size is not None: + runtime_features = AttentionRuntimeFeatures( + chunked_prefill=True, + chunk_size=chunk_size, + chunked_prefill_buffer_batch_size=1, + ) + meta_c2 = _make_metadata( + attn_cls, + new_per_seq, + kv_chunked, + mapping, + sparse_config, + cached_per_seq=cached_per_seq, + runtime_features=runtime_features, + ) + pos_c2 = torch.cat( + [torch.arange(c, c + n, device=device, dtype=torch.int32) for c, n in chunk_specs] + ) + out_c2 = _run_forward( + mla, q[c2_idx], compressed_kv[c2_idx], k_pe[c2_idx], latent_cache[c2_idx], pos_c2, meta_c2 + ) + + torch.testing.assert_close(out_c2, out_ref[c2_idx], rtol=0.08, atol=0.08) + + kv_ref.shutdown() + kv_chunked.shutdown() + + +@pytest.mark.skipif(get_sm_version() < 90, reason="MLA requires SM90+") +def test_chunked_context_rejects_when_kv_exceeds_threshold(): + """When max KV length (cached + new) exceeds threshold, short MHA is not used. + + Regression test: previously, only new token count was checked against + the threshold. With cached tokens, the effective attention span can far + exceed the threshold even when new_tokens is small. + """ + device = torch.device("cuda") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + rope_config = _make_rope_config() + # 1 sequence: 64 cached + 32 new = 96 total KV. + # threshold=80: chunk 1 (64 tokens) passes, chunk 2's max KV (96) exceeds. + cached_per_seq = [64] + new_per_seq = [32] + total_per_seq = [96] + threshold = 80 + + mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold) + _init_mla_weights(mla) + attn_cls = get_attention_backend("TRTLLM", sparse_config) + + q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(total_per_seq, device) + + # Chunk 1: populate KV cache with cached tokens. + kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, total_per_seq, device) + meta_c1 = _make_metadata(attn_cls, cached_per_seq, kv_mgr, mapping, sparse_config) + c1_idx = torch.arange(cached_per_seq[0], dtype=torch.long, device=device) + pos_c1 = torch.arange(cached_per_seq[0], device=device, dtype=torch.int32) + _run_forward( + mla, q[c1_idx], compressed_kv[c1_idx], k_pe[c1_idx], latent_cache[c1_idx], pos_c1, meta_c1 + ) + + # Chunk 2: cached KV path — threshold check matters here. + meta_c2 = _make_metadata( + attn_cls, new_per_seq, kv_mgr, mapping, sparse_config, cached_per_seq=cached_per_seq + ) + pos_c2 = torch.arange(cached_per_seq[0], total_per_seq[0], device=device, dtype=torch.int32) + + # max_ctx_kv_len (96) > threshold (80) -> short MHA should NOT be used. + assert not mla._should_use_short_mha(meta_c2, pos_c2) + + # With threshold large enough for the full KV -> short MHA IS used. + mla.short_seq_mha_threshold = total_per_seq[0] + 100 + assert mla._should_use_short_mha(meta_c2, pos_c2) + + kv_mgr.shutdown()