Skip to content

Support NVFP4 KV for prefill and batch attention kernels#2820

Closed
Tom-Zheng wants to merge 5 commits intoflashinfer-ai:mainfrom
Tom-Zheng:add-sm120-nvfp4-kv-prefill
Closed

Support NVFP4 KV for prefill and batch attention kernels#2820
Tom-Zheng wants to merge 5 commits intoflashinfer-ai:mainfrom
Tom-Zheng:add-sm120-nvfp4-kv-prefill

Conversation

@Tom-Zheng
Copy link
Copy Markdown
Contributor

@Tom-Zheng Tom-Zheng commented Mar 19, 2026

📌 Description

This MR supports NVFP4 KV input for batch prefill and batch attention kernels. It widely supports all arch.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • NVFP4 (packed 4-bit) KV cache support with optional per-block scale factors and FP4 decode paths.
  • Public API
    • Prefill/decode/attention APIs accept optional KV scale-factor inputs and KV-scale arguments; batch/run wrappers forward them.
  • Compatibility
    • CPU fallback for FP4 decode and runtime checks to disable unsupported backend paths.
  • Tests
    • New NVFP4-focused tests and helpers covering prefill, batch attention, decode, and single-request flows.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 19, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end-to-end NVFP4 (packed FP4 / torch.uint8) KV-cache scale-factor support: new Python parameters, JIT/module generator changes to accept KV-scale tensors, prefill/decode attention device kernels to load/apply per-block scale-factors, FP4 dequantization helpers including CPU fallback, and tests exercising the flow.

Changes

Cohort / File(s) Summary
Python API & runtime callers
flashinfer/attention.py, flashinfer/prefill.py, flashinfer/decode.py
Added kv_block_scales / maybe_*_cache_sf / key_block_scales / value_block_scales parameters and forwarded them into JIT/custom-op calls; adjusted paged-run arg ordering for trtllm-gen.
JIT generators & dtype maps
flashinfer/jit/attention/modules.py, flashinfer/jit/utils.py
Added dtype_map_kv, switched KV-specific dtype mapping usage, declared optional maybe_k_cache_sf/maybe_v_cache_sf additional tensor inputs, and generated a custom additional-params setter.
CUDA kernels / device templates
include/flashinfer/attention/prefill.cuh, include/flashinfer/attention/persistent.cuh
Implemented FP4-packed GMEM handling, per-warp/shared-memory KV scale-factor buffers, new produce_*_sf/page_produce_*_sf loaders, adjusted offset math, and passed SF smem pointers + lane indices into compute routines.
cp.async / smem / swizzle helpers
include/flashinfer/cp_async.cuh, include/flashinfer/permuted_smem.cuh, include/flashinfer/frag_layout_swizzle.cuh
Added predicate-aware cp.async helpers (pred_load_128b_from_64b, pred_load_32b), 64B async load API, and two 16b→4b frag-layout swizzle helpers for FP4 expansion.
Vector casts & dequantization
include/flashinfer/vec_dtypes.cuh, flashinfer/quantization/fp4_quantization.py
Added vec_cast specializations to expand __nv_fp4x2_e2m1half/bfloat16 and a CPU fallback dequantization helper (E2M1 LUT + ufp8 scale decoding) for older GPUs.
Backend selection utils
flashinfer/utils.py
is_fa3_backend_supported now early-rejects torch.uint8 KV dtype for FA3 backend.
Tests & helpers
tests/test_helpers/utils_fp4.py, tests/attention/*
Added NVFP4 test helpers (create_nvfp4_kv, nvfp4_to_float) and new tests for single/batch/ragged/paged prefill, batch attention, and batch decode using packed NVFP4 KV + per-block scale tensors.
Vectorized dtype helpers
include/flashinfer/vec_dtypes.cuh
Added FP4→FP16/BF16 vec_cast specializations and software fallback paths for older architectures.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant PyAPI as Python API (BatchAttention/Prefill/Decode)
    participant JIT as JIT / custom-op module
    participant Kernel as CUDA kernel
    participant GMEM as KV GMEM (packed NVFP4 + block SF)

    Client->>PyAPI: call run(..., kv_block_scales=kv_sf)
    PyAPI->>JIT: forward tensors + kv_block_scales / maybe_*_cache_sf
    JIT->>Kernel: launch kernel with maybe_k_cache_sf / maybe_v_cache_sf pointers
    Kernel->>GMEM: page_produce_kv (load packed FP4) and page_produce_kv_sf (load SF bytes)
    Kernel->>Kernel: place SF into smem, expand/dequantize, compute_qk / compute_sfm_v using SF + lane_idx
    Kernel-->>JIT: return outputs
    JIT-->>PyAPI: relay outputs
    PyAPI-->>Client: deliver result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • yzh119
  • aleozlx
  • sricketts
  • nvmbreughe
  • samuellees
  • bkryu
  • cyx-6
  • jimmyzho
  • kahyunnam
  • nv-yunzheq

Poem

🐰 I nibble at packed nibbles, two by two,
I carry scales that tell kernels what to do,
shared mem hums, pages glide and sway,
FP4 bits wake up and join the play,
carrots for tests — precision hops away.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main change: adding NVFP4 KV support to prefill and batch attention kernels.
Description check ✅ Passed The description covers the main objective (NVFP4 KV support for batch prefill/attention kernels) and confirms checklist completion, though it lacks related issue links and detailed reviewer notes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by integrating native support for NVFP4 KV cache input across its batch prefill and batch attention kernels. This allows for more memory-efficient and potentially faster inference on NVIDIA GPUs by leveraging 4-bit quantization for key and value tensors. The changes span from low-level CUDA kernel implementations for data loading and computation to Python-level utilities for quantization, dequantization, and comprehensive testing, ensuring broad compatibility and correctness.

Highlights

  • NVFP4 KV Cache Support: Introduced support for NVFP4 (NVIDIA FP4) KV cache input for both batch prefill and batch attention kernels, extending low-precision capabilities across all supported architectures.
  • Quantization and Dequantization Utilities: Added Python utilities (_to_nvfp4, _nvfp4_to_float) for quantizing float tensors to NVFP4 and dequantizing them back, including handling of per-group FP8 scaling factors.
  • CUDA Kernel Modifications: Modified core CUDA kernels (prefill.cuh, persistent.cuh) to handle NVFP4 data loading, shared memory management, and apply scaling factors during QK and SFM*V computations. This includes new is_fp4_type traits and shared memory allocations for scale factors.
  • Asynchronous Memory Operations: Implemented new cp_async primitives (pred_load_128b_from_64b, pred_load_32b) for efficient asynchronous loading of packed FP4 data and their corresponding scale factors from global to shared memory.
  • Fragment Layout Swizzling: Added specialized fragment layout swizzling functions (frag_layout_swizzle_16b_to_4b, frag_layout_swizzle_16b_to_4b_trans) to correctly process packed FP4 data within MMA operations.
  • Type Casting and Fallbacks: Extended vec_cast to support conversion from __nv_fp4x2_e2m1 (packed FP4) to half and nv_bfloat16, including a pure-PyTorch CPU fallback for dequantization on older GPU architectures (pre-SM90).
  • Benchmarking and Testing: Updated benchmark routines to include NVFP4 as a supported KV data type, adjusted numerical tolerances for lower precision, and added new correctness tests for batch attention and batch prefill with paged KV cache using NVFP4.
  • Documentation: Added new internal documentation files (.claude/memory/MEMORY.md, .claude/memory/prefill_cuh_structure.md) to analyze and describe the structure of prefill.cuh.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 (NVIDIA FP4) KV cache quantization in FlashInfer's prefill and batch attention operations. Key changes include updating benchmark routines to handle NVFP4 as a KV data type, adjusting tolerances for lower precision, and filtering unsupported backends. The core C++ kernels and Python JIT modules are extended to manage packed NVFP4 data and per-group scale factors, including modifications to memory access patterns and MMA operations for proper dequantization. Review comments highlight the need for improved clarity in the global_scale calculation in _to_nvfp4 and detailed explanations for the intricate scaling factor application logic within the compute_qk and compute_sfm_v device functions.

Comment thread benchmarks/routines/attention.py Outdated
Comment thread include/flashinfer/attention/prefill.cuh
Comment thread include/flashinfer/attention/prefill.cuh
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been created, and the CI pipeline #46514756 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/routines/attention.py (2)

1197-1201: ⚠️ Potential issue | 🟠 Major

Don't quantize the paged NVFP4 cache twice.

Lines 1036-1059 already build packed kv_cache plus kv_cache_sf. Lines 1197-1201 then feed that packed uint8 data back into nvfp4_quantize_paged_kv_cache(...), which expects floating-point KV input and discards the scales you just computed. Reuse kv_cache_sf as kv_block_scales here.

Suggested change
-    if is_nvfp4_kv:
-        kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = (
-            nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1])
-        )
-        kv_cache = kv_cache_nvfp4
+    if use_nvfp4_kv:
+        kv_block_scales = kv_cache_sf
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1197 - 1201, The code is
re-quantizing an already-packed NVFP4 paged KV cache: when is_nvfp4_kv is true
you call nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) on
uint8-packed data and overwrite the correct scales; instead, reuse the
precomputed packed cache and scales (kv_cache and kv_cache_sf) produced earlier:
set kv_cache = kv_cache (or keep existing packed variable) and assign
kv_block_scales = kv_cache_sf (and ensure k_scale and v_scale use the previously
computed values), removing the nvfp4_quantize_paged_kv_cache call inside the
is_nvfp4_kv branch so you don't discard the original scales.

865-867: ⚠️ Potential issue | 🔴 Critical

Use one NVFP4 feature-flag name throughout this function.

Line 865 defines is_nvfp4_kv, but the new branches later read use_nvfp4_kv (for example Lines 883, 965, 1036, and 1218). testBatchPrefillWithPagedKVCacheWrapper() currently throws before any benchmark runs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 865 - 867, The function uses
two inconsistent feature-flag names (is_nvfp4_kv and use_nvfp4_kv) causing
branches to miss the intended flag; unify them by picking one name (e.g.,
replace the initial definition is_nvfp4_kv = args.kv_dtype == "nvfp4" with
use_nvfp4_kv = args.kv_dtype == "nvfp4" or create use_nvfp4_kv = is_nvfp4_kv
immediately after) and update all branch checks (references in the function such
as the later conditionals at lines referencing use_nvfp4_kv) to use that single
symbol so the NVFP4 path is consistently triggered (also ensure any dtype checks
that set kv_dtype remain correct).
include/flashinfer/attention/prefill.cuh (1)

1671-1702: ⚠️ Potential issue | 🟠 Major

These prefill paths now consume FP4 scale tiles without ever producing them.

compute_qk / compute_sfm_v now dereference k_sf_smem and v_sf_smem for FP4 KV, but SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel still only call produce_kv(...). If either path is instantiated with __nv_fp4x2_e2m1, it will multiply against uninitialized shared memory and silently corrupt the result. Please either plumb per-row SF loads into these kernels too, or add a compile-time guard that keeps FP4 limited to the paged path for now.

Also applies to: 2118-2150

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/prefill.cuh` around lines 1671 - 1702, The
prefill paths call compute_qk and compute_sfm_v which read k_sf_smem/v_sf_smem
for FP4 but the kernels SinglePrefillWithKVCacheDevice and
BatchPrefillWithRaggedKVCacheKernel only call produce_kv and never populate
those scale-factor tiles, so instantiating with __nv_fp4x2_e2m1 will read
uninitialized shared memory; fix by adding a compile-time guard that prevents
FP4 (e.g. static_assert or if constexpr) in SinglePrefillWithKVCacheDevice and
BatchPrefillWithRaggedKVCacheKernel (or the wrapper that calls produce_kv) when
KTraits::ScalarType == __nv_fp4x2_e2m1, OR alternatively plumb the per-row SF
loads into those kernels by ensuring produce_kv is invoked with the
SharedMemFillMode that fills k_sf_smem/v_sf_smem (or explicitly call the SF fill
helper) before compute_qk/compute_sfm_v are executed; pick one approach and
apply it consistently to both call sites referencing compute_qk, compute_sfm_v,
k_sf_smem, v_sf_smem, and produce_kv.
🧹 Nitpick comments (4)
flashinfer/jit/attention/modules.py (1)

1836-1870: Factor the batch-attention setter generation out of this function.

This copies the nullable/scalar assignment rules from generate_additional_params(). A small helper that accepts the target prefix (params, params[i], params.additional_params) would keep the batch path from drifting the next time additional-parameter semantics change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/attention/modules.py` around lines 1836 - 1870, Extract the
logic that builds the param-assignment lines into a small helper (e.g.,
generate_additional_params_setter(prefix, additional_tensor_names,
additional_tensor_dtypes, additional_scalar_names)) and replace the inline
batch_additional_params_setter construction with a call to that helper; the
helper should implement the same nullable tensor/ scalar rules currently
duplicated (the conditional branch for var.startswith("maybe") and the scalar
formatting) but use the provided target prefix (e.g., "params[i]", "params", or
"params.additional_params") when formatting each assignment; update the call
sites (the batch path that currently creates batch_additional_params_setter and
any other place using generate_additional_params output) to call the new helper
so semantics remain identical but the formatting logic is centralized.
.claude/memory/prefill_cuh_structure.md (1)

40-62: Call out the NVFP4 scale-factor path explicitly.

This note still reads like a generic prefill overview. The new FP4-specific pieces—maybe_k_cache_sf / maybe_v_cache_sf, page_produce_kv_sf, and the shared-memory scale buffers consumed by compute_qk / compute_sfm_v—are exactly what future readers will look for in this PR.

Based on learnings: Keep documentation in sync with code changes, particularly CLAUDE.md and .claude/skills/ when modifying infrastructure changes, patterns, new conventions, or deprecations.

Also applies to: 75-86

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.claude/memory/prefill_cuh_structure.md around lines 40 - 62, Update the
prefill overview to explicitly document the NVFP4 scale-factor path: describe
the new FP4-specific symbols maybe_k_cache_sf and maybe_v_cache_sf, the
page_produce_kv_sf path, and the shared-memory scale buffers that compute_qk and
compute_sfm_v consume; note where these are emitted/loaded and how they flow
through page_produce_kv_sf → shared-memory buffers → compute_qk/compute_sfm_v,
and add a short cross-reference to the infra docs/skills that must be updated
when changing these conventions so readers can find the FP4 scale-factor
behavior quickly.
tests/attention/test_batch_attention.py (1)

308-309: Exercise signed E2M1 codes too.

Line 309 clears both sign bits, so this test never covers negative NVFP4 values. A sign-handling regression would still pass here; either remove the mask or add a second signed-data case.

Suggested change
-    packed &= 0x77  # clear bit 3 (0x08) and bit 7 (0x80) to ensure non-negative
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_attention.py` around lines 308 - 309, The test
currently forces all packed bytes non-negative by applying "packed &= 0x77", so
negative NVFP4 (signed E2M1) values are never exercised; either remove the mask
expression "packed &= 0x77" to allow both signs, or add a second test case that
constructs a signed-data variant (e.g., copy the existing "packed" and set the
sign bits for NVFP4 by OR-ing the appropriate bits such as 0x08 and/or 0x80) and
run the same assertions on that signed copy so both unsigned and negative NVFP4
paths are covered.
tests/attention/test_batch_prefill_kernels.py (1)

1074-1083: Add at least one causal NVFP4 case.

This matrix hardcodes causal=False, so the new NVFP4 path never exercises the masked/tail-tile logic that changed in the kernel code. A small causal=True case would cover the scale-factor path under masking as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_prefill_kernels.py` around lines 1074 - 1083, The
test matrix currently forces causal=False so the NVFP4 kernel path never hits
masked/tail-tile logic; update the test_batch_prefill_with_paged_kv_cache_nvfp4
parameterization (the `@pytest.mark.parametrize`("causal", ...) on the test) to
include True (e.g., [False, True]) so at least one run exercises the
causal/masked path for NVFP4; keep existing q_dtype values unchanged so the
NVFP4 path is still exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1667-1675: The NVFP4 allowlist is too permissive: in the
use_nvfp4_kv block (variable backends and list nvfp4_unsupported) remove "auto"
and "trtllm-native" from the supported set so ragged NVFP4 only stays enabled
for backends that actually wrap ragged K/V (e.g., "fa2" and "fa3"); update the
allowed list used to compute nvfp4_unsupported from ["fa2", "trtllm-native",
"auto"] to only ["fa2", "fa3"] (or the concrete backends that implement ragged
NVFP4) so backends that don't forward k_sf/v_sf are filtered out.

In `@include/flashinfer/cp_async.cuh`:
- Around line 191-224: The cp.async call in pred_load_128b_from_64b uses
cp_size=8 which doesn’t zero the upper 8 bytes; change the cp.async
invocation(s) to use cp_size=16 and pass src_size as a variable (src_size =
predicate ? 8 : 0) so cp.async zero-fills bytes 8..15 when src_size is 0, and
similarly adjust the kNoFill branch to issue cp.async with cp_size=16 and
src_size conditionally 8 or 0 (instead of cp_size=8); also apply the same fix
pattern to the 32b helper described in the comment: use cp_size appropriate to
the full destination (e.g., 16 for 128b destination or 4 for 32b helper) and
make src_size variable (0 when wanting explicit zero-fill, nonzero when copying)
so cp.async actually zeros the upper bytes.

In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 486-510: The CUDA version gate incorrectly requires both
__CUDACC_VER_MAJOR__ >= 13 and __CUDACC_VER_MINOR__ >= 2, which fails for CUDA
14.x; update the preprocessor condition that guards the fast-path asm (the block
using cvt.rn.bf16x2.e2m1x2 and variable y/b) to check the combined version
(e.g., compare major/minor together or compute a single numeric version) so it
enables the fast path for CUDA >= 13.2 (including 14.0+).

---

Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1197-1201: The code is re-quantizing an already-packed NVFP4 paged
KV cache: when is_nvfp4_kv is true you call
nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) on uint8-packed
data and overwrite the correct scales; instead, reuse the precomputed packed
cache and scales (kv_cache and kv_cache_sf) produced earlier: set kv_cache =
kv_cache (or keep existing packed variable) and assign kv_block_scales =
kv_cache_sf (and ensure k_scale and v_scale use the previously computed values),
removing the nvfp4_quantize_paged_kv_cache call inside the is_nvfp4_kv branch so
you don't discard the original scales.
- Around line 865-867: The function uses two inconsistent feature-flag names
(is_nvfp4_kv and use_nvfp4_kv) causing branches to miss the intended flag; unify
them by picking one name (e.g., replace the initial definition is_nvfp4_kv =
args.kv_dtype == "nvfp4" with use_nvfp4_kv = args.kv_dtype == "nvfp4" or create
use_nvfp4_kv = is_nvfp4_kv immediately after) and update all branch checks
(references in the function such as the later conditionals at lines referencing
use_nvfp4_kv) to use that single symbol so the NVFP4 path is consistently
triggered (also ensure any dtype checks that set kv_dtype remain correct).

In `@include/flashinfer/attention/prefill.cuh`:
- Around line 1671-1702: The prefill paths call compute_qk and compute_sfm_v
which read k_sf_smem/v_sf_smem for FP4 but the kernels
SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel only call
produce_kv and never populate those scale-factor tiles, so instantiating with
__nv_fp4x2_e2m1 will read uninitialized shared memory; fix by adding a
compile-time guard that prevents FP4 (e.g. static_assert or if constexpr) in
SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel (or the
wrapper that calls produce_kv) when KTraits::ScalarType == __nv_fp4x2_e2m1, OR
alternatively plumb the per-row SF loads into those kernels by ensuring
produce_kv is invoked with the SharedMemFillMode that fills k_sf_smem/v_sf_smem
(or explicitly call the SF fill helper) before compute_qk/compute_sfm_v are
executed; pick one approach and apply it consistently to both call sites
referencing compute_qk, compute_sfm_v, k_sf_smem, v_sf_smem, and produce_kv.

---

Nitpick comments:
In @.claude/memory/prefill_cuh_structure.md:
- Around line 40-62: Update the prefill overview to explicitly document the
NVFP4 scale-factor path: describe the new FP4-specific symbols maybe_k_cache_sf
and maybe_v_cache_sf, the page_produce_kv_sf path, and the shared-memory scale
buffers that compute_qk and compute_sfm_v consume; note where these are
emitted/loaded and how they flow through page_produce_kv_sf → shared-memory
buffers → compute_qk/compute_sfm_v, and add a short cross-reference to the infra
docs/skills that must be updated when changing these conventions so readers can
find the FP4 scale-factor behavior quickly.

In `@flashinfer/jit/attention/modules.py`:
- Around line 1836-1870: Extract the logic that builds the param-assignment
lines into a small helper (e.g., generate_additional_params_setter(prefix,
additional_tensor_names, additional_tensor_dtypes, additional_scalar_names)) and
replace the inline batch_additional_params_setter construction with a call to
that helper; the helper should implement the same nullable tensor/ scalar rules
currently duplicated (the conditional branch for var.startswith("maybe") and the
scalar formatting) but use the provided target prefix (e.g., "params[i]",
"params", or "params.additional_params") when formatting each assignment; update
the call sites (the batch path that currently creates
batch_additional_params_setter and any other place using
generate_additional_params output) to call the new helper so semantics remain
identical but the formatting logic is centralized.

In `@tests/attention/test_batch_attention.py`:
- Around line 308-309: The test currently forces all packed bytes non-negative
by applying "packed &= 0x77", so negative NVFP4 (signed E2M1) values are never
exercised; either remove the mask expression "packed &= 0x77" to allow both
signs, or add a second test case that constructs a signed-data variant (e.g.,
copy the existing "packed" and set the sign bits for NVFP4 by OR-ing the
appropriate bits such as 0x08 and/or 0x80) and run the same assertions on that
signed copy so both unsigned and negative NVFP4 paths are covered.

In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1074-1083: The test matrix currently forces causal=False so the
NVFP4 kernel path never hits masked/tail-tile logic; update the
test_batch_prefill_with_paged_kv_cache_nvfp4 parameterization (the
`@pytest.mark.parametrize`("causal", ...) on the test) to include True (e.g.,
[False, True]) so at least one run exercises the causal/masked path for NVFP4;
keep existing q_dtype values unchanged so the NVFP4 path is still exercised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9c342064-89d2-4570-b892-646302704193

📥 Commits

Reviewing files that changed from the base of the PR and between c4a159a and 193a6c6.

📒 Files selected for processing (18)
  • .claude/memory/MEMORY.md
  • .claude/memory/prefill_cuh_structure.md
  • benchmarks/routines/attention.py
  • flashinfer/attention.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/utils.py
  • include/flashinfer/attention/persistent.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/permuted_smem.cuh
  • include/flashinfer/vec_dtypes.cuh
  • mha_ref.cu
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_prefill_kernels.py
👮 Files not reviewed due to content moderation or server errors (7)
  • flashinfer/utils.py
  • .claude/memory/MEMORY.md
  • flashinfer/jit/utils.py
  • flashinfer/attention.py
  • include/flashinfer/permuted_smem.cuh
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py

Comment thread benchmarks/routines/attention.py Outdated
Comment thread include/flashinfer/cp_async.cuh
Comment thread include/flashinfer/vec_dtypes.cuh
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/attention/test_batch_prefill_kernels.py (1)

1151-1165: Please add an asymmetric head_dim_qk/head_dim_vo case here.

head_dim_vo is omitted, so this suite only covers the default head_dim_vo == head_dim_qk path. A 192/128-style case would exercise the packed-V sizing logic that the current fixture cannot catch. The same omission exists in tests/attention/test_batch_attention.py, so it would be worth updating both together.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_prefill_kernels.py` around lines 1151 - 1165, The
test only exercises symmetric head sizes because wrapper.plan is called with
head_dim only; add an asymmetric case by passing explicit head_dim_qk and
head_dim_vo arguments to wrapper.plan (for example head_dim_qk=192,
head_dim_vo=128) so the packed-V sizing logic is exercised; update the
wrapper.plan invocation in tests/attention/test_batch_prefill_kernels.py (and
mirror the same change in tests/attention/test_batch_attention.py) to include
these two explicit parameters instead of relying on the default head_dim
equality.
include/flashinfer/attention/prefill.cuh (1)

449-498: Consider adding null check for defensive programming.

If sf_ptr is nullptr but is_fp4_type_v<DTypeKV> is true, the function computes offsets and calls pred_load_32b with an invalid source pointer. While the design assumes FP4 usage implies scales are provided, a null check would add robustness:

 template <bool produce_v, typename KTraits, typename IdType>
 __device__ __forceinline__ void page_produce_kv_sf(
     typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr,
     ...) {
   if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return;
+  if (sf_ptr == nullptr) return;

This prevents undefined behavior if FP4 is compiled but scales are accidentally omitted at runtime.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/prefill.cuh` around lines 449 - 498, The
function page_produce_kv_sf may dereference sf_ptr when is_fp4_type_v<typename
KTraits::DTypeKV> is true; add a defensive null check at the start of
page_produce_kv_sf (after the is_fp4_type_v constexpr) that returns early if
sf_ptr == nullptr to avoid computing sf_gmem_offset and calling
cp_async::pred_load_32b with an invalid source pointer; keep the check
independent of produce_v and ensure it triggers before the NUM_SF_ITERS loop so
symbols page_produce_kv_sf, sf_ptr, is_fp4_type_v, and cp_async::pred_load_32b
are addressed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 737-738: The code currently assumes packed NVFP4 V/O width equals
q.shape[-1] when kv_block_scales (or key_block_scales/value_block_scales)
indicate a packed V cache, which breaks configs where head_dim_vo !=
head_dim_qk; update the run path to use the planned head_dim_vo instead of
q.shape[-1] for packed outputs and persist the planned head_dim_vo from plan()
onto self (e.g., self.head_dim_vo) so run() can read it; adjust any branches
that check kv_block_scales/key_block_scales/value_block_scales to select
self.head_dim_vo as the V/O width when packed is detected.

In `@include/flashinfer/cp_async.cuh`:
- Around line 192-223: pred_load_128b_from_64b: ensure the cp.async path zeroes
the upper 8 bytes to match the fallback by changing the assembly copy size to 16
while keeping src-size=8 (i.e. use cp.async.ca.shared.global with cp-size=16,
src-size=8) in both the fill-mode (kFillZero) branch and the kNoFill branch so
the upper half of the 16-byte slot is zero-padded when only 8 bytes are sourced;
keep the predicate logic and the fallback (smem_u64[1] = 0) unchanged.

---

Nitpick comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 449-498: The function page_produce_kv_sf may dereference sf_ptr
when is_fp4_type_v<typename KTraits::DTypeKV> is true; add a defensive null
check at the start of page_produce_kv_sf (after the is_fp4_type_v constexpr)
that returns early if sf_ptr == nullptr to avoid computing sf_gmem_offset and
calling cp_async::pred_load_32b with an invalid source pointer; keep the check
independent of produce_v and ensure it triggers before the NUM_SF_ITERS loop so
symbols page_produce_kv_sf, sf_ptr, is_fp4_type_v, and cp_async::pred_load_32b
are addressed.

In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1151-1165: The test only exercises symmetric head sizes because
wrapper.plan is called with head_dim only; add an asymmetric case by passing
explicit head_dim_qk and head_dim_vo arguments to wrapper.plan (for example
head_dim_qk=192, head_dim_vo=128) so the packed-V sizing logic is exercised;
update the wrapper.plan invocation in
tests/attention/test_batch_prefill_kernels.py (and mirror the same change in
tests/attention/test_batch_attention.py) to include these two explicit
parameters instead of relying on the default head_dim equality.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bddea1b0-077d-4606-bfad-a761fbad1553

📥 Commits

Reviewing files that changed from the base of the PR and between c4a159a and 27930f3.

📒 Files selected for processing (14)
  • flashinfer/attention.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/utils.py
  • include/flashinfer/attention/persistent.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/permuted_smem.cuh
  • include/flashinfer/vec_dtypes.cuh
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_prefill_kernels.py

Comment thread flashinfer/prefill.py
Comment thread include/flashinfer/cp_async.cuh
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46514756: 6/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #46572392 is currently running. I'll report back once the pipeline job completes.

aleozlx pushed a commit that referenced this pull request Mar 20, 2026
…2725)

## Summary
SM120 desktop Blackwell GPUs (RTX PRO 6000, RTX 5090) are blocked from
NVFP4 MoE grouped GEMM due to hardcoded SM100-only checks.

**Changes:**
- `jit/fused_moe.py`: Add major version 12 to `supported_major_versions`
- `csrc/trtllm_fused_moe_kernel_launcher.cu`: `ICHECK_EQ(major, 10)` ->
`ICHECK_GE(major, 10)`

**Benchmark** (Qwen3.5-397B on 4x RTX PRO 6000 SM120):
| Config | tok/s | Output |
|--------|-------|--------|
| compute_120f (CUDA 13.0) | 39.0 | Correct |
| compute_120a (CUDA 12.8) | 14.6 | Correct (slow fallback) |
| Marlin W4A16 | 46-49 | Correct |

**Root cause:** All TMA WS grouped GEMM autotuner tactics fail on
`compute_120a`, requiring `compute_120f` (CUDA 13.0).

CuTe DSL `admissible_archs` in vendored CUTLASS also needs
`sm_120a`/`sm_120f` (cpasync/copy.py, tcgen05/mma.py, arch/mbar.py,
etc).

Related: CUTLASS #2820, #2800; vLLM #33416, #33333; FlashInfer #2577

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Broadened GPU architecture checks to accept additional modern compute
capabilities (SM 10.x and 12.x), improving compatibility and clearer SM
reporting.
* Improved compute-capability detection and encoding, preserving
user-provided architecture suffixes and more accurately generating nvcc
architecture flags.
* Expanded JIT module generation to include additional CUDA majors so
fused-MoE kernels run on more recent GPUs.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Brandon Music <brandon.m.music@gmail.com>
Co-authored-by: Brandon Music <brandonmmusic-max@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Brandon Music <brandonmusic@pop-os.tail8674da.ts.net>
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46572392: 6/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #47354783 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

1319-1367: ⚠️ Potential issue | 🔴 Critical

v_scale never affects the new single/ragged NVFP4 paths.

Only k_scale is folded into sm_scale. v_scale is neither forwarded nor post-applied here, so any non-unit global V scale silently returns the wrong values.

🐛 Proposed fix
@@
     module.run(
         q,
         k,
         v,
@@
         k_sf,
         v_sf,
     )
+    is_float_one = isinstance(v_scale, float) and v_scale == 1.0
+    if v_scale is not None and not is_float_one:
+        if is_float8(out):
+            out = (out.to(torch.float32) * v_scale).to(out.dtype)
+        else:
+            out *= v_scale
@@
         assert self._cached_module is not None, "cached module is not initialized"
         self._cached_module.ragged_run(*run_args)
+        is_float_one = isinstance(v_scale, float) and v_scale == 1.0
+        if v_scale is not None and not is_float_one:
+            if is_float8(out):
+                out = (out.to(torch.float32) * v_scale).to(out.dtype)
+            else:
+                out *= v_scale
         return (out, lse) if return_lse else out

Also applies to: 3191-3330

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1319 - 1367, The NVFP4 path folds k_scale
into sm_scale but never applies v_scale, so outputs are wrong when a non-unit V
scale is used; fix by applying v_scale to the produced output for the
NVFP4/packed-KV case (kv_cache_sf != None) — either pass v_scale through into
the prefill kernel if it supports it or multiply out by v_scale after module.run
(operate on out), using the existing symbols k_scale, v_scale, sm_scale, out,
kv_cache_sf and v_sf to detect the packed NVFP4 branch and apply the correct
scaling.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)

1330-1341: ⚠️ Potential issue | 🔴 Critical

Don't size packed-NVFP4 outputs from q.shape[-1].

Packed KV changes storage width, not the logical V/O width. These branches still assume head_dim_vo == head_dim_qk, which breaks asymmetric QK/VO shapes.

🐛 Proposed fix
@@
-    out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1]
+    out_head_dim = v.shape[-1] * 2 if kv_cache_sf is not None else v.shape[-1]
@@
         if head_dim_vo is None:
             head_dim_vo = head_dim_qk
+        self._head_dim_vo = head_dim_vo
@@
-        out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1]
+        out_head_dim = (
+            self._head_dim_vo
+            if self._cached_kv_data_type == torch.uint8
+            else v.shape[-1]
+        )

Also applies to: 3199-3212

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1330 - 1341, The code incorrectly sets
out_head_dim / head_dim_vo from q.shape[-1] when kv_cache_sf is not None (packed
NVFP4); packed storage width differs from logical V/O width, so stop deriving
logical head_dim_vo from q.shape[-1]. Instead compute out_head_dim = v.shape[-1]
(the logical V/O head dim) regardless of kv_cache_sf, and pass that value as the
head_dim_vo argument to get_single_prefill_module; also update the other
symmetric block that mirrors this logic (the later occurrence handling packed
KV) to use v.shape[-1] rather than q.shape[-1]. Ensure any allocation of out
uses this logical out_head_dim and that get_single_prefill_module receives
q.shape[-1] for head_dim_qk and v.shape[-1] for head_dim_vo.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3172-3177: The code currently always folds q_scale and k_scale
into sm_scale (in the sm_scale computation near sm_scale = 1.0 /
math.sqrt(q.size(-1))), which causes double-scaling when the ragged cuDNN path
is used because the cuDNN call below also receives q_scale and k_scale
separately; change the logic to mirror the paged wrapper's cuDNN guard used
elsewhere: only multiply sm_scale by q_scale and k_scale when NOT taking the
cuDNN ragged path (i.e., when cuDNN is not used), otherwise leave sm_scale as
the geometric/default scale and pass q_scale/k_scale unchanged to the cuDNN
call; apply the same fix to the second occurrence around lines 3245-3260 so both
regions use the same cuDNN guard and avoid double-scaling.

In `@include/flashinfer/attention/prefill.cuh`:
- Around line 461-510: page_produce_kv_sf can dereference sf_ptr when in_bounds
is true; add a null-pointer guard at the top of page_produce_kv_sf (the FP4-only
branch) to avoid dereferencing sf_ptr (e.g., if (sf_ptr == nullptr) return; or
otherwise ensure all in_bounds are false when sf_ptr is null) so that the
subsequent call to cp_async::pred_load_32b(...) never receives sf_ptr +
sf_gmem_offset when sf_ptr is null.

In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1030-1165: The NVFP4 tests (e.g.,
test_batch_prefill_with_paged_kv_cache_nvfp4) run unconditionally on all GPUs;
add the repo-standard compute-capability guard using flashinfer.utils to skip
unsupported architectures: import flashinfer.utils and at the start of the test
call get_compute_capability()/is_sm90a_supported()/is_sm100a_supported() (or
directly use is_sm90a_supported() or is_sm100a_supported()) and call
pytest.skip(...) when neither is supported; apply the same guard to the other
NVFP4 test block (the one referenced in the comment for lines 1168-1260) so
unsupported runners skip instead of failing.

In `@tests/attention/test_single_prefill.py`:
- Around line 107-160: The test test_single_prefill_with_kv_cache_nvfp4 must be
gated by GPU compute capability: import and call the repo-standard helpers from
flashinfer.utils (get_compute_capability(), is_sm90a_supported(),
is_sm100a_supported()) at the start of the test (inside
test_single_prefill_with_kv_cache_nvfp4) and skip the test when the current GPU
does not support NVFP4 (i.e., when neither is_sm90a_supported() nor
is_sm100a_supported() is true); use pytest.skip or pytest.mark.skipif with a
clear reason so unsupported runners skip cleanly.

---

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 1319-1367: The NVFP4 path folds k_scale into sm_scale but never
applies v_scale, so outputs are wrong when a non-unit V scale is used; fix by
applying v_scale to the produced output for the NVFP4/packed-KV case
(kv_cache_sf != None) — either pass v_scale through into the prefill kernel if
it supports it or multiply out by v_scale after module.run (operate on out),
using the existing symbols k_scale, v_scale, sm_scale, out, kv_cache_sf and v_sf
to detect the packed NVFP4 branch and apply the correct scaling.

---

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 1330-1341: The code incorrectly sets out_head_dim / head_dim_vo
from q.shape[-1] when kv_cache_sf is not None (packed NVFP4); packed storage
width differs from logical V/O width, so stop deriving logical head_dim_vo from
q.shape[-1]. Instead compute out_head_dim = v.shape[-1] (the logical V/O head
dim) regardless of kv_cache_sf, and pass that value as the head_dim_vo argument
to get_single_prefill_module; also update the other symmetric block that mirrors
this logic (the later occurrence handling packed KV) to use v.shape[-1] rather
than q.shape[-1]. Ensure any allocation of out uses this logical out_head_dim
and that get_single_prefill_module receives q.shape[-1] for head_dim_qk and
v.shape[-1] for head_dim_vo.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5b8dcc69-72db-4f1c-ac39-46beb9a4da1e

📥 Commits

Reviewing files that changed from the base of the PR and between 27930f3 and 865f912.

📒 Files selected for processing (7)
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py
  • include/flashinfer/attention/prefill.cuh
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_prefill_kernels.py
  • tests/attention/test_single_prefill.py
  • tests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/jit/attention/modules.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/attention/test_batch_attention.py

Comment thread flashinfer/prefill.py
Comment thread include/flashinfer/attention/prefill.cuh
Comment thread tests/attention/test_batch_prefill_kernels.py
Comment thread tests/attention/test_single_prefill.py
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47354783: 6/20 passed

@Tom-Zheng Tom-Zheng force-pushed the add-sm120-nvfp4-kv-prefill branch from 865f912 to 067bd9d Compare April 1, 2026 03:20
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/prefill.py (2)

599-629: ⚠️ Potential issue | 🟠 Major

Fake op signature missing FP8 scale parameters.

_fake_ragged_run is missing scale_q, scale_k, scale_v parameters that are present in ragged_run (lines 503-505). This will cause signature mismatches during tracing.

🐛 Proposed fix
 `@register_fake_op`(f"flashinfer::{uri}_ragged_run")
 def _fake_ragged_run(
     float_workspace_buffer: torch.Tensor,
     int_workspace_buffer: torch.Tensor,
     plan_info_vec: List[int],
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     qo_indptr: torch.Tensor,
     kv_indptr: torch.Tensor,
     o: torch.Tensor,
     maybe_lse: Optional[torch.Tensor],
     mask_mode: int,
     layout: int,
     window_left: int,
     enable_pdl: bool,
     maybe_custom_mask: Optional[torch.Tensor],
     maybe_mask_indptr: Optional[torch.Tensor],
     maybe_alibi_slopes: Optional[torch.Tensor],
     maybe_prefix_len_ptr: Optional[torch.Tensor],
     maybe_token_pos_in_items_ptr: Optional[torch.Tensor],
     maybe_max_item_len_ptr: Optional[torch.Tensor],
     logits_soft_cap: float,
     sm_scale: float,
     rope_scale: float,
     rope_theta: float,
     token_pos_in_items_len: int,
     maybe_k_cache_sf: Optional[torch.Tensor] = None,
     maybe_v_cache_sf: Optional[torch.Tensor] = None,
+    scale_q: Optional[torch.Tensor] = None,
+    scale_k: Optional[torch.Tensor] = None,
+    scale_v: Optional[torch.Tensor] = None,
 ) -> None:
     pass
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 599 - 629, The fake op _fake_ragged_run
must match the real ragged_run signature: add the FP8 scale parameters scale_q,
scale_k, scale_v to _fake_ragged_run with the same types and positions used in
ragged_run so tracing won't fail; update the function signature for
register_fake_op("flashinfer::{uri}_ragged_run") to include scale_q, scale_k,
scale_v (use the same Optional/torch.Tensor typing and default values as
ragged_run) and leave the body as pass.

421-441: ⚠️ Potential issue | 🟠 Major

Fake op signature missing scale parameters.

_fake_run_single_prefill is missing scale_q, scale_k, scale_v parameters that are present in the real run_single_prefill function (lines 351-353). This signature mismatch can cause issues with torch.compile and other JIT tracing scenarios.

🐛 Proposed fix
 `@register_fake_op`(f"flashinfer::{uri}_run")
 def _fake_run_single_prefill(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     tmp: torch.Tensor,
     o: torch.Tensor,
     maybe_lse: Optional[torch.Tensor],
     mask_mode: int,
     layout: int,
     window_left: int,
     maybe_packed_custom_mask: Optional[torch.Tensor],
     maybe_alibi_slopes: Optional[torch.Tensor],
     logits_soft_cap: float,
     sm_scale: float,
+    scale_q: Optional[torch.Tensor],
+    scale_k: Optional[torch.Tensor],
+    scale_v: Optional[torch.Tensor],
     rope_scale: float,
     rope_theta: float,
     maybe_k_cache_sf: Optional[torch.Tensor] = None,
     maybe_v_cache_sf: Optional[torch.Tensor] = None,
 ) -> None:
     pass
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 421 - 441, The fake op
_fake_run_single_prefill has a signature mismatch: add the missing scale
parameters scale_q, scale_k, scale_v to its parameter list so it exactly matches
the real run_single_prefill signature; ensure the new parameters use the same
names, types/order and defaulting as in run_single_prefill (place them before
maybe_k_cache_sf/maybe_v_cache_sf like the real function) so torch.compile/JIT
tracing sees an identical call signature.
♻️ Duplicate comments (1)
include/flashinfer/attention/prefill.cuh (1)

461-510: ⚠️ Potential issue | 🔴 Critical

Fail fast when NVFP4 scale tensors are missing.

maybe_k_cache_sf / maybe_v_cache_sf default to nullptr, and these helpers still issue pred_load_32b through sf_ptr whenever the FP4 path is instantiated. A legacy or malformed caller will turn that into a null global-memory load.

🛡️ Possible fix
 template <bool produce_v, typename KTraits, typename IdType>
 __device__ __forceinline__ void page_produce_kv_sf(
     typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr,
     const uint32_t packed_page_iter_base, const uint32_t packed_kv_bound,
     const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h,
     const uint32_t kv_stride_n, const uint_fastdiv& page_size, const IdType* indices,
     const uint32_t kv_idx_base, const uint32_t kv_len, const uint32_t warp_idx,
     const uint32_t lane_idx) {
   if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return;
+  if (sf_ptr == nullptr) {
+    FLASHINFER_RUNTIME_ASSERT("NVFP4 KV cache requires block scale tensors.");
+  }
@@
 template <bool produce_v, typename KTraits>
 __device__ __forceinline__ void produce_kv_sf(typename KTraits::SharedStorage* smem_storage,
                                               uint8_t* sf_ptr, const uint32_t kv_abs_base,
                                               const uint32_t kv_head_idx,
                                               const uint32_t kv_stride_n,
                                               const uint32_t kv_stride_h,
                                               const uint32_t kv_idx_base, const uint32_t kv_len,
                                               const uint32_t warp_idx, const uint32_t lane_idx) {
   if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return;
+  if (sf_ptr == nullptr) {
+    FLASHINFER_RUNTIME_ASSERT("NVFP4 KV cache requires block scale tensors.");
+  }

Also applies to: 535-576

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/prefill.cuh` around lines 461 - 510, The kernel
page_produce_kv_sf can attempt pred_load_32b through a null sf_ptr when the FP4
SF cache pointers (maybe_k_cache_sf / maybe_v_cache_sf passed as sf_ptr) are
nullptr; guard against this by checking sf_ptr (or the original
maybe_k_cache_sf/maybe_v_cache_sf) before issuing the cp_async load. Concretely,
inside page_produce_kv_sf (and the analogous block at lines 535-576) update the
in_bounds predicate to also require sf_ptr != nullptr (or return/skip early when
sf_ptr is null) so cp_async::pred_load_32b is only called when sf_ptr is valid.
🧹 Nitpick comments (2)
tests/test_helpers/utils_fp4.py (1)

103-128: Don't hardcode the NVFP4 global scale to 1.0.

All tests that build fixtures through this helper now bypass the new k_scale / v_scale plumbing. A bug in the global-scale path would still pass because both the kernel and the reference see the identity scale.

♻️ Possible tweak
-def create_nvfp4_kv(shape, device):
+def create_nvfp4_kv(shape, device, global_scale=1.0):
@@
-    return packed, sf, torch.tensor(1.0, device=device)
+    return packed, sf, torch.tensor(global_scale, device=device, dtype=torch.float32)
tests/attention/test_batch_attention.py (1)

293-296: Use the shared skip helpers here instead of another raw CC xfail.

If SM120/121 is still unsupported for this case, make it a helper-based skip; otherwise this marker hides failures on the exact architecture the new NVFP4 path needs to cover.

As per coding guidelines, tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_attention.py` around lines 293 - 296, Replace the
raw pytest.mark.xfail with a shared helper-based skip using the utilities in
flashinfer.utils: import get_compute_capability (or the appropriate helper) from
flashinfer.utils and change the marker on the test that currently uses
pytest.mark.xfail(get_compute_capability(torch.device(device="cuda"))[0] == 12,
...) to pytest.mark.skipif(get_compute_capability(torch.device("cuda"))[0] ==
12, reason="SM120/121 unsupported for this test") or, if a dedicated helper
exists (e.g., is_sm120_supported()), use that helper instead to decide skipping;
update the decorator on the test accordingly and remove the raw xfail usage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1147-1149: The v_scale parameter is accepted in the function
signature alongside kv_cache_sf and k_scale but never applied to the produced
output; either remove v_scale or apply it consistently like
BatchPrefillWithPagedKVCacheWrapper.run(). Locate the single-request prefill
function that declares kv_cache_sf, k_scale, v_scale (same signature shown) and
add the same post-output scaling logic used in
BatchPrefillWithPagedKVCacheWrapper.run(): after computing the output, if
v_scale is not None and not is_float_one then multiply the output tensor by
v_scale (or conversely remove v_scale from the signature and all callers if this
mode shouldn't support value scaling). Ensure references to k_scale/kv_cache_sf
behavior remain unchanged.
- Around line 471-473: The mutates_args lists incorrectly include read-only
scale-factor tensors; remove "maybe_k_cache_sf" and "maybe_v_cache_sf" from
ragged_run's mutates_args and remove "key_block_scales" and "value_block_scales"
from paged_run's mutates_args, after confirming kernels do not mutate them (they
only produce new transposed tensors). Update the mutates_args declarations in
the ragged_run and paged_run call sites to omit those symbols, keeping
get_trtllm_gen_prefill_module and run_single_prefill behavior as-is, and run the
existing tests or a quick smoke-run to ensure torch.compile optimizations no
longer treat these tensors as mutated.

In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback _e2m1_and_ufp8sf_scale_to_float_cpu can
receive a flat 1-D ufp8_scale_tensor (length either K/sf_vec_size or
M*(K/sf_vec_size)) and currently treats it as already shaped [M, K/sf_vec_size],
causing wrong broadcasting; before decoding the UFP8 scales, detect and
reshape/expand ufp8_scale_tensor: if ufp8_scale_tensor.dim() == 1 and its
numel() == (m * (k // sf_vec_size)) then view it as (m, k // sf_vec_size); if
numel() == (k // sf_vec_size) then expand/unsqueeze to (m, k // sf_vec_size);
otherwise if dim()==2 ensure its shape matches (m, k // sf_vec_size) and raise a
clear error if not; then proceed with the existing decoding and
repeat_interleave logic using this normalized per-row scale tensor.

---

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 599-629: The fake op _fake_ragged_run must match the real
ragged_run signature: add the FP8 scale parameters scale_q, scale_k, scale_v to
_fake_ragged_run with the same types and positions used in ragged_run so tracing
won't fail; update the function signature for
register_fake_op("flashinfer::{uri}_ragged_run") to include scale_q, scale_k,
scale_v (use the same Optional/torch.Tensor typing and default values as
ragged_run) and leave the body as pass.
- Around line 421-441: The fake op _fake_run_single_prefill has a signature
mismatch: add the missing scale parameters scale_q, scale_k, scale_v to its
parameter list so it exactly matches the real run_single_prefill signature;
ensure the new parameters use the same names, types/order and defaulting as in
run_single_prefill (place them before maybe_k_cache_sf/maybe_v_cache_sf like the
real function) so torch.compile/JIT tracing sees an identical call signature.

---

Duplicate comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 461-510: The kernel page_produce_kv_sf can attempt pred_load_32b
through a null sf_ptr when the FP4 SF cache pointers (maybe_k_cache_sf /
maybe_v_cache_sf passed as sf_ptr) are nullptr; guard against this by checking
sf_ptr (or the original maybe_k_cache_sf/maybe_v_cache_sf) before issuing the
cp_async load. Concretely, inside page_produce_kv_sf (and the analogous block at
lines 535-576) update the in_bounds predicate to also require sf_ptr != nullptr
(or return/skip early when sf_ptr is null) so cp_async::pred_load_32b is only
called when sf_ptr is valid.

---

Nitpick comments:
In `@tests/attention/test_batch_attention.py`:
- Around line 293-296: Replace the raw pytest.mark.xfail with a shared
helper-based skip using the utilities in flashinfer.utils: import
get_compute_capability (or the appropriate helper) from flashinfer.utils and
change the marker on the test that currently uses
pytest.mark.xfail(get_compute_capability(torch.device(device="cuda"))[0] == 12,
...) to pytest.mark.skipif(get_compute_capability(torch.device("cuda"))[0] ==
12, reason="SM120/121 unsupported for this test") or, if a dedicated helper
exists (e.g., is_sm120_supported()), use that helper instead to decide skipping;
update the decorator on the test accordingly and remove the raw xfail usage.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 980616bb-d5f7-4d8c-962b-1548d733a9a3

📥 Commits

Reviewing files that changed from the base of the PR and between 865f912 and 067bd9d.

📒 Files selected for processing (16)
  • flashinfer/attention.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/utils.py
  • include/flashinfer/attention/persistent.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/permuted_smem.cuh
  • include/flashinfer/vec_dtypes.cuh
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_prefill_kernels.py
  • tests/attention/test_single_prefill.py
  • tests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (2)
  • flashinfer/utils.py
  • flashinfer/jit/utils.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • flashinfer/attention.py
  • include/flashinfer/permuted_smem.cuh
  • tests/attention/test_single_prefill.py
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/cp_async.cuh
  • tests/attention/test_batch_prefill_kernels.py
  • include/flashinfer/vec_dtypes.cuh
  • flashinfer/jit/attention/modules.py

Comment thread flashinfer/prefill.py
Comment thread flashinfer/prefill.py
Comment thread flashinfer/quantization/fp4_quantization.py
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #47420565 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47420565: 5/20 passed

@Tom-Zheng Tom-Zheng requested a review from qsang-nv as a code owner April 15, 2026 05:48
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

jit: add torch.float4_e2m1fn_x2 to dtype maps

Add conditional entries for torch.float4_e2m1fn_x2 in
filename_safe_dtype_map ("fp4_e2m1") and dtype_map_kv
("__nv_fp4x2_e2m1") so that BatchPrefillWithPagedKVCacheWrapper
can select the FP4 kernel plan without KeyError when
kv_data_type=float4_e2m1fn_x2 is passed to begin_forward.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

fix batch decode UT

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

fix batch decode function and accuracy; add nvfp4 kv test

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
@Tom-Zheng Tom-Zheng force-pushed the add-sm120-nvfp4-kv-prefill branch from eb98c96 to 7be54dc Compare April 15, 2026 05:52
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #48569185 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)

1437-1452: ⚠️ Potential issue | 🔴 Critical

Remove the prefill-only placeholders from decode's TRT-LLM paged_run call.

get_trtllm_gen_decode_module().paged_run() does not accept the max_q_len / batch_size / cum_seq_lens_* slots you added here. With the current list, self._max_kv_len binds to sinks, sinks binds to uses_shared_paged_kv_idx, and the remaining args overflow the wrapper, so this branch will raise as soon as TRT-LLM decode runs.

Proposed fix
                 run_args += [
                     self._num_qo_heads,
                     self._num_kv_heads,
                     self._block_tables,
                     self._kv_lens_buffer,
                     page_size,
-                    None,  # max_q_len (not applicable for decode)
                     self._max_kv_len,
-                    None,  # batch_size (not applicable for decode)
-                    None,  # cum_seq_lens_q (not applicable for decode)
-                    None,  # cum_seq_lens_kv (not applicable for decode)
                     sinks,
                     key_block_scales,
                     value_block_scales,
                     skip_softmax_threshold_scale_factor,
                     True,  # uses_shared_paged_kv_idx
                 ]

Also applies to: 2085-2128

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1437 - 1452, The decode branch is passing
prefill-only placeholder slots into get_trtllm_gen_decode_module().paged_run(),
causing argument misalignment: remove the None placeholders for max_q_len,
batch_size, cum_seq_lens_q and cum_seq_lens_kv from the run_args list (the
entries between page_size and self._max_kv_len) so the call to paged_run
receives the correct parameters (ensure sinks, key_block_scales,
value_block_scales, skip_softmax_threshold_scale_factor and
uses_shared_paged_kv_idx bind to the intended positions); apply the same removal
in the other decode occurrence that mirrors this block.
♻️ Duplicate comments (1)
include/flashinfer/attention/prefill.cuh (1)

462-516: ⚠️ Potential issue | 🟠 Major

Require non-null SF pointers before issuing FP4 SF loads.

The FP4 SF helpers still form sf_ptr + sf_gmem_offset unconditionally, but all three call sites default maybe_k_cache_sf / maybe_v_cache_sf to nullptr when the params pack does not expose scale tensors. Any FP4 specialization reached without SF tensors will dereference a null base pointer before pred_load_32b can predicate the load.

If missing SF tensors are invalid, fail earlier; if they are meant to be optional, the helpers need an explicit null-handling path.

Also applies to: 542-585, 1644-1651, 2064-2071, 2411-2418

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/attention/prefill.cuh` around lines 462 - 516, The FP4 SF
load logic (e.g., in page_produce_kv_sf) unconditionally computes sf_ptr +
sf_gmem_offset and passes it to cp_async::pred_load_32b, which will dereference
a null when maybe_k_cache_sf/maybe_v_cache_sf are nullptr; add explicit
null-handling: before forming sf_gmem_offset or calling pred_load_32b in
page_produce_kv_sf (and the other FP4 SF helper call sites), test if
sf_ptr==nullptr and either (a) fail fast with a clear error if SF tensors are
required or (b) use a safe no-op path that supplies a dummy/valid pointer and
sets in_bounds=false (so pred_load_32b is fully predicated) when SF is optional;
update all related helpers/call sites (maybe_k_cache_sf, maybe_v_cache_sf, and
other FP4 SF helper functions that call pred_load_32b) accordingly to avoid null
pointer arithmetic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 110-117: Dispatcher SMEM accounting ignores the new FP4
scale-factor buffers in SharedStorageQKVO (k_sf_smem/v_sf_smem), causing
NUM_MMA_KV/DISPATCH_NUM_MMA_KV to pick configs that overflow on FP4 builds;
update the dispatcher math that computes num_ctas_per_sm and max_num_mma_kv_smem
to include the additional bytes of k_sf_smem and v_sf_smem when
is_fp4_type_v<DTypeKV> is true (i.e., add CTA_TILE_KV * HEAD_DIM_QK /
NVFP4_SF_VEC_SIZE and CTA_TILE_KV * HEAD_DIM_VO / NVFP4_SF_VEC_SIZE bytes
respectively, aligned as in SharedStorageQKVO) and apply the same change at the
other occurrences noted (around the other ranges) so the per-CTA dynamic SMEM
budget reflects the SF buffers before DISPATCH_NUM_MMA_KV selection.

In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-691: The test_batch_decode_with_paged_kv_cache_nvfp4 test
lacks a GPU architecture guard; add a skip decorator using is_sm100a_supported()
via `@pytest.mark.skipif`(not is_sm100a_supported(), reason="NVFP4 tests require
SM100+/Blackwell") placed before the existing `@pytest.mark.parametrize`
decorators so the test is skipped on unsupported hardware; locate the test
function name test_batch_decode_with_paged_kv_cache_nvfp4 and add the skipif
decorator consistent with other NVFP4 tests (e.g.,
tests/moe/test_trtllm_cutlass_fused_moe.py).

---

Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 1437-1452: The decode branch is passing prefill-only placeholder
slots into get_trtllm_gen_decode_module().paged_run(), causing argument
misalignment: remove the None placeholders for max_q_len, batch_size,
cum_seq_lens_q and cum_seq_lens_kv from the run_args list (the entries between
page_size and self._max_kv_len) so the call to paged_run receives the correct
parameters (ensure sinks, key_block_scales, value_block_scales,
skip_softmax_threshold_scale_factor and uses_shared_paged_kv_idx bind to the
intended positions); apply the same removal in the other decode occurrence that
mirrors this block.

---

Duplicate comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 462-516: The FP4 SF load logic (e.g., in page_produce_kv_sf)
unconditionally computes sf_ptr + sf_gmem_offset and passes it to
cp_async::pred_load_32b, which will dereference a null when
maybe_k_cache_sf/maybe_v_cache_sf are nullptr; add explicit null-handling:
before forming sf_gmem_offset or calling pred_load_32b in page_produce_kv_sf
(and the other FP4 SF helper call sites), test if sf_ptr==nullptr and either (a)
fail fast with a clear error if SF tensors are required or (b) use a safe no-op
path that supplies a dummy/valid pointer and sets in_bounds=false (so
pred_load_32b is fully predicated) when SF is optional; update all related
helpers/call sites (maybe_k_cache_sf, maybe_v_cache_sf, and other FP4 SF helper
functions that call pred_load_32b) accordingly to avoid null pointer arithmetic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f17c6f7f-060a-4ea6-8ddb-8c31da22da20

📥 Commits

Reviewing files that changed from the base of the PR and between 067bd9d and eb98c96.

📒 Files selected for processing (4)
  • flashinfer/decode.py
  • flashinfer/jit/utils.py
  • include/flashinfer/attention/prefill.cuh
  • tests/attention/test_batch_decode_kernels.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/jit/utils.py

Comment thread include/flashinfer/attention/prefill.cuh
Comment thread tests/attention/test_batch_decode_kernels.py
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #48569551 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (6)
flashinfer/quantization/fp4_quantization.py (1)

101-145: ⚠️ Potential issue | 🟠 Major

Normalize flat scale buffers before this fallback.

The public dequant path still accepts flattened ufp8_scale_tensor, but this implementation treats it as already shaped [M, K // sf_vec_size]. On < SM90, a 1-D unswizzled scale buffer will now either broadcast incorrectly or fail once repeat_interleave(..., dim=-1) runs.

Suggested fix
 def _e2m1_and_ufp8sf_scale_to_float_cpu(
     e2m1_tensor: torch.Tensor,
     ufp8_scale_tensor: torch.Tensor,
     global_scale_tensor: Optional[torch.Tensor],
     sf_vec_size: int,
     ufp8_type: int,
     is_sf_swizzled_layout: bool,
 ) -> torch.Tensor:
@@
     device = e2m1_tensor.device
     m, k_half = e2m1_tensor.shape
     k = k_half * 2
+    expected_sf_cols = k // sf_vec_size
+
+    if ufp8_scale_tensor.dim() == 1:
+        if ufp8_scale_tensor.numel() == expected_sf_cols:
+            ufp8_scale_tensor = ufp8_scale_tensor.unsqueeze(0).expand(m, -1)
+        elif ufp8_scale_tensor.numel() == m * expected_sf_cols:
+            ufp8_scale_tensor = ufp8_scale_tensor.reshape(m, expected_sf_cols)
+        else:
+            raise ValueError(
+                f"Expected {expected_sf_cols} or {m * expected_sf_cols} scale values, "
+                f"got {ufp8_scale_tensor.numel()}"
+            )
+    elif tuple(ufp8_scale_tensor.shape) != (m, expected_sf_cols):
+        raise ValueError(
+            f"Expected scale tensor shape {(m, expected_sf_cols)}, "
+            f"got {tuple(ufp8_scale_tensor.shape)}"
+        )
 
     # Unpack two E2M1 nibbles per byte: low nibble = even indices, high nibble = odd
     fp4_vals = torch.empty(m, k, dtype=torch.uint8, device=device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 101 - 145, The CPU
fallback _e2m1_and_ufp8sf_scale_to_float_cpu assumes ufp8_scale_tensor is shaped
[M, K//sf_vec_size]; normalize flattened inputs first: inside
_e2m1_and_ufp8sf_scale_to_float_cpu, detect if ufp8_scale_tensor.dim() == 1 and
if so, if its length == sf_len (where sf_len = k // sf_vec_size) repeat it
across the batch to shape [m, sf_len]; if its length == m * sf_len reshape it to
[m, sf_len]; ensure the tensor is moved to the same device/dtype before later
ops so the later sf_float, repeat_interleave, and broadcasting use the correct
shape and device.
tests/attention/test_batch_decode_kernels.py (1)

684-691: ⚠️ Potential issue | 🟠 Major

Add the missing architecture skip for the NVFP4 decode test.

This new test still has no capability guard, so unsupported GPUs will fail before the decode assertions are reached. Please add a pytest.mark.skipif(...) using the NVFP4 helper/API capability check.

As per coding guidelines "Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc)"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_decode_kernels.py` around lines 684 - 691, Add a
skip guard to the test_batch_decode_with_paged_kv_cache_nvfp4 test so it doesn't
run on unsupported GPUs: import and use the NVFP4 capability check (e.g.,
flashinfer.utils.is_nvfp4_supported() or the equivalent API method) in a
pytest.mark.skipif(...) decorator above the test function to skip when the check
returns False; ensure the decorator message explains it's skipping due to
missing NVFP4 support.
flashinfer/prefill.py (4)

3322-3325: ⚠️ Potential issue | 🟠 Major

Guard ragged q_scale / k_scale folding on the backend.

This still multiplies q_scale and k_scale into sm_scale before dispatch, but the cuDNN branch below also receives both scalars separately. Ragged cuDNN calls will be double-scaled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3322 - 3325, The code multiplies q_scale
and k_scale into sm_scale unconditionally, but the cuDNN branch below also
receives q_scale and k_scale separately causing double-scaling for ragged cuDNN
calls; change the logic around sm_scale, q_scale, and k_scale so you only fold
(multiply) q_scale/k_scale into sm_scale when not using the cuDNN path (or
alternatively stop passing q_scale/k_scale separately to cuDNN), i.e., gate the
sm_scale *= q_scale and sm_scale *= k_scale operations behind the condition that
selects the non-cuDNN backend (use the same branch/flag used for dispatch to
cuDNN), and ensure the cuDNN branch receives either folded sm_scale or the
separate q_scale/k_scale but not both.

1148-1150: ⚠️ Potential issue | 🟠 Major

v_scale is still ignored in single-request prefill.

k_scale is folded into sm_scale, but v_scale never affects out. That makes single_prefill_with_kv_cache() numerically inconsistent with the paged wrapper and with its own signature.

Also applies to: 1354-1379

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1148 - 1150, single_prefill_with_kv_cache
ignores v_scale (only folds k_scale into sm_scale), causing numerical
inconsistency with the paged wrapper and the function signature; fix by applying
v_scale to the value tensor before it's used to compute out (mirror how k_scale
was folded into sm_scale) inside single_prefill_with_kv_cache: when kv_cache_sf
(and v_scale) are present, scale the v component (from kv_cache_sf or the
produced v) by v_scale (or incorporate it into existing sm_scale logic) so that
out uses the scaled values; update any related paths where kv_cache_sf is
unpacked and where out is computed so both k_scale and v_scale affect the final
output consistently.

1337-1347: ⚠️ Potential issue | 🔴 Critical

Don’t infer packed V/O width from q.shape[-1].

kv_cache_sf only tells you the V cache is packed; it does not guarantee head_dim_vo == head_dim_qk. These branches still allocate the output with Q’s width and, for the single-request path, JIT the wrong head_dim_vo specialization for asymmetric QK/VO configs.

Also applies to: 2334-2346, 3347-3361

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1337 - 1347, The code incorrectly infers
V/O packed width from q.shape[-1]; instead compute out_head_dim from v.shape[-1]
(not q.shape[-1]) and use that value when allocating out and when calling
get_single_prefill_module so the JIT specialization gets the correct head_dim_vo
for asymmetric QK/VO configs (update the out = torch.empty(...) allocation and
the get_single_prefill_module(...) call that currently passes q.shape[-1] to use
out_head_dim/v.shape[-1]); apply the same fix to the other similar branches
where out_head_dim is derived from q.shape[-1].

467-476: ⚠️ Potential issue | 🟡 Minor

Drop the scale-factor tensors from mutates_args.

maybe_k_cache_sf / maybe_v_cache_sf and key_block_scales / value_block_scales are forwarded as read-only inputs. Marking them as mutated pessimizes torch.compile alias analysis for no gain.

Also applies to: 636-647

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 467 - 476, The decorator call to
register_custom_op currently lists maybe_k_cache_sf and maybe_v_cache_sf (and
likewise key_block_scales/value_block_scales in the other occurrence) inside
mutates_args which falsely marks these tensors as mutated; remove
maybe_k_cache_sf and maybe_v_cache_sf from the mutates_args tuple (and remove
key_block_scales/value_block_scales from the duplicate occurrence) so they are
treated as read-only inputs by torch.compile, leaving only genuinely mutated
buffers (e.g., float_workspace_buffer, int_workspace_buffer, o, maybe_lse) in
mutates_args.
🧹 Nitpick comments (1)
include/flashinfer/permuted_smem.cuh (1)

176-181: Document the partial-copy contract here.

load_64b_async() writes a 64-bit source into a 128-bit shared-memory slot via pred_load_128b_from_64b. A one-line note about which half is populated and why this path was chosen over a dedicated 64-bit SMEM layout would make the later FP4 loader much harder to misuse.

As per coding guidelines "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/permuted_smem.cuh` around lines 176 - 181, Add a one-line
comment in load_64b_async documenting the partial-copy contract: state that
calling cp_async::pred_load_128b_from_64b with a 64-bit source writes the 64-bit
value into one half of the 128-bit shared-memory slot (specify which half is
populated and that the other half is left unchanged/undefined), and briefly
justify why this 128-bit SMEM path is used (to reuse existing b128_t/128-bit
alignment and avoid a separate 64-bit SMEM layout) and note the alternative
considered (a dedicated 64-bit SMEM layout) so later users of load_64b_async,
b128_t, base and cp_async::pred_load_128b_from_64b cannot misuse the
partial-copy behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/attention.py`:
- Around line 149-151: The code currently accepts packed NVFP4/uint8 KV caches
while allowing kv_block_scales to be None, which yields incorrect attention
results; update the functions/methods that accept the parameter kv_block_scales
(the signature showing "kv_block_scales: Optional[Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor]]] = None" and the other similar occurrence
around lines 182-187) to validate inputs and raise a clear error when a
uint8/packed KV cache is provided but kv_block_scales is None; detect the dtype
(torch.uint8) or packed-NVFP4 indicator on the kv tensors and throw a ValueError
with an explanatory message requiring per-block scales instead of silently
proceeding.

In `@flashinfer/decode.py`:
- Around line 1458-1468: The run_args list passed to
get_trtllm_gen_decode_module().paged_run() in flashinfer/decode.py currently
includes four prefill-only None placeholders after page_size, which overflows
the trtllm-gen decode wrapper; update the run_args construction in the decode
path (the code that appends self._num_qo_heads, self._num_kv_heads,
self._block_tables, self._kv_lens_buffer, page_size, ...) to only include
page_size and then self._max_kv_len (remove the subsequent None entries for
max_q_len, batch_size, cum_seq_lens_q, cum_seq_lens_kv) so the argument list
matches the paged_run() decode signature.

In `@flashinfer/prefill.py`:
- Around line 1326-1333: The code currently unpacks kv_cache_sf into k_sf,v_sf
but does not reject packed NVFP4 KV when kv_cache_sf is missing; add a fail-fast
check in the prefill path: if kv_cache_sf is None and the KV cache tensors (the
K and V tensors used in this function) have dtype torch.uint8 (packed NVFP4),
raise a clear exception (ValueError) instead of proceeding with null
scale-factor pointers; apply the same guard in the other identical prefill spot
referenced (the second occurrence corresponding to the other block used by
BatchPrefillWithPagedKVCacheWrapper.run and trtllm_batch_context_with_kv_cache)
so both paths reject packed uint8 KV when kv_cache_sf is absent and avoid
incorrect dequantization.

In `@tests/attention/test_batch_attention.py`:
- Around line 293-306: The NVFP4-specific test test_batch_attention_nvfp4 must
be gated so it skips on unsupported GPU architectures; update the test to check
the appropriate capability via flashinfer.utils (e.g., call the relevant
is_smXX_supported() helper such as is_sm90a_supported()/is_sm100a_supported() or
use the API method is_compute_capability_supported(cc)) at the start of the test
and call pytest.skip with a clear message when the capability is absent so the
fixture setup (NVFP4 kernels) is not attempted on unsupported GPUs.

---

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 3322-3325: The code multiplies q_scale and k_scale into sm_scale
unconditionally, but the cuDNN branch below also receives q_scale and k_scale
separately causing double-scaling for ragged cuDNN calls; change the logic
around sm_scale, q_scale, and k_scale so you only fold (multiply)
q_scale/k_scale into sm_scale when not using the cuDNN path (or alternatively
stop passing q_scale/k_scale separately to cuDNN), i.e., gate the sm_scale *=
q_scale and sm_scale *= k_scale operations behind the condition that selects the
non-cuDNN backend (use the same branch/flag used for dispatch to cuDNN), and
ensure the cuDNN branch receives either folded sm_scale or the separate
q_scale/k_scale but not both.
- Around line 1148-1150: single_prefill_with_kv_cache ignores v_scale (only
folds k_scale into sm_scale), causing numerical inconsistency with the paged
wrapper and the function signature; fix by applying v_scale to the value tensor
before it's used to compute out (mirror how k_scale was folded into sm_scale)
inside single_prefill_with_kv_cache: when kv_cache_sf (and v_scale) are present,
scale the v component (from kv_cache_sf or the produced v) by v_scale (or
incorporate it into existing sm_scale logic) so that out uses the scaled values;
update any related paths where kv_cache_sf is unpacked and where out is computed
so both k_scale and v_scale affect the final output consistently.
- Around line 1337-1347: The code incorrectly infers V/O packed width from
q.shape[-1]; instead compute out_head_dim from v.shape[-1] (not q.shape[-1]) and
use that value when allocating out and when calling get_single_prefill_module so
the JIT specialization gets the correct head_dim_vo for asymmetric QK/VO configs
(update the out = torch.empty(...) allocation and the
get_single_prefill_module(...) call that currently passes q.shape[-1] to use
out_head_dim/v.shape[-1]); apply the same fix to the other similar branches
where out_head_dim is derived from q.shape[-1].
- Around line 467-476: The decorator call to register_custom_op currently lists
maybe_k_cache_sf and maybe_v_cache_sf (and likewise
key_block_scales/value_block_scales in the other occurrence) inside mutates_args
which falsely marks these tensors as mutated; remove maybe_k_cache_sf and
maybe_v_cache_sf from the mutates_args tuple (and remove
key_block_scales/value_block_scales from the duplicate occurrence) so they are
treated as read-only inputs by torch.compile, leaving only genuinely mutated
buffers (e.g., float_workspace_buffer, int_workspace_buffer, o, maybe_lse) in
mutates_args.

In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback _e2m1_and_ufp8sf_scale_to_float_cpu
assumes ufp8_scale_tensor is shaped [M, K//sf_vec_size]; normalize flattened
inputs first: inside _e2m1_and_ufp8sf_scale_to_float_cpu, detect if
ufp8_scale_tensor.dim() == 1 and if so, if its length == sf_len (where sf_len =
k // sf_vec_size) repeat it across the batch to shape [m, sf_len]; if its length
== m * sf_len reshape it to [m, sf_len]; ensure the tensor is moved to the same
device/dtype before later ops so the later sf_float, repeat_interleave, and
broadcasting use the correct shape and device.

In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-691: Add a skip guard to the
test_batch_decode_with_paged_kv_cache_nvfp4 test so it doesn't run on
unsupported GPUs: import and use the NVFP4 capability check (e.g.,
flashinfer.utils.is_nvfp4_supported() or the equivalent API method) in a
pytest.mark.skipif(...) decorator above the test function to skip when the check
returns False; ensure the decorator message explains it's skipping due to
missing NVFP4 support.

---

Nitpick comments:
In `@include/flashinfer/permuted_smem.cuh`:
- Around line 176-181: Add a one-line comment in load_64b_async documenting the
partial-copy contract: state that calling cp_async::pred_load_128b_from_64b with
a 64-bit source writes the 64-bit value into one half of the 128-bit
shared-memory slot (specify which half is populated and that the other half is
left unchanged/undefined), and briefly justify why this 128-bit SMEM path is
used (to reuse existing b128_t/128-bit alignment and avoid a separate 64-bit
SMEM layout) and note the alternative considered (a dedicated 64-bit SMEM
layout) so later users of load_64b_async, b128_t, base and
cp_async::pred_load_128b_from_64b cannot misuse the partial-copy behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 67117767-fd5d-42a6-90fc-eac3b4d24d08

📥 Commits

Reviewing files that changed from the base of the PR and between eb98c96 and 6ea7fde.

📒 Files selected for processing (18)
  • flashinfer/attention.py
  • flashinfer/decode.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/utils.py
  • include/flashinfer/attention/persistent.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/permuted_smem.cuh
  • include/flashinfer/vec_dtypes.cuh
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_decode_kernels.py
  • tests/attention/test_batch_prefill_kernels.py
  • tests/attention/test_single_prefill.py
  • tests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (2)
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/vec_dtypes.cuh
🚧 Files skipped from review as they are similar to previous changes (5)
  • flashinfer/utils.py
  • tests/test_helpers/utils_fp4.py
  • flashinfer/jit/utils.py
  • tests/attention/test_single_prefill.py
  • flashinfer/jit/attention/modules.py

Comment thread flashinfer/attention.py Outdated
Comment thread flashinfer/decode.py
Comment thread flashinfer/prefill.py
Comment thread tests/attention/test_batch_attention.py
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !431 has been updated with latest changes, and the CI pipeline #48591867 is currently running. I'll report back once the pipeline job completes.

@Tom-Zheng Tom-Zheng closed this Apr 17, 2026
@Tom-Zheng
Copy link
Copy Markdown
Contributor Author

Transfer to #3097

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants