Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 137 additions & 45 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import math
import os
import weakref
from typing import List, Optional, Union, cast

Expand Down Expand Up @@ -1125,29 +1126,6 @@ def yarn_get_mscale(scale=1, mscale=1):
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
q_scaling = 1.0 / (mscale * mscale)

if not self.is_dsa:
self.mha = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads_tp,
head_dim=self.qk_head_dim,
num_kv_heads=self.num_key_value_heads_tp,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
)
else:
self.mha = None

self.mqa = create_attention(
config.attn_backend,
self.layer_idx,
Expand Down Expand Up @@ -1186,6 +1164,48 @@ def yarn_get_mscale(scale=1, mscale=1):
is_neox=pos_embd_params.is_neox,
)

# Short-sequence MHA optimization for DSA models:
# For short prefill sequences, use MHA (kv_b_proj expansion + standard
# attention) instead of the absorption path, which has overhead from
# extra BMMs and larger head_dim (kv_lora_rank + qk_rope_head_dim).
# Only active when rope_fusion is True (DSA with TrtllmAttention).
_threshold_str = os.environ.get('TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD',
'0')
try:
self.short_seq_mha_threshold = int(_threshold_str)
except ValueError as err:
raise ValueError(
f"TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD must be an integer, "
f"got '{_threshold_str}'") from err

# MHA attention backend: used by non-DSA (standard MLA) and optionally
# by DSA for the short-seq path (dense attention, no sparse config).
_short_seq_mha = (self.is_dsa and self.short_seq_mha_threshold > 0
and not self.apply_rotary_emb)
if not self.is_dsa or _short_seq_mha:
self.mha = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads_tp,
head_dim=self.qk_head_dim,
num_kv_heads=self.num_key_value_heads_tp,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=(None if _short_seq_mha else
config.sparse_attention_config),
)
else:
self.mha = None

self.llama_4_scaling = False
if hasattr(config.pretrained_config, 'llama_4_scaling'):
self.llama_4_scaling = True
Expand All @@ -1198,9 +1218,11 @@ def yarn_get_mscale(scale=1, mscale=1):
self.create_weights()

def create_weights(self):
# self.mha/mqa has no weights but has states that are related to quant_config,
# which could be modified after __init__
if not self.is_dsa:
# self.mha/mqa has no weights but has states that are related to
# quant_config, which could be modified after __init__.
# self.mha is non-None for non-DSA models (standard MHA) and for DSA
# models when the short-seq MHA optimization is active.
if self.mha is not None:
self.mha.update_quant_config(self.quant_config)
self.mqa.update_quant_config(self.quant_config)

Expand Down Expand Up @@ -1344,11 +1366,8 @@ def forward_impl(self,
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
output (torch.Tensor): Pre-allocated output tensor, written in-place.
latent_cache_gen (Optional[torch.Tensor]): The latent cache used in generation.

Returns:
torch.Tensor: The output tensor.
"""
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
Expand Down Expand Up @@ -1450,11 +1469,9 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
position_ids (Optional[torch.IntTensor]): The position IDs.
hidden_states (torch.Tensor): The hidden states.
attn_metadata (AttentionMetadata): The attention metadata.

Returns:
torch.Tensor: The output tensor.
output (torch.Tensor): Pre-allocated output tensor, written in-place.
"""
assert self.mha is None and self.mqa is not None, "DSA is only supported in MQA mode"
assert self.mqa is not None, "DSA is only supported in MQA mode"
# split q, k, v into context and gen batches
num_contexts = attn_metadata.num_contexts
num_generations = attn_metadata.num_generations
Expand Down Expand Up @@ -1484,14 +1501,29 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],

# TODO: fuse wq_b + (indexer) wlq here
q = self.q_b_proj(q)
# Indexer
topk_indices = self.indexer(
qr,
hidden_states,
attn_metadata,
position_ids,
indexer_k=indexer_k, # indexer K proj
)

# Check if the short-seq MHA path will handle context, in which case
# the indexer (topk_indices) is not needed for context tokens.
# The MHA path handles cached tokens via forward_context(), which
# dispatches to forward_context_with_cached_kv or
# forward_context_with_chunked_prefill as needed.
use_short_mha_for_ctx = (num_contexts > 0
and self._should_use_short_mha(
attn_metadata, position_ids))

# Skip the indexer entirely when the short MHA path handles all
# context tokens and there are no generation tokens.
if use_short_mha_for_ctx and num_generations == 0:
topk_indices = None
else:
# Indexer
topk_indices = self.indexer(
qr,
hidden_states,
attn_metadata,
position_ids,
indexer_k=indexer_k, # indexer K proj
)

assert q.shape[
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
Expand All @@ -1514,7 +1546,9 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
attn_metadata,
output[:num_ctx_tokens, :],
latent_cache_ctx,
topk_indices=topk_indices[:num_ctx_tokens, :],
topk_indices=topk_indices[:num_ctx_tokens, :]
if topk_indices is not None else None,
position_ids=position_ids,
)

if num_generations > 0:
Expand Down Expand Up @@ -1546,6 +1580,10 @@ def forward_context_default(
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Dense MHA context path: expand KV via kv_b_proj and run attention.

Used by non-DSA models and as the short-seq MHA fallback for DSA models.
"""
kv = self.kv_b_proj(compressed_kv)
k_nope, v = kv.split(
[
Expand All @@ -1559,6 +1597,9 @@ def forward_context_default(
maybe_compiled_copy_(
k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim))
# When rope_fusion=True (apply_rotary_emb=False), the rope portion
# of k is left uninitialized here; the fused attention kernel
# handles k_pe RoPE via latent_cache instead.
if self.apply_rotary_emb:
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
self.qk_rope_head_dim)
Expand All @@ -1577,6 +1618,23 @@ def forward_context_default(

return attn_output

def _should_use_short_mha(self, attn_metadata: AttentionMetadata,
position_ids: Optional[torch.Tensor]) -> bool:
"""Check if the short-seq MHA optimization should be used for context.

Uses max_ctx_kv_len (max total KV length per context sequence,
including cached tokens) when available, to correctly account for
chunked context where the full attention span exceeds the threshold
even if the new token count is small. Falls back to num_ctx_tokens
(total new context tokens) when max_ctx_kv_len is not set.
"""
if not (self.short_seq_mha_threshold > 0 and not self.apply_rotary_emb
and self.mapping.cp_size == 1 and position_ids is not None):
return False
effective_len = getattr(attn_metadata, 'max_ctx_kv_len',
attn_metadata.num_ctx_tokens)
return effective_len <= self.short_seq_mha_threshold

def forward_context_dsa(
self,
q: torch.Tensor,
Expand All @@ -1586,7 +1644,38 @@ def forward_context_dsa(
output: torch.Tensor,
latent_cache: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run context-phase attention for DSA models.

Dispatches to the short-seq MHA path (forward_context) when the max
per-sequence KV length (including cached tokens) is within the
threshold, or falls through to the absorption/sparse MLA path
otherwise. forward_context() further dispatches to the appropriate
handler (forward_context_default, forward_context_with_cached_kv, or
forward_context_with_chunked_prefill) based on cached-KV state.

Args:
q: Query tensor, shape [num_ctx_tokens, num_heads * qk_head_dim].
compressed_kv: Latent KV, shape [num_ctx_tokens, kv_lora_rank].
k_pe: RoPE key portion, shape [num_ctx_tokens, qk_rope_head_dim].
attn_metadata: Attention metadata for the current batch.
output: Pre-allocated output tensor, written in-place.
latent_cache: Concatenated [compressed_kv, k_pe] for KV cache.
topk_indices: Sparse routing indices from the indexer (None when
the short-seq MHA path is used).
position_ids: Token position IDs (required for short-seq MHA).
"""
# Short-sequence MHA: bypass absorption path for short prefills,
# using kv_b_proj expansion + standard attention instead.
# See __init__ comment for rationale. topk_indices is not used
# because dense attention is faster than sparse routing at this scale.
# forward_context() handles cached tokens by dispatching to
# forward_context_with_cached_kv or forward_context_with_chunked_prefill.
if self._should_use_short_mha(attn_metadata, position_ids):
return self.forward_context(q, compressed_kv, k_pe, position_ids,
attn_metadata, output, latent_cache)

if get_sm_version() >= 100:
return self.forward_absorption_context(q,
compressed_kv,
Expand Down Expand Up @@ -1929,10 +2018,13 @@ def forward_context(
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
q.dtype, q.device)
if trtllm_attention.is_chunked_prefill_for_mla_context(
attn_metadata):
attn_metadata) and get_sm_version() >= 100:
return self.forward_context_with_chunked_prefill(
q, compressed_kv, latent_cache, attn_metadata, output)
elif trtllm_attention.has_cached_kv_for_mla_context(attn_metadata):
elif trtllm_attention.has_cached_kv_for_mla_context(
attn_metadata
) or trtllm_attention.is_chunked_prefill_for_mla_context(
attn_metadata):
return self.forward_context_with_cached_kv(
q, latent_cache, attn_metadata, output)
return self.forward_context_default(q, compressed_kv, k_pe,
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5846166)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con1_ctx1_pp8_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5846166)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] SKIP (https://nvbugs/5879614)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=True] SKIP (https://nvbugs/5893116)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5875522)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5875522)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5875522)
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
Expand Down
Loading
Loading