Draft - Don't Review - AD Deepseek-V3-Lite and mla enablement#12089
Draft - Don't Review - AD Deepseek-V3-Lite and mla enablement#12089MrGeva wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
📝 WalkthroughWalkthroughThis pull request introduces TRT-LLM MLA (Multi-head Latent Attention) backend support for Auto-Deploy, including configuration entries for MLA attention transforms, public API exports for MLA components, comprehensive MLA implementation with host-side metadata planning and paged cache management, updated test configurations, and an extensive new unit test suite. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (CPU)
participant Planner as MLA Planner
participant Metadata as Metadata Prep
participant Cache as Paged KV Cache
participant GPU as GPU (Attention Op)
Host->>Planner: Initialize with batch info, seq lengths, cache config
Planner->>Planner: Plan buffers for decode/prefill paths
Planner->>Metadata: Populate planning state, block offsets, pool pointers
Metadata->>Host: Return prepared metadata (cu_num_pages, page indices, seq info)
Host->>GPU: Send input (Q, K_scaled, V) + prepared metadata
GPU->>Cache: Lookup/manage paged KV cache via block offsets
GPU->>GPU: Execute MLA attention (prefill or decode phase)
GPU->>Cache: Write latent cache to paged KV cache
GPU->>Host: Return attention output
sequenceDiagram
participant Input as Forward Input
participant Router as Path Router
participant Prefill as Prefill Handler
participant Decode as Decode Handler
participant SDPA as SDPA Fallback
participant Cache as Latent Cache
participant Output as Output Projection
Input->>Router: Check context_only or generation_only flag
alt Prefill Path
Router->>Prefill: Handle prefill phase
Prefill->>Prefill: Absorb weights if needed
Prefill->>SDPA: Use SDPA or chunked attention
SDPA->>Cache: Write latent cache to paged KV
Cache->>Output: Project via W_v
else Decode Path
Router->>Decode: Handle decode phase
Decode->>Cache: Read latent cache from paged KV
Decode->>Decode: Execute MLA kernel
Decode->>Cache: Update latent cache
Cache->>Output: Project via W_v
end
Output->>Output: Return final attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py`:
- Around line 16-25: This file
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py is missing the
required NVIDIA Apache-2.0 source header; add the standard NVIDIA/Apache-2.0
license header (including the correct copyright year of the latest meaningful
modification) at the very top of the file before any imports/exports, preserving
the existing import of TrtllmMLAAttention and trtllm_mla_with_cache and the
__all__ export list (TorchBackendMLAAttention, FlashInferMLAAttention,
TrtllmMLAAttention, torch_mla, torch_backend_mla_with_cache,
flashinfer_mla_with_cache, trtllm_mla_with_cache).
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py`:
- Around line 138-174: reset currently returns early if self.workspace exists,
causing reuse of buffers across different devices or capacity needs; update
reset(self, device, max_batch, max_blocks_per_seq) to detect mismatches (device
or any capacity smaller than requested) by comparing device/type/shape of
self.workspace, self.block_offsets, self.block_ids_per_seq and any FP8 scale /
decode scratch buffers and stored capacity attrs (introduce e.g. self._device,
self._max_batch, self._max_blocks_per_seq if not present), and if mismatched
free/clear existing tensors (set them to None) and proceed with full
reallocation so buffers are recreated on the correct device and with correct
sizes.
- Around line 671-687: The cached prefill path is dropping the MLA scaling
because _handle_prefill() ignores its scale parameter and calls
_call_thop_attention_mla() with a hardcoded 1.0; update both calls inside
_handle_prefill() to forward the incoming scale parameter instead of 1.0
(preserve the same variable name used in _handle_prefill). Also modify
get_constants() to read scale from positional args as well as kwargs on
source_attn_node (e.g., check source_attn_node.args at the expected index before
falling back to source_attn_node.kwargs["scale"]) so models that pass scale
positionally behave identically to torch_mla. Ensure callers and tests still
pass the same scale through the call chain and that the default remains 1.0 when
no scale is provided.
- Around line 703-706: Remove the unnecessary device-wide stall: delete the
torch.cuda.synchronize() call located after the "Multiple prefill sequences"
comment in trtllm_mla.py (the call to torch.cuda.synchronize() is the offending
symbol). Do not replace it with a global sync; simply remove it so the default
CUDA stream ordering is relied upon (this preserves correctness and allows
CUDA-graph capture to work).
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 541-542: The helper's default MLA backend
(mla_attn_backend="flashinfer_mla") no longer matches the shipped default;
change the helper's default to mla_attn_backend="trtllm_mla" so generated
configs set insert_cached_mla_attention.backend to trtllm_mla, and confirm tests
like test_nvfp4() will exercise the new backend without needing explicit
parametrization.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py`:
- Around line 125-134: The test helper currently constructs MLA cache and offset
math using a hardcoded factor of 2 (kv dimension [max_num_pages, 2,
num_kv_heads, page_size, latent_dim] and block_offset_multiplier = 2), which
does not match production's TrtllmMLAAttention registration (kv_factor=1,
kv_layout="HND"); update the helper in test_trtllm_mla_op.py to use kv_factor=1
by replacing the hardcoded second-dimension 2 with 1 (i.e., [max_num_pages, 1,
num_kv_heads, page_size, latent_dim]) and change any corresponding
block_offset_multiplier or other hardcoded 2s in that file (the duplicated
occurrences around the later helper/lines) to 1 so the test cache layout matches
the production MLA allocator contract.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: de77708a-0254-43d7-99a4-1f53ba1a9fc1
📒 Files selected for processing (5)
examples/auto_deploy/model_registry/configs/glm-4.7-flash.yamltensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py
| from .trtllm_mla import TrtllmMLAAttention, trtllm_mla_with_cache | ||
|
|
||
| __all__ = [ | ||
| "TorchBackendMLAAttention", | ||
| "FlashInferMLAAttention", | ||
| "TrtllmMLAAttention", | ||
| "torch_mla", | ||
| "torch_backend_mla_with_cache", | ||
| "flashinfer_mla_with_cache", | ||
| "trtllm_mla_with_cache", |
There was a problem hiding this comment.
Add the standard NVIDIA/Apache header before exporting new public symbols.
This module is under tensorrt_llm/ and is being modified, but it still has no source header/license block. Please add the required header at the top in this PR. As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py` around lines 16 -
25, This file tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py is
missing the required NVIDIA Apache-2.0 source header; add the standard
NVIDIA/Apache-2.0 license header (including the correct copyright year of the
latest meaningful modification) at the very top of the file before any
imports/exports, preserving the existing import of TrtllmMLAAttention and
trtllm_mla_with_cache and the __all__ export list (TorchBackendMLAAttention,
FlashInferMLAAttention, TrtllmMLAAttention, torch_mla,
torch_backend_mla_with_cache, flashinfer_mla_with_cache, trtllm_mla_with_cache).
| def reset(self, device: torch.device, max_batch: int, max_blocks_per_seq: int) -> None: | ||
| """One-time allocation of ALL persistent buffers.""" | ||
| if self.workspace is not None: | ||
| return | ||
|
|
||
| self.workspace = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) | ||
| self.host_pool_mapping = torch.zeros(1, 2, dtype=torch.int32, device="cpu", pin_memory=True) | ||
| self.host_total_kv_lens = torch.zeros(2, dtype=torch.int64, device="cpu", pin_memory=True) | ||
| self.host_request_types = torch.zeros( | ||
| max_batch, dtype=torch.int32, device="cpu", pin_memory=True | ||
| ) | ||
| self.block_offsets = torch.zeros( | ||
| 1, max_batch, 2, max_blocks_per_seq, dtype=torch.int32, device=device | ||
| ) | ||
| self.host_past_kv_lengths = torch.zeros( | ||
| max_batch, dtype=torch.int32, device="cpu", pin_memory=True | ||
| ) | ||
| self.host_context_lengths = torch.zeros( | ||
| max_batch, dtype=torch.int32, device="cpu", pin_memory=True | ||
| ) | ||
| self.block_ids_per_seq = torch.zeros( | ||
| max_batch, max_blocks_per_seq, dtype=torch.int32, device=device | ||
| ) | ||
| self.cu_q_seqlens: Optional[torch.Tensor] = None | ||
| self.cu_kv_seqlens: Optional[torch.Tensor] = None | ||
| self.fmha_scheduler_counter = torch.zeros(1, dtype=torch.int32, device=device) | ||
|
|
||
| self.decode_page_idx = torch.zeros(max_batch, dtype=torch.int64, device=device) | ||
| self.decode_slot_idx = torch.zeros(max_batch, dtype=torch.int64, device=device) | ||
| self._decode_page_idx_host = torch.zeros( | ||
| max_batch, dtype=torch.int64, device="cpu", pin_memory=True | ||
| ) | ||
| self._decode_slot_idx_host = torch.zeros( | ||
| max_batch, dtype=torch.int64, device="cpu", pin_memory=True | ||
| ) | ||
|
|
||
| def ensure_decode_buffers( |
There was a problem hiding this comment.
Reinitialize the singleton when device or capacities change.
reset() bails out as soon as workspace exists, but this planner is module-global. A later MLA session on a different GPU or with larger max_batch / max_blocks_per_seq will reuse wrong-device or undersized tensors (workspace, block_offsets, block_ids_per_seq, FP8 scale buffers, decode scratch), which is a hard runtime bug. The unit helper in tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py masks this by manually calling __init__() before every invocation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
138 - 174, reset currently returns early if self.workspace exists, causing reuse
of buffers across different devices or capacity needs; update reset(self,
device, max_batch, max_blocks_per_seq) to detect mismatches (device or any
capacity smaller than requested) by comparing device/type/shape of
self.workspace, self.block_offsets, self.block_ids_per_seq and any FP8 scale /
decode scratch buffers and stored capacity attrs (introduce e.g. self._device,
self._max_batch, self._max_blocks_per_seq if not present), and if mismatched
free/clear existing tensors (set them to None) and proceed with full
reallocation so buffers are recreated on the correct device and with correct
sizes.
| _call_thop_attention_mla( | ||
| q, | ||
| k, | ||
| v, | ||
| output, | ||
| None, | ||
| None, | ||
| False, | ||
| 1, | ||
| num_heads, | ||
| num_heads, | ||
| qk_head_dim, | ||
| tokens_per_block, | ||
| max_num_requests, | ||
| max_context_length, | ||
| 1.0, | ||
| quant_mode, |
There was a problem hiding this comment.
Custom MLA scale parameter is dropped on the cached prefill path.
The _handle_prefill() method accepts a scale parameter but ignores it, hardcoding 1.0 in both calls to _call_thop_attention_mla() at lines 686 and 738. Additionally, get_constants() only reads scale from source_attn_node.kwargs, ignoring cases where torch_mla receives scale as a positional argument. This causes silent divergence from torch_mla behavior for models using non-default MLA scale values.
Forward the scale parameter in both _call_thop_attention_mla calls, and update get_constants() to also check positional arguments.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
671 - 687, The cached prefill path is dropping the MLA scaling because
_handle_prefill() ignores its scale parameter and calls
_call_thop_attention_mla() with a hardcoded 1.0; update both calls inside
_handle_prefill() to forward the incoming scale parameter instead of 1.0
(preserve the same variable name used in _handle_prefill). Also modify
get_constants() to read scale from positional args as well as kwargs on
source_attn_node (e.g., check source_attn_node.args at the expected index before
falling back to source_attn_node.kwargs["scale"]) so models that pass scale
positionally behave identically to torch_mla. Ensure callers and tests still
pass the same scale through the call chain and that the default remains 1.0 when
no scale is provided.
| # Multiple prefill sequences: process each independently. | ||
| # With use_paged_context_fmha=False the kernel treats all tokens as one | ||
| # sequence, so we must call per-sequence to avoid cross-sequence attention. | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, explore the file to understand its structure
cd tensorrt_llm/_torch/auto_deploy/custom_ops/mla || exit 1
ls -laRepository: NVIDIA/TensorRT-LLM
Length of output: 512
🏁 Script executed:
# Check file size and read the relevant section
wc -l trtllm_mla.py
# Read around lines 703-706 with broader context
sed -n '680,730p' trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 164
🏁 Script executed:
# Find the _handle_prefill function definition and understand its full scope
grep -n "_handle_prefill\|torch.cuda.synchronize\|cuda.stream\|Stream" trtllm_mla.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 109
🏁 Script executed:
# Check if CUDA-graph functionality is used anywhere in this file
rg "cuda.graph|CUDAGraph|capture" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 208
🏁 Script executed:
# Check the imports and understand stream usage patterns
head -100 trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 135
🏁 Script executed:
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 128
🏁 Script executed:
sed -n '680,730p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1947
🏁 Script executed:
grep -n "_handle_prefill\|torch.cuda.synchronize\|cuda.stream\|Stream" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 161
🏁 Script executed:
head -100 tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 4215
🏁 Script executed:
# Get the full _handle_prefill function context
sed -n '598,750p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 5895
🏁 Script executed:
# Check for any stream usage and synchronization patterns in this file
grep -n "stream\|Stream\|synchronize\|is_current_stream_capturing\|cuda_graph" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 252
🏁 Script executed:
# Look at the context where _handle_prefill is called
sed -n '1260,1280p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 617
🏁 Script executed:
# Check if _handle_prefill is protected by CUDA graph checks
sed -n '1250,1290p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1215
🏁 Script executed:
# Look for where the multi-sequence case is executed and whether it's in CUDA-graph context
sed -n '1200,1270p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2637
🏁 Script executed:
# Check _call_thop_attention_mla signature to understand what operations are done there
grep -A 50 "def _call_thop_attention_mla" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 1929
🏁 Script executed:
# Check if there's any documentation about when synchronize is needed
rg "synchronize|stream.*order" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/ -A 2 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 632
🏁 Script executed:
# Check the context where _handle_prefill is called - look for CUDA graph guards
sed -n '1150,1230p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3467
🏁 Script executed:
# Look for is_capturing usage and graph context in the forward function
grep -B 10 -A 5 "is_capturing\|is_current_stream_capturing" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 914
🏁 Script executed:
# Check if this is used in a CUDA-graph capturing context by searching the calling function
sed -n '1100,1150p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1930
🏁 Script executed:
# Check the parent forward function to understand the full context
sed -n '1050,1110p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2115
🏁 Script executed:
# Look for where this class/function is used to see if it's involved in CUDA graph capture
grep -r "trtllm_mla_attention\|TrtllmMLAAttention" tensorrt_llm/ --include="*.py" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 527
🏁 Script executed:
# Check if there's any CUDA graph capture wrapper around the forward pass
rg "cuda.graph|torch.cuda.graph|CUDAGraph" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/ -A 3 -B 3Repository: NVIDIA/TensorRT-LLM
Length of output: 11130
🏁 Script executed:
# Check if there's a CUDA-graph guard around the _handle_prefill call site
sed -n '1220,1280p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2119
🏁 Script executed:
# Look at what happens during the single-sequence prefill path (num_prefill <= 1) to compare
sed -n '670,710p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1340
🏁 Script executed:
# Check if synchronize appears anywhere else in similar attention backends
rg "torch.cuda.synchronize" tensorrt_llm/_torch/auto_deploy/custom_ops/ -B 5 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 1239
🏁 Script executed:
# Check what happens inside the loop to see if there's any reason for the synchronize
sed -n '706,780p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2853
🏁 Script executed:
# Check if there are any similar patterns in the standard trtllm attention backend for comparison
rg "torch.cuda.synchronize" tensorrt_llm/_torch/auto_deploy/custom_ops/ -B 10 -A 10 --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2214
🏁 Script executed:
# Look for any CUDA graph handling that would be broken by synchronize
rg "is_capturing\|is_current_stream_capturing" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py -B 5 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 45
Remove the unnecessary device-wide synchronize from multi-sequence prefill.
The default CUDA stream is ordered; torch.cuda.synchronize() stalls the host and breaks CUDA-graph capture without providing any correctness benefit. Each iteration's kernel launch is sequentially ordered on the same stream.
Suggested fix
- torch.cuda.synchronize()
token_offset = 0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Multiple prefill sequences: process each independently. | |
| # With use_paged_context_fmha=False the kernel treats all tokens as one | |
| # sequence, so we must call per-sequence to avoid cross-sequence attention. | |
| torch.cuda.synchronize() | |
| # Multiple prefill sequences: process each independently. | |
| # With use_paged_context_fmha=False the kernel treats all tokens as one | |
| # sequence, so we must call per-sequence to avoid cross-sequence attention. | |
| token_offset = 0 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
703 - 706, Remove the unnecessary device-wide stall: delete the
torch.cuda.synchronize() call located after the "Multiple prefill sequences"
comment in trtllm_mla.py (the call to torch.cuda.synchronize() is the offending
symbol). Do not replace it with a global sync; simply remove it so the default
CUDA stream ordering is relied upon (this preserves correctness and allows
CUDA-graph capture to work).
| mla_attn_backend="flashinfer_mla"): | ||
| config = { |
There was a problem hiding this comment.
Make the GLM test helper default match the shipped config.
examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml now defaults insert_cached_mla_attention.backend to trtllm_mla, but this helper still defaults to flashinfer_mla. As a result, test_nvfp4() keeps exercising the old backend unless it is parameterized explicitly.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 541
- 542, The helper's default MLA backend (mla_attn_backend="flashinfer_mla") no
longer matches the shipped default; change the helper's default to
mla_attn_backend="trtllm_mla" so generated configs set
insert_cached_mla_attention.backend to trtllm_mla, and confirm tests like
test_nvfp4() will exercise the new backend without needing explicit
parametrization.
| # HND paged cache: [num_pages, 2, num_kv_heads, page_size, latent_dim] | ||
| kv_cache = torch.zeros( | ||
| max_num_pages, | ||
| 2, | ||
| num_kv_heads, | ||
| page_size, | ||
| latent_dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
Align the test cache layout with the production MLA allocator.
TrtllmMLAAttention.get_cache_initializers() registers MLA cache as kv_factor=1 / kv_layout="HND" in tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py, but these helpers build [num_pages, 2, ...] caches and a block_offset_multiplier of 2. That means the new tests are validating a different memory contract than the one AutoDeploy will actually use, so a real layout regression can still ship green. Please switch this helper, and the duplicated hardcoded 2s lower in the file, to the same kv_factor=1 contract as production.
Suggested alignment for the root helper
- # HND paged cache: [num_pages, 2, num_kv_heads, page_size, latent_dim]
+ # HND paged cache: [num_pages, 1, num_kv_heads, page_size, latent_dim]
kv_cache = torch.zeros(
max_num_pages,
- 2,
+ 1,
num_kv_heads,
page_size,
latent_dim,
dtype=dtype,
device=device,
)
@@
- # block_offset_multiplier: for HND with kv_factor=2, each page occupies 2 slots
- block_offset_multiplier = 2
+ # block_offset_multiplier: for MLA kv_factor=1, each page occupies 1 slot
+ block_offset_multiplier = 1Also applies to: 201-203
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py`
around lines 125 - 134, The test helper currently constructs MLA cache and
offset math using a hardcoded factor of 2 (kv dimension [max_num_pages, 2,
num_kv_heads, page_size, latent_dim] and block_offset_multiplier = 2), which
does not match production's TrtllmMLAAttention registration (kv_factor=1,
kv_layout="HND"); update the helper in test_trtllm_mla_op.py to use kv_factor=1
by replacing the hardcoded second-dimension 2 with 1 (i.e., [max_num_pages, 1,
num_kv_heads, page_size, latent_dim]) and change any corresponding
block_offset_multiplier or other hardcoded 2s in that file (the duplicated
occurrences around the later helper/lines) to 1 so the test cache layout matches
the production MLA allocator contract.
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Five fixes were required to make both
trtllm_mlaandflashinfer_mlaAuto-Deploybackends work on B200 (SM100) GPUs. The issues were all latent bugs that only
manifest on Blackwell due to differences in kernel support, memory allocation
patterns, and hardware constraints compared to H100 (SM90).
Fix 1: Triton RMS Norm stride bug (root cause of
trtllm_mlacrash)File:
tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_rms_norm.pyBug: The Triton RMS norm kernel uses a single
input_row_stridefor bothreading the input tensor and writing the output tensor (line 48:
out_ptr = output + prog_id * input_row_stride). When the input is anon-contiguous slice (e.g.
kv_a_proj_output[:, :512]from a[N, 576]tensor),
input_row_stride = 576but the output fromtorch.empty_like()iscontiguous with stride
512. The kernel writes at stride-576 offsets into astride-512 buffer, causing massive out-of-bounds writes.
For DeepSeek-V3-Lite with 4032 tokens: the kernel writes up to element
2,321,847 into a buffer of only 2,064,384 elements — a 257K element overshoot
that corrupts adjacent GPU memory. This manifests as an "illegal memory access"
in the next Triton kernel launch (layer 1's
kv_a_layernorm), making theroot cause extremely hard to trace.
Fix: Force contiguity before calling the kernel:
Note: This bug is architecture-independent but only triggers on B200
because the
trtllm_mlaSDPA prefill path (SM100-only) creates batch sizeslarge enough for the out-of-bounds overshoot to hit live memory. On H100 the
thop prefill kernel is used instead, so the non-contiguous
kv_a_layernorminput never passes through this Triton kernel at problematic scales. The
FlashInfer RMS norm variant (
flashinfer_rmsnorm) is immune because it calls.reshape(-1, ...)which implicitly copies non-contiguous tensors.Fix 2: FlashInfer RoPE kernel crash on SM100
File:
tensorrt_llm/_torch/auto_deploy/custom_ops/rope/flashinfer_rope.pyBug: FlashInfer's
BatchQKApplyRotaryPosIdsCosSinCachekernel producesillegal memory accesses on SM100.
Fix: Added a pure PyTorch RoPE fallback (
_apply_rope_pytorch) that isused when
get_sm_version() >= 100. Supports both NeOX and interleavedrotation styles.
Fix 3: Position ID memory corruption across layers
File:
tensorrt_llm/_torch/auto_deploy/transform/library/rope.pyBug: The RoPE graph transform cached a single
position_idsnode andreused it across all layers. On SM100, intermediate kernel launches could
overwrite the shared buffer, causing subsequent layers to index the
cos_sin_cachewith corrupted position values (triggeringvectorized_gather_kernel: index out of bounds).Fix: Removed the per-layer caching of
position_idsin both_optimize_explicit(DeepSeek path) and_get_position_ids(general path).Each layer now gets its own fresh
arangenode, preventing cross-layercorruption.
Fix 4: FlashInfer prefill kernel crash on SM100
File:
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.pyBug: FlashInfer's
BatchPrefillWithRaggedKVCachekernel crashes withMLA's non-standard
head_dim_qk(e.g. 192) on SM100.Fix: On SM100+, all prefill operations are routed through the
BatchMLAPagedAttentionWrapper(the "chunked prefill" / paged path) whichoperates in compressed latent space and doesn't hit the incompatible kernel:
Fix 5: trtllm_mla B200 decode support
File:
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.pyMultiple issues addressed:
Power-of-2 head ratio: The trtllm-gen FMHA MLA decode kernel on SM100
requires
num_heads // num_kv_headsto be a power of 2. Added_mla_padded_num_heads()and zero-padded Q/output buffers.cuBLAS bf16 GEMM failure:
CUBLAS_STATUS_EXECUTION_FAILEDincublasGemmStridedBatchedExduring CUDA graph capture on SM100. Workedaround by casting
torch.bmmoperands tofloat()and back for theweight-absorption and output-projection GEMMs.
SDPA prefill fallback: The thop MLA context kernel hits illegal memory
accesses on SM100. Added SDPA fallback (
_batched_fresh_prefill_sdpaand_handle_chunked_prefill) for all prefill on SM100+.Direct cache indexing: Switched from
kv_cache_block_offsets(whichincludes interleaved pool strides) to
block_ids_per_seqfor direct pageindexing, simplifying both read and write paths.
Pre-computed decode metadata:
cu_kv_decodeand cache write indices(
page_idx,slot_idx) are computed on the host duringplan()andcopied to GPU, eliminating per-layer GPU kernel launches.
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.