Skip to content

Support NVFP4 KV for prefill and batch attention kernels#3097

Open
Tom-Zheng wants to merge 11 commits intoflashinfer-ai:mainfrom
Tom-Zheng:add-sm120-nvfp4-kv-prefill-v2
Open

Support NVFP4 KV for prefill and batch attention kernels#3097
Tom-Zheng wants to merge 11 commits intoflashinfer-ai:mainfrom
Tom-Zheng:add-sm120-nvfp4-kv-prefill-v2

Conversation

@Tom-Zheng
Copy link
Copy Markdown
Contributor

@Tom-Zheng Tom-Zheng commented Apr 17, 2026

📌 Description

This MR supports NVFP4 KV input for batch prefill and batch attention kernels. It widely supports all arch (SM80+).

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • NVFP4 (packed FP4) KV-cache support across prefill, attention, and decode with optional per-block scale-factor inputs and expanded accepted formats for scale tensors.
    • New CPU dequantization fallback for older GPUs and broader FP4 vectorized handling.
  • Bug Fixes

    • Runtime guard disables an incompatible backend for packed-KV (uint8) cases.
  • Tests

    • Added extensive NVFP4 unit tests and test helpers for prefill, attention, and decode.

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

jit: add torch.float4_e2m1fn_x2 to dtype maps

Add conditional entries for torch.float4_e2m1fn_x2 in
filename_safe_dtype_map ("fp4_e2m1") and dtype_map_kv
("__nv_fp4x2_e2m1") so that BatchPrefillWithPagedKVCacheWrapper
can select the FP4 kernel plan without KeyError when
kv_data_type=float4_e2m1fn_x2 is passed to begin_forward.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

fix batch decode UT

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>

fix batch decode function and accuracy; add nvfp4 kv test

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds NVFP4 (uint8-packed FP4) KV-cache support across APIs, JIT generators, CUDA kernels, quantization (with CPU fallback), backend validation, and tests; public signatures accept stacked or tuple KV-scale tensors and thread them through kernels and generated modules.

Changes

Cohort / File(s) Summary
API — attention/decode/prefill
flashinfer/attention.py, flashinfer/prefill.py, flashinfer/decode.py
Extended public run signatures to accept kv_cache_sf as either a stacked tensor or (k_sf, v_sf) tuple; unified unpacking via _unpack_paged_kv_cache(...) and forwarded per-K/V scale tensors into module/kernel calls and backend-specific run argument layouts.
JIT generator / dtype maps
flashinfer/jit/attention/modules.py, flashinfer/jit/utils.py
Added dtype_map_kv and FP4-aware filename token; wired optional maybe_k_cache_sf/maybe_v_cache_sf into generated attention/prefill modules and customization setters; adjusted KV dtype template usages.
Prefill/attention kernel headers
include/flashinfer/attention/prefill.cuh, include/flashinfer/attention/persistent.cuh
Integrated FP4-packed KV loads and per-block KV scale-factor caches: new SF loaders (*_produce_kv_sf), SF-aware shared-memory layout, adjusted GMEM→SMEM pointer arithmetic and compute callsites to accept per-warp SF offsets and apply per-element scaling.
Low-level CUDA helpers
include/flashinfer/cp_async.cuh, include/flashinfer/frag_layout_swizzle.cuh, include/flashinfer/permuted_smem.cuh, include/flashinfer/vec_dtypes.cuh
Added predicated async-load helpers (pred_load_128b_from_64b, pred_load_32b), fragment swizzle helpers (16b→4b, transposed), smem_t::load_64b_async, and vec_cast specializations to dequantize __nv_fp4x2_e2m1half/bf16 (PTX or LUT fallbacks).
Quantization / CPU fallback
flashinfer/quantization/fp4_quantization.py
Added CPU dequantization fallback _e2m1_and_ufp8sf_scale_to_float_cpu(...) with E2M1 LUT and per-device LUT cache; dispatches to CPU fallback when device compute capability < 90.
Backend validation
flashinfer/utils.py
is_fa3_backend_supported(...) now rejects FA3 backend when dtype_kv == torch.uint8 (NVFP4 KV).
Tests & helpers
tests/test_helpers/utils_fp4.py, tests/attention/test_single_prefill.py, tests/attention/test_batch_prefill_kernels.py, tests/attention/test_batch_attention.py, tests/attention/test_batch_decode_kernels.py
Added NVFP4 test utilities (create_nvfp4_kv, nvfp4_to_float) and multiple NVFP4-parametrized tests (single prefill, batch prefill ragged/paged, batch attention, batch decode) with relaxed tolerances and workspace/setup adjustments.
Other C++/CUDA plumbing
csrc/xqa/utils.cuh, include/flashinfer/permuted_smem.cuh, include/flashinfer/frag_layout_swizzle.cuh, include/flashinfer/vec_dtypes.cuh, include/flashinfer/cp_async.cuh
Adjusted compile-time gating, added swizzle/vec/async helpers, and small compile-condition fixes to enable optimized conversions and FP4 paths.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant API as flashinfer API
    participant JIT as JIT Generator
    participant Kernel as CUDA Kernel
    participant Quant as Quant/CpuFallback

    User->>API: call prefill/decode/attention with `kv_cache_sf`
    API->>API: normalize/unpack `kv_cache_sf` (tuple or stacked tensor)
    API->>JIT: plan/module with `kv_data_type=uint8` and SF tensor wiring
    JIT->>Kernel: emit kernels that accept maybe_k/v_cache_sf and SF offsets
    API->>Kernel: invoke kernel with packed FP4 bytes + kv_cache_sf
    Kernel->>Quant: load packed FP4 + sf bytes, call dequant (CUDA or CPU fallback)
    Kernel->>Kernel: apply per-element scaling and compute attention
    Kernel->>User: return attention outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • sricketts
  • yongwww
  • bkryu
  • jimmyzho
  • nv-yunzheq
  • samuellees
  • kahyunnam

Poem

🐇 I nibbled bytes, packed tiny flecks of light,

Swizzled nibbles under stars of CUDA night.
I threaded scales through kernels, soft and spry—
KV caches hop and spark; the GPUs sigh.
Hop, little cache, into fast compute delight.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.84% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support NVFP4 KV for prefill and batch attention kernels' clearly and concisely describes the main change: adding NVFP4 KV support to prefill and batch attention kernels.
Description check ✅ Passed The description addresses the template requirements: includes what the PR does (NVFP4 KV support for batch prefill/attention), confirms tests added/passing, and indicates pre-commit checks completed.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 KV cache across FlashInfer, including prefill and decode kernels. Key changes include the addition of scale factor parameters to attention APIs, updates to JIT module generation, and the implementation of CUDA kernels for packed FP4 data. Review feedback identified a critical bug in decode.py where the paged_kv_cache argument was incorrectly made conditional, potentially breaking default backends. Additionally, a discrepancy was noted in a code comment regarding shared memory alignment assumptions for certain head dimensions.

Comment thread flashinfer/decode.py Outdated
Comment thread include/flashinfer/attention/prefill.cuh Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/attention.py (1)

163-187: ⚠️ Potential issue | 🟠 Major

Require kv_cache_sf for NVFP4 KV inputs.

This path now forwards k_cache_sf / v_cache_sf, but it never rejects NVFP4 KV runs when the scales are omitted. flashinfer/prefill.py:2276-2278 already guards the same condition; without that check here, torch.uint8/native FP4 KV can still flow into the kernel with None scale tensors.

💡 Proposed fix
         k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout)
+        needs_kv_sf = k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8
+        if hasattr(torch, "float4_e2m1fn_x2"):
+            needs_kv_sf = needs_kv_sf or (
+                k_cache.dtype == torch.float4_e2m1fn_x2
+                or v_cache.dtype == torch.float4_e2m1fn_x2
+            )
+        if needs_kv_sf and kv_cache_sf is None:
+            raise ValueError("kv_cache_sf must be provided for NVFP4 KV cache.")
+
         if out is None:
             out = torch.empty_like(q)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/attention.py` around lines 163 - 187, The NVFP4 path currently
unpacks kv_cache_sf into k_cache_sf/v_cache_sf but doesn't enforce that k_scale
and v_scale are provided; add a guard after unpacking (where
_unpack_paged_kv_cache produces k_cache_sf, v_cache_sf) that checks if either
k_cache_sf or v_cache_sf is not None and either k_scale or v_scale is None, and
raise a clear ValueError (or AssertionError) rejecting NVFP4 KV inputs without
their corresponding scale tensors; reference the unpack step and variables
k_cache_sf, v_cache_sf, k_scale, and v_scale so the check is placed right after
the unpack and before any NVFP4 data can flow into the kernel.
flashinfer/decode.py (1)

1454-1474: ⚠️ Potential issue | 🔴 Critical

Remove the four None placeholders that don't exist in the trtllm-gen decode signature.

The get_trtllm_gen_decode_module().paged_run(...) function signature does not include max_q_len, batch_size, cum_seq_lens_q, or cum_seq_lens_kv parameters. Passing these four None values shifts max_kv_len, sinks, key_block_scales, and value_block_scales to wrong parameter slots, causing argument misalignment and runtime failure.

Suggested fix
                if self._backend == "trtllm-gen":
                    # decode.py's trtllm-gen paged_run (get_trtllm_gen_decode_module)
                    # has a different optional-param layout than prefill.py's paged_run
                    run_args += [paged_kv_cache]

                run_args += [
                    self._num_qo_heads,
                    self._num_kv_heads,
                    self._block_tables,
                    self._kv_lens_buffer,
                    page_size,
-                    None,  # max_q_len (not applicable for decode)
                    self._max_kv_len,
-                    None,  # batch_size (not applicable for decode)
-                    None,  # cum_seq_lens_q (not applicable for decode)
-                    None,  # cum_seq_lens_kv (not applicable for decode)
                    sinks,
                    key_block_scales,
                    value_block_scales,
                    skip_softmax_threshold_scale_factor,
                    True,  # uses_shared_paged_kv_idx
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1454 - 1474, The trtllm-gen decode path is
passing four extra None placeholders into run_args which don't exist in
get_trtllm_gen_decode_module().paged_run, causing argument misalignment; update
the code that builds run_args in decode.py (the branch where self._backend ==
"trtllm-gen" and the subsequent run_args += [...]) by removing the four None
entries for max_q_len, batch_size, cum_seq_lens_q and cum_seq_lens_kv so that
self._max_kv_len, sinks, key_block_scales and value_block_scales align with the
trtllm-gen paged_run signature. Ensure the special-case comment for trtllm-gen
remains and run_args ordering matches the trtllm-gen paged_run parameter list.
flashinfer/prefill.py (1)

467-476: ⚠️ Potential issue | 🟠 Major

mutates_args incorrectly lists read-only scale-factor tensors.

maybe_k_cache_sf / maybe_v_cache_sf (ragged) and key_block_scales / value_block_scales (paged) are inputs to the attention kernels, not outputs — the kernels only dequantize from them. Declaring them in mutates_args tells torch.library they are written by the op, which:

  • breaks aliasing/functionalization assumptions under torch.compile (spurious data dependencies, forced rematerialization / cloning),
  • is inconsistent with how other FP8 scale tensors (scale_q/k/v, fp8_scale_*) are wired into the same ops.

Please remove these from mutates_args in both ragged_run and paged_run registrations.

🔧 Suggested fix
     `@register_custom_op`(
         f"flashinfer::{uri}_ragged_run",
         mutates_args=(
             "float_workspace_buffer",
             "int_workspace_buffer",
             "o",
             "maybe_lse",
-            "maybe_k_cache_sf",
-            "maybe_v_cache_sf",
         ),
     )
     `@register_custom_op`(
         f"flashinfer::{uri}_paged_run",
         mutates_args=(
             "float_workspace_buffer",
             "int_workspace_buffer",
             "paged_k_cache",
             "paged_v_cache",
             "o",
             "maybe_lse",
-            "key_block_scales",
-            "value_block_scales",
         ),
     )

Also applies to: 636-647

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 467 - 476, Remove read-only scale tensors
from the op mutation lists: in the register_custom_op call for the ragged_run
variant (symbol name contains "flashinfer::{uri}_ragged_run") delete
maybe_k_cache_sf and maybe_v_cache_sf from the mutates_args tuple; similarly, in
the paged_run registration (the other register_custom_op block that references
key_block_scales and value_block_scales) remove key_block_scales and
value_block_scales from mutates_args so only truly written buffers (e.g.,
float_workspace_buffer, int_workspace_buffer, o, maybe_lse) remain listed; keep
the tensors as regular inputs to the op registration rather than declaring them
as mutated.
🧹 Nitpick comments (4)
flashinfer/prefill.py (2)

2354-2359: Use logging.warning instead of print for the NHD→HND conversion warning.

logging is already imported (Line 18). print(...) bypasses user-configured log levels/filters and writes to stdout, which is noisy in library code and tends to pollute serving-stack logs.

🔧 Suggested fix
-                print(
-                    "[WARNING] NVFP4 KV cache with NHD layout will be converted to HND, "
-                    "incurring extra transpose and contiguous copy overhead. "
-                    "Use kv_layout='HND' for better performance."
-                )
+                logging.warning(
+                    "NVFP4 KV cache with NHD layout will be converted to HND, "
+                    "incurring extra transpose and contiguous copy overhead. "
+                    "Use kv_layout='HND' for better performance."
+                )

Also consider emitting the warning only once per process (e.g. warnings.warn with a category, or a module-level guard), otherwise it will fire on every layer/step of a long run.

Also applies to: 4073-4078

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 2354 - 2359, Replace the print-based
warning at the NHD→HND conversion site (the block checking key_block_scales)
with a proper logging call (logging.warning) so it honors user log
configuration; update the message there and the similar occurrences around the
other site (the block at ~4073-4078) to use logging.warning instead of print.
Also consider making the warning emit only once per process by using
warnings.warn with an appropriate Warning subclass or a module-level boolean
guard to avoid repeated spam during long runs.

1326-1329: Inconsistent kv_cache_sf unpacking/validation across prefill entry points.

Three call sites now unpack kv_cache_sf and they all diverge:

  • single_prefill_with_kv_cache (Lines 1326-1329): assumes tuple, blind k_sf, v_sf = kv_cache_sf — a stacked tensor or a list will fail with an opaque error.
  • BatchPrefillWithRaggedKVCacheWrapper.run (Lines 3342-3346): checks isinstance(kv_cache_sf, tuple) only (misses list), and otherwise blindly calls .unbind(dim=1) without a TypeError fallback.
  • BatchPrefillWithPagedKVCacheWrapper.run (Lines 2282-2291) and trtllm_batch_context_with_kv_cache (Lines 4058-4067): correctly accept (tuple, list) or a stacked tensor, with a TypeError for anything else — matching the decode pattern.

Please align the single-prefill and ragged paths with the paged pattern (accept tuple/list or a torch.Tensor of shape [num_pages, 2, ...], else raise TypeError). Factoring this into a small helper (e.g., _unpack_kv_cache_sf) would also remove the four near-duplicate blocks across this file and decode.py/attention.py.

Also applies to: 3340-3346

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1326 - 1329, Align kv_cache_sf handling
by extracting a small helper (e.g., _unpack_kv_cache_sf) and replace the four
near-duplicate unpack blocks in single_prefill_with_kv_cache,
BatchPrefillWithRaggedKVCacheWrapper.run,
BatchPrefillWithPagedKVCacheWrapper.run, and trtllm_batch_context_with_kv_cache
to call it; the helper should accept either a tuple/list of (k_sf, v_sf) and
return them, or accept a torch.Tensor of shape [num_pages, 2, ...] and
split/unbind dim=1 into (k_sf, v_sf), and otherwise raise a TypeError with a
clear message — this ensures consistent behavior for tuples, lists, and stacked
tensors and removes duplicate logic across the file (and similar blocks in
decode.py/attention.py).
include/flashinfer/cp_async.cuh (2)

215-220: Leftover debug markers in fallback path.

Commented-out sentinel values (0xcd..., 0xefef...) look like debug artifacts. Consider removing to keep the fallback clean.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/cp_async.cuh` around lines 215 - 220, Remove the leftover
commented sentinel debug markers in the shared-memory fallback path: delete the
commented-out lines containing 0xcdcd... and 0xefef... so only the intended
zero-fill logic remains (keep the existing smem_u64[1] = 0; and the
*((uint4*)smem_ptr) = make_uint4(0,0,0,0); under SharedMemFillMode::kFillZero).
This cleans up artifacts around smem_u64 and smem_ptr handling without changing
the zero-fill behavior.

191-193: prefetch_mode template parameter is silently ignored by the fast path.

Unlike pred_load_128b, this helper never branches on prefetch_mode (no .L2::128B variant is emitted, and cp.async.ca has no such hint). Callers such as smem_t::load_64b_async<...> that pass PrefetchMode::kPrefetch will not get prefetching. Consider either dropping the template parameter to make that explicit, or documenting that prefetch is a no-op for 64b loads.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/cp_async.cuh` around lines 191 - 193, The template
parameter prefetch_mode on pred_load_128b_from_64b is unused so callers (e.g.,
smem_t::load_64b_async) passing PrefetchMode::kPrefetch get no prefetch
behavior; either remove the prefetch_mode template parameter from
pred_load_128b_from_64b to make prefetch a no-op explicit, or implement a branch
on prefetch_mode (matching pred_load_128b's branching) and emit the appropriate
prefetch variant when prefetch_mode==PrefetchMode::kPrefetch (so callers like
smem_t::load_64b_async actually trigger prefetch), and update/delete callers or
comments accordingly to keep behavior consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 1301-1309: The stacked-tensor branch for kv_cache_sf assumes dim 1
equals 2 before calling kv_cache_sf.unbind(dim=1); validate that kv_cache_sf is
a tensor with at least 2 dimensions and that kv_cache_sf.size(1) == 2 (and
optionally include kv_cache_sf.dim() and kv_cache_sf.shape in the error) before
unbinding, and if the check fails raise a TypeError with a clear message stating
the expected shape [num_pages, 2, ...] and the actual shape; update the
torch.is_tensor branch around the kv_cache_sf.unbind(dim=1) call to perform this
check and include the actual tensor shape in the raised error.

In `@flashinfer/prefill.py`:
- Around line 1148-1150: The function single_prefill_with_kv_cache now accepts
kv_cache_sf, k_scale, and v_scale but v_scale is ignored; update
single_prefill_with_kv_cache to either (preferred) apply v_scale to the output
tensor after the kernel (mirroring BatchPrefillWithPagedKVCacheWrapper.run where
out *= v_scale) and ensure kv_cache_sf and k_scale are folded/used consistently
into sm_scale, or (if v_scale not supported) remove v_scale from the signature
and raise NotImplementedError when callers pass it; additionally add docstring
entries for kv_cache_sf, k_scale, and v_scale describing the expected
scale-factor layout (match the paged wrapper docstring) so callers know the
semantics.
- Around line 3321-3326: The ragged-path double-scales sm_scale by multiplying
in q_scale and k_scale before passing it as attn_scale to
cudnn_batch_prefill_with_kv_cache; fix by only applying sm_scale *= q_scale /
k_scale when the backend is not cudnn (i.e., wrap the q_scale/k_scale
multiplications in if self._backend != "cudnn":) so cudnn receives unadjusted
attn_scale while other backends keep the existing descaling; apply the same
guard in the second cudnn call site (the other block that prepares sm_scale
before calling cudnn_batch_prefill_with_kv_cache) so q_scale/k_scale are not
applied twice.

In `@flashinfer/utils.py`:
- Around line 423-425: The current guard only checks for torch.uint8 (dtype_kv)
and misses native FP4; update the FA3-exclusion guard to also return False when
dtype_kv is the native FP4 dtype (torch.float4_e2m1fn_x2) so native FP4 KVs are
routed off FA3 like uint8; modify the condition around dtype_kv (the if block
that returns False for FA3) to include a check for torch.float4_e2m1fn_x2
alongside torch.uint8.

In `@include/flashinfer/attention/prefill.cuh`:
- Around line 110-117: The SMEM heuristics must include the FP4 scale-factor
buffers k_sf_smem and v_sf_smem when computing SharedStorage-related limits:
update the calculations that produce num_ctas_per_sm and max_num_mma_kv_smem
(and any logic that picks NUM_MMA_KV) to add the conditional size of the new
scale-factor arrays (the same conditional expression used for
k_sf_smem/v_sf_smem based on is_fp4_type_v<DTypeKV> and the widths CTA_TILE_KV *
HEAD_DIM_QK / NVFP4_SF_VEC_SIZE and CTA_TILE_KV * HEAD_DIM_VO /
NVFP4_SF_VEC_SIZE) into the total shared-memory per-CTA estimate so the
dispatcher uses the real SharedStorage size (the value later passed to
cudaFuncSetAttribute) when deciding CTAs and NUM_MMA_KV.

In `@include/flashinfer/cp_async.cuh`:
- Around line 194-223: The PTX path currently issues cp.async with cp-size=8 so
the upper 8 bytes of the 16-byte SMEM slot remain uninitialized; update both
inline-asm cp.async invocations in the SharedMemFillMode handling (the block
using smem_int_ptr and the kNoFill asm block) to use cp-size=16 (i.e., change
the "n"(8) cp-size immediates to "n"(16)) while keeping src-size as 8 (or 0 when
predicate is false), so cp.async zero-fills the upper 8 bytes and matches the
fallback that writes smem_u64[1]=0; also remove or fix the misleading comment
that "cp.async always zeros the upper 8 bytes" to reflect that zeroing only
applies within cp-size.

In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 486-487: The CUDA version gate in
include/flashinfer/vec_dtypes.cuh currently checks (__CUDACC_VER_MAJOR__ >= 13)
&& (__CUDACC_VER_MINOR__ >= 2) which fails for newer majors (e.g., 14.x); change
the condition to allow any major greater than 13 or major equal to 13 with minor
>= 2 (i.e., use a combined check of __CUDACC_VER_MAJOR__ and
__CUDACC_VER_MINOR__ so CUDA 13.2+ and all CUDA 14+ are accepted) to ensure the
native cvt.rn.bf16x2.e2m1x2 path is used; update the `#if` surrounding that macro
check accordingly and keep the same `#endif` and related code paths (look for
__CUDACC_VER_MAJOR__, __CUDACC_VER_MINOR__, and the cvt.rn.bf16x2.e2m1x2 usage
locations).

In `@tests/attention/test_batch_attention.py`:
- Around line 293-424: The test test_batch_attention_nvfp4 currently only xfails
SM120/121 and should instead skip on unsupported GPU architectures; update the
test to early-skip using the project's utility/API (e.g.,
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() or
flashinfer.api_name.is_compute_capability_supported(cc)) before any CUDA work is
done, so that test_batch_attention_nvfp4 returns pytest.skip(...) when the
current device compute capability is not supported; place the check at the top
of test_batch_attention_nvfp4 (before torch.manual_seed and tensor creation) and
reference those utility calls to determine support.

In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-810: The NVFP4 test
(test_batch_decode_with_paged_kv_cache_nvfp4) must be skipped on unsupported
GPUs; add an early guard that checks compute capability support (e.g., using
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() or
api_name.is_compute_capability_supported(cc)) and call pytest.skip if
NVFP4/tensor-core paths aren't available before creating workspace/setting
use_tensor_cores=True and calling wrapper.plan/run; place the check near the
start of the test function so the hard-coded NVFP4 tensor-core path is never
executed on unsupported architectures.

In `@tests/attention/test_single_prefill.py`:
- Around line 145-154: The test calls
flashinfer.prefill.single_prefill_with_kv_cache but passes only k_scale
(k_global_scale.item()); add the missing V-side global scale by passing
v_scale=v_global_scale.item() (or v_global_scale if non-tensor) into the same
call so the V scaling path is exercised; update the call site in
tests/attention/test_single_prefill.py where single_prefill_with_kv_cache(...)
is invoked to include the v_scale named argument (the other args like
kv_cache_sf=(k_sf, v_sf) stay unchanged).

---

Outside diff comments:
In `@flashinfer/attention.py`:
- Around line 163-187: The NVFP4 path currently unpacks kv_cache_sf into
k_cache_sf/v_cache_sf but doesn't enforce that k_scale and v_scale are provided;
add a guard after unpacking (where _unpack_paged_kv_cache produces k_cache_sf,
v_cache_sf) that checks if either k_cache_sf or v_cache_sf is not None and
either k_scale or v_scale is None, and raise a clear ValueError (or
AssertionError) rejecting NVFP4 KV inputs without their corresponding scale
tensors; reference the unpack step and variables k_cache_sf, v_cache_sf,
k_scale, and v_scale so the check is placed right after the unpack and before
any NVFP4 data can flow into the kernel.

In `@flashinfer/decode.py`:
- Around line 1454-1474: The trtllm-gen decode path is passing four extra None
placeholders into run_args which don't exist in
get_trtllm_gen_decode_module().paged_run, causing argument misalignment; update
the code that builds run_args in decode.py (the branch where self._backend ==
"trtllm-gen" and the subsequent run_args += [...]) by removing the four None
entries for max_q_len, batch_size, cum_seq_lens_q and cum_seq_lens_kv so that
self._max_kv_len, sinks, key_block_scales and value_block_scales align with the
trtllm-gen paged_run signature. Ensure the special-case comment for trtllm-gen
remains and run_args ordering matches the trtllm-gen paged_run parameter list.

In `@flashinfer/prefill.py`:
- Around line 467-476: Remove read-only scale tensors from the op mutation
lists: in the register_custom_op call for the ragged_run variant (symbol name
contains "flashinfer::{uri}_ragged_run") delete maybe_k_cache_sf and
maybe_v_cache_sf from the mutates_args tuple; similarly, in the paged_run
registration (the other register_custom_op block that references
key_block_scales and value_block_scales) remove key_block_scales and
value_block_scales from mutates_args so only truly written buffers (e.g.,
float_workspace_buffer, int_workspace_buffer, o, maybe_lse) remain listed; keep
the tensors as regular inputs to the op registration rather than declaring them
as mutated.

---

Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 2354-2359: Replace the print-based warning at the NHD→HND
conversion site (the block checking key_block_scales) with a proper logging call
(logging.warning) so it honors user log configuration; update the message there
and the similar occurrences around the other site (the block at ~4073-4078) to
use logging.warning instead of print. Also consider making the warning emit only
once per process by using warnings.warn with an appropriate Warning subclass or
a module-level boolean guard to avoid repeated spam during long runs.
- Around line 1326-1329: Align kv_cache_sf handling by extracting a small helper
(e.g., _unpack_kv_cache_sf) and replace the four near-duplicate unpack blocks in
single_prefill_with_kv_cache, BatchPrefillWithRaggedKVCacheWrapper.run,
BatchPrefillWithPagedKVCacheWrapper.run, and trtllm_batch_context_with_kv_cache
to call it; the helper should accept either a tuple/list of (k_sf, v_sf) and
return them, or accept a torch.Tensor of shape [num_pages, 2, ...] and
split/unbind dim=1 into (k_sf, v_sf), and otherwise raise a TypeError with a
clear message — this ensures consistent behavior for tuples, lists, and stacked
tensors and removes duplicate logic across the file (and similar blocks in
decode.py/attention.py).

In `@include/flashinfer/cp_async.cuh`:
- Around line 215-220: Remove the leftover commented sentinel debug markers in
the shared-memory fallback path: delete the commented-out lines containing
0xcdcd... and 0xefef... so only the intended zero-fill logic remains (keep the
existing smem_u64[1] = 0; and the *((uint4*)smem_ptr) = make_uint4(0,0,0,0);
under SharedMemFillMode::kFillZero). This cleans up artifacts around smem_u64
and smem_ptr handling without changing the zero-fill behavior.
- Around line 191-193: The template parameter prefetch_mode on
pred_load_128b_from_64b is unused so callers (e.g., smem_t::load_64b_async)
passing PrefetchMode::kPrefetch get no prefetch behavior; either remove the
prefetch_mode template parameter from pred_load_128b_from_64b to make prefetch a
no-op explicit, or implement a branch on prefetch_mode (matching
pred_load_128b's branching) and emit the appropriate prefetch variant when
prefetch_mode==PrefetchMode::kPrefetch (so callers like smem_t::load_64b_async
actually trigger prefetch), and update/delete callers or comments accordingly to
keep behavior consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a0310754-826d-4f9d-931b-dcd44974a801

📥 Commits

Reviewing files that changed from the base of the PR and between 0e18a1c and 5d6f6fd.

📒 Files selected for processing (18)
  • flashinfer/attention.py
  • flashinfer/decode.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/jit/utils.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/utils.py
  • include/flashinfer/attention/persistent.cuh
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh
  • include/flashinfer/frag_layout_swizzle.cuh
  • include/flashinfer/permuted_smem.cuh
  • include/flashinfer/vec_dtypes.cuh
  • tests/attention/test_batch_attention.py
  • tests/attention/test_batch_decode_kernels.py
  • tests/attention/test_batch_prefill_kernels.py
  • tests/attention/test_single_prefill.py
  • tests/test_helpers/utils_fp4.py

Comment thread flashinfer/decode.py Outdated
Comment thread flashinfer/prefill.py
Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/utils.py
Comment thread include/flashinfer/attention/prefill.cuh
Comment thread include/flashinfer/cp_async.cuh
Comment thread include/flashinfer/vec_dtypes.cuh Outdated
Comment thread tests/attention/test_batch_attention.py
Comment thread tests/attention/test_batch_decode_kernels.py
Comment thread tests/attention/test_single_prefill.py
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Comment thread flashinfer/decode.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/quantization/fp4_quantization.py
Comment thread flashinfer/prefill.py
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

3466-3520: ⚠️ Potential issue | 🔴 Critical

Critical: NVFP4 k_scale/q_scale fold into sm_scale is a no-op in the ragged path.

run_args is built at Lines 3466-3511 with the current sm_scale float value already appended. The subsequent sm_scale *= q_scale / sm_scale *= k_scale at Lines 3517-3520 only rebind the local sm_scale variable — since Python floats are immutable, the value already inside run_args is not updated. The kernel therefore receives the unscaled sm_scale, silently producing wrong numerics for NVFP4 KV whenever a non-unit k_scale/q_scale is passed.

The fold must happen before sm_scale is placed into run_args (mirroring the paged wrapper at Lines 2333-2337 and single_prefill_with_kv_cache at Lines 1338-1339).

🔧 Proposed fix
         if return_lse:
             ...
+        # For NVFP4 KV, fuse q_scale/k_scale into sm_scale before packing run_args
+        if kv_cache_sf is not None:
+            if q_scale is not None:
+                sm_scale *= q_scale
+            if k_scale is not None:
+                sm_scale *= k_scale
+
         # Unpack kv_cache_sf for NVFP4 ragged KV
         k_sf, v_sf = None, None
         ...
         run_args += [
             ...
             logits_soft_cap,
             sm_scale,
             ...
         ]
         # For FP8, append scale tensors
         if is_float8(q):
             run_args.extend(list(args))  # scale_q, scale_k, scale_v
-        # For NVFP4 KV, fuse k_scale into sm_scale
-        elif kv_cache_sf is not None:
-            if q_scale is not None:
-                sm_scale *= q_scale
-            if k_scale is not None:
-                sm_scale *= k_scale
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3466 - 3520, The ragged-path builds
run_args and appends the current sm_scale before folding q_scale/k_scale, so
multiply q_scale and k_scale into sm_scale (when kv_cache_sf is not None and
q_scale/k_scale are not None) before sm_scale is appended to run_args; modify
the else branch around run_args construction in prefill.py to perform the NVFP4
fold (sm_scale *= q_scale and sm_scale *= k_scale) prior to extending run_args
with sm_scale (mirroring the paged wrapper and single_prefill_with_kv_cache
behavior), leaving the existing is_float8 and kv_cache_sf checks but relocating
the multiplication earlier so the kernel receives the updated sm_scale.
♻️ Duplicate comments (1)
include/flashinfer/cp_async.cuh (1)

194-222: ⚠️ Potential issue | 🟠 Major

PTX fast path leaves upper 8 bytes of the 128-bit SMEM slot uninitialized (divergence from docstring & fallback).

The docstring (lines 186-189) and the #else fallback (lines 213-218) both guarantee the upper 64 bits of the 128-bit destination are zero. The PTX path, however, issues cp.async.ca.shared.global with cp-size = 8; per PTX semantics zero-fill only applies in the range [src-size, cp-size), so with cp-size=8 nothing beyond byte 7 is written. All three meaningful cases (kFillZero pred=true/false, kNoFill pred=true) therefore leave bytes 8..15 indeterminate, which will silently corrupt NVFP4 KV loads that rely on the documented zero-padding.

The comment on line 202 (“cp.async always zeros the upper 8 bytes”) is also incorrect — zero-fill is bounded by cp-size, not the destination slot size.

Fix by using cp-size = 16 while keeping src-size = 8 (or 0 when predicate is false), which makes PTX zero-fill bytes 8..15 and aligns with the fallback.

🛠️ Proposed fix
   if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
     int src_in_bytes = predicate ? 8 : 0;
     asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
-                 "l"(gmem_ptr), "n"(8), "r"(src_in_bytes));
-
+                 "l"(gmem_ptr), "n"(16), "r"(src_in_bytes));
   } else {
-    // kNoFill: only issue the copy if predicate is true; cp.async always zeros the upper 8 bytes
+    // kNoFill: only issue the copy if predicate is true. cp-size=16 with src-size=8 makes
+    // cp.async zero-fill the upper 8 bytes of the 128-bit destination slot.
     asm volatile(
         "{\n"
         " .reg .pred p;\n"
         " setp.ne.b32 p, %0, 0;\n"
         " `@p` cp.async.ca.shared.global [%1], [%2], %3, %4;\n"
         "}\n" ::"r"((int)predicate),
-        "r"(smem_int_ptr), "l"(gmem_ptr), "n"(8), "n"(8));
+        "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16), "n"(8));
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/cp_async.cuh` around lines 194 - 222, The PTX branch for
cp.async (in the code handling SharedMemFillMode and predicate around
smem_ptr/gmem_ptr) uses cp-size=8 which leaves bytes 8..15 uninitialized; change
the cp.async invocations so cp-size=16 while keeping src-size as 8 (or 0 when
predicate is false) so the hardware zero-fills the upper 8 bytes to match the
docstring and the fallback path; update both asm blocks (the if constexpr
fill_mode==kFillZero path and the else kNoFill path) to pass 16 as the cp-size
operand and ensure the src-size operand remains 8 or 0 accordingly.
🧹 Nitpick comments (2)
flashinfer/decode.py (1)

1343-1352: Use warnings.warn instead of print to avoid per-call spam.

This warning sits on the per-call run() path and will print on every forward pass for every layer when NVFP4 KV is used with NHD. warnings.warn integrates with Python's logging/filter machinery and is deduplicated by default. The same pattern at lines 2511–2515 in trtllm_batch_decode_with_kv_cache should be updated together.

🔧 Proposed change
+import warnings
@@
             if key_block_scales is not None:
-                print(
-                    "[WARNING] NVFP4 KV cache with NHD layout will be converted to HND, "
-                    "incurring extra transpose and contiguous copy overhead. "
-                    "Use kv_layout='HND' for better performance."
-                )
+                warnings.warn(
+                    "NVFP4 KV cache with NHD layout will be converted to HND, "
+                    "incurring extra transpose and contiguous copy overhead. "
+                    "Use kv_layout='HND' for better performance.",
+                    stacklevel=2,
+                )

And analogously at lines 2511–2515.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1343 - 1352, Replace the print warning in
the NVFP4 KV cache NHD->HND conversion with warnings.warn so it's integrated
with Python's warning system and can be filtered/deduplicated; specifically, in
the run() path where key_block_scales, k_cache, v_cache, and value_block_scales
are being transposed/contiguified, call warnings.warn(...) instead of
print(...), and make the analogous change in trtllm_batch_decode_with_kv_cache
for the same message to keep behavior consistent across both code paths.
flashinfer/prefill.py (1)

2371-2380: Use warnings.warn (or the module logger) instead of print for the NHD→HND conversion warning.

print cannot be filtered/suppressed by users and mixes with stdout of the hosting service. The rest of this module uses logging; prefer warnings.warn(..., stacklevel=2) so callers can control it. Same issue at Lines 4096-4100 in trtllm_batch_context_with_kv_cache.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 2371 - 2380, Replace the inline print
used when converting NVFP4 KV cache NHD→HND with a proper warning via
warnings.warn(..., stacklevel=2) (or use the module logger) so callers can
filter/suppress it; specifically update the block that checks key_block_scales
is not None where k_cache, v_cache, key_block_scales and value_block_scales are
made contiguous/transposed to call warnings.warn with the same message and
stacklevel=2 instead of print, and make the same change in the analogous block
inside trtllm_batch_context_with_kv_cache (the other occurrence around the
key/value cache transpose).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback fails when ufp8_scale_tensor is 1D
because sf_float stays 1D and repeat_interleave produces a [K] vector that
cannot broadcast to float_vals [M, K]; fix this by normalizing sf_float to shape
[M, K/sf_vec_size] before expanding: after decoding into sf_float (the result of
view(torch.float8_e4m3fn).float() or torch.pow(...)), compute sf_len = k //
sf_vec_size and if sf_float.dim() == 1 then reshape/expand it to (m, sf_len) via
unsqueeze(0).expand(m, -1); if it is 2D but has shape (1, sf_len) also expand to
(m, sf_len); then continue with sf_expanded =
sf_float.repeat_interleave(sf_vec_size, dim=-1) and the rest of
_e2m1_and_ufp8sf_scale_to_float_cpu unchanged so broadcasting matches
float_vals.

---

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3466-3520: The ragged-path builds run_args and appends the current
sm_scale before folding q_scale/k_scale, so multiply q_scale and k_scale into
sm_scale (when kv_cache_sf is not None and q_scale/k_scale are not None) before
sm_scale is appended to run_args; modify the else branch around run_args
construction in prefill.py to perform the NVFP4 fold (sm_scale *= q_scale and
sm_scale *= k_scale) prior to extending run_args with sm_scale (mirroring the
paged wrapper and single_prefill_with_kv_cache behavior), leaving the existing
is_float8 and kv_cache_sf checks but relocating the multiplication earlier so
the kernel receives the updated sm_scale.

---

Duplicate comments:
In `@include/flashinfer/cp_async.cuh`:
- Around line 194-222: The PTX branch for cp.async (in the code handling
SharedMemFillMode and predicate around smem_ptr/gmem_ptr) uses cp-size=8 which
leaves bytes 8..15 uninitialized; change the cp.async invocations so cp-size=16
while keeping src-size as 8 (or 0 when predicate is false) so the hardware
zero-fills the upper 8 bytes to match the docstring and the fallback path;
update both asm blocks (the if constexpr fill_mode==kFillZero path and the else
kNoFill path) to pass 16 as the cp-size operand and ensure the src-size operand
remains 8 or 0 accordingly.

---

Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 1343-1352: Replace the print warning in the NVFP4 KV cache
NHD->HND conversion with warnings.warn so it's integrated with Python's warning
system and can be filtered/deduplicated; specifically, in the run() path where
key_block_scales, k_cache, v_cache, and value_block_scales are being
transposed/contiguified, call warnings.warn(...) instead of print(...), and make
the analogous change in trtllm_batch_decode_with_kv_cache for the same message
to keep behavior consistent across both code paths.

In `@flashinfer/prefill.py`:
- Around line 2371-2380: Replace the inline print used when converting NVFP4 KV
cache NHD→HND with a proper warning via warnings.warn(..., stacklevel=2) (or use
the module logger) so callers can filter/suppress it; specifically update the
block that checks key_block_scales is not None where k_cache, v_cache,
key_block_scales and value_block_scales are made contiguous/transposed to call
warnings.warn with the same message and stacklevel=2 instead of print, and make
the same change in the analogous block inside trtllm_batch_context_with_kv_cache
(the other occurrence around the key/value cache transpose).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cbd4e1b7-3e62-4bdc-bfe5-0a1422f36591

📥 Commits

Reviewing files that changed from the base of the PR and between 7b481ac and ed840c8.

📒 Files selected for processing (5)
  • flashinfer/decode.py
  • flashinfer/prefill.py
  • flashinfer/quantization/fp4_quantization.py
  • include/flashinfer/attention/prefill.cuh
  • include/flashinfer/cp_async.cuh

Comment thread flashinfer/quantization/fp4_quantization.py
.
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)

3267-3308: ⚠️ Potential issue | 🟡 Minor

Document kv_cache_sf in BatchPrefillWithRaggedKVCacheWrapper.run docstring.

The new kv_cache_sf parameter is missing from the docstring (Parameters block at Lines 3274-3300). Users won't know the accepted shapes/layout/dtype (float8_e4m3fn, stacked-tensor vs tuple form, etc.) — the paged wrapper's docstring at Lines 2249-2268 is the natural template. Also worth noting in docs: unlike the paged/single paths, post-kernel v_scale is applied only when kv_cache_sf is not None (Lines 3526-3528), so callers using v_scale with non-NVFP4 inputs on this wrapper will silently get unscaled output.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3267 - 3308, Add documentation for the
new kv_cache_sf parameter in BatchPrefillWithRaggedKVCacheWrapper.run: describe
accepted types and layouts (Optional[Union[torch.Tensor, Tuple[torch.Tensor,
torch.Tensor]]], explain stacked-tensor vs tuple form, expected dtypes like
float8_e4m3fn and NVFP4 specifics, and the tensor shapes/ordering consistent
with k/v inputs; also note the behavioral difference that post-kernel v_scale is
applied only when kv_cache_sf is not None (so callers using v_scale with
non-NVFP4 inputs on this wrapper will not get scaled output), mirroring the
style and details used in the paged wrapper docstring for clarity.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)

636-645: ⚠️ Potential issue | 🟠 Major

key_block_scales / value_block_scales should not be in paged_run's mutates_args.

These tensors are read-only load sources in produce_kv_sf / page_produce_kv_sf; marking them as mutated forces unnecessary functionalization copies under torch.compile and inhibits CUDA graph capture / re-use. Note the analogous additions in run_single_prefill (Line 338) and ragged_run (Line 469) correctly leave maybe_k_cache_sf / maybe_v_cache_sf out of mutates_args — this paged path is the outlier. This matches prior feedback that was reportedly addressed but reappears here.

🔧 Suggested fix
         mutates_args=(
             "float_workspace_buffer",
             "int_workspace_buffer",
             "paged_k_cache",
             "paged_v_cache",
             "o",
             "maybe_lse",
-            "key_block_scales",
-            "value_block_scales",
         ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 636 - 645, In paged_run remove
key_block_scales and value_block_scales from the mutates_args tuple so they are
treated as read-only (they are load-only sources used by produce_kv_sf /
page_produce_kv_sf); mirror the approach used in run_single_prefill and
ragged_run where maybe_k_cache_sf / maybe_v_cache_sf are not listed as mutated.
Update the mutates_args declaration in paged_run (remove "key_block_scales",
"value_block_scales") and verify there are no other places in paged_run that
rely on those names being reported as mutated.
🧹 Nitpick comments (2)
include/flashinfer/vec_dtypes.cuh (2)

498-511: Verify the fp16→bf16 fallback round-trip is always lossless for e2m1 values.

Each e2m1 value ∈ {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} is exactly representable in both fp16 and bf16, so the fallback __half22float2__float22bfloat162_rn round-trip produces the same bit pattern as the native cvt.rn.bf16x2.e2m1x2 output. This is correct.

Minor nit: the fallback could skip the PTX fp16 intermediate and directly use the LUT (lines 518–535) for one table lookup instead of a PTX cvt + two fp conversions. Only applicable if the fallback gets exercised on SM100+ with CUDA 13.0–13.1 (before the code path switches to the native instruction in 13.2+) — otherwise the current implementation is fine.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/vec_dtypes.cuh` around lines 498 - 511, The fallback
currently converts e2m1->fp16 via the inline asm into fp16x2, then to float and
to bf16 (__half2 h2 / __float22bfloat162_rn) before storing y; replace that
three-step conversion with a direct lookup into the existing e2m1→bf16 lookup
table used elsewhere (the same LUT referenced in the surrounding code) so you
produce y in one table lookup instead of doing the asm cvt +
__half2/__float22bfloat162_rn sequence; update the code paths that set fp16x2 /
h2 / bf16x2 to instead index the LUT with the original input (variable b) and
write the LUT result into y.

428-443: Align vec_cast with the batched extraction pattern used in csrc/xqa/utils.cuh for consistency and performance.

The identical batched-extraction optimization already exists in csrc/xqa/utils.cuh::convertKCacheWordToF16<half, __nv_fp4_e2m1> and its bf16 variant: both use mov.b32 {byte0, _, byte2, _} to extract two valid fp4 bytes per word, then invoke two separate cvt.rn.f16x2.e2m1x2 instructions on the extracted .b8 registers. This reduces the loop count from vec_size / 2 to vec_size / 4 and halves the extraction overhead on this KV-dequant hot path.

Apply the same pattern to both fp16 and bf16x2 branches in vec_cast<half, __nv_fp4x2_e2m1>::cast (lines 428–443) and vec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast (lines 482–513).

♻️ Sketch of the batched form (fp16 branch)
-#pragma unroll
-    for (size_t i = 0; i < vec_size / 2; ++i) {
-      uint32_t y;
-      // Valid fp4x2 bytes are at even positions (stride 2); odd positions are padding.
-      uint32_t b = reinterpret_cast<const uint8_t*>(src)[i * 2];
-      asm volatile(
-          "{\n"
-          ".reg .b8 fp4_byte;\n"
-          "mov.b32 {fp4_byte, _, _, _}, %1;\n"
-          "cvt.rn.f16x2.e2m1x2 %0, fp4_byte;\n"
-          "}"
-          : "=r"(y)
-          : "r"(b));
-      reinterpret_cast<uint32_t*>(dst)[i] = y;
-    }
+    // Batch extract two valid bytes per word (stride-2 layout: valid at positions 0,2).
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      uint32_t word = reinterpret_cast<const uint32_t*>(src)[i];
+      uint32_t y0, y1;
+      asm volatile(
+          "{\n"
+          ".reg .b8 b0, b2;\n"
+          "mov.b32 {b0, _, b2, _}, %2;\n"
+          "cvt.rn.f16x2.e2m1x2 %0, b0;\n"
+          "cvt.rn.f16x2.e2m1x2 %1, b2;\n"
+          "}"
+          : "=r"(y0), "=r"(y1)
+          : "r"(word));
+      reinterpret_cast<uint32_t*>(dst)[i * 2 + 0] = y0;
+      reinterpret_cast<uint32_t*>(dst)[i * 2 + 1] = y1;
+    }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/vec_dtypes.cuh` around lines 428 - 443, The current fp16
and bf16x2 branches in vec_cast<half, __nv_fp4x2_e2m1>::cast and
vec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast extract one valid fp4 byte per
iteration; change them to the batched extraction used in
convertKCacheWordToF16<half, __nv_fp4_e2m1> by loading a 32-bit word with
mov.b32 {byte0, _, byte2, _} to obtain two valid fp4 bytes, then invoke two
cvt.rn.f16x2.e2m1x2 (or appropriate bf16 convert) operations on the two
extracted .b8 registers, reduce the loop range from vec_size/2 to vec_size/4,
and store two uint32_t results per loop iteration to halve extraction overhead
on the hot KV-dequant path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3267-3308: Add documentation for the new kv_cache_sf parameter in
BatchPrefillWithRaggedKVCacheWrapper.run: describe accepted types and layouts
(Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], explain
stacked-tensor vs tuple form, expected dtypes like float8_e4m3fn and NVFP4
specifics, and the tensor shapes/ordering consistent with k/v inputs; also note
the behavioral difference that post-kernel v_scale is applied only when
kv_cache_sf is not None (so callers using v_scale with non-NVFP4 inputs on this
wrapper will not get scaled output), mirroring the style and details used in the
paged wrapper docstring for clarity.

---

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 636-645: In paged_run remove key_block_scales and
value_block_scales from the mutates_args tuple so they are treated as read-only
(they are load-only sources used by produce_kv_sf / page_produce_kv_sf); mirror
the approach used in run_single_prefill and ragged_run where maybe_k_cache_sf /
maybe_v_cache_sf are not listed as mutated. Update the mutates_args declaration
in paged_run (remove "key_block_scales", "value_block_scales") and verify there
are no other places in paged_run that rely on those names being reported as
mutated.

---

Nitpick comments:
In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 498-511: The fallback currently converts e2m1->fp16 via the inline
asm into fp16x2, then to float and to bf16 (__half2 h2 / __float22bfloat162_rn)
before storing y; replace that three-step conversion with a direct lookup into
the existing e2m1→bf16 lookup table used elsewhere (the same LUT referenced in
the surrounding code) so you produce y in one table lookup instead of doing the
asm cvt + __half2/__float22bfloat162_rn sequence; update the code paths that set
fp16x2 / h2 / bf16x2 to instead index the LUT with the original input (variable
b) and write the LUT result into y.
- Around line 428-443: The current fp16 and bf16x2 branches in vec_cast<half,
__nv_fp4x2_e2m1>::cast and vec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast extract
one valid fp4 byte per iteration; change them to the batched extraction used in
convertKCacheWordToF16<half, __nv_fp4_e2m1> by loading a 32-bit word with
mov.b32 {byte0, _, byte2, _} to obtain two valid fp4 bytes, then invoke two
cvt.rn.f16x2.e2m1x2 (or appropriate bf16 convert) operations on the two
extracted .b8 registers, reduce the loop range from vec_size/2 to vec_size/4,
and store two uint32_t results per loop iteration to halve extraction overhead
on the hot KV-dequant path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f92d400a-15e0-48d4-b0d6-a6e4ea771ddb

📥 Commits

Reviewing files that changed from the base of the PR and between ed840c8 and 7f0073e.

📒 Files selected for processing (3)
  • csrc/xqa/utils.cuh
  • flashinfer/prefill.py
  • include/flashinfer/vec_dtypes.cuh

Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_single_prefill.py (1)

112-132: Dead-code skip: causal is parametrized only to False.

@pytest.mark.parametrize("causal", [False]) makes the if qo_len > kv_len and causal branch unreachable. Either expand the causal parametrization to include True (and verify the NVFP4 kernel supports it) or drop the skip and the causal parameter entirely to keep the test matrix honest.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_single_prefill.py` around lines 112 - 132, The test
parametrizes causal only as False making the conditional "if qo_len > kv_len and
causal" dead code; update test_single_prefill_with_kv_cache_nvfp4 by either
adding True to the pytest.mark.parametrize("causal", [False]) list (and ensure
NVFP4 kernel supports causal behavior) or remove the causal parameter and the
associated skip check entirely so the branch is not unreachable; adjust any
related test matrix or assumptions accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_single_prefill.py`:
- Around line 112-132: The test parametrizes causal only as False making the
conditional "if qo_len > kv_len and causal" dead code; update
test_single_prefill_with_kv_cache_nvfp4 by either adding True to the
pytest.mark.parametrize("causal", [False]) list (and ensure NVFP4 kernel
supports causal behavior) or remove the causal parameter and the associated skip
check entirely so the branch is not unreachable; adjust any related test matrix
or assumptions accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 37f05332-993d-4acf-8aca-133f0474232c

📥 Commits

Reviewing files that changed from the base of the PR and between 7f0073e and cfd6387.

📒 Files selected for processing (1)
  • tests/attention/test_single_prefill.py

.
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qsang-nv
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !580 has been created, and the CI pipeline #49155734 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants