Skip to content

[Fmha] Add head_dim=512 support for trtllm attention kernels#2959

Open
djmmoss wants to merge 16 commits intoflashinfer-ai:mainfrom
djmmoss:dmoss/trtllm-fmha-head-dim-512
Open

[Fmha] Add head_dim=512 support for trtllm attention kernels#2959
djmmoss wants to merge 16 commits intoflashinfer-ai:mainfrom
djmmoss:dmoss/trtllm-fmha-head-dim-512

Conversation

@djmmoss
Copy link
Copy Markdown
Collaborator

@djmmoss djmmoss commented Apr 2, 2026

Add support for head_dim=512 in the trtllm FMHA kernel selection.

Changes

  • Add SDPA-based reference implementation for head_dim > 256 in tests (FlashInfer FA2/FA3 kernels don't support head_dim > 256)
  • Add test_trtllm_batch_prefill_head_dim_512 and test_trtllm_batch_decode_head_dim_512 covering BF16, FP16, FP8, and NVFP4 dtypes

Summary by CodeRabbit

  • Bug Fixes

    • Improved kernel selection for large head dimensions (>256) and adjusted sparse-attention behavior to use the correct Top-K limit.
    • Conditional NaN checks to avoid false positives for certain low-precision dtypes.
    • Updated artifact/version identifiers and checksums for FMHA assets.
  • Tests

    • Added tests exercising head_dim=512 scenarios for prefill and decode.
    • Added a paged-KV reference implementation for attention validation.
  • Chores

    • Download logic now supports optional credentials from environment variables for authenticated fetches.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR updates the TRTLLM-Gen FMHA artifact path and checksum, adds/adjusts kernel selection heuristics for large head dimensions and sparse-attention flags, introduces a new paged-SDPA reference and tests for head_dim=512, adjusts MLA NaN checks, and adds session auth for remote cubin downloads.

Changes

Cohort / File(s) Summary
FMHA Artifact
flashinfer/artifacts.py
Replaced ArtifactPath.TRTLLM_GEN_FMHA directory identifier and updated CheckSumHash.TRTLLM_GEN_FMHA SHA256 checksum.
FMHA Kernel Selection & Params
include/flashinfer/trtllm/fmha/fmhaKernels.cuh, include/flashinfer/trtllm/fmha/kernelParams.h, csrc/fmhaReduction.cu
Changed kernel hash input from mSparseMla to mSparseAttn != 0, added heuristic to cap headDimPerCtaV to 256 for large headDimV in non-MLA/non-SwapsMmaAbForGeneration paths, replaced mSparseMlaTopK with mSparseAttnTopK, and switched reduction launch flag to mSparseAttn != 0. Review layout/ABI alignment and sparse-topK usage.
JIT cubin downloader
flashinfer/jit/cubin_loader.py
When creating a default requests session, read URM_USER/URM_TOKEN env vars and set session.auth for HTTP basic auth if present.
Tests — Attention
tests/attention/test_trtllm_gen_attention.py
Added sdpa_paged_reference (paged KV SDPA reference), routed prefill/decode reference to it for head_dim>256, moved/ensured sink assignment, and added parameterized tests for head_dim=512.
Tests — MLA
tests/attention/test_trtllm_gen_mla.py
Adjusted NaN assertions: skip output NaN check for torch.float8_e4m3fn, kept reference NaN check unconditional; minor formatting tweak.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • aleozlx
  • cyx-6
  • yzh119
  • sricketts
  • samuellees
  • saltyminty
  • bkryu

Poem

🐰 I hopped through kernels, checksums in my paw,

head dims trimmed to keep things in awe,
paged KV danced and tests began to sing,
auth keys snug so downloads take wing,
a tiny rabbit cheers—let the CI bells ring! 🎉

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The description explains what changes were made and why, but lacks the structured sections from the repository template (checklist items, related issues, reviewer notes). Add the required sections from the PR template: Related Issues link, Pre-commit Checks confirmation, Tests checklist items, and Reviewer Notes section.
Docstring Coverage ⚠️ Warning Docstring coverage is 41.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately captures the main change: adding head_dim=512 support for trtllm FMHA attention kernels.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for head dimensions greater than 256 by capping the head dimension per CTA and providing a PyTorch-based SDPA reference fallback for testing. It also refactors the handling of KV scaling factor strides by deriving them from KV data strides and removing explicit stride parameters from the runner configuration. Review feedback highlights the removal of critical technical documentation regarding TMA layout constraints and hardware requirements, as well as a potential device mismatch bug in the new test reference function.

// cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1.
// - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual
// strides are read from the tensor and passed to the TMA descriptor.
// Create the TMA shape/stride for K.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The detailed documentation regarding TMA layout requirements and hardware constraints for K/V data tensors has been removed. This information is valuable for understanding the underlying TMA configuration (e.g., the requirement for stride 1 on dim 0) and should be preserved. Additionally, the updated comment incorrectly implies this function is only for 'K', whereas it is used for both 'K' and 'V'.

  // Create the TMA shape/stride for K/V data tensors.
  //
  // Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim]
  //   - head_dim (last dim) MUST have stride 1. This is a TMA hardware constraint:
  //     cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1.
  //   - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual
  //     strides are read from the tensor and passed to the TMA descriptor.

// kSfStrideBatch) and can differ from the KV data strides.
// - cuTensorMapEncodeTiled requires all non-dim0 strides to be multiples of 16 bytes, so
// sfStrideHeads and sfStrideBatch must each be a multiple of 16.
// Create the TMA shape/stride for KV scaling factors.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The detailed documentation regarding TMA layout requirements for KV scaling factors (NVFP4) has been removed. This comment explained the 16-byte box width requirement and the reshaping logic (merging page_size and head_dim // 16), which is crucial for maintaining this complex TMA configuration. Furthermore, the new implementation derives SF strides from KV strides (stride / 16), which implicitly assumes a specific layout and alignment that was previously explicitly documented and configurable.

  // Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache).
  //
  // Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16]
  //   - The last two dims (page_size, head_dim // 16) MUST be contiguous (stride[-1] = 1,
  //     stride[-2] = head_dim // 16). This is because we reshape them into
  //     (16, page_size * head_dim / 16 / 16) with hardcoded stride[1] = 16 to satisfy TMA's
  //     16-byte box width requirement. Each scale factor is 1 byte (FP8), and head_dim // 16
  //     can be < 16 (e.g., 8 for head_dim=128), so we must merge with page_size to reach 16.
  //   - The head and batch/page strides are derived from the KV data strides and must
  //     be multiples of 16 bytes to satisfy cuTensorMapEncodeTiled requirements.

Comment on lines +500 to +502
q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(q_lens, dim=0)]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating torch.tensor([0]) without specifying a device will create it on the default device (usually CPU). If q_lens is on a different device (e.g., GPU), torch.cat will raise a RuntimeError. It is safer to create the zero tensor on the same device and with the same dtype as q_lens to ensure compatibility.

Suggested change
q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(q_lens, dim=0)]
)
q_indptr = torch.cat(
[torch.zeros(1, dtype=q_lens.dtype, device=q_lens.device), torch.cumsum(q_lens, dim=0)]
)

@djmmoss djmmoss force-pushed the dmoss/trtllm-fmha-head-dim-512 branch 3 times, most recently from b1492ab to d1fb9e1 Compare April 2, 2026 19:29
@djmmoss djmmoss force-pushed the dmoss/trtllm-fmha-head-dim-512 branch 2 times, most recently from 70505f3 to ba9b7ec Compare April 3, 2026 23:51
@djmmoss djmmoss marked this pull request as ready for review April 3, 2026 23:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 479-570: The SDPA reference sdpa_paged_reference() and the tests
routing to it don't cover all requested attention modes: callers
(_test_trtllm_batch_prefill and _test_trtllm_batch_decode) currently send heads
>256 to the SDPA fallback before checking enable_sink, and sdpa_paged_reference
only implements plain causal/windowed attention and treats window_left==0 as
disabled instead of using -1 sentinel. Fix by moving the enable_sink check so
kernels that request sink attention (enable_sink True) do not fall back to
sdpa_paged_reference (ensure callers check enable_sink before routing by
head_dim), and update sdpa_paged_reference to support the sentinel window_left
== -1 (treat -1 as unlimited) and to honor sink-mode masks if needed; reference
symbols: sdpa_paged_reference, _test_trtllm_batch_prefill,
_test_trtllm_batch_decode, enable_sink, window_left.
- Around line 1658-1665: The parametrization for q_dtype/kv_dtype/o_dtype is
missing FP4 (nvfp4) variants for the 512-dim test cases; update the
pytest.mark.parametrize tuples used around the 512-dim matrix tests (the list
containing ("bf16","bf16","bf16"), ("fp16","fp16","fp16"), ("fp8","fp8","fp8"),
("fp8","fp8","bf16")) to also include ("fp8","fp8","nvfp4") and
("fp8","nvfp4","fp8"); make the same additions to the other identical
parametrization block referenced in the comment (the second block around lines
1728-1734) so both 512-dim test matrices exercise FP4 paths.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fb64a718-9c12-4030-be3f-4ffebf35a5e8

📥 Commits

Reviewing files that changed from the base of the PR and between ee3ca01 and ba9b7ec.

📒 Files selected for processing (6)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Comment on lines +479 to +570
def sdpa_paged_reference(
ref_q: torch.Tensor,
ref_kv_cache: torch.Tensor,
q_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
page_size: int,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
kv_layout: str,
window_left: int,
):
"""Pure PyTorch SDPA reference for head dims unsupported by FlashInfer kernels.

ref_kv_cache layout:
HND: [num_pages, 2, num_kv_heads, page_size, head_dim]
NHD: [num_pages, 2, page_size, num_kv_heads, head_dim]
"""
sm_scale = 1.0 / (head_dim**0.5)
batch_size = q_lens.shape[0]
q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, dim=0)])
outputs = []
for b in range(batch_size):
q_start = q_indptr[b].item()
q_end = q_indptr[b + 1].item()
q_b = ref_q[q_start:q_end] # [q_len, num_qo_heads, head_dim]
s_len = seq_lens[b].item()
num_pages = (s_len + page_size - 1) // page_size

# Gather KV from paged cache
page_ids = page_table[b, :num_pages]
kv_pages = ref_kv_cache[page_ids]
k_pages = kv_pages[:, 0] # K half
v_pages = kv_pages[:, 1] # V half

if kv_layout == "HND":
# k_pages: [num_pages, num_kv_heads, page_size, head_dim]
# transpose to [num_pages, page_size, num_kv_heads, head_dim] then flatten pages
k_flat = k_pages.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim)[
:s_len
]
v_flat = v_pages.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim)[
:s_len
]
else: # NHD
# k_pages: [num_pages, page_size, num_kv_heads, head_dim]
k_flat = k_pages.reshape(-1, num_kv_heads, head_dim)[:s_len]
v_flat = v_pages.reshape(-1, num_kv_heads, head_dim)[:s_len]

# k_flat, v_flat: [s_len, num_kv_heads, head_dim]
q_len = q_b.shape[0]
head_grp = num_qo_heads // num_kv_heads

# Expand KV for GQA: [s_len, num_qo_heads, head_dim]
k_exp = (
k_flat.unsqueeze(2)
.expand(-1, num_kv_heads, head_grp, -1)
.reshape(s_len, num_qo_heads, head_dim)
)
v_exp = (
v_flat.unsqueeze(2)
.expand(-1, num_kv_heads, head_grp, -1)
.reshape(s_len, num_qo_heads, head_dim)
)

# Transpose to [num_qo_heads, seq_len, head_dim] for SDPA
q_t = q_b.transpose(0, 1).float() # [num_qo_heads, q_len, head_dim]
k_t = k_exp.transpose(0, 1).float() # [num_qo_heads, s_len, head_dim]
v_t = v_exp.transpose(0, 1).float() # [num_qo_heads, s_len, head_dim]

# Build causal mask: query position i can attend to kv position j if j <= (s_len - q_len) + i
kv_offset = s_len - q_len
q_pos = torch.arange(q_len, device=q_b.device).unsqueeze(1) + kv_offset
k_pos = torch.arange(s_len, device=q_b.device).unsqueeze(0)
causal_mask = k_pos <= q_pos # [q_len, s_len]
if window_left > 0:
causal_mask = causal_mask & (q_pos - k_pos <= window_left)
attn_mask = causal_mask.unsqueeze(0).expand(num_qo_heads, -1, -1)

out_b = torch.nn.functional.scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=attn_mask,
scale=sm_scale,
)
outputs.append(
out_b.transpose(0, 1).to(ref_q.dtype)
) # [q_len, num_qo_heads, head_dim]

return torch.cat(outputs, dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make the SDPA fallback match all requested attention modes.

The head_dim > 256 branches route here before checking enable_sink, but sdpa_paged_reference() only models plain causal/windowed attention. That makes a future 512-dim sink case compare the kernel against a no-sink reference. Also, window_left == 0 currently behaves like “disabled” because the mask is only clamped when window_left > 0; -1 is the sentinel.

Small guard/fix
 def sdpa_paged_reference(
     ref_q: torch.Tensor,
     ref_kv_cache: torch.Tensor,
@@
-        if window_left > 0:
+        if window_left >= 0:
             causal_mask = causal_mask & (q_pos - k_pos <= window_left)
-    if head_dim > 256:
+    if head_dim > 256:
+        assert not enable_sink, "SDPA fallback does not model attention sinks yet"
         output_ref = sdpa_paged_reference(

Apply the same guard in both _test_trtllm_batch_prefill(...) and _test_trtllm_batch_decode(...).

Also applies to: 676-693, 1116-1133

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 479 - 570, The
SDPA reference sdpa_paged_reference() and the tests routing to it don't cover
all requested attention modes: callers (_test_trtllm_batch_prefill and
_test_trtllm_batch_decode) currently send heads >256 to the SDPA fallback before
checking enable_sink, and sdpa_paged_reference only implements plain
causal/windowed attention and treats window_left==0 as disabled instead of using
-1 sentinel. Fix by moving the enable_sink check so kernels that request sink
attention (enable_sink True) do not fall back to sdpa_paged_reference (ensure
callers check enable_sink before routing by head_dim), and update
sdpa_paged_reference to support the sentinel window_left == -1 (treat -1 as
unlimited) and to honor sink-mode masks if needed; reference symbols:
sdpa_paged_reference, _test_trtllm_batch_prefill, _test_trtllm_batch_decode,
enable_sink, window_left.

Comment on lines +1658 to +1665
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "bf16"),
],
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The new 512-dim matrices still miss the FP4 variants.

If this PR is meant to validate the 512-dim FP4 paths too, add the ("fp8", "fp8", "nvfp4") and ("fp8", "nvfp4", "fp8") cases here. Right now the added coverage only exercises bf16/fp16/fp8 for head_dim=512.

Add the missing dtype pairs to both 512-dim matrices
 `@pytest.mark.parametrize`(
     "q_dtype,kv_dtype,o_dtype",
     [
         ("bf16", "bf16", "bf16"),
         ("fp16", "fp16", "fp16"),
         ("fp8", "fp8", "fp8"),
         ("fp8", "fp8", "bf16"),
+        ("fp8", "fp8", "nvfp4"),
+        ("fp8", "nvfp4", "fp8"),
     ],
 )

Also applies to: 1728-1734

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 1658 - 1665, The
parametrization for q_dtype/kv_dtype/o_dtype is missing FP4 (nvfp4) variants for
the 512-dim test cases; update the pytest.mark.parametrize tuples used around
the 512-dim matrix tests (the list containing ("bf16","bf16","bf16"),
("fp16","fp16","fp16"), ("fp8","fp8","fp8"), ("fp8","fp8","bf16")) to also
include ("fp8","fp8","nvfp4") and ("fp8","nvfp4","fp8"); make the same additions
to the other identical parametrization block referenced in the comment (the
second block around lines 1728-1734) so both 512-dim test matrices exercise FP4
paths.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the PR description says nvfp4 support was added. Not sure which one is the source of truth.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saltyminty, that's exactly the crux of the issue. The PR description says "Tests cover dtypes BF16, FP16, FP8, and NVFP4" but the actual parametrization for both test_trtllm_batch_prefill_head_dim_512 and test_trtllm_batch_decode_head_dim_512 only includes bf16/fp16/fp8 combinations — no nvfp4.

There are two possibilities:

  1. NVFP4 is not supported by the TRT-LLM kernel at head_dim=512 — in which case the PR description is inaccurate and should be corrected to remove "NVFP4" from the claimed coverage.
  2. NVFP4 was accidentally omitted from the test matrices — in which case ("fp8", "fp8", "nvfp4") and ("fp8", "nvfp4", "fp8") should be added to both 512-dim parametrize blocks (the sdpa_paged_reference fallback already operates on the BF16 ref_kv_cache, which is the pre-quantization reference for nvfp4, so the reference path should be compatible).

@djmmoss — could you clarify which is the case? If the TRT-LLM kernel genuinely supports NVFP4 at head_dim=512, the test coverage should reflect that.


🧠 Learnings used
Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes × 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.

Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: benchmarks/bench_ssu_sweep_mtp.py:351-359
Timestamp: 2026-04-02T17:54:18.493Z
Learning: In `benchmarks/bench_ssu_sweep_mtp.py` (flashinfer-ai/flashinfer), `parse_dtype_spec()` intentionally allows `bf16-philox-*` and `f32-philox-*` specs even though the MTP path only supports stochastic rounding for fp16 state. This is a deliberate permissive design for a benchmark script (not production code); unsupported combinations will simply fail at runtime. Do not flag this as a bug or suggest restricting the regex to `f16` only.

Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3007
File: tests/utils/test_norm.py:0-0
Timestamp: 2026-04-07T21:44:40.431Z
Learning: In `tests/utils/test_norm.py` (flashinfer-ai/flashinfer), when writing regression tests for large-stride (> INT32_MAX) tensor paths, using `torch.as_strided(small_buf, (M, H), (_INT64_STRIDE, 1))` on a small buffer is unsafe and will segfault because row 1 is at byte offset `2^31` beyond the allocation. The correct pattern is: allocate a flat buffer of at least `_INT64_STRIDE + H` elements, then create the strided view from it so every row is backed by real memory. For fused_add_rmsnorm tests, only the input tensor `x` needs to be non-contiguous (using the large flat buffer); the residual `r` can remain a normally-allocated contiguous tensor — one non-contiguous tensor is sufficient to trigger `is_contiguous() == False` in the kernel selection path, avoiding a second ~4 GB allocation.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: benchmarks/routines/moe.py:1360-1369
Timestamp: 2026-04-14T20:22:02.405Z
Learning: In `benchmarks/routines/moe.py` (flashinfer-ai/flashinfer PR `#3066`), the functional API path (`cute_dsl_fused_moe_nvfp4`, enabled via `--use_functional_api`) is intentionally allowed to run under CUDA graph capture (`bench_gpu_time` with `use_cuda_graph=is_cuda_graph_compatible`). The maintainer confirmed this is fine for the benchmark. Do not flag the lack of a CUDA-graph guard for the functional API path as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp4_quantize.py:668-685
Timestamp: 2026-03-27T20:21:04.233Z
Learning: In `flashinfer/quantization/kernels/mxfp4_quantize.py` (flashinfer-ai/flashinfer), the swizzled MXFP4 layout requires K % 128 == 0, which guarantees that `num_sf_blocks_per_row = K // 32` is always a multiple of 4, so `padded_sf_cols == num_sf_blocks_per_row` — no SF column padding ever occurs. The final `scale_output.reshape(-1, num_sf_blocks_per_row)` at the end of `mxfp4_quantize_cute_dsl` is therefore correct for both linear and swizzled paths: it uses the logical SF column count (which downstream consumers expect), not a physically padded count. This matches the CUDA backend behavior (`sf.reshape((-1, input.shape[-1] // sf_vec_size))`). Do not flag this reshape as incorrect or suggest branching on `padded_sf_cols`.

Learnt from: DomBrown
Repo: flashinfer-ai/flashinfer PR: 2770
File: flashinfer/decode.py:2231-2235
Timestamp: 2026-03-19T20:24:35.442Z
Learning: In `flashinfer/decode.py` (and related files `flashinfer/prefill.py`, `flashinfer/mla.py`), the `uses_shared_paged_kv_idx=False` mode is intended for direct TRT-LLM integration. When this flag is False, the `kv_cache` and `kv_block_scales` are expected to already be in TRT-LLM's native paged layout (separate K/V page indices, 3D block_tables `[batch_size, 2, max_num_pages_per_seq]`). The test code interleaves/reshapes tensors only to simulate TRT-LLM layout from a FlashInfer/vLLM-layout fixture — this is a test artifact, not a requirement imposed on real callers.

Learnt from: blake-snc
Repo: flashinfer-ai/flashinfer PR: 0
File: :0-0
Timestamp: 2026-04-16T15:52:27.219Z
Learning: In `flashinfer/prefill.py` (flashinfer-ai/flashinfer PR `#3016`), the fmha_v2 execution path for `BatchPrefillWithRaggedKVCacheWrapper` intentionally raises `NotImplementedError` for `return_lse=True`. TRT-LLM fmha_v2 returns raw stats in `(total_tokens, num_heads, 2)` format which is incompatible with FlashInfer's LSE format; proper conversion is non-trivial and has not been implemented. Do not flag the missing LSE return as an oversight — the NotImplementedError with explanation is the correct guard.

Learnt from: blake-snc
Repo: flashinfer-ai/flashinfer PR: 0
File: :0-0
Timestamp: 2026-04-16T15:52:27.219Z
Learning: In `flashinfer/prefill.py` (flashinfer-ai/flashinfer PR `#3016`), the fmha_v2 backend path for SM120 standard shapes packs Q/K/V via `torch.stack([q, k, v], dim=1)` into PACKED_QKV format. This is currently the only supported layout because SM120 standard kernels are only generated for PACKED_QKV; separate Q/K/V kernel variants would require additional kernel generation. The torch.stack overhead is an acknowledged limitation tracked as a future optimization. Do not flag this as a correctness issue.

Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:00.701Z
Learning: Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures. Ensure test files under tests/ follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.

@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !518 has been created, and the CI pipeline #47929723 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47929723: 8/20 passed

@djmmoss djmmoss force-pushed the dmoss/trtllm-fmha-head-dim-512 branch 2 times, most recently from bebd4fd to 55dae60 Compare April 8, 2026 21:50
@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 9, 2026

/bot run

djmmoss added 6 commits April 9, 2026 20:45
Restore the TMA box reshape optimization that was removed in the
head_dim=512 PR but is unrelated to head_dim=512 support (reshapeFactor
is always 1 at head_dim >= 256). This keeps the PR focused on only the
changes needed for head_dim=512.

AI-assisted.
…llm-gen a339772b

The trtllm-gen commit a339772b (context sparse MQA/GQA support) changed:
1. Renamed mSparseMla (bool) → mSparseAttn (int) in TllmGenFmhaKernelMetaInfo
2. Moved mSparseAttnTopK to immediately after mSkipSoftmaxThresholdScaleFactor
   in the GPU KernelParams struct (was 12 bytes later), causing sparse MLA to
   read topK=0 → NaN output
3. Renamed kernel SparseP1 → StaticTokenSparseP1

Changes:
- kernelParams.h: rename field to mSparseAttnTopK and fix struct ordering
- fmhaKernels.cuh: use kernelMeta.mSparseAttn != 0 for kernel hash
- fmhaReduction.cu: use kernelMeta.mSparseAttn != 0 and params.mSparseAttnTopK
- artifacts.py: update to new artifact hash f8dea5d4 with fixed sparse kernels
- cubin_loader.py: add URM_USER/URM_TOKEN auth for internal artifact repo
- test_trtllm_gen_mla.py: guard FP8 NaN check consistent with early check

All 1344 sparse MLA BF16 tests now pass.

AI-assisted.
…parse MLA kernels

New artifact includes SM100a variants of sparse MLA kernels, fixing
CUDA_ERROR_INVALID_VALUE on SM100a (B200) hardware.

AI-assisted.
- Fix torch.tensor([0]) missing device: use torch.zeros(1, dtype=..., device=...)
  to avoid RuntimeError when q_lens is on GPU (Gemini review)
- Fix window_left sentinel: change `> 0` to `>= 0` so window_left=0 correctly
  applies the mask; -1 is the sentinel for unlimited window (CodeRabbit review)
- Add assert not enable_sink before SDPA fallback for head_dim>256: SDPA
  reference doesn't model attention sinks, guard both prefill and decode
  paths to prevent comparing against a wrong reference (CodeRabbit review)

AI-assisted.
…=512 GQA kernels

New artifact restores GQA head_dim=512 kernels (dropped in 23695e6d) while
retaining SM100a sparse MLA kernels added in that release.

AI-assisted.
@djmmoss djmmoss force-pushed the dmoss/trtllm-fmha-head-dim-512 branch from 6649dea to a100b1c Compare April 10, 2026 00:36
@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 10, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !518 has been updated with latest changes, and the CI pipeline #48169167 is currently running. I'll report back once the pipeline job completes.

ianliuy added a commit to ianliuy/sglang that referenced this pull request Apr 13, 2026
When num_kv_shared_layers > 0 (e.g., Gemma4 E4B, Gemma3n), KV-sharing layers
set k=None/v=None and the backend reads directly from the fp8 KV cache without
dequantization. This causes a dtype mismatch (bf16 q * fp8 k) crash in the
Triton extend_attention kernel.

Fix: Add dequantization after reading shared KV from cache in
triton_backend.py. The dtype check is guarded by k.dtype != q.dtype, so it
is a no-op for non-quantized KV caches. Also apply k_scale_float/v_scale_float
to properly restore original values when quantization scales are present.

The flashinfer backend fix is deferred to a separate PR per maintainer
request, as it requires additional work blocked on flashinfer-ai/flashinfer#2959.

Fixes sgl-project#22277

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_q_len", [255])
@pytest.mark.parametrize("max_kv_len", [511])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more q_len and kv_len configurations?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_in_kv_len", [110])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

djmmoss added 2 commits April 13, 2026 22:36
…configs

Address PR review feedback from yzh119 to add more q_len and kv_len
configurations to the head_dim=512 tests:

- Prefill: max_q_len [255] -> [1, 255, 511], max_kv_len [511] -> [511, 2047]
- Decode: max_in_kv_len [110] -> [110, 4096, 8192]
- Relax wrapper comparison for head_dim > 256 (use rtol/atol=1e-2 instead
  of exact equality) since large head dims accumulate FP rounding error.
@djmmoss djmmoss requested a review from qsang-nv as a code owner April 14, 2026 16:40
@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !518 has been updated with latest changes, and the CI pipeline #48514164 is currently running. I'll report back once the pipeline job completes.

@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 16, 2026

@yzh119 when can we get this in?

@nvpohanh
Copy link
Copy Markdown
Contributor

@djmmoss are we blocked by anything or just blocked by review?

@djmmoss
Copy link
Copy Markdown
Collaborator Author

djmmoss commented Apr 17, 2026

@nvpohanh AFAIK blocked by review, the errors on the CI pipelines are either preexisting or unrelated to these changes

elif v_scale == o_scale == 1.0:
# Large head dims (e.g. 512) accumulate enough FP error that
# wrapper and direct outputs are not bit-identical.
torch.testing.assert_close(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused how could wrapper and direct output be different, are they using the same implementation?

Comment on lines +1658 to +1665
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "bf16"),
],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the PR description says nvfp4 support was added. Not sure which one is the source of truth.

Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved but please address comments before merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants