Skip to content

Draft - Don't Review - AD Deepseek-V3-Lite and mla enablement#12089

Open
MrGeva wants to merge 14 commits intoNVIDIA:mainfrom
nv-auto-deploy:eg/ds_with_mla_enablement
Open

Draft - Don't Review - AD Deepseek-V3-Lite and mla enablement#12089
MrGeva wants to merge 14 commits intoNVIDIA:mainfrom
nv-auto-deploy:eg/ds_with_mla_enablement

Conversation

@MrGeva
Copy link
Collaborator

@MrGeva MrGeva commented Mar 10, 2026

Five fixes were required to make both trtllm_mla and flashinfer_mla Auto-Deploy
backends work on B200 (SM100) GPUs. The issues were all latent bugs that only
manifest on Blackwell due to differences in kernel support, memory allocation
patterns, and hardware constraints compared to H100 (SM90).


Fix 1: Triton RMS Norm stride bug (root cause of trtllm_mla crash)

File: tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_rms_norm.py

Bug: The Triton RMS norm kernel uses a single input_row_stride for both
reading the input tensor and writing the output tensor (line 48:
out_ptr = output + prog_id * input_row_stride). When the input is a
non-contiguous slice (e.g. kv_a_proj_output[:, :512] from a [N, 576]
tensor), input_row_stride = 576 but the output from torch.empty_like() is
contiguous with stride 512. The kernel writes at stride-576 offsets into a
stride-512 buffer, causing massive out-of-bounds writes.

For DeepSeek-V3-Lite with 4032 tokens: the kernel writes up to element
2,321,847 into a buffer of only 2,064,384 elements — a 257K element overshoot
that corrupts adjacent GPU memory. This manifests as an "illegal memory access"
in the next Triton kernel launch (layer 1's kv_a_layernorm), making the
root cause extremely hard to trace.

Fix: Force contiguity before calling the kernel:

hidden_states = hidden_states.reshape(-1, feat_size).contiguous()

Note: This bug is architecture-independent but only triggers on B200
because the trtllm_mla SDPA prefill path (SM100-only) creates batch sizes
large enough for the out-of-bounds overshoot to hit live memory. On H100 the
thop prefill kernel is used instead, so the non-contiguous kv_a_layernorm
input never passes through this Triton kernel at problematic scales. The
FlashInfer RMS norm variant (flashinfer_rmsnorm) is immune because it calls
.reshape(-1, ...) which implicitly copies non-contiguous tensors.


Fix 2: FlashInfer RoPE kernel crash on SM100

File: tensorrt_llm/_torch/auto_deploy/custom_ops/rope/flashinfer_rope.py

Bug: FlashInfer's BatchQKApplyRotaryPosIdsCosSinCache kernel produces
illegal memory accesses on SM100.

Fix: Added a pure PyTorch RoPE fallback (_apply_rope_pytorch) that is
used when get_sm_version() >= 100. Supports both NeOX and interleaved
rotation styles.


Fix 3: Position ID memory corruption across layers

File: tensorrt_llm/_torch/auto_deploy/transform/library/rope.py

Bug: The RoPE graph transform cached a single position_ids node and
reused it across all layers. On SM100, intermediate kernel launches could
overwrite the shared buffer, causing subsequent layers to index the
cos_sin_cache with corrupted position values (triggering
vectorized_gather_kernel: index out of bounds).

Fix: Removed the per-layer caching of position_ids in both
_optimize_explicit (DeepSeek path) and _get_position_ids (general path).
Each layer now gets its own fresh arange node, preventing cross-layer
corruption.


Fix 4: FlashInfer prefill kernel crash on SM100

File: tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py

Bug: FlashInfer's BatchPrefillWithRaggedKVCache kernel crashes with
MLA's non-standard head_dim_qk (e.g. 192) on SM100.

Fix: On SM100+, all prefill operations are routed through the
BatchMLAPagedAttentionWrapper (the "chunked prefill" / paged path) which
operates in compressed latent space and doesn't hit the incompatible kernel:

use_paged_prefill = is_chunked_prefill or get_sm_version() >= 100

Fix 5: trtllm_mla B200 decode support

File: tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Multiple issues addressed:

  • Power-of-2 head ratio: The trtllm-gen FMHA MLA decode kernel on SM100
    requires num_heads // num_kv_heads to be a power of 2. Added
    _mla_padded_num_heads() and zero-padded Q/output buffers.

  • cuBLAS bf16 GEMM failure: CUBLAS_STATUS_EXECUTION_FAILED in
    cublasGemmStridedBatchedEx during CUDA graph capture on SM100. Worked
    around by casting torch.bmm operands to float() and back for the
    weight-absorption and output-projection GEMMs.

  • SDPA prefill fallback: The thop MLA context kernel hits illegal memory
    accesses on SM100. Added SDPA fallback (_batched_fresh_prefill_sdpa and
    _handle_chunked_prefill) for all prefill on SM100+.

  • Direct cache indexing: Switched from kv_cache_block_offsets (which
    includes interleaved pool strides) to block_ids_per_seq for direct page
    indexing, simplifying both read and write paths.

  • Pre-computed decode metadata: cu_kv_decode and cache write indices
    (page_idx, slot_idx) are computed on the host during plan() and
    copied to GPU, eliminating per-layer GPU kernel launches.

Summary by CodeRabbit

  • New Features

    • Introduced Multi-head Latent Attention (MLA) backend support for model inference.
    • MLA attention transform configuration now available in GLM-4 Flash and Nemotron deployment pipelines.
    • Extended model deployment pipeline with MLA-specific buffer management and metadata handling capabilities.
  • Tests

    • Added comprehensive unit and integration tests validating MLA attention operations across various configurations and batch scenarios.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva MrGeva requested review from a team as code owners March 10, 2026 19:46
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

This pull request introduces TRT-LLM MLA (Multi-head Latent Attention) backend support for Auto-Deploy, including configuration entries for MLA attention transforms, public API exports for MLA components, comprehensive MLA implementation with host-side metadata planning and paged cache management, updated test configurations, and an extensive new unit test suite.

Changes

Cohort / File(s) Summary
Configuration and API Exports
examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml, tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
Added insert_cached_mla_attention transform entry to configuration and exported TrtllmMLAAttention class and trtllm_mla_with_cache function from MLA package.
Core MLA Implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py
Introduced comprehensive TRT-LLM MLA attention backend featuring MLA planner with persistent buffers for decode/prefill paths, host-side metadata preparation, custom attention op with prefill/decode handlers, latent cache utilities, chunked/SDPA fallbacks, and AttentionDescriptor subclass integration.
Test Updates and New Suite
tests/integration/defs/accuracy/test_llm_api_autodeploy.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py
Refactored test configurations to use mla_attn_backend parameter with insert_cached_mla_attention transform; added comprehensive new test suite validating TRT-LLM MLA ops across prefill, decode, multi-step generation, and mixed-batch scenarios with CUDA SM 8.0\+ requirement.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host (CPU)
    participant Planner as MLA Planner
    participant Metadata as Metadata Prep
    participant Cache as Paged KV Cache
    participant GPU as GPU (Attention Op)
    
    Host->>Planner: Initialize with batch info, seq lengths, cache config
    Planner->>Planner: Plan buffers for decode/prefill paths
    Planner->>Metadata: Populate planning state, block offsets, pool pointers
    Metadata->>Host: Return prepared metadata (cu_num_pages, page indices, seq info)
    Host->>GPU: Send input (Q, K_scaled, V) + prepared metadata
    GPU->>Cache: Lookup/manage paged KV cache via block offsets
    GPU->>GPU: Execute MLA attention (prefill or decode phase)
    GPU->>Cache: Write latent cache to paged KV cache
    GPU->>Host: Return attention output
Loading
sequenceDiagram
    participant Input as Forward Input
    participant Router as Path Router
    participant Prefill as Prefill Handler
    participant Decode as Decode Handler
    participant SDPA as SDPA Fallback
    participant Cache as Latent Cache
    participant Output as Output Projection
    
    Input->>Router: Check context_only or generation_only flag
    alt Prefill Path
        Router->>Prefill: Handle prefill phase
        Prefill->>Prefill: Absorb weights if needed
        Prefill->>SDPA: Use SDPA or chunked attention
        SDPA->>Cache: Write latent cache to paged KV
        Cache->>Output: Project via W_v
    else Decode Path
        Router->>Decode: Handle decode phase
        Decode->>Cache: Read latent cache from paged KV
        Decode->>Decode: Execute MLA kernel
        Decode->>Cache: Update latent cache
        Cache->>Output: Project via W_v
    end
    Output->>Output: Return final attention output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is vague and generic, using non-descriptive terms like 'Draft' and 'Deepseek-V3-Lite and mla enablement' without clearly conveying the primary change. Provide a clear, specific title following the template: [JIRA/NVBugs/Issue/None][type] Summary. For example: [None][feat] Add TRT-LLM MLA attention backend for auto-deploy with GLM-4 and Nemotron support.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed PR description provides comprehensive technical details about five specific fixes for SM100 GPU support, with clear explanations of root causes and solutions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py`:
- Around line 16-25: This file
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py is missing the
required NVIDIA Apache-2.0 source header; add the standard NVIDIA/Apache-2.0
license header (including the correct copyright year of the latest meaningful
modification) at the very top of the file before any imports/exports, preserving
the existing import of TrtllmMLAAttention and trtllm_mla_with_cache and the
__all__ export list (TorchBackendMLAAttention, FlashInferMLAAttention,
TrtllmMLAAttention, torch_mla, torch_backend_mla_with_cache,
flashinfer_mla_with_cache, trtllm_mla_with_cache).

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py`:
- Around line 138-174: reset currently returns early if self.workspace exists,
causing reuse of buffers across different devices or capacity needs; update
reset(self, device, max_batch, max_blocks_per_seq) to detect mismatches (device
or any capacity smaller than requested) by comparing device/type/shape of
self.workspace, self.block_offsets, self.block_ids_per_seq and any FP8 scale /
decode scratch buffers and stored capacity attrs (introduce e.g. self._device,
self._max_batch, self._max_blocks_per_seq if not present), and if mismatched
free/clear existing tensors (set them to None) and proceed with full
reallocation so buffers are recreated on the correct device and with correct
sizes.
- Around line 671-687: The cached prefill path is dropping the MLA scaling
because _handle_prefill() ignores its scale parameter and calls
_call_thop_attention_mla() with a hardcoded 1.0; update both calls inside
_handle_prefill() to forward the incoming scale parameter instead of 1.0
(preserve the same variable name used in _handle_prefill). Also modify
get_constants() to read scale from positional args as well as kwargs on
source_attn_node (e.g., check source_attn_node.args at the expected index before
falling back to source_attn_node.kwargs["scale"]) so models that pass scale
positionally behave identically to torch_mla. Ensure callers and tests still
pass the same scale through the call chain and that the default remains 1.0 when
no scale is provided.
- Around line 703-706: Remove the unnecessary device-wide stall: delete the
torch.cuda.synchronize() call located after the "Multiple prefill sequences"
comment in trtllm_mla.py (the call to torch.cuda.synchronize() is the offending
symbol). Do not replace it with a global sync; simply remove it so the default
CUDA stream ordering is relied upon (this preserves correctness and allows
CUDA-graph capture to work).

In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 541-542: The helper's default MLA backend
(mla_attn_backend="flashinfer_mla") no longer matches the shipped default;
change the helper's default to mla_attn_backend="trtllm_mla" so generated
configs set insert_cached_mla_attention.backend to trtllm_mla, and confirm tests
like test_nvfp4() will exercise the new backend without needing explicit
parametrization.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py`:
- Around line 125-134: The test helper currently constructs MLA cache and offset
math using a hardcoded factor of 2 (kv dimension [max_num_pages, 2,
num_kv_heads, page_size, latent_dim] and block_offset_multiplier = 2), which
does not match production's TrtllmMLAAttention registration (kv_factor=1,
kv_layout="HND"); update the helper in test_trtllm_mla_op.py to use kv_factor=1
by replacing the hardcoded second-dimension 2 with 1 (i.e., [max_num_pages, 1,
num_kv_heads, page_size, latent_dim]) and change any corresponding
block_offset_multiplier or other hardcoded 2s in that file (the duplicated
occurrences around the later helper/lines) to 1 so the test cache layout matches
the production MLA allocator contract.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: de77708a-0254-43d7-99a4-1f53ba1a9fc1

📥 Commits

Reviewing files that changed from the base of the PR and between 3ce0ec8 and 3cc3011.

📒 Files selected for processing (5)
  • examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py

Comment on lines +16 to +25
from .trtllm_mla import TrtllmMLAAttention, trtllm_mla_with_cache

__all__ = [
"TorchBackendMLAAttention",
"FlashInferMLAAttention",
"TrtllmMLAAttention",
"torch_mla",
"torch_backend_mla_with_cache",
"flashinfer_mla_with_cache",
"trtllm_mla_with_cache",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add the standard NVIDIA/Apache header before exporting new public symbols.

This module is under tensorrt_llm/ and is being modified, but it still has no source header/license block. Please add the required header at the top in this PR. As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py` around lines 16 -
25, This file tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py is
missing the required NVIDIA Apache-2.0 source header; add the standard
NVIDIA/Apache-2.0 license header (including the correct copyright year of the
latest meaningful modification) at the very top of the file before any
imports/exports, preserving the existing import of TrtllmMLAAttention and
trtllm_mla_with_cache and the __all__ export list (TorchBackendMLAAttention,
FlashInferMLAAttention, TrtllmMLAAttention, torch_mla,
torch_backend_mla_with_cache, flashinfer_mla_with_cache, trtllm_mla_with_cache).

Comment on lines +138 to +174
def reset(self, device: torch.device, max_batch: int, max_blocks_per_seq: int) -> None:
"""One-time allocation of ALL persistent buffers."""
if self.workspace is not None:
return

self.workspace = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device)
self.host_pool_mapping = torch.zeros(1, 2, dtype=torch.int32, device="cpu", pin_memory=True)
self.host_total_kv_lens = torch.zeros(2, dtype=torch.int64, device="cpu", pin_memory=True)
self.host_request_types = torch.zeros(
max_batch, dtype=torch.int32, device="cpu", pin_memory=True
)
self.block_offsets = torch.zeros(
1, max_batch, 2, max_blocks_per_seq, dtype=torch.int32, device=device
)
self.host_past_kv_lengths = torch.zeros(
max_batch, dtype=torch.int32, device="cpu", pin_memory=True
)
self.host_context_lengths = torch.zeros(
max_batch, dtype=torch.int32, device="cpu", pin_memory=True
)
self.block_ids_per_seq = torch.zeros(
max_batch, max_blocks_per_seq, dtype=torch.int32, device=device
)
self.cu_q_seqlens: Optional[torch.Tensor] = None
self.cu_kv_seqlens: Optional[torch.Tensor] = None
self.fmha_scheduler_counter = torch.zeros(1, dtype=torch.int32, device=device)

self.decode_page_idx = torch.zeros(max_batch, dtype=torch.int64, device=device)
self.decode_slot_idx = torch.zeros(max_batch, dtype=torch.int64, device=device)
self._decode_page_idx_host = torch.zeros(
max_batch, dtype=torch.int64, device="cpu", pin_memory=True
)
self._decode_slot_idx_host = torch.zeros(
max_batch, dtype=torch.int64, device="cpu", pin_memory=True
)

def ensure_decode_buffers(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reinitialize the singleton when device or capacities change.

reset() bails out as soon as workspace exists, but this planner is module-global. A later MLA session on a different GPU or with larger max_batch / max_blocks_per_seq will reuse wrong-device or undersized tensors (workspace, block_offsets, block_ids_per_seq, FP8 scale buffers, decode scratch), which is a hard runtime bug. The unit helper in tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py masks this by manually calling __init__() before every invocation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
138 - 174, reset currently returns early if self.workspace exists, causing reuse
of buffers across different devices or capacity needs; update reset(self,
device, max_batch, max_blocks_per_seq) to detect mismatches (device or any
capacity smaller than requested) by comparing device/type/shape of
self.workspace, self.block_offsets, self.block_ids_per_seq and any FP8 scale /
decode scratch buffers and stored capacity attrs (introduce e.g. self._device,
self._max_batch, self._max_blocks_per_seq if not present), and if mismatched
free/clear existing tensors (set them to None) and proceed with full
reallocation so buffers are recreated on the correct device and with correct
sizes.

Comment on lines +671 to +687
_call_thop_attention_mla(
q,
k,
v,
output,
None,
None,
False,
1,
num_heads,
num_heads,
qk_head_dim,
tokens_per_block,
max_num_requests,
max_context_length,
1.0,
quant_mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Custom MLA scale parameter is dropped on the cached prefill path.

The _handle_prefill() method accepts a scale parameter but ignores it, hardcoding 1.0 in both calls to _call_thop_attention_mla() at lines 686 and 738. Additionally, get_constants() only reads scale from source_attn_node.kwargs, ignoring cases where torch_mla receives scale as a positional argument. This causes silent divergence from torch_mla behavior for models using non-default MLA scale values.

Forward the scale parameter in both _call_thop_attention_mla calls, and update get_constants() to also check positional arguments.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
671 - 687, The cached prefill path is dropping the MLA scaling because
_handle_prefill() ignores its scale parameter and calls
_call_thop_attention_mla() with a hardcoded 1.0; update both calls inside
_handle_prefill() to forward the incoming scale parameter instead of 1.0
(preserve the same variable name used in _handle_prefill). Also modify
get_constants() to read scale from positional args as well as kwargs on
source_attn_node (e.g., check source_attn_node.args at the expected index before
falling back to source_attn_node.kwargs["scale"]) so models that pass scale
positionally behave identically to torch_mla. Ensure callers and tests still
pass the same scale through the call chain and that the default remains 1.0 when
no scale is provided.

Comment on lines +703 to +706
# Multiple prefill sequences: process each independently.
# With use_paged_context_fmha=False the kernel treats all tokens as one
# sequence, so we must call per-sequence to avoid cross-sequence attention.
torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, explore the file to understand its structure
cd tensorrt_llm/_torch/auto_deploy/custom_ops/mla || exit 1
ls -la

Repository: NVIDIA/TensorRT-LLM

Length of output: 512


🏁 Script executed:

# Check file size and read the relevant section
wc -l trtllm_mla.py
# Read around lines 703-706 with broader context
sed -n '680,730p' trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 164


🏁 Script executed:

# Find the _handle_prefill function definition and understand its full scope
grep -n "_handle_prefill\|torch.cuda.synchronize\|cuda.stream\|Stream" trtllm_mla.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 109


🏁 Script executed:

# Check if CUDA-graph functionality is used anywhere in this file
rg "cuda.graph|CUDAGraph|capture" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 208


🏁 Script executed:

# Check the imports and understand stream usage patterns
head -100 trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 135


🏁 Script executed:

wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 128


🏁 Script executed:

sed -n '680,730p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1947


🏁 Script executed:

grep -n "_handle_prefill\|torch.cuda.synchronize\|cuda.stream\|Stream" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 161


🏁 Script executed:

head -100 tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 4215


🏁 Script executed:

# Get the full _handle_prefill function context
sed -n '598,750p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 5895


🏁 Script executed:

# Check for any stream usage and synchronization patterns in this file
grep -n "stream\|Stream\|synchronize\|is_current_stream_capturing\|cuda_graph" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 252


🏁 Script executed:

# Look at the context where _handle_prefill is called
sed -n '1260,1280p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 617


🏁 Script executed:

# Check if _handle_prefill is protected by CUDA graph checks
sed -n '1250,1290p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1215


🏁 Script executed:

# Look for where the multi-sequence case is executed and whether it's in CUDA-graph context
sed -n '1200,1270p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2637


🏁 Script executed:

# Check _call_thop_attention_mla signature to understand what operations are done there
grep -A 50 "def _call_thop_attention_mla" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 1929


🏁 Script executed:

# Check if there's any documentation about when synchronize is needed
rg "synchronize|stream.*order" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/ -A 2 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 632


🏁 Script executed:

# Check the context where _handle_prefill is called - look for CUDA graph guards
sed -n '1150,1230p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3467


🏁 Script executed:

# Look for is_capturing usage and graph context in the forward function
grep -B 10 -A 5 "is_capturing\|is_current_stream_capturing" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 914


🏁 Script executed:

# Check if this is used in a CUDA-graph capturing context by searching the calling function
sed -n '1100,1150p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1930


🏁 Script executed:

# Check the parent forward function to understand the full context
sed -n '1050,1110p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2115


🏁 Script executed:

# Look for where this class/function is used to see if it's involved in CUDA graph capture
grep -r "trtllm_mla_attention\|TrtllmMLAAttention" tensorrt_llm/ --include="*.py" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 527


🏁 Script executed:

# Check if there's any CUDA graph capture wrapper around the forward pass
rg "cuda.graph|torch.cuda.graph|CUDAGraph" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/ -A 3 -B 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 11130


🏁 Script executed:

# Check if there's a CUDA-graph guard around the _handle_prefill call site
sed -n '1220,1280p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2119


🏁 Script executed:

# Look at what happens during the single-sequence prefill path (num_prefill <= 1) to compare
sed -n '670,710p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1340


🏁 Script executed:

# Check if synchronize appears anywhere else in similar attention backends
rg "torch.cuda.synchronize" tensorrt_llm/_torch/auto_deploy/custom_ops/ -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 1239


🏁 Script executed:

# Check what happens inside the loop to see if there's any reason for the synchronize
sed -n '706,780p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2853


🏁 Script executed:

# Check if there are any similar patterns in the standard trtllm attention backend for comparison
rg "torch.cuda.synchronize" tensorrt_llm/_torch/auto_deploy/custom_ops/ -B 10 -A 10 --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2214


🏁 Script executed:

# Look for any CUDA graph handling that would be broken by synchronize
rg "is_capturing\|is_current_stream_capturing" tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


Remove the unnecessary device-wide synchronize from multi-sequence prefill.

The default CUDA stream is ordered; torch.cuda.synchronize() stalls the host and breaks CUDA-graph capture without providing any correctness benefit. Each iteration's kernel launch is sequentially ordered on the same stream.

Suggested fix
-        torch.cuda.synchronize()
         token_offset = 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Multiple prefill sequences: process each independently.
# With use_paged_context_fmha=False the kernel treats all tokens as one
# sequence, so we must call per-sequence to avoid cross-sequence attention.
torch.cuda.synchronize()
# Multiple prefill sequences: process each independently.
# With use_paged_context_fmha=False the kernel treats all tokens as one
# sequence, so we must call per-sequence to avoid cross-sequence attention.
token_offset = 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py` around lines
703 - 706, Remove the unnecessary device-wide stall: delete the
torch.cuda.synchronize() call located after the "Multiple prefill sequences"
comment in trtllm_mla.py (the call to torch.cuda.synchronize() is the offending
symbol). Do not replace it with a global sync; simply remove it so the default
CUDA stream ordering is relied upon (this preserves correctness and allows
CUDA-graph capture to work).

Comment on lines +541 to 542
mla_attn_backend="flashinfer_mla"):
config = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make the GLM test helper default match the shipped config.

examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml now defaults insert_cached_mla_attention.backend to trtllm_mla, but this helper still defaults to flashinfer_mla. As a result, test_nvfp4() keeps exercising the old backend unless it is parameterized explicitly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 541
- 542, The helper's default MLA backend (mla_attn_backend="flashinfer_mla") no
longer matches the shipped default; change the helper's default to
mla_attn_backend="trtllm_mla" so generated configs set
insert_cached_mla_attention.backend to trtllm_mla, and confirm tests like
test_nvfp4() will exercise the new backend without needing explicit
parametrization.

Comment on lines +125 to +134
# HND paged cache: [num_pages, 2, num_kv_heads, page_size, latent_dim]
kv_cache = torch.zeros(
max_num_pages,
2,
num_kv_heads,
page_size,
latent_dim,
dtype=dtype,
device=device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Align the test cache layout with the production MLA allocator.

TrtllmMLAAttention.get_cache_initializers() registers MLA cache as kv_factor=1 / kv_layout="HND" in tensorrt_llm/_torch/auto_deploy/custom_ops/mla/trtllm_mla.py, but these helpers build [num_pages, 2, ...] caches and a block_offset_multiplier of 2. That means the new tests are validating a different memory contract than the one AutoDeploy will actually use, so a real layout regression can still ship green. Please switch this helper, and the duplicated hardcoded 2s lower in the file, to the same kv_factor=1 contract as production.

Suggested alignment for the root helper
-    # HND paged cache: [num_pages, 2, num_kv_heads, page_size, latent_dim]
+    # HND paged cache: [num_pages, 1, num_kv_heads, page_size, latent_dim]
     kv_cache = torch.zeros(
         max_num_pages,
-        2,
+        1,
         num_kv_heads,
         page_size,
         latent_dim,
         dtype=dtype,
         device=device,
     )
@@
-    # block_offset_multiplier: for HND with kv_factor=2, each page occupies 2 slots
-    block_offset_multiplier = 2
+    # block_offset_multiplier: for MLA kv_factor=1, each page occupies 1 slot
+    block_offset_multiplier = 1

Also applies to: 201-203

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_trtllm_mla_op.py`
around lines 125 - 134, The test helper currently constructs MLA cache and
offset math using a hardcoded factor of 2 (kv dimension [max_num_pages, 2,
num_kv_heads, page_size, latent_dim] and block_offset_multiplier = 2), which
does not match production's TrtllmMLAAttention registration (kv_factor=1,
kv_layout="HND"); update the helper in test_trtllm_mla_op.py to use kv_factor=1
by replacing the hardcoded second-dimension 2 with 1 (i.e., [max_num_pages, 1,
num_kv_heads, page_size, latent_dim]) and change any corresponding
block_offset_multiplier or other hardcoded 2s in that file (the duplicated
occurrences around the later helper/lines) to 1 so the test cache layout matches
the production MLA allocator contract.

MrGeva added 3 commits March 10, 2026 23:08
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant