[Fmha] Add head_dim=512 support for trtllm attention kernels#2959
[Fmha] Add head_dim=512 support for trtllm attention kernels#2959djmmoss wants to merge 16 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR updates the TRTLLM-Gen FMHA artifact path and checksum, adds/adjusts kernel selection heuristics for large head dimensions and sparse-attention flags, introduces a new paged-SDPA reference and tests for head_dim=512, adjusts MLA NaN checks, and adds session auth for remote cubin downloads. Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for head dimensions greater than 256 by capping the head dimension per CTA and providing a PyTorch-based SDPA reference fallback for testing. It also refactors the handling of KV scaling factor strides by deriving them from KV data strides and removing explicit stride parameters from the runner configuration. Review feedback highlights the removal of critical technical documentation regarding TMA layout constraints and hardware requirements, as well as a potential device mismatch bug in the new test reference function.
| // cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1. | ||
| // - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual | ||
| // strides are read from the tensor and passed to the TMA descriptor. | ||
| // Create the TMA shape/stride for K. |
There was a problem hiding this comment.
The detailed documentation regarding TMA layout requirements and hardware constraints for K/V data tensors has been removed. This information is valuable for understanding the underlying TMA configuration (e.g., the requirement for stride 1 on dim 0) and should be preserved. Additionally, the updated comment incorrectly implies this function is only for 'K', whereas it is used for both 'K' and 'V'.
// Create the TMA shape/stride for K/V data tensors.
//
// Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim]
// - head_dim (last dim) MUST have stride 1. This is a TMA hardware constraint:
// cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1.
// - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual
// strides are read from the tensor and passed to the TMA descriptor.| // kSfStrideBatch) and can differ from the KV data strides. | ||
| // - cuTensorMapEncodeTiled requires all non-dim0 strides to be multiples of 16 bytes, so | ||
| // sfStrideHeads and sfStrideBatch must each be a multiple of 16. | ||
| // Create the TMA shape/stride for KV scaling factors. |
There was a problem hiding this comment.
The detailed documentation regarding TMA layout requirements for KV scaling factors (NVFP4) has been removed. This comment explained the 16-byte box width requirement and the reshaping logic (merging page_size and head_dim // 16), which is crucial for maintaining this complex TMA configuration. Furthermore, the new implementation derives SF strides from KV strides (stride / 16), which implicitly assumes a specific layout and alignment that was previously explicitly documented and configurable.
// Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache).
//
// Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16]
// - The last two dims (page_size, head_dim // 16) MUST be contiguous (stride[-1] = 1,
// stride[-2] = head_dim // 16). This is because we reshape them into
// (16, page_size * head_dim / 16 / 16) with hardcoded stride[1] = 16 to satisfy TMA's
// 16-byte box width requirement. Each scale factor is 1 byte (FP8), and head_dim // 16
// can be < 16 (e.g., 8 for head_dim=128), so we must merge with page_size to reach 16.
// - The head and batch/page strides are derived from the KV data strides and must
// be multiples of 16 bytes to satisfy cuTensorMapEncodeTiled requirements.| q_indptr = torch.cat( | ||
| [torch.tensor([0]), torch.cumsum(q_lens, dim=0)] | ||
| ) |
There was a problem hiding this comment.
Creating torch.tensor([0]) without specifying a device will create it on the default device (usually CPU). If q_lens is on a different device (e.g., GPU), torch.cat will raise a RuntimeError. It is safer to create the zero tensor on the same device and with the same dtype as q_lens to ensure compatibility.
| q_indptr = torch.cat( | |
| [torch.tensor([0]), torch.cumsum(q_lens, dim=0)] | |
| ) | |
| q_indptr = torch.cat( | |
| [torch.zeros(1, dtype=q_lens.dtype, device=q_lens.device), torch.cumsum(q_lens, dim=0)] | |
| ) |
b1492ab to
d1fb9e1
Compare
70505f3 to
ba9b7ec
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 479-570: The SDPA reference sdpa_paged_reference() and the tests
routing to it don't cover all requested attention modes: callers
(_test_trtllm_batch_prefill and _test_trtllm_batch_decode) currently send heads
>256 to the SDPA fallback before checking enable_sink, and sdpa_paged_reference
only implements plain causal/windowed attention and treats window_left==0 as
disabled instead of using -1 sentinel. Fix by moving the enable_sink check so
kernels that request sink attention (enable_sink True) do not fall back to
sdpa_paged_reference (ensure callers check enable_sink before routing by
head_dim), and update sdpa_paged_reference to support the sentinel window_left
== -1 (treat -1 as unlimited) and to honor sink-mode masks if needed; reference
symbols: sdpa_paged_reference, _test_trtllm_batch_prefill,
_test_trtllm_batch_decode, enable_sink, window_left.
- Around line 1658-1665: The parametrization for q_dtype/kv_dtype/o_dtype is
missing FP4 (nvfp4) variants for the 512-dim test cases; update the
pytest.mark.parametrize tuples used around the 512-dim matrix tests (the list
containing ("bf16","bf16","bf16"), ("fp16","fp16","fp16"), ("fp8","fp8","fp8"),
("fp8","fp8","bf16")) to also include ("fp8","fp8","nvfp4") and
("fp8","nvfp4","fp8"); make the same additions to the other identical
parametrization block referenced in the comment (the second block around lines
1728-1734) so both 512-dim test matrices exercise FP4 paths.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fb64a718-9c12-4030-be3f-4ffebf35a5e8
📒 Files selected for processing (6)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
| def sdpa_paged_reference( | ||
| ref_q: torch.Tensor, | ||
| ref_kv_cache: torch.Tensor, | ||
| q_lens: torch.Tensor, | ||
| seq_lens: torch.Tensor, | ||
| page_table: torch.Tensor, | ||
| page_size: int, | ||
| num_qo_heads: int, | ||
| num_kv_heads: int, | ||
| head_dim: int, | ||
| kv_layout: str, | ||
| window_left: int, | ||
| ): | ||
| """Pure PyTorch SDPA reference for head dims unsupported by FlashInfer kernels. | ||
|
|
||
| ref_kv_cache layout: | ||
| HND: [num_pages, 2, num_kv_heads, page_size, head_dim] | ||
| NHD: [num_pages, 2, page_size, num_kv_heads, head_dim] | ||
| """ | ||
| sm_scale = 1.0 / (head_dim**0.5) | ||
| batch_size = q_lens.shape[0] | ||
| q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, dim=0)]) | ||
| outputs = [] | ||
| for b in range(batch_size): | ||
| q_start = q_indptr[b].item() | ||
| q_end = q_indptr[b + 1].item() | ||
| q_b = ref_q[q_start:q_end] # [q_len, num_qo_heads, head_dim] | ||
| s_len = seq_lens[b].item() | ||
| num_pages = (s_len + page_size - 1) // page_size | ||
|
|
||
| # Gather KV from paged cache | ||
| page_ids = page_table[b, :num_pages] | ||
| kv_pages = ref_kv_cache[page_ids] | ||
| k_pages = kv_pages[:, 0] # K half | ||
| v_pages = kv_pages[:, 1] # V half | ||
|
|
||
| if kv_layout == "HND": | ||
| # k_pages: [num_pages, num_kv_heads, page_size, head_dim] | ||
| # transpose to [num_pages, page_size, num_kv_heads, head_dim] then flatten pages | ||
| k_flat = k_pages.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim)[ | ||
| :s_len | ||
| ] | ||
| v_flat = v_pages.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim)[ | ||
| :s_len | ||
| ] | ||
| else: # NHD | ||
| # k_pages: [num_pages, page_size, num_kv_heads, head_dim] | ||
| k_flat = k_pages.reshape(-1, num_kv_heads, head_dim)[:s_len] | ||
| v_flat = v_pages.reshape(-1, num_kv_heads, head_dim)[:s_len] | ||
|
|
||
| # k_flat, v_flat: [s_len, num_kv_heads, head_dim] | ||
| q_len = q_b.shape[0] | ||
| head_grp = num_qo_heads // num_kv_heads | ||
|
|
||
| # Expand KV for GQA: [s_len, num_qo_heads, head_dim] | ||
| k_exp = ( | ||
| k_flat.unsqueeze(2) | ||
| .expand(-1, num_kv_heads, head_grp, -1) | ||
| .reshape(s_len, num_qo_heads, head_dim) | ||
| ) | ||
| v_exp = ( | ||
| v_flat.unsqueeze(2) | ||
| .expand(-1, num_kv_heads, head_grp, -1) | ||
| .reshape(s_len, num_qo_heads, head_dim) | ||
| ) | ||
|
|
||
| # Transpose to [num_qo_heads, seq_len, head_dim] for SDPA | ||
| q_t = q_b.transpose(0, 1).float() # [num_qo_heads, q_len, head_dim] | ||
| k_t = k_exp.transpose(0, 1).float() # [num_qo_heads, s_len, head_dim] | ||
| v_t = v_exp.transpose(0, 1).float() # [num_qo_heads, s_len, head_dim] | ||
|
|
||
| # Build causal mask: query position i can attend to kv position j if j <= (s_len - q_len) + i | ||
| kv_offset = s_len - q_len | ||
| q_pos = torch.arange(q_len, device=q_b.device).unsqueeze(1) + kv_offset | ||
| k_pos = torch.arange(s_len, device=q_b.device).unsqueeze(0) | ||
| causal_mask = k_pos <= q_pos # [q_len, s_len] | ||
| if window_left > 0: | ||
| causal_mask = causal_mask & (q_pos - k_pos <= window_left) | ||
| attn_mask = causal_mask.unsqueeze(0).expand(num_qo_heads, -1, -1) | ||
|
|
||
| out_b = torch.nn.functional.scaled_dot_product_attention( | ||
| q_t, | ||
| k_t, | ||
| v_t, | ||
| attn_mask=attn_mask, | ||
| scale=sm_scale, | ||
| ) | ||
| outputs.append( | ||
| out_b.transpose(0, 1).to(ref_q.dtype) | ||
| ) # [q_len, num_qo_heads, head_dim] | ||
|
|
||
| return torch.cat(outputs, dim=0) |
There was a problem hiding this comment.
Make the SDPA fallback match all requested attention modes.
The head_dim > 256 branches route here before checking enable_sink, but sdpa_paged_reference() only models plain causal/windowed attention. That makes a future 512-dim sink case compare the kernel against a no-sink reference. Also, window_left == 0 currently behaves like “disabled” because the mask is only clamped when window_left > 0; -1 is the sentinel.
Small guard/fix
def sdpa_paged_reference(
ref_q: torch.Tensor,
ref_kv_cache: torch.Tensor,
@@
- if window_left > 0:
+ if window_left >= 0:
causal_mask = causal_mask & (q_pos - k_pos <= window_left)- if head_dim > 256:
+ if head_dim > 256:
+ assert not enable_sink, "SDPA fallback does not model attention sinks yet"
output_ref = sdpa_paged_reference(Apply the same guard in both _test_trtllm_batch_prefill(...) and _test_trtllm_batch_decode(...).
Also applies to: 676-693, 1116-1133
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_trtllm_gen_attention.py` around lines 479 - 570, The
SDPA reference sdpa_paged_reference() and the tests routing to it don't cover
all requested attention modes: callers (_test_trtllm_batch_prefill and
_test_trtllm_batch_decode) currently send heads >256 to the SDPA fallback before
checking enable_sink, and sdpa_paged_reference only implements plain
causal/windowed attention and treats window_left==0 as disabled instead of using
-1 sentinel. Fix by moving the enable_sink check so kernels that request sink
attention (enable_sink True) do not fall back to sdpa_paged_reference (ensure
callers check enable_sink before routing by head_dim), and update
sdpa_paged_reference to support the sentinel window_left == -1 (treat -1 as
unlimited) and to honor sink-mode masks if needed; reference symbols:
sdpa_paged_reference, _test_trtllm_batch_prefill, _test_trtllm_batch_decode,
enable_sink, window_left.
| @pytest.mark.parametrize( | ||
| "q_dtype,kv_dtype,o_dtype", | ||
| [ | ||
| ("bf16", "bf16", "bf16"), | ||
| ("fp16", "fp16", "fp16"), | ||
| ("fp8", "fp8", "fp8"), | ||
| ("fp8", "fp8", "bf16"), | ||
| ], |
There was a problem hiding this comment.
The new 512-dim matrices still miss the FP4 variants.
If this PR is meant to validate the 512-dim FP4 paths too, add the ("fp8", "fp8", "nvfp4") and ("fp8", "nvfp4", "fp8") cases here. Right now the added coverage only exercises bf16/fp16/fp8 for head_dim=512.
Add the missing dtype pairs to both 512-dim matrices
`@pytest.mark.parametrize`(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "bf16"),
+ ("fp8", "fp8", "nvfp4"),
+ ("fp8", "nvfp4", "fp8"),
],
)Also applies to: 1728-1734
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_trtllm_gen_attention.py` around lines 1658 - 1665, The
parametrization for q_dtype/kv_dtype/o_dtype is missing FP4 (nvfp4) variants for
the 512-dim test cases; update the pytest.mark.parametrize tuples used around
the 512-dim matrix tests (the list containing ("bf16","bf16","bf16"),
("fp16","fp16","fp16"), ("fp8","fp8","fp8"), ("fp8","fp8","bf16")) to also
include ("fp8","fp8","nvfp4") and ("fp8","nvfp4","fp8"); make the same additions
to the other identical parametrization block referenced in the comment (the
second block around lines 1728-1734) so both 512-dim test matrices exercise FP4
paths.
There was a problem hiding this comment.
+1, the PR description says nvfp4 support was added. Not sure which one is the source of truth.
There was a problem hiding this comment.
@saltyminty, that's exactly the crux of the issue. The PR description says "Tests cover dtypes BF16, FP16, FP8, and NVFP4" but the actual parametrization for both test_trtllm_batch_prefill_head_dim_512 and test_trtllm_batch_decode_head_dim_512 only includes bf16/fp16/fp8 combinations — no nvfp4.
There are two possibilities:
- NVFP4 is not supported by the TRT-LLM kernel at head_dim=512 — in which case the PR description is inaccurate and should be corrected to remove "NVFP4" from the claimed coverage.
- NVFP4 was accidentally omitted from the test matrices — in which case
("fp8", "fp8", "nvfp4")and("fp8", "nvfp4", "fp8")should be added to both 512-dim parametrize blocks (thesdpa_paged_referencefallback already operates on the BF16ref_kv_cache, which is the pre-quantization reference for nvfp4, so the reference path should be compatible).
@djmmoss — could you clarify which is the case? If the TRT-LLM kernel genuinely supports NVFP4 at head_dim=512, the test coverage should reflect that.
🧠 Learnings used
Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes × 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.
Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: benchmarks/bench_ssu_sweep_mtp.py:351-359
Timestamp: 2026-04-02T17:54:18.493Z
Learning: In `benchmarks/bench_ssu_sweep_mtp.py` (flashinfer-ai/flashinfer), `parse_dtype_spec()` intentionally allows `bf16-philox-*` and `f32-philox-*` specs even though the MTP path only supports stochastic rounding for fp16 state. This is a deliberate permissive design for a benchmark script (not production code); unsupported combinations will simply fail at runtime. Do not flag this as a bug or suggest restricting the regex to `f16` only.
Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3007
File: tests/utils/test_norm.py:0-0
Timestamp: 2026-04-07T21:44:40.431Z
Learning: In `tests/utils/test_norm.py` (flashinfer-ai/flashinfer), when writing regression tests for large-stride (> INT32_MAX) tensor paths, using `torch.as_strided(small_buf, (M, H), (_INT64_STRIDE, 1))` on a small buffer is unsafe and will segfault because row 1 is at byte offset `2^31` beyond the allocation. The correct pattern is: allocate a flat buffer of at least `_INT64_STRIDE + H` elements, then create the strided view from it so every row is backed by real memory. For fused_add_rmsnorm tests, only the input tensor `x` needs to be non-contiguous (using the large flat buffer); the residual `r` can remain a normally-allocated contiguous tensor — one non-contiguous tensor is sufficient to trigger `is_contiguous() == False` in the kernel selection path, avoiding a second ~4 GB allocation.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: benchmarks/routines/moe.py:1360-1369
Timestamp: 2026-04-14T20:22:02.405Z
Learning: In `benchmarks/routines/moe.py` (flashinfer-ai/flashinfer PR `#3066`), the functional API path (`cute_dsl_fused_moe_nvfp4`, enabled via `--use_functional_api`) is intentionally allowed to run under CUDA graph capture (`bench_gpu_time` with `use_cuda_graph=is_cuda_graph_compatible`). The maintainer confirmed this is fine for the benchmark. Do not flag the lack of a CUDA-graph guard for the functional API path as a bug.
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp4_quantize.py:668-685
Timestamp: 2026-03-27T20:21:04.233Z
Learning: In `flashinfer/quantization/kernels/mxfp4_quantize.py` (flashinfer-ai/flashinfer), the swizzled MXFP4 layout requires K % 128 == 0, which guarantees that `num_sf_blocks_per_row = K // 32` is always a multiple of 4, so `padded_sf_cols == num_sf_blocks_per_row` — no SF column padding ever occurs. The final `scale_output.reshape(-1, num_sf_blocks_per_row)` at the end of `mxfp4_quantize_cute_dsl` is therefore correct for both linear and swizzled paths: it uses the logical SF column count (which downstream consumers expect), not a physically padded count. This matches the CUDA backend behavior (`sf.reshape((-1, input.shape[-1] // sf_vec_size))`). Do not flag this reshape as incorrect or suggest branching on `padded_sf_cols`.
Learnt from: DomBrown
Repo: flashinfer-ai/flashinfer PR: 2770
File: flashinfer/decode.py:2231-2235
Timestamp: 2026-03-19T20:24:35.442Z
Learning: In `flashinfer/decode.py` (and related files `flashinfer/prefill.py`, `flashinfer/mla.py`), the `uses_shared_paged_kv_idx=False` mode is intended for direct TRT-LLM integration. When this flag is False, the `kv_cache` and `kv_block_scales` are expected to already be in TRT-LLM's native paged layout (separate K/V page indices, 3D block_tables `[batch_size, 2, max_num_pages_per_seq]`). The test code interleaves/reshapes tensors only to simulate TRT-LLM layout from a FlashInfer/vLLM-layout fixture — this is a test artifact, not a requirement imposed on real callers.
Learnt from: blake-snc
Repo: flashinfer-ai/flashinfer PR: 0
File: :0-0
Timestamp: 2026-04-16T15:52:27.219Z
Learning: In `flashinfer/prefill.py` (flashinfer-ai/flashinfer PR `#3016`), the fmha_v2 execution path for `BatchPrefillWithRaggedKVCacheWrapper` intentionally raises `NotImplementedError` for `return_lse=True`. TRT-LLM fmha_v2 returns raw stats in `(total_tokens, num_heads, 2)` format which is incompatible with FlashInfer's LSE format; proper conversion is non-trivial and has not been implemented. Do not flag the missing LSE return as an oversight — the NotImplementedError with explanation is the correct guard.
Learnt from: blake-snc
Repo: flashinfer-ai/flashinfer PR: 0
File: :0-0
Timestamp: 2026-04-16T15:52:27.219Z
Learning: In `flashinfer/prefill.py` (flashinfer-ai/flashinfer PR `#3016`), the fmha_v2 backend path for SM120 standard shapes packs Q/K/V via `torch.stack([q, k, v], dim=1)` into PACKED_QKV format. This is currently the only supported layout because SM120 standard kernels are only generated for PACKED_QKV; separate Q/K/V kernel variants would require additional kernel generation. The torch.stack overhead is an acknowledged limitation tracked as a future optimization. Do not flag this as a correctness issue.
Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:00.701Z
Learning: Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures. Ensure test files under tests/ follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.
|
/bot run |
|
[FAILED] Pipeline #47929723: 8/20 passed |
bebd4fd to
55dae60
Compare
|
/bot run |
Restore the TMA box reshape optimization that was removed in the head_dim=512 PR but is unrelated to head_dim=512 support (reshapeFactor is always 1 at head_dim >= 256). This keeps the PR focused on only the changes needed for head_dim=512. AI-assisted.
…llm-gen a339772b The trtllm-gen commit a339772b (context sparse MQA/GQA support) changed: 1. Renamed mSparseMla (bool) → mSparseAttn (int) in TllmGenFmhaKernelMetaInfo 2. Moved mSparseAttnTopK to immediately after mSkipSoftmaxThresholdScaleFactor in the GPU KernelParams struct (was 12 bytes later), causing sparse MLA to read topK=0 → NaN output 3. Renamed kernel SparseP1 → StaticTokenSparseP1 Changes: - kernelParams.h: rename field to mSparseAttnTopK and fix struct ordering - fmhaKernels.cuh: use kernelMeta.mSparseAttn != 0 for kernel hash - fmhaReduction.cu: use kernelMeta.mSparseAttn != 0 and params.mSparseAttnTopK - artifacts.py: update to new artifact hash f8dea5d4 with fixed sparse kernels - cubin_loader.py: add URM_USER/URM_TOKEN auth for internal artifact repo - test_trtllm_gen_mla.py: guard FP8 NaN check consistent with early check All 1344 sparse MLA BF16 tests now pass. AI-assisted.
…parse MLA kernels New artifact includes SM100a variants of sparse MLA kernels, fixing CUDA_ERROR_INVALID_VALUE on SM100a (B200) hardware. AI-assisted.
- Fix torch.tensor([0]) missing device: use torch.zeros(1, dtype=..., device=...) to avoid RuntimeError when q_lens is on GPU (Gemini review) - Fix window_left sentinel: change `> 0` to `>= 0` so window_left=0 correctly applies the mask; -1 is the sentinel for unlimited window (CodeRabbit review) - Add assert not enable_sink before SDPA fallback for head_dim>256: SDPA reference doesn't model attention sinks, guard both prefill and decode paths to prevent comparing against a wrong reference (CodeRabbit review) AI-assisted.
…=512 GQA kernels New artifact restores GQA head_dim=512 kernels (dropped in 23695e6d) while retaining SM100a sparse MLA kernels added in that release. AI-assisted.
6649dea to
a100b1c
Compare
|
/bot run |
When num_kv_shared_layers > 0 (e.g., Gemma4 E4B, Gemma3n), KV-sharing layers set k=None/v=None and the backend reads directly from the fp8 KV cache without dequantization. This causes a dtype mismatch (bf16 q * fp8 k) crash in the Triton extend_attention kernel. Fix: Add dequantization after reading shared KV from cache in triton_backend.py. The dtype check is guarded by k.dtype != q.dtype, so it is a no-op for non-quantized KV caches. Also apply k_scale_float/v_scale_float to properly restore original values when quantization scales are present. The flashinfer backend fix is deferred to a separate PR per maintainer request, as it requires additional work blocked on flashinfer-ai/flashinfer#2959. Fixes sgl-project#22277 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
| @pytest.mark.parametrize("enable_pdl", [None]) | ||
| @pytest.mark.parametrize("enable_sink", [False]) | ||
| @pytest.mark.parametrize("max_q_len", [255]) | ||
| @pytest.mark.parametrize("max_kv_len", [511]) |
There was a problem hiding this comment.
Can we add more q_len and kv_len configurations?
| ) | ||
| @pytest.mark.parametrize("enable_pdl", [None]) | ||
| @pytest.mark.parametrize("enable_sink", [False]) | ||
| @pytest.mark.parametrize("max_in_kv_len", [110]) |
…configs Address PR review feedback from yzh119 to add more q_len and kv_len configurations to the head_dim=512 tests: - Prefill: max_q_len [255] -> [1, 255, 511], max_kv_len [511] -> [511, 2047] - Decode: max_in_kv_len [110] -> [110, 4096, 8192] - Relax wrapper comparison for head_dim > 256 (use rtol/atol=1e-2 instead of exact equality) since large head dims accumulate FP rounding error.
|
/bot run |
|
@yzh119 when can we get this in? |
|
@djmmoss are we blocked by anything or just blocked by review? |
|
@nvpohanh AFAIK blocked by review, the errors on the CI pipelines are either preexisting or unrelated to these changes |
| elif v_scale == o_scale == 1.0: | ||
| # Large head dims (e.g. 512) accumulate enough FP error that | ||
| # wrapper and direct outputs are not bit-identical. | ||
| torch.testing.assert_close( |
There was a problem hiding this comment.
I'm confused how could wrapper and direct output be different, are they using the same implementation?
| @pytest.mark.parametrize( | ||
| "q_dtype,kv_dtype,o_dtype", | ||
| [ | ||
| ("bf16", "bf16", "bf16"), | ||
| ("fp16", "fp16", "fp16"), | ||
| ("fp8", "fp8", "fp8"), | ||
| ("fp8", "fp8", "bf16"), | ||
| ], |
There was a problem hiding this comment.
+1, the PR description says nvfp4 support was added. Not sure which one is the source of truth.
saltyminty
left a comment
There was a problem hiding this comment.
Approved but please address comments before merging
Add support for
head_dim=512in the trtllm FMHA kernel selection.Changes
head_dim > 256in tests (FlashInfer FA2/FA3 kernels don't supporthead_dim > 256)test_trtllm_batch_prefill_head_dim_512andtest_trtllm_batch_decode_head_dim_512covering BF16, FP16, FP8, and NVFP4 dtypesSummary by CodeRabbit
Bug Fixes
Tests
Chores