Support NVFP4 KV for prefill and batch attention kernels#3097
Support NVFP4 KV for prefill and batch attention kernels#3097Tom-Zheng wants to merge 11 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
jit: add torch.float4_e2m1fn_x2 to dtype maps
Add conditional entries for torch.float4_e2m1fn_x2 in
filename_safe_dtype_map ("fp4_e2m1") and dtype_map_kv
("__nv_fp4x2_e2m1") so that BatchPrefillWithPagedKVCacheWrapper
can select the FP4 kernel plan without KeyError when
kv_data_type=float4_e2m1fn_x2 is passed to begin_forward.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
fix batch decode UT
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
fix batch decode function and accuracy; add nvfp4 kv test
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds NVFP4 (uint8-packed FP4) KV-cache support across APIs, JIT generators, CUDA kernels, quantization (with CPU fallback), backend validation, and tests; public signatures accept stacked or tuple KV-scale tensors and thread them through kernels and generated modules. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant API as flashinfer API
participant JIT as JIT Generator
participant Kernel as CUDA Kernel
participant Quant as Quant/CpuFallback
User->>API: call prefill/decode/attention with `kv_cache_sf`
API->>API: normalize/unpack `kv_cache_sf` (tuple or stacked tensor)
API->>JIT: plan/module with `kv_data_type=uint8` and SF tensor wiring
JIT->>Kernel: emit kernels that accept maybe_k/v_cache_sf and SF offsets
API->>Kernel: invoke kernel with packed FP4 bytes + kv_cache_sf
Kernel->>Quant: load packed FP4 + sf bytes, call dequant (CUDA or CPU fallback)
Kernel->>Kernel: apply per-element scaling and compute attention
Kernel->>User: return attention outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 KV cache across FlashInfer, including prefill and decode kernels. Key changes include the addition of scale factor parameters to attention APIs, updates to JIT module generation, and the implementation of CUDA kernels for packed FP4 data. Review feedback identified a critical bug in decode.py where the paged_kv_cache argument was incorrectly made conditional, potentially breaking default backends. Additionally, a discrepancy was noted in a code comment regarding shared memory alignment assumptions for certain head dimensions.
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/attention.py (1)
163-187:⚠️ Potential issue | 🟠 MajorRequire
kv_cache_sffor NVFP4 KV inputs.This path now forwards
k_cache_sf/v_cache_sf, but it never rejects NVFP4 KV runs when the scales are omitted.flashinfer/prefill.py:2276-2278already guards the same condition; without that check here,torch.uint8/native FP4 KV can still flow into the kernel withNonescale tensors.💡 Proposed fix
k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout) + needs_kv_sf = k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8 + if hasattr(torch, "float4_e2m1fn_x2"): + needs_kv_sf = needs_kv_sf or ( + k_cache.dtype == torch.float4_e2m1fn_x2 + or v_cache.dtype == torch.float4_e2m1fn_x2 + ) + if needs_kv_sf and kv_cache_sf is None: + raise ValueError("kv_cache_sf must be provided for NVFP4 KV cache.") + if out is None: out = torch.empty_like(q)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/attention.py` around lines 163 - 187, The NVFP4 path currently unpacks kv_cache_sf into k_cache_sf/v_cache_sf but doesn't enforce that k_scale and v_scale are provided; add a guard after unpacking (where _unpack_paged_kv_cache produces k_cache_sf, v_cache_sf) that checks if either k_cache_sf or v_cache_sf is not None and either k_scale or v_scale is None, and raise a clear ValueError (or AssertionError) rejecting NVFP4 KV inputs without their corresponding scale tensors; reference the unpack step and variables k_cache_sf, v_cache_sf, k_scale, and v_scale so the check is placed right after the unpack and before any NVFP4 data can flow into the kernel.flashinfer/decode.py (1)
1454-1474:⚠️ Potential issue | 🔴 CriticalRemove the four
Noneplaceholders that don't exist in the trtllm-gen decode signature.The
get_trtllm_gen_decode_module().paged_run(...)function signature does not includemax_q_len,batch_size,cum_seq_lens_q, orcum_seq_lens_kvparameters. Passing these fourNonevalues shiftsmax_kv_len,sinks,key_block_scales, andvalue_block_scalesto wrong parameter slots, causing argument misalignment and runtime failure.Suggested fix
if self._backend == "trtllm-gen": # decode.py's trtllm-gen paged_run (get_trtllm_gen_decode_module) # has a different optional-param layout than prefill.py's paged_run run_args += [paged_kv_cache] run_args += [ self._num_qo_heads, self._num_kv_heads, self._block_tables, self._kv_lens_buffer, page_size, - None, # max_q_len (not applicable for decode) self._max_kv_len, - None, # batch_size (not applicable for decode) - None, # cum_seq_lens_q (not applicable for decode) - None, # cum_seq_lens_kv (not applicable for decode) sinks, key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, True, # uses_shared_paged_kv_idx🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1454 - 1474, The trtllm-gen decode path is passing four extra None placeholders into run_args which don't exist in get_trtllm_gen_decode_module().paged_run, causing argument misalignment; update the code that builds run_args in decode.py (the branch where self._backend == "trtllm-gen" and the subsequent run_args += [...]) by removing the four None entries for max_q_len, batch_size, cum_seq_lens_q and cum_seq_lens_kv so that self._max_kv_len, sinks, key_block_scales and value_block_scales align with the trtllm-gen paged_run signature. Ensure the special-case comment for trtllm-gen remains and run_args ordering matches the trtllm-gen paged_run parameter list.flashinfer/prefill.py (1)
467-476:⚠️ Potential issue | 🟠 Major
mutates_argsincorrectly lists read-only scale-factor tensors.
maybe_k_cache_sf/maybe_v_cache_sf(ragged) andkey_block_scales/value_block_scales(paged) are inputs to the attention kernels, not outputs — the kernels only dequantize from them. Declaring them inmutates_argstellstorch.librarythey are written by the op, which:
- breaks aliasing/functionalization assumptions under
torch.compile(spurious data dependencies, forced rematerialization / cloning),- is inconsistent with how other FP8 scale tensors (
scale_q/k/v,fp8_scale_*) are wired into the same ops.Please remove these from
mutates_argsin bothragged_runandpaged_runregistrations.🔧 Suggested fix
`@register_custom_op`( f"flashinfer::{uri}_ragged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "o", "maybe_lse", - "maybe_k_cache_sf", - "maybe_v_cache_sf", ), )`@register_custom_op`( f"flashinfer::{uri}_paged_run", mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "paged_k_cache", "paged_v_cache", "o", "maybe_lse", - "key_block_scales", - "value_block_scales", ), )Also applies to: 636-647
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 467 - 476, Remove read-only scale tensors from the op mutation lists: in the register_custom_op call for the ragged_run variant (symbol name contains "flashinfer::{uri}_ragged_run") delete maybe_k_cache_sf and maybe_v_cache_sf from the mutates_args tuple; similarly, in the paged_run registration (the other register_custom_op block that references key_block_scales and value_block_scales) remove key_block_scales and value_block_scales from mutates_args so only truly written buffers (e.g., float_workspace_buffer, int_workspace_buffer, o, maybe_lse) remain listed; keep the tensors as regular inputs to the op registration rather than declaring them as mutated.
🧹 Nitpick comments (4)
flashinfer/prefill.py (2)
2354-2359: Uselogging.warninginstead of
loggingis already imported (Line 18).print(...)bypasses user-configured log levels/filters and writes to stdout, which is noisy in library code and tends to pollute serving-stack logs.🔧 Suggested fix
- print( - "[WARNING] NVFP4 KV cache with NHD layout will be converted to HND, " - "incurring extra transpose and contiguous copy overhead. " - "Use kv_layout='HND' for better performance." - ) + logging.warning( + "NVFP4 KV cache with NHD layout will be converted to HND, " + "incurring extra transpose and contiguous copy overhead. " + "Use kv_layout='HND' for better performance." + )Also consider emitting the warning only once per process (e.g.
warnings.warnwith a category, or a module-level guard), otherwise it will fire on every layer/step of a long run.Also applies to: 4073-4078
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 2354 - 2359, Replace the print-based warning at the NHD→HND conversion site (the block checking key_block_scales) with a proper logging call (logging.warning) so it honors user log configuration; update the message there and the similar occurrences around the other site (the block at ~4073-4078) to use logging.warning instead of print. Also consider making the warning emit only once per process by using warnings.warn with an appropriate Warning subclass or a module-level boolean guard to avoid repeated spam during long runs.
1326-1329: Inconsistentkv_cache_sfunpacking/validation across prefill entry points.Three call sites now unpack
kv_cache_sfand they all diverge:
single_prefill_with_kv_cache(Lines 1326-1329): assumes tuple, blindk_sf, v_sf = kv_cache_sf— a stacked tensor or alistwill fail with an opaque error.BatchPrefillWithRaggedKVCacheWrapper.run(Lines 3342-3346): checksisinstance(kv_cache_sf, tuple)only (misseslist), and otherwise blindly calls.unbind(dim=1)without aTypeErrorfallback.BatchPrefillWithPagedKVCacheWrapper.run(Lines 2282-2291) andtrtllm_batch_context_with_kv_cache(Lines 4058-4067): correctly accept(tuple, list)or a stacked tensor, with aTypeErrorfor anything else — matching the decode pattern.Please align the single-prefill and ragged paths with the paged pattern (accept tuple/list or a
torch.Tensorof shape[num_pages, 2, ...], else raiseTypeError). Factoring this into a small helper (e.g.,_unpack_kv_cache_sf) would also remove the four near-duplicate blocks across this file anddecode.py/attention.py.Also applies to: 3340-3346
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1326 - 1329, Align kv_cache_sf handling by extracting a small helper (e.g., _unpack_kv_cache_sf) and replace the four near-duplicate unpack blocks in single_prefill_with_kv_cache, BatchPrefillWithRaggedKVCacheWrapper.run, BatchPrefillWithPagedKVCacheWrapper.run, and trtllm_batch_context_with_kv_cache to call it; the helper should accept either a tuple/list of (k_sf, v_sf) and return them, or accept a torch.Tensor of shape [num_pages, 2, ...] and split/unbind dim=1 into (k_sf, v_sf), and otherwise raise a TypeError with a clear message — this ensures consistent behavior for tuples, lists, and stacked tensors and removes duplicate logic across the file (and similar blocks in decode.py/attention.py).include/flashinfer/cp_async.cuh (2)
215-220: Leftover debug markers in fallback path.Commented-out sentinel values (
0xcd...,0xefef...) look like debug artifacts. Consider removing to keep the fallback clean.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/cp_async.cuh` around lines 215 - 220, Remove the leftover commented sentinel debug markers in the shared-memory fallback path: delete the commented-out lines containing 0xcdcd... and 0xefef... so only the intended zero-fill logic remains (keep the existing smem_u64[1] = 0; and the *((uint4*)smem_ptr) = make_uint4(0,0,0,0); under SharedMemFillMode::kFillZero). This cleans up artifacts around smem_u64 and smem_ptr handling without changing the zero-fill behavior.
191-193:prefetch_modetemplate parameter is silently ignored by the fast path.Unlike
pred_load_128b, this helper never branches onprefetch_mode(no.L2::128Bvariant is emitted, andcp.async.cahas no such hint). Callers such assmem_t::load_64b_async<...>that passPrefetchMode::kPrefetchwill not get prefetching. Consider either dropping the template parameter to make that explicit, or documenting that prefetch is a no-op for 64b loads.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/cp_async.cuh` around lines 191 - 193, The template parameter prefetch_mode on pred_load_128b_from_64b is unused so callers (e.g., smem_t::load_64b_async) passing PrefetchMode::kPrefetch get no prefetch behavior; either remove the prefetch_mode template parameter from pred_load_128b_from_64b to make prefetch a no-op explicit, or implement a branch on prefetch_mode (matching pred_load_128b's branching) and emit the appropriate prefetch variant when prefetch_mode==PrefetchMode::kPrefetch (so callers like smem_t::load_64b_async actually trigger prefetch), and update/delete callers or comments accordingly to keep behavior consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 1301-1309: The stacked-tensor branch for kv_cache_sf assumes dim 1
equals 2 before calling kv_cache_sf.unbind(dim=1); validate that kv_cache_sf is
a tensor with at least 2 dimensions and that kv_cache_sf.size(1) == 2 (and
optionally include kv_cache_sf.dim() and kv_cache_sf.shape in the error) before
unbinding, and if the check fails raise a TypeError with a clear message stating
the expected shape [num_pages, 2, ...] and the actual shape; update the
torch.is_tensor branch around the kv_cache_sf.unbind(dim=1) call to perform this
check and include the actual tensor shape in the raised error.
In `@flashinfer/prefill.py`:
- Around line 1148-1150: The function single_prefill_with_kv_cache now accepts
kv_cache_sf, k_scale, and v_scale but v_scale is ignored; update
single_prefill_with_kv_cache to either (preferred) apply v_scale to the output
tensor after the kernel (mirroring BatchPrefillWithPagedKVCacheWrapper.run where
out *= v_scale) and ensure kv_cache_sf and k_scale are folded/used consistently
into sm_scale, or (if v_scale not supported) remove v_scale from the signature
and raise NotImplementedError when callers pass it; additionally add docstring
entries for kv_cache_sf, k_scale, and v_scale describing the expected
scale-factor layout (match the paged wrapper docstring) so callers know the
semantics.
- Around line 3321-3326: The ragged-path double-scales sm_scale by multiplying
in q_scale and k_scale before passing it as attn_scale to
cudnn_batch_prefill_with_kv_cache; fix by only applying sm_scale *= q_scale /
k_scale when the backend is not cudnn (i.e., wrap the q_scale/k_scale
multiplications in if self._backend != "cudnn":) so cudnn receives unadjusted
attn_scale while other backends keep the existing descaling; apply the same
guard in the second cudnn call site (the other block that prepares sm_scale
before calling cudnn_batch_prefill_with_kv_cache) so q_scale/k_scale are not
applied twice.
In `@flashinfer/utils.py`:
- Around line 423-425: The current guard only checks for torch.uint8 (dtype_kv)
and misses native FP4; update the FA3-exclusion guard to also return False when
dtype_kv is the native FP4 dtype (torch.float4_e2m1fn_x2) so native FP4 KVs are
routed off FA3 like uint8; modify the condition around dtype_kv (the if block
that returns False for FA3) to include a check for torch.float4_e2m1fn_x2
alongside torch.uint8.
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 110-117: The SMEM heuristics must include the FP4 scale-factor
buffers k_sf_smem and v_sf_smem when computing SharedStorage-related limits:
update the calculations that produce num_ctas_per_sm and max_num_mma_kv_smem
(and any logic that picks NUM_MMA_KV) to add the conditional size of the new
scale-factor arrays (the same conditional expression used for
k_sf_smem/v_sf_smem based on is_fp4_type_v<DTypeKV> and the widths CTA_TILE_KV *
HEAD_DIM_QK / NVFP4_SF_VEC_SIZE and CTA_TILE_KV * HEAD_DIM_VO /
NVFP4_SF_VEC_SIZE) into the total shared-memory per-CTA estimate so the
dispatcher uses the real SharedStorage size (the value later passed to
cudaFuncSetAttribute) when deciding CTAs and NUM_MMA_KV.
In `@include/flashinfer/cp_async.cuh`:
- Around line 194-223: The PTX path currently issues cp.async with cp-size=8 so
the upper 8 bytes of the 16-byte SMEM slot remain uninitialized; update both
inline-asm cp.async invocations in the SharedMemFillMode handling (the block
using smem_int_ptr and the kNoFill asm block) to use cp-size=16 (i.e., change
the "n"(8) cp-size immediates to "n"(16)) while keeping src-size as 8 (or 0 when
predicate is false), so cp.async zero-fills the upper 8 bytes and matches the
fallback that writes smem_u64[1]=0; also remove or fix the misleading comment
that "cp.async always zeros the upper 8 bytes" to reflect that zeroing only
applies within cp-size.
In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 486-487: The CUDA version gate in
include/flashinfer/vec_dtypes.cuh currently checks (__CUDACC_VER_MAJOR__ >= 13)
&& (__CUDACC_VER_MINOR__ >= 2) which fails for newer majors (e.g., 14.x); change
the condition to allow any major greater than 13 or major equal to 13 with minor
>= 2 (i.e., use a combined check of __CUDACC_VER_MAJOR__ and
__CUDACC_VER_MINOR__ so CUDA 13.2+ and all CUDA 14+ are accepted) to ensure the
native cvt.rn.bf16x2.e2m1x2 path is used; update the `#if` surrounding that macro
check accordingly and keep the same `#endif` and related code paths (look for
__CUDACC_VER_MAJOR__, __CUDACC_VER_MINOR__, and the cvt.rn.bf16x2.e2m1x2 usage
locations).
In `@tests/attention/test_batch_attention.py`:
- Around line 293-424: The test test_batch_attention_nvfp4 currently only xfails
SM120/121 and should instead skip on unsupported GPU architectures; update the
test to early-skip using the project's utility/API (e.g.,
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() or
flashinfer.api_name.is_compute_capability_supported(cc)) before any CUDA work is
done, so that test_batch_attention_nvfp4 returns pytest.skip(...) when the
current device compute capability is not supported; place the check at the top
of test_batch_attention_nvfp4 (before torch.manual_seed and tensor creation) and
reference those utility calls to determine support.
In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-810: The NVFP4 test
(test_batch_decode_with_paged_kv_cache_nvfp4) must be skipped on unsupported
GPUs; add an early guard that checks compute capability support (e.g., using
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() or
api_name.is_compute_capability_supported(cc)) and call pytest.skip if
NVFP4/tensor-core paths aren't available before creating workspace/setting
use_tensor_cores=True and calling wrapper.plan/run; place the check near the
start of the test function so the hard-coded NVFP4 tensor-core path is never
executed on unsupported architectures.
In `@tests/attention/test_single_prefill.py`:
- Around line 145-154: The test calls
flashinfer.prefill.single_prefill_with_kv_cache but passes only k_scale
(k_global_scale.item()); add the missing V-side global scale by passing
v_scale=v_global_scale.item() (or v_global_scale if non-tensor) into the same
call so the V scaling path is exercised; update the call site in
tests/attention/test_single_prefill.py where single_prefill_with_kv_cache(...)
is invoked to include the v_scale named argument (the other args like
kv_cache_sf=(k_sf, v_sf) stay unchanged).
---
Outside diff comments:
In `@flashinfer/attention.py`:
- Around line 163-187: The NVFP4 path currently unpacks kv_cache_sf into
k_cache_sf/v_cache_sf but doesn't enforce that k_scale and v_scale are provided;
add a guard after unpacking (where _unpack_paged_kv_cache produces k_cache_sf,
v_cache_sf) that checks if either k_cache_sf or v_cache_sf is not None and
either k_scale or v_scale is None, and raise a clear ValueError (or
AssertionError) rejecting NVFP4 KV inputs without their corresponding scale
tensors; reference the unpack step and variables k_cache_sf, v_cache_sf,
k_scale, and v_scale so the check is placed right after the unpack and before
any NVFP4 data can flow into the kernel.
In `@flashinfer/decode.py`:
- Around line 1454-1474: The trtllm-gen decode path is passing four extra None
placeholders into run_args which don't exist in
get_trtllm_gen_decode_module().paged_run, causing argument misalignment; update
the code that builds run_args in decode.py (the branch where self._backend ==
"trtllm-gen" and the subsequent run_args += [...]) by removing the four None
entries for max_q_len, batch_size, cum_seq_lens_q and cum_seq_lens_kv so that
self._max_kv_len, sinks, key_block_scales and value_block_scales align with the
trtllm-gen paged_run signature. Ensure the special-case comment for trtllm-gen
remains and run_args ordering matches the trtllm-gen paged_run parameter list.
In `@flashinfer/prefill.py`:
- Around line 467-476: Remove read-only scale tensors from the op mutation
lists: in the register_custom_op call for the ragged_run variant (symbol name
contains "flashinfer::{uri}_ragged_run") delete maybe_k_cache_sf and
maybe_v_cache_sf from the mutates_args tuple; similarly, in the paged_run
registration (the other register_custom_op block that references
key_block_scales and value_block_scales) remove key_block_scales and
value_block_scales from mutates_args so only truly written buffers (e.g.,
float_workspace_buffer, int_workspace_buffer, o, maybe_lse) remain listed; keep
the tensors as regular inputs to the op registration rather than declaring them
as mutated.
---
Nitpick comments:
In `@flashinfer/prefill.py`:
- Around line 2354-2359: Replace the print-based warning at the NHD→HND
conversion site (the block checking key_block_scales) with a proper logging call
(logging.warning) so it honors user log configuration; update the message there
and the similar occurrences around the other site (the block at ~4073-4078) to
use logging.warning instead of print. Also consider making the warning emit only
once per process by using warnings.warn with an appropriate Warning subclass or
a module-level boolean guard to avoid repeated spam during long runs.
- Around line 1326-1329: Align kv_cache_sf handling by extracting a small helper
(e.g., _unpack_kv_cache_sf) and replace the four near-duplicate unpack blocks in
single_prefill_with_kv_cache, BatchPrefillWithRaggedKVCacheWrapper.run,
BatchPrefillWithPagedKVCacheWrapper.run, and trtllm_batch_context_with_kv_cache
to call it; the helper should accept either a tuple/list of (k_sf, v_sf) and
return them, or accept a torch.Tensor of shape [num_pages, 2, ...] and
split/unbind dim=1 into (k_sf, v_sf), and otherwise raise a TypeError with a
clear message — this ensures consistent behavior for tuples, lists, and stacked
tensors and removes duplicate logic across the file (and similar blocks in
decode.py/attention.py).
In `@include/flashinfer/cp_async.cuh`:
- Around line 215-220: Remove the leftover commented sentinel debug markers in
the shared-memory fallback path: delete the commented-out lines containing
0xcdcd... and 0xefef... so only the intended zero-fill logic remains (keep the
existing smem_u64[1] = 0; and the *((uint4*)smem_ptr) = make_uint4(0,0,0,0);
under SharedMemFillMode::kFillZero). This cleans up artifacts around smem_u64
and smem_ptr handling without changing the zero-fill behavior.
- Around line 191-193: The template parameter prefetch_mode on
pred_load_128b_from_64b is unused so callers (e.g., smem_t::load_64b_async)
passing PrefetchMode::kPrefetch get no prefetch behavior; either remove the
prefetch_mode template parameter from pred_load_128b_from_64b to make prefetch a
no-op explicit, or implement a branch on prefetch_mode (matching
pred_load_128b's branching) and emit the appropriate prefetch variant when
prefetch_mode==PrefetchMode::kPrefetch (so callers like smem_t::load_64b_async
actually trigger prefetch), and update/delete callers or comments accordingly to
keep behavior consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a0310754-826d-4f9d-931b-dcd44974a801
📒 Files selected for processing (18)
flashinfer/attention.pyflashinfer/decode.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyflashinfer/utils.pyinclude/flashinfer/attention/persistent.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuhinclude/flashinfer/frag_layout_swizzle.cuhinclude/flashinfer/permuted_smem.cuhinclude/flashinfer/vec_dtypes.cuhtests/attention/test_batch_attention.pytests/attention/test_batch_decode_kernels.pytests/attention/test_batch_prefill_kernels.pytests/attention/test_single_prefill.pytests/test_helpers/utils_fp4.py
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
3466-3520:⚠️ Potential issue | 🔴 CriticalCritical: NVFP4
k_scale/q_scalefold intosm_scaleis a no-op in the ragged path.
run_argsis built at Lines 3466-3511 with the currentsm_scalefloat value already appended. The subsequentsm_scale *= q_scale/sm_scale *= k_scaleat Lines 3517-3520 only rebind the localsm_scalevariable — since Python floats are immutable, the value already insiderun_argsis not updated. The kernel therefore receives the unscaledsm_scale, silently producing wrong numerics for NVFP4 KV whenever a non-unitk_scale/q_scaleis passed.The fold must happen before
sm_scaleis placed intorun_args(mirroring the paged wrapper at Lines 2333-2337 andsingle_prefill_with_kv_cacheat Lines 1338-1339).🔧 Proposed fix
if return_lse: ... + # For NVFP4 KV, fuse q_scale/k_scale into sm_scale before packing run_args + if kv_cache_sf is not None: + if q_scale is not None: + sm_scale *= q_scale + if k_scale is not None: + sm_scale *= k_scale + # Unpack kv_cache_sf for NVFP4 ragged KV k_sf, v_sf = None, None ... run_args += [ ... logits_soft_cap, sm_scale, ... ] # For FP8, append scale tensors if is_float8(q): run_args.extend(list(args)) # scale_q, scale_k, scale_v - # For NVFP4 KV, fuse k_scale into sm_scale - elif kv_cache_sf is not None: - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3466 - 3520, The ragged-path builds run_args and appends the current sm_scale before folding q_scale/k_scale, so multiply q_scale and k_scale into sm_scale (when kv_cache_sf is not None and q_scale/k_scale are not None) before sm_scale is appended to run_args; modify the else branch around run_args construction in prefill.py to perform the NVFP4 fold (sm_scale *= q_scale and sm_scale *= k_scale) prior to extending run_args with sm_scale (mirroring the paged wrapper and single_prefill_with_kv_cache behavior), leaving the existing is_float8 and kv_cache_sf checks but relocating the multiplication earlier so the kernel receives the updated sm_scale.
♻️ Duplicate comments (1)
include/flashinfer/cp_async.cuh (1)
194-222:⚠️ Potential issue | 🟠 MajorPTX fast path leaves upper 8 bytes of the 128-bit SMEM slot uninitialized (divergence from docstring & fallback).
The docstring (lines 186-189) and the
#elsefallback (lines 213-218) both guarantee the upper 64 bits of the 128-bit destination are zero. The PTX path, however, issuescp.async.ca.shared.globalwithcp-size = 8; per PTX semantics zero-fill only applies in the range[src-size, cp-size), so withcp-size=8nothing beyond byte 7 is written. All three meaningful cases (kFillZero pred=true/false, kNoFill pred=true) therefore leave bytes 8..15 indeterminate, which will silently corrupt NVFP4 KV loads that rely on the documented zero-padding.The comment on line 202 (“cp.async always zeros the upper 8 bytes”) is also incorrect — zero-fill is bounded by
cp-size, not the destination slot size.Fix by using
cp-size = 16while keepingsrc-size = 8(or 0 when predicate is false), which makes PTX zero-fill bytes 8..15 and aligns with the fallback.🛠️ Proposed fix
if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 8 : 0; asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), - "l"(gmem_ptr), "n"(8), "r"(src_in_bytes)); - + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); } else { - // kNoFill: only issue the copy if predicate is true; cp.async always zeros the upper 8 bytes + // kNoFill: only issue the copy if predicate is true. cp-size=16 with src-size=8 makes + // cp.async zero-fill the upper 8 bytes of the 128-bit destination slot. asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " `@p` cp.async.ca.shared.global [%1], [%2], %3, %4;\n" "}\n" ::"r"((int)predicate), - "r"(smem_int_ptr), "l"(gmem_ptr), "n"(8), "n"(8)); + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16), "n"(8)); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/cp_async.cuh` around lines 194 - 222, The PTX branch for cp.async (in the code handling SharedMemFillMode and predicate around smem_ptr/gmem_ptr) uses cp-size=8 which leaves bytes 8..15 uninitialized; change the cp.async invocations so cp-size=16 while keeping src-size as 8 (or 0 when predicate is false) so the hardware zero-fills the upper 8 bytes to match the docstring and the fallback path; update both asm blocks (the if constexpr fill_mode==kFillZero path and the else kNoFill path) to pass 16 as the cp-size operand and ensure the src-size operand remains 8 or 0 accordingly.
🧹 Nitpick comments (2)
flashinfer/decode.py (1)
1343-1352: Usewarnings.warninstead ofThis warning sits on the per-call
run()path and will print on every forward pass for every layer when NVFP4 KV is used withNHD.warnings.warnintegrates with Python's logging/filter machinery and is deduplicated by default. The same pattern at lines 2511–2515 intrtllm_batch_decode_with_kv_cacheshould be updated together.🔧 Proposed change
+import warnings @@ if key_block_scales is not None: - print( - "[WARNING] NVFP4 KV cache with NHD layout will be converted to HND, " - "incurring extra transpose and contiguous copy overhead. " - "Use kv_layout='HND' for better performance." - ) + warnings.warn( + "NVFP4 KV cache with NHD layout will be converted to HND, " + "incurring extra transpose and contiguous copy overhead. " + "Use kv_layout='HND' for better performance.", + stacklevel=2, + )And analogously at lines 2511–2515.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1343 - 1352, Replace the print warning in the NVFP4 KV cache NHD->HND conversion with warnings.warn so it's integrated with Python's warning system and can be filtered/deduplicated; specifically, in the run() path where key_block_scales, k_cache, v_cache, and value_block_scales are being transposed/contiguified, call warnings.warn(...) instead of print(...), and make the analogous change in trtllm_batch_decode_with_kv_cache for the same message to keep behavior consistent across both code paths.flashinfer/prefill.py (1)
2371-2380: Usewarnings.warn(or the module logger) instead of
logging; preferwarnings.warn(..., stacklevel=2)so callers can control it. Same issue at Lines 4096-4100 intrtllm_batch_context_with_kv_cache.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 2371 - 2380, Replace the inline print used when converting NVFP4 KV cache NHD→HND with a proper warning via warnings.warn(..., stacklevel=2) (or use the module logger) so callers can filter/suppress it; specifically update the block that checks key_block_scales is not None where k_cache, v_cache, key_block_scales and value_block_scales are made contiguous/transposed to call warnings.warn with the same message and stacklevel=2 instead of print, and make the same change in the analogous block inside trtllm_batch_context_with_kv_cache (the other occurrence around the key/value cache transpose).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback fails when ufp8_scale_tensor is 1D
because sf_float stays 1D and repeat_interleave produces a [K] vector that
cannot broadcast to float_vals [M, K]; fix this by normalizing sf_float to shape
[M, K/sf_vec_size] before expanding: after decoding into sf_float (the result of
view(torch.float8_e4m3fn).float() or torch.pow(...)), compute sf_len = k //
sf_vec_size and if sf_float.dim() == 1 then reshape/expand it to (m, sf_len) via
unsqueeze(0).expand(m, -1); if it is 2D but has shape (1, sf_len) also expand to
(m, sf_len); then continue with sf_expanded =
sf_float.repeat_interleave(sf_vec_size, dim=-1) and the rest of
_e2m1_and_ufp8sf_scale_to_float_cpu unchanged so broadcasting matches
float_vals.
---
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3466-3520: The ragged-path builds run_args and appends the current
sm_scale before folding q_scale/k_scale, so multiply q_scale and k_scale into
sm_scale (when kv_cache_sf is not None and q_scale/k_scale are not None) before
sm_scale is appended to run_args; modify the else branch around run_args
construction in prefill.py to perform the NVFP4 fold (sm_scale *= q_scale and
sm_scale *= k_scale) prior to extending run_args with sm_scale (mirroring the
paged wrapper and single_prefill_with_kv_cache behavior), leaving the existing
is_float8 and kv_cache_sf checks but relocating the multiplication earlier so
the kernel receives the updated sm_scale.
---
Duplicate comments:
In `@include/flashinfer/cp_async.cuh`:
- Around line 194-222: The PTX branch for cp.async (in the code handling
SharedMemFillMode and predicate around smem_ptr/gmem_ptr) uses cp-size=8 which
leaves bytes 8..15 uninitialized; change the cp.async invocations so cp-size=16
while keeping src-size as 8 (or 0 when predicate is false) so the hardware
zero-fills the upper 8 bytes to match the docstring and the fallback path;
update both asm blocks (the if constexpr fill_mode==kFillZero path and the else
kNoFill path) to pass 16 as the cp-size operand and ensure the src-size operand
remains 8 or 0 accordingly.
---
Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 1343-1352: Replace the print warning in the NVFP4 KV cache
NHD->HND conversion with warnings.warn so it's integrated with Python's warning
system and can be filtered/deduplicated; specifically, in the run() path where
key_block_scales, k_cache, v_cache, and value_block_scales are being
transposed/contiguified, call warnings.warn(...) instead of print(...), and make
the analogous change in trtllm_batch_decode_with_kv_cache for the same message
to keep behavior consistent across both code paths.
In `@flashinfer/prefill.py`:
- Around line 2371-2380: Replace the inline print used when converting NVFP4 KV
cache NHD→HND with a proper warning via warnings.warn(..., stacklevel=2) (or use
the module logger) so callers can filter/suppress it; specifically update the
block that checks key_block_scales is not None where k_cache, v_cache,
key_block_scales and value_block_scales are made contiguous/transposed to call
warnings.warn with the same message and stacklevel=2 instead of print, and make
the same change in the analogous block inside trtllm_batch_context_with_kv_cache
(the other occurrence around the key/value cache transpose).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cbd4e1b7-3e62-4bdc-bfe5-0a1422f36591
📒 Files selected for processing (5)
flashinfer/decode.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuh
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
3267-3308:⚠️ Potential issue | 🟡 MinorDocument
kv_cache_sfinBatchPrefillWithRaggedKVCacheWrapper.rundocstring.The new
kv_cache_sfparameter is missing from the docstring (Parameters block at Lines 3274-3300). Users won't know the accepted shapes/layout/dtype (float8_e4m3fn, stacked-tensor vs tuple form, etc.) — the paged wrapper's docstring at Lines 2249-2268 is the natural template. Also worth noting in docs: unlike the paged/single paths, post-kernelv_scaleis applied only whenkv_cache_sf is not None(Lines 3526-3528), so callers usingv_scalewith non-NVFP4 inputs on this wrapper will silently get unscaled output.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3267 - 3308, Add documentation for the new kv_cache_sf parameter in BatchPrefillWithRaggedKVCacheWrapper.run: describe accepted types and layouts (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], explain stacked-tensor vs tuple form, expected dtypes like float8_e4m3fn and NVFP4 specifics, and the tensor shapes/ordering consistent with k/v inputs; also note the behavioral difference that post-kernel v_scale is applied only when kv_cache_sf is not None (so callers using v_scale with non-NVFP4 inputs on this wrapper will not get scaled output), mirroring the style and details used in the paged wrapper docstring for clarity.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)
636-645:⚠️ Potential issue | 🟠 Major
key_block_scales/value_block_scalesshould not be inpaged_run'smutates_args.These tensors are read-only load sources in
produce_kv_sf/page_produce_kv_sf; marking them as mutated forces unnecessary functionalization copies undertorch.compileand inhibits CUDA graph capture / re-use. Note the analogous additions inrun_single_prefill(Line 338) andragged_run(Line 469) correctly leavemaybe_k_cache_sf/maybe_v_cache_sfout ofmutates_args— this paged path is the outlier. This matches prior feedback that was reportedly addressed but reappears here.🔧 Suggested fix
mutates_args=( "float_workspace_buffer", "int_workspace_buffer", "paged_k_cache", "paged_v_cache", "o", "maybe_lse", - "key_block_scales", - "value_block_scales", ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 636 - 645, In paged_run remove key_block_scales and value_block_scales from the mutates_args tuple so they are treated as read-only (they are load-only sources used by produce_kv_sf / page_produce_kv_sf); mirror the approach used in run_single_prefill and ragged_run where maybe_k_cache_sf / maybe_v_cache_sf are not listed as mutated. Update the mutates_args declaration in paged_run (remove "key_block_scales", "value_block_scales") and verify there are no other places in paged_run that rely on those names being reported as mutated.
🧹 Nitpick comments (2)
include/flashinfer/vec_dtypes.cuh (2)
498-511: Verify the fp16→bf16 fallback round-trip is always lossless for e2m1 values.Each e2m1 value ∈ {±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} is exactly representable in both fp16 and bf16, so the fallback
__half22float2→__float22bfloat162_rnround-trip produces the same bit pattern as the nativecvt.rn.bf16x2.e2m1x2output. This is correct.Minor nit: the fallback could skip the PTX fp16 intermediate and directly use the LUT (lines 518–535) for one table lookup instead of a PTX cvt + two fp conversions. Only applicable if the fallback gets exercised on SM100+ with CUDA 13.0–13.1 (before the code path switches to the native instruction in 13.2+) — otherwise the current implementation is fine.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/vec_dtypes.cuh` around lines 498 - 511, The fallback currently converts e2m1->fp16 via the inline asm into fp16x2, then to float and to bf16 (__half2 h2 / __float22bfloat162_rn) before storing y; replace that three-step conversion with a direct lookup into the existing e2m1→bf16 lookup table used elsewhere (the same LUT referenced in the surrounding code) so you produce y in one table lookup instead of doing the asm cvt + __half2/__float22bfloat162_rn sequence; update the code paths that set fp16x2 / h2 / bf16x2 to instead index the LUT with the original input (variable b) and write the LUT result into y.
428-443: Align vec_cast with the batched extraction pattern used in csrc/xqa/utils.cuh for consistency and performance.The identical batched-extraction optimization already exists in
csrc/xqa/utils.cuh::convertKCacheWordToF16<half, __nv_fp4_e2m1>and its bf16 variant: both usemov.b32 {byte0, _, byte2, _}to extract two valid fp4 bytes per word, then invoke two separatecvt.rn.f16x2.e2m1x2instructions on the extracted .b8 registers. This reduces the loop count fromvec_size / 2tovec_size / 4and halves the extraction overhead on this KV-dequant hot path.Apply the same pattern to both fp16 and bf16x2 branches in
vec_cast<half, __nv_fp4x2_e2m1>::cast(lines 428–443) andvec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast(lines 482–513).♻️ Sketch of the batched form (fp16 branch)
-#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - // Valid fp4x2 bytes are at even positions (stride 2); odd positions are padding. - uint32_t b = reinterpret_cast<const uint8_t*>(src)[i * 2]; - asm volatile( - "{\n" - ".reg .b8 fp4_byte;\n" - "mov.b32 {fp4_byte, _, _, _}, %1;\n" - "cvt.rn.f16x2.e2m1x2 %0, fp4_byte;\n" - "}" - : "=r"(y) - : "r"(b)); - reinterpret_cast<uint32_t*>(dst)[i] = y; - } + // Batch extract two valid bytes per word (stride-2 layout: valid at positions 0,2). +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + uint32_t word = reinterpret_cast<const uint32_t*>(src)[i]; + uint32_t y0, y1; + asm volatile( + "{\n" + ".reg .b8 b0, b2;\n" + "mov.b32 {b0, _, b2, _}, %2;\n" + "cvt.rn.f16x2.e2m1x2 %0, b0;\n" + "cvt.rn.f16x2.e2m1x2 %1, b2;\n" + "}" + : "=r"(y0), "=r"(y1) + : "r"(word)); + reinterpret_cast<uint32_t*>(dst)[i * 2 + 0] = y0; + reinterpret_cast<uint32_t*>(dst)[i * 2 + 1] = y1; + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/vec_dtypes.cuh` around lines 428 - 443, The current fp16 and bf16x2 branches in vec_cast<half, __nv_fp4x2_e2m1>::cast and vec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast extract one valid fp4 byte per iteration; change them to the batched extraction used in convertKCacheWordToF16<half, __nv_fp4_e2m1> by loading a 32-bit word with mov.b32 {byte0, _, byte2, _} to obtain two valid fp4 bytes, then invoke two cvt.rn.f16x2.e2m1x2 (or appropriate bf16 convert) operations on the two extracted .b8 registers, reduce the loop range from vec_size/2 to vec_size/4, and store two uint32_t results per loop iteration to halve extraction overhead on the hot KV-dequant path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 3267-3308: Add documentation for the new kv_cache_sf parameter in
BatchPrefillWithRaggedKVCacheWrapper.run: describe accepted types and layouts
(Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], explain
stacked-tensor vs tuple form, expected dtypes like float8_e4m3fn and NVFP4
specifics, and the tensor shapes/ordering consistent with k/v inputs; also note
the behavioral difference that post-kernel v_scale is applied only when
kv_cache_sf is not None (so callers using v_scale with non-NVFP4 inputs on this
wrapper will not get scaled output), mirroring the style and details used in the
paged wrapper docstring for clarity.
---
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 636-645: In paged_run remove key_block_scales and
value_block_scales from the mutates_args tuple so they are treated as read-only
(they are load-only sources used by produce_kv_sf / page_produce_kv_sf); mirror
the approach used in run_single_prefill and ragged_run where maybe_k_cache_sf /
maybe_v_cache_sf are not listed as mutated. Update the mutates_args declaration
in paged_run (remove "key_block_scales", "value_block_scales") and verify there
are no other places in paged_run that rely on those names being reported as
mutated.
---
Nitpick comments:
In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 498-511: The fallback currently converts e2m1->fp16 via the inline
asm into fp16x2, then to float and to bf16 (__half2 h2 / __float22bfloat162_rn)
before storing y; replace that three-step conversion with a direct lookup into
the existing e2m1→bf16 lookup table used elsewhere (the same LUT referenced in
the surrounding code) so you produce y in one table lookup instead of doing the
asm cvt + __half2/__float22bfloat162_rn sequence; update the code paths that set
fp16x2 / h2 / bf16x2 to instead index the LUT with the original input (variable
b) and write the LUT result into y.
- Around line 428-443: The current fp16 and bf16x2 branches in vec_cast<half,
__nv_fp4x2_e2m1>::cast and vec_cast<nv_bfloat16, __nv_fp4x2_e2m1>::cast extract
one valid fp4 byte per iteration; change them to the batched extraction used in
convertKCacheWordToF16<half, __nv_fp4_e2m1> by loading a 32-bit word with
mov.b32 {byte0, _, byte2, _} to obtain two valid fp4 bytes, then invoke two
cvt.rn.f16x2.e2m1x2 (or appropriate bf16 convert) operations on the two
extracted .b8 registers, reduce the loop range from vec_size/2 to vec_size/4,
and store two uint32_t results per loop iteration to halve extraction overhead
on the hot KV-dequant path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f92d400a-15e0-48d4-b0d6-a6e4ea771ddb
📒 Files selected for processing (3)
csrc/xqa/utils.cuhflashinfer/prefill.pyinclude/flashinfer/vec_dtypes.cuh
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/attention/test_single_prefill.py (1)
112-132: Dead-code skip:causalis parametrized only toFalse.
@pytest.mark.parametrize("causal", [False])makes theif qo_len > kv_len and causalbranch unreachable. Either expand thecausalparametrization to includeTrue(and verify the NVFP4 kernel supports it) or drop the skip and thecausalparameter entirely to keep the test matrix honest.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_single_prefill.py` around lines 112 - 132, The test parametrizes causal only as False making the conditional "if qo_len > kv_len and causal" dead code; update test_single_prefill_with_kv_cache_nvfp4 by either adding True to the pytest.mark.parametrize("causal", [False]) list (and ensure NVFP4 kernel supports causal behavior) or remove the causal parameter and the associated skip check entirely so the branch is not unreachable; adjust any related test matrix or assumptions accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/attention/test_single_prefill.py`:
- Around line 112-132: The test parametrizes causal only as False making the
conditional "if qo_len > kv_len and causal" dead code; update
test_single_prefill_with_kv_cache_nvfp4 by either adding True to the
pytest.mark.parametrize("causal", [False]) list (and ensure NVFP4 kernel
supports causal behavior) or remove the causal parameter and the associated skip
check entirely so the branch is not unreachable; adjust any related test matrix
or assumptions accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 37f05332-993d-4acf-8aca-133f0474232c
📒 Files selected for processing (1)
tests/attention/test_single_prefill.py
|
/bot run |
📌 Description
This MR supports NVFP4 KV input for batch prefill and batch attention kernels. It widely supports all arch (SM80+).
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests