Support NVFP4 KV for prefill and batch attention kernels#2820
Support NVFP4 KV for prefill and batch attention kernels#2820Tom-Zheng wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end-to-end NVFP4 (packed FP4 / Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant PyAPI as Python API (BatchAttention/Prefill/Decode)
participant JIT as JIT / custom-op module
participant Kernel as CUDA kernel
participant GMEM as KV GMEM (packed NVFP4 + block SF)
Client->>PyAPI: call run(..., kv_block_scales=kv_sf)
PyAPI->>JIT: forward tensors + kv_block_scales / maybe_*_cache_sf
JIT->>Kernel: launch kernel with maybe_k_cache_sf / maybe_v_cache_sf pointers
Kernel->>GMEM: page_produce_kv (load packed FP4) and page_produce_kv_sf (load SF bytes)
Kernel->>Kernel: place SF into smem, expand/dequantize, compute_qk / compute_sfm_v using SF + lane_idx
Kernel-->>JIT: return outputs
JIT-->>PyAPI: relay outputs
PyAPI-->>Client: deliver result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by integrating native support for NVFP4 KV cache input across its batch prefill and batch attention kernels. This allows for more memory-efficient and potentially faster inference on NVIDIA GPUs by leveraging 4-bit quantization for key and value tensors. The changes span from low-level CUDA kernel implementations for data loading and computation to Python-level utilities for quantization, dequantization, and comprehensive testing, ensuring broad compatibility and correctness. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 (NVIDIA FP4) KV cache quantization in FlashInfer's prefill and batch attention operations. Key changes include updating benchmark routines to handle NVFP4 as a KV data type, adjusting tolerances for lower precision, and filtering unsupported backends. The core C++ kernels and Python JIT modules are extended to manage packed NVFP4 data and per-group scale factors, including modifications to memory access patterns and MMA operations for proper dequantization. Review comments highlight the need for improved clarity in the global_scale calculation in _to_nvfp4 and detailed explanations for the intricate scaling factor application logic within the compute_qk and compute_sfm_v device functions.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
benchmarks/routines/attention.py (2)
1197-1201:⚠️ Potential issue | 🟠 MajorDon't quantize the paged NVFP4 cache twice.
Lines 1036-1059 already build packed
kv_cachepluskv_cache_sf. Lines 1197-1201 then feed that packeduint8data back intonvfp4_quantize_paged_kv_cache(...), which expects floating-point KV input and discards the scales you just computed. Reusekv_cache_sfaskv_block_scaleshere.Suggested change
- if is_nvfp4_kv: - kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = ( - nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) - ) - kv_cache = kv_cache_nvfp4 + if use_nvfp4_kv: + kv_block_scales = kv_cache_sf🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1197 - 1201, The code is re-quantizing an already-packed NVFP4 paged KV cache: when is_nvfp4_kv is true you call nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) on uint8-packed data and overwrite the correct scales; instead, reuse the precomputed packed cache and scales (kv_cache and kv_cache_sf) produced earlier: set kv_cache = kv_cache (or keep existing packed variable) and assign kv_block_scales = kv_cache_sf (and ensure k_scale and v_scale use the previously computed values), removing the nvfp4_quantize_paged_kv_cache call inside the is_nvfp4_kv branch so you don't discard the original scales.
865-867:⚠️ Potential issue | 🔴 CriticalUse one NVFP4 feature-flag name throughout this function.
Line 865 defines
is_nvfp4_kv, but the new branches later readuse_nvfp4_kv(for example Lines 883, 965, 1036, and 1218).testBatchPrefillWithPagedKVCacheWrapper()currently throws before any benchmark runs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 865 - 867, The function uses two inconsistent feature-flag names (is_nvfp4_kv and use_nvfp4_kv) causing branches to miss the intended flag; unify them by picking one name (e.g., replace the initial definition is_nvfp4_kv = args.kv_dtype == "nvfp4" with use_nvfp4_kv = args.kv_dtype == "nvfp4" or create use_nvfp4_kv = is_nvfp4_kv immediately after) and update all branch checks (references in the function such as the later conditionals at lines referencing use_nvfp4_kv) to use that single symbol so the NVFP4 path is consistently triggered (also ensure any dtype checks that set kv_dtype remain correct).include/flashinfer/attention/prefill.cuh (1)
1671-1702:⚠️ Potential issue | 🟠 MajorThese prefill paths now consume FP4 scale tiles without ever producing them.
compute_qk/compute_sfm_vnow dereferencek_sf_smemandv_sf_smemfor FP4 KV, butSinglePrefillWithKVCacheDeviceandBatchPrefillWithRaggedKVCacheKernelstill only callproduce_kv(...). If either path is instantiated with__nv_fp4x2_e2m1, it will multiply against uninitialized shared memory and silently corrupt the result. Please either plumb per-row SF loads into these kernels too, or add a compile-time guard that keeps FP4 limited to the paged path for now.Also applies to: 2118-2150
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/prefill.cuh` around lines 1671 - 1702, The prefill paths call compute_qk and compute_sfm_v which read k_sf_smem/v_sf_smem for FP4 but the kernels SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel only call produce_kv and never populate those scale-factor tiles, so instantiating with __nv_fp4x2_e2m1 will read uninitialized shared memory; fix by adding a compile-time guard that prevents FP4 (e.g. static_assert or if constexpr) in SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel (or the wrapper that calls produce_kv) when KTraits::ScalarType == __nv_fp4x2_e2m1, OR alternatively plumb the per-row SF loads into those kernels by ensuring produce_kv is invoked with the SharedMemFillMode that fills k_sf_smem/v_sf_smem (or explicitly call the SF fill helper) before compute_qk/compute_sfm_v are executed; pick one approach and apply it consistently to both call sites referencing compute_qk, compute_sfm_v, k_sf_smem, v_sf_smem, and produce_kv.
🧹 Nitpick comments (4)
flashinfer/jit/attention/modules.py (1)
1836-1870: Factor the batch-attention setter generation out of this function.This copies the nullable/scalar assignment rules from
generate_additional_params(). A small helper that accepts the target prefix (params,params[i],params.additional_params) would keep the batch path from drifting the next time additional-parameter semantics change.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/attention/modules.py` around lines 1836 - 1870, Extract the logic that builds the param-assignment lines into a small helper (e.g., generate_additional_params_setter(prefix, additional_tensor_names, additional_tensor_dtypes, additional_scalar_names)) and replace the inline batch_additional_params_setter construction with a call to that helper; the helper should implement the same nullable tensor/ scalar rules currently duplicated (the conditional branch for var.startswith("maybe") and the scalar formatting) but use the provided target prefix (e.g., "params[i]", "params", or "params.additional_params") when formatting each assignment; update the call sites (the batch path that currently creates batch_additional_params_setter and any other place using generate_additional_params output) to call the new helper so semantics remain identical but the formatting logic is centralized..claude/memory/prefill_cuh_structure.md (1)
40-62: Call out the NVFP4 scale-factor path explicitly.This note still reads like a generic prefill overview. The new FP4-specific pieces—
maybe_k_cache_sf/maybe_v_cache_sf,page_produce_kv_sf, and the shared-memory scale buffers consumed bycompute_qk/compute_sfm_v—are exactly what future readers will look for in this PR.Based on learnings: Keep documentation in sync with code changes, particularly CLAUDE.md and
.claude/skills/when modifying infrastructure changes, patterns, new conventions, or deprecations.Also applies to: 75-86
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In @.claude/memory/prefill_cuh_structure.md around lines 40 - 62, Update the prefill overview to explicitly document the NVFP4 scale-factor path: describe the new FP4-specific symbols maybe_k_cache_sf and maybe_v_cache_sf, the page_produce_kv_sf path, and the shared-memory scale buffers that compute_qk and compute_sfm_v consume; note where these are emitted/loaded and how they flow through page_produce_kv_sf → shared-memory buffers → compute_qk/compute_sfm_v, and add a short cross-reference to the infra docs/skills that must be updated when changing these conventions so readers can find the FP4 scale-factor behavior quickly.tests/attention/test_batch_attention.py (1)
308-309: Exercise signed E2M1 codes too.Line 309 clears both sign bits, so this test never covers negative NVFP4 values. A sign-handling regression would still pass here; either remove the mask or add a second signed-data case.
Suggested change
- packed &= 0x77 # clear bit 3 (0x08) and bit 7 (0x80) to ensure non-negative🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_attention.py` around lines 308 - 309, The test currently forces all packed bytes non-negative by applying "packed &= 0x77", so negative NVFP4 (signed E2M1) values are never exercised; either remove the mask expression "packed &= 0x77" to allow both signs, or add a second test case that constructs a signed-data variant (e.g., copy the existing "packed" and set the sign bits for NVFP4 by OR-ing the appropriate bits such as 0x08 and/or 0x80) and run the same assertions on that signed copy so both unsigned and negative NVFP4 paths are covered.tests/attention/test_batch_prefill_kernels.py (1)
1074-1083: Add at least one causal NVFP4 case.This matrix hardcodes
causal=False, so the new NVFP4 path never exercises the masked/tail-tile logic that changed in the kernel code. A smallcausal=Truecase would cover the scale-factor path under masking as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_prefill_kernels.py` around lines 1074 - 1083, The test matrix currently forces causal=False so the NVFP4 kernel path never hits masked/tail-tile logic; update the test_batch_prefill_with_paged_kv_cache_nvfp4 parameterization (the `@pytest.mark.parametrize`("causal", ...) on the test) to include True (e.g., [False, True]) so at least one run exercises the causal/masked path for NVFP4; keep existing q_dtype values unchanged so the NVFP4 path is still exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1667-1675: The NVFP4 allowlist is too permissive: in the
use_nvfp4_kv block (variable backends and list nvfp4_unsupported) remove "auto"
and "trtllm-native" from the supported set so ragged NVFP4 only stays enabled
for backends that actually wrap ragged K/V (e.g., "fa2" and "fa3"); update the
allowed list used to compute nvfp4_unsupported from ["fa2", "trtllm-native",
"auto"] to only ["fa2", "fa3"] (or the concrete backends that implement ragged
NVFP4) so backends that don't forward k_sf/v_sf are filtered out.
In `@include/flashinfer/cp_async.cuh`:
- Around line 191-224: The cp.async call in pred_load_128b_from_64b uses
cp_size=8 which doesn’t zero the upper 8 bytes; change the cp.async
invocation(s) to use cp_size=16 and pass src_size as a variable (src_size =
predicate ? 8 : 0) so cp.async zero-fills bytes 8..15 when src_size is 0, and
similarly adjust the kNoFill branch to issue cp.async with cp_size=16 and
src_size conditionally 8 or 0 (instead of cp_size=8); also apply the same fix
pattern to the 32b helper described in the comment: use cp_size appropriate to
the full destination (e.g., 16 for 128b destination or 4 for 32b helper) and
make src_size variable (0 when wanting explicit zero-fill, nonzero when copying)
so cp.async actually zeros the upper bytes.
In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 486-510: The CUDA version gate incorrectly requires both
__CUDACC_VER_MAJOR__ >= 13 and __CUDACC_VER_MINOR__ >= 2, which fails for CUDA
14.x; update the preprocessor condition that guards the fast-path asm (the block
using cvt.rn.bf16x2.e2m1x2 and variable y/b) to check the combined version
(e.g., compare major/minor together or compute a single numeric version) so it
enables the fast path for CUDA >= 13.2 (including 14.0+).
---
Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1197-1201: The code is re-quantizing an already-packed NVFP4 paged
KV cache: when is_nvfp4_kv is true you call
nvfp4_quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) on uint8-packed
data and overwrite the correct scales; instead, reuse the precomputed packed
cache and scales (kv_cache and kv_cache_sf) produced earlier: set kv_cache =
kv_cache (or keep existing packed variable) and assign kv_block_scales =
kv_cache_sf (and ensure k_scale and v_scale use the previously computed values),
removing the nvfp4_quantize_paged_kv_cache call inside the is_nvfp4_kv branch so
you don't discard the original scales.
- Around line 865-867: The function uses two inconsistent feature-flag names
(is_nvfp4_kv and use_nvfp4_kv) causing branches to miss the intended flag; unify
them by picking one name (e.g., replace the initial definition is_nvfp4_kv =
args.kv_dtype == "nvfp4" with use_nvfp4_kv = args.kv_dtype == "nvfp4" or create
use_nvfp4_kv = is_nvfp4_kv immediately after) and update all branch checks
(references in the function such as the later conditionals at lines referencing
use_nvfp4_kv) to use that single symbol so the NVFP4 path is consistently
triggered (also ensure any dtype checks that set kv_dtype remain correct).
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 1671-1702: The prefill paths call compute_qk and compute_sfm_v
which read k_sf_smem/v_sf_smem for FP4 but the kernels
SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel only call
produce_kv and never populate those scale-factor tiles, so instantiating with
__nv_fp4x2_e2m1 will read uninitialized shared memory; fix by adding a
compile-time guard that prevents FP4 (e.g. static_assert or if constexpr) in
SinglePrefillWithKVCacheDevice and BatchPrefillWithRaggedKVCacheKernel (or the
wrapper that calls produce_kv) when KTraits::ScalarType == __nv_fp4x2_e2m1, OR
alternatively plumb the per-row SF loads into those kernels by ensuring
produce_kv is invoked with the SharedMemFillMode that fills k_sf_smem/v_sf_smem
(or explicitly call the SF fill helper) before compute_qk/compute_sfm_v are
executed; pick one approach and apply it consistently to both call sites
referencing compute_qk, compute_sfm_v, k_sf_smem, v_sf_smem, and produce_kv.
---
Nitpick comments:
In @.claude/memory/prefill_cuh_structure.md:
- Around line 40-62: Update the prefill overview to explicitly document the
NVFP4 scale-factor path: describe the new FP4-specific symbols maybe_k_cache_sf
and maybe_v_cache_sf, the page_produce_kv_sf path, and the shared-memory scale
buffers that compute_qk and compute_sfm_v consume; note where these are
emitted/loaded and how they flow through page_produce_kv_sf → shared-memory
buffers → compute_qk/compute_sfm_v, and add a short cross-reference to the infra
docs/skills that must be updated when changing these conventions so readers can
find the FP4 scale-factor behavior quickly.
In `@flashinfer/jit/attention/modules.py`:
- Around line 1836-1870: Extract the logic that builds the param-assignment
lines into a small helper (e.g., generate_additional_params_setter(prefix,
additional_tensor_names, additional_tensor_dtypes, additional_scalar_names)) and
replace the inline batch_additional_params_setter construction with a call to
that helper; the helper should implement the same nullable tensor/ scalar rules
currently duplicated (the conditional branch for var.startswith("maybe") and the
scalar formatting) but use the provided target prefix (e.g., "params[i]",
"params", or "params.additional_params") when formatting each assignment; update
the call sites (the batch path that currently creates
batch_additional_params_setter and any other place using
generate_additional_params output) to call the new helper so semantics remain
identical but the formatting logic is centralized.
In `@tests/attention/test_batch_attention.py`:
- Around line 308-309: The test currently forces all packed bytes non-negative
by applying "packed &= 0x77", so negative NVFP4 (signed E2M1) values are never
exercised; either remove the mask expression "packed &= 0x77" to allow both
signs, or add a second test case that constructs a signed-data variant (e.g.,
copy the existing "packed" and set the sign bits for NVFP4 by OR-ing the
appropriate bits such as 0x08 and/or 0x80) and run the same assertions on that
signed copy so both unsigned and negative NVFP4 paths are covered.
In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1074-1083: The test matrix currently forces causal=False so the
NVFP4 kernel path never hits masked/tail-tile logic; update the
test_batch_prefill_with_paged_kv_cache_nvfp4 parameterization (the
`@pytest.mark.parametrize`("causal", ...) on the test) to include True (e.g.,
[False, True]) so at least one run exercises the causal/masked path for NVFP4;
keep existing q_dtype values unchanged so the NVFP4 path is still exercised.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9c342064-89d2-4570-b892-646302704193
📒 Files selected for processing (18)
.claude/memory/MEMORY.md.claude/memory/prefill_cuh_structure.mdbenchmarks/routines/attention.pyflashinfer/attention.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyflashinfer/utils.pyinclude/flashinfer/attention/persistent.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuhinclude/flashinfer/frag_layout_swizzle.cuhinclude/flashinfer/permuted_smem.cuhinclude/flashinfer/vec_dtypes.cuhmha_ref.cutests/attention/test_batch_attention.pytests/attention/test_batch_prefill_kernels.py
👮 Files not reviewed due to content moderation or server errors (7)
- flashinfer/utils.py
- .claude/memory/MEMORY.md
- flashinfer/jit/utils.py
- flashinfer/attention.py
- include/flashinfer/permuted_smem.cuh
- flashinfer/prefill.py
- flashinfer/quantization/fp4_quantization.py
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/attention/test_batch_prefill_kernels.py (1)
1151-1165: Please add an asymmetrichead_dim_qk/head_dim_vocase here.
head_dim_vois omitted, so this suite only covers the defaulthead_dim_vo == head_dim_qkpath. A 192/128-style case would exercise the packed-V sizing logic that the current fixture cannot catch. The same omission exists intests/attention/test_batch_attention.py, so it would be worth updating both together.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_prefill_kernels.py` around lines 1151 - 1165, The test only exercises symmetric head sizes because wrapper.plan is called with head_dim only; add an asymmetric case by passing explicit head_dim_qk and head_dim_vo arguments to wrapper.plan (for example head_dim_qk=192, head_dim_vo=128) so the packed-V sizing logic is exercised; update the wrapper.plan invocation in tests/attention/test_batch_prefill_kernels.py (and mirror the same change in tests/attention/test_batch_attention.py) to include these two explicit parameters instead of relying on the default head_dim equality.include/flashinfer/attention/prefill.cuh (1)
449-498: Consider adding null check for defensive programming.If
sf_ptrisnullptrbutis_fp4_type_v<DTypeKV>is true, the function computes offsets and callspred_load_32bwith an invalid source pointer. While the design assumes FP4 usage implies scales are provided, a null check would add robustness:template <bool produce_v, typename KTraits, typename IdType> __device__ __forceinline__ void page_produce_kv_sf( typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr, ...) { if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return; + if (sf_ptr == nullptr) return;This prevents undefined behavior if FP4 is compiled but scales are accidentally omitted at runtime.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/prefill.cuh` around lines 449 - 498, The function page_produce_kv_sf may dereference sf_ptr when is_fp4_type_v<typename KTraits::DTypeKV> is true; add a defensive null check at the start of page_produce_kv_sf (after the is_fp4_type_v constexpr) that returns early if sf_ptr == nullptr to avoid computing sf_gmem_offset and calling cp_async::pred_load_32b with an invalid source pointer; keep the check independent of produce_v and ensure it triggers before the NUM_SF_ITERS loop so symbols page_produce_kv_sf, sf_ptr, is_fp4_type_v, and cp_async::pred_load_32b are addressed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 737-738: The code currently assumes packed NVFP4 V/O width equals
q.shape[-1] when kv_block_scales (or key_block_scales/value_block_scales)
indicate a packed V cache, which breaks configs where head_dim_vo !=
head_dim_qk; update the run path to use the planned head_dim_vo instead of
q.shape[-1] for packed outputs and persist the planned head_dim_vo from plan()
onto self (e.g., self.head_dim_vo) so run() can read it; adjust any branches
that check kv_block_scales/key_block_scales/value_block_scales to select
self.head_dim_vo as the V/O width when packed is detected.
In `@include/flashinfer/cp_async.cuh`:
- Around line 192-223: pred_load_128b_from_64b: ensure the cp.async path zeroes
the upper 8 bytes to match the fallback by changing the assembly copy size to 16
while keeping src-size=8 (i.e. use cp.async.ca.shared.global with cp-size=16,
src-size=8) in both the fill-mode (kFillZero) branch and the kNoFill branch so
the upper half of the 16-byte slot is zero-padded when only 8 bytes are sourced;
keep the predicate logic and the fallback (smem_u64[1] = 0) unchanged.
---
Nitpick comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 449-498: The function page_produce_kv_sf may dereference sf_ptr
when is_fp4_type_v<typename KTraits::DTypeKV> is true; add a defensive null
check at the start of page_produce_kv_sf (after the is_fp4_type_v constexpr)
that returns early if sf_ptr == nullptr to avoid computing sf_gmem_offset and
calling cp_async::pred_load_32b with an invalid source pointer; keep the check
independent of produce_v and ensure it triggers before the NUM_SF_ITERS loop so
symbols page_produce_kv_sf, sf_ptr, is_fp4_type_v, and cp_async::pred_load_32b
are addressed.
In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1151-1165: The test only exercises symmetric head sizes because
wrapper.plan is called with head_dim only; add an asymmetric case by passing
explicit head_dim_qk and head_dim_vo arguments to wrapper.plan (for example
head_dim_qk=192, head_dim_vo=128) so the packed-V sizing logic is exercised;
update the wrapper.plan invocation in
tests/attention/test_batch_prefill_kernels.py (and mirror the same change in
tests/attention/test_batch_attention.py) to include these two explicit
parameters instead of relying on the default head_dim equality.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bddea1b0-077d-4606-bfad-a761fbad1553
📒 Files selected for processing (14)
flashinfer/attention.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyflashinfer/utils.pyinclude/flashinfer/attention/persistent.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuhinclude/flashinfer/frag_layout_swizzle.cuhinclude/flashinfer/permuted_smem.cuhinclude/flashinfer/vec_dtypes.cuhtests/attention/test_batch_attention.pytests/attention/test_batch_prefill_kernels.py
|
[FAILED] Pipeline #46514756: 6/20 passed |
|
/bot run |
…2725) ## Summary SM120 desktop Blackwell GPUs (RTX PRO 6000, RTX 5090) are blocked from NVFP4 MoE grouped GEMM due to hardcoded SM100-only checks. **Changes:** - `jit/fused_moe.py`: Add major version 12 to `supported_major_versions` - `csrc/trtllm_fused_moe_kernel_launcher.cu`: `ICHECK_EQ(major, 10)` -> `ICHECK_GE(major, 10)` **Benchmark** (Qwen3.5-397B on 4x RTX PRO 6000 SM120): | Config | tok/s | Output | |--------|-------|--------| | compute_120f (CUDA 13.0) | 39.0 | Correct | | compute_120a (CUDA 12.8) | 14.6 | Correct (slow fallback) | | Marlin W4A16 | 46-49 | Correct | **Root cause:** All TMA WS grouped GEMM autotuner tactics fail on `compute_120a`, requiring `compute_120f` (CUDA 13.0). CuTe DSL `admissible_archs` in vendored CUTLASS also needs `sm_120a`/`sm_120f` (cpasync/copy.py, tcgen05/mma.py, arch/mbar.py, etc). Related: CUTLASS #2820, #2800; vLLM #33416, #33333; FlashInfer #2577 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Broadened GPU architecture checks to accept additional modern compute capabilities (SM 10.x and 12.x), improving compatibility and clearer SM reporting. * Improved compute-capability detection and encoding, preserving user-provided architecture suffixes and more accurately generating nvcc architecture flags. * Expanded JIT module generation to include additional CUDA majors so fused-MoE kernels run on more recent GPUs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Brandon Music <brandon.m.music@gmail.com> Co-authored-by: Brandon Music <brandonmmusic-max@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Brandon Music <brandonmusic@pop-os.tail8674da.ts.net>
|
[FAILED] Pipeline #46572392: 6/20 passed |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
1319-1367:⚠️ Potential issue | 🔴 Critical
v_scalenever affects the new single/ragged NVFP4 paths.Only
k_scaleis folded intosm_scale.v_scaleis neither forwarded nor post-applied here, so any non-unit global V scale silently returns the wrong values.🐛 Proposed fix
@@ module.run( q, k, v, @@ k_sf, v_sf, ) + is_float_one = isinstance(v_scale, float) and v_scale == 1.0 + if v_scale is not None and not is_float_one: + if is_float8(out): + out = (out.to(torch.float32) * v_scale).to(out.dtype) + else: + out *= v_scale@@ assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) + is_float_one = isinstance(v_scale, float) and v_scale == 1.0 + if v_scale is not None and not is_float_one: + if is_float8(out): + out = (out.to(torch.float32) * v_scale).to(out.dtype) + else: + out *= v_scale return (out, lse) if return_lse else outAlso applies to: 3191-3330
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1319 - 1367, The NVFP4 path folds k_scale into sm_scale but never applies v_scale, so outputs are wrong when a non-unit V scale is used; fix by applying v_scale to the produced output for the NVFP4/packed-KV case (kv_cache_sf != None) — either pass v_scale through into the prefill kernel if it supports it or multiply out by v_scale after module.run (operate on out), using the existing symbols k_scale, v_scale, sm_scale, out, kv_cache_sf and v_sf to detect the packed NVFP4 branch and apply the correct scaling.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)
1330-1341:⚠️ Potential issue | 🔴 CriticalDon't size packed-NVFP4 outputs from
q.shape[-1].Packed KV changes storage width, not the logical V/O width. These branches still assume
head_dim_vo == head_dim_qk, which breaks asymmetric QK/VO shapes.🐛 Proposed fix
@@ - out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1] + out_head_dim = v.shape[-1] * 2 if kv_cache_sf is not None else v.shape[-1]@@ if head_dim_vo is None: head_dim_vo = head_dim_qk + self._head_dim_vo = head_dim_vo@@ - out_head_dim = q.shape[-1] if kv_cache_sf is not None else v.shape[-1] + out_head_dim = ( + self._head_dim_vo + if self._cached_kv_data_type == torch.uint8 + else v.shape[-1] + )Also applies to: 3199-3212
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1330 - 1341, The code incorrectly sets out_head_dim / head_dim_vo from q.shape[-1] when kv_cache_sf is not None (packed NVFP4); packed storage width differs from logical V/O width, so stop deriving logical head_dim_vo from q.shape[-1]. Instead compute out_head_dim = v.shape[-1] (the logical V/O head dim) regardless of kv_cache_sf, and pass that value as the head_dim_vo argument to get_single_prefill_module; also update the other symmetric block that mirrors this logic (the later occurrence handling packed KV) to use v.shape[-1] rather than q.shape[-1]. Ensure any allocation of out uses this logical out_head_dim and that get_single_prefill_module receives q.shape[-1] for head_dim_qk and v.shape[-1] for head_dim_vo.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3172-3177: The code currently always folds q_scale and k_scale
into sm_scale (in the sm_scale computation near sm_scale = 1.0 /
math.sqrt(q.size(-1))), which causes double-scaling when the ragged cuDNN path
is used because the cuDNN call below also receives q_scale and k_scale
separately; change the logic to mirror the paged wrapper's cuDNN guard used
elsewhere: only multiply sm_scale by q_scale and k_scale when NOT taking the
cuDNN ragged path (i.e., when cuDNN is not used), otherwise leave sm_scale as
the geometric/default scale and pass q_scale/k_scale unchanged to the cuDNN
call; apply the same fix to the second occurrence around lines 3245-3260 so both
regions use the same cuDNN guard and avoid double-scaling.
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 461-510: page_produce_kv_sf can dereference sf_ptr when in_bounds
is true; add a null-pointer guard at the top of page_produce_kv_sf (the FP4-only
branch) to avoid dereferencing sf_ptr (e.g., if (sf_ptr == nullptr) return; or
otherwise ensure all in_bounds are false when sf_ptr is null) so that the
subsequent call to cp_async::pred_load_32b(...) never receives sf_ptr +
sf_gmem_offset when sf_ptr is null.
In `@tests/attention/test_batch_prefill_kernels.py`:
- Around line 1030-1165: The NVFP4 tests (e.g.,
test_batch_prefill_with_paged_kv_cache_nvfp4) run unconditionally on all GPUs;
add the repo-standard compute-capability guard using flashinfer.utils to skip
unsupported architectures: import flashinfer.utils and at the start of the test
call get_compute_capability()/is_sm90a_supported()/is_sm100a_supported() (or
directly use is_sm90a_supported() or is_sm100a_supported()) and call
pytest.skip(...) when neither is supported; apply the same guard to the other
NVFP4 test block (the one referenced in the comment for lines 1168-1260) so
unsupported runners skip instead of failing.
In `@tests/attention/test_single_prefill.py`:
- Around line 107-160: The test test_single_prefill_with_kv_cache_nvfp4 must be
gated by GPU compute capability: import and call the repo-standard helpers from
flashinfer.utils (get_compute_capability(), is_sm90a_supported(),
is_sm100a_supported()) at the start of the test (inside
test_single_prefill_with_kv_cache_nvfp4) and skip the test when the current GPU
does not support NVFP4 (i.e., when neither is_sm90a_supported() nor
is_sm100a_supported() is true); use pytest.skip or pytest.mark.skipif with a
clear reason so unsupported runners skip cleanly.
---
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 1319-1367: The NVFP4 path folds k_scale into sm_scale but never
applies v_scale, so outputs are wrong when a non-unit V scale is used; fix by
applying v_scale to the produced output for the NVFP4/packed-KV case
(kv_cache_sf != None) — either pass v_scale through into the prefill kernel if
it supports it or multiply out by v_scale after module.run (operate on out),
using the existing symbols k_scale, v_scale, sm_scale, out, kv_cache_sf and v_sf
to detect the packed NVFP4 branch and apply the correct scaling.
---
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 1330-1341: The code incorrectly sets out_head_dim / head_dim_vo
from q.shape[-1] when kv_cache_sf is not None (packed NVFP4); packed storage
width differs from logical V/O width, so stop deriving logical head_dim_vo from
q.shape[-1]. Instead compute out_head_dim = v.shape[-1] (the logical V/O head
dim) regardless of kv_cache_sf, and pass that value as the head_dim_vo argument
to get_single_prefill_module; also update the other symmetric block that mirrors
this logic (the later occurrence handling packed KV) to use v.shape[-1] rather
than q.shape[-1]. Ensure any allocation of out uses this logical out_head_dim
and that get_single_prefill_module receives q.shape[-1] for head_dim_qk and
v.shape[-1] for head_dim_vo.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5b8dcc69-72db-4f1c-ac39-46beb9a4da1e
📒 Files selected for processing (7)
flashinfer/jit/attention/modules.pyflashinfer/prefill.pyinclude/flashinfer/attention/prefill.cuhtests/attention/test_batch_attention.pytests/attention/test_batch_prefill_kernels.pytests/attention/test_single_prefill.pytests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/jit/attention/modules.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/attention/test_batch_attention.py
|
[FAILED] Pipeline #47354783: 6/20 passed |
865f912 to
067bd9d
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/prefill.py (2)
599-629:⚠️ Potential issue | 🟠 MajorFake op signature missing FP8 scale parameters.
_fake_ragged_runis missingscale_q,scale_k,scale_vparameters that are present inragged_run(lines 503-505). This will cause signature mismatches during tracing.🐛 Proposed fix
`@register_fake_op`(f"flashinfer::{uri}_ragged_run") def _fake_ragged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, maybe_custom_mask: Optional[torch.Tensor], maybe_mask_indptr: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], maybe_prefix_len_ptr: Optional[torch.Tensor], maybe_token_pos_in_items_ptr: Optional[torch.Tensor], maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, token_pos_in_items_len: int, maybe_k_cache_sf: Optional[torch.Tensor] = None, maybe_v_cache_sf: Optional[torch.Tensor] = None, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, ) -> None: pass🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 599 - 629, The fake op _fake_ragged_run must match the real ragged_run signature: add the FP8 scale parameters scale_q, scale_k, scale_v to _fake_ragged_run with the same types and positions used in ragged_run so tracing won't fail; update the function signature for register_fake_op("flashinfer::{uri}_ragged_run") to include scale_q, scale_k, scale_v (use the same Optional/torch.Tensor typing and default values as ragged_run) and leave the body as pass.
421-441:⚠️ Potential issue | 🟠 MajorFake op signature missing scale parameters.
_fake_run_single_prefillis missingscale_q,scale_k,scale_vparameters that are present in the realrun_single_prefillfunction (lines 351-353). This signature mismatch can cause issues withtorch.compileand other JIT tracing scenarios.🐛 Proposed fix
`@register_fake_op`(f"flashinfer::{uri}_run") def _fake_run_single_prefill( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode: int, layout: int, window_left: int, maybe_packed_custom_mask: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, maybe_k_cache_sf: Optional[torch.Tensor] = None, maybe_v_cache_sf: Optional[torch.Tensor] = None, ) -> None: pass🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 421 - 441, The fake op _fake_run_single_prefill has a signature mismatch: add the missing scale parameters scale_q, scale_k, scale_v to its parameter list so it exactly matches the real run_single_prefill signature; ensure the new parameters use the same names, types/order and defaulting as in run_single_prefill (place them before maybe_k_cache_sf/maybe_v_cache_sf like the real function) so torch.compile/JIT tracing sees an identical call signature.
♻️ Duplicate comments (1)
include/flashinfer/attention/prefill.cuh (1)
461-510:⚠️ Potential issue | 🔴 CriticalFail fast when NVFP4 scale tensors are missing.
maybe_k_cache_sf/maybe_v_cache_sfdefault tonullptr, and these helpers still issuepred_load_32bthroughsf_ptrwhenever the FP4 path is instantiated. A legacy or malformed caller will turn that into a null global-memory load.🛡️ Possible fix
template <bool produce_v, typename KTraits, typename IdType> __device__ __forceinline__ void page_produce_kv_sf( typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr, const uint32_t packed_page_iter_base, const uint32_t packed_kv_bound, const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h, const uint32_t kv_stride_n, const uint_fastdiv& page_size, const IdType* indices, const uint32_t kv_idx_base, const uint32_t kv_len, const uint32_t warp_idx, const uint32_t lane_idx) { if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return; + if (sf_ptr == nullptr) { + FLASHINFER_RUNTIME_ASSERT("NVFP4 KV cache requires block scale tensors."); + } @@ template <bool produce_v, typename KTraits> __device__ __forceinline__ void produce_kv_sf(typename KTraits::SharedStorage* smem_storage, uint8_t* sf_ptr, const uint32_t kv_abs_base, const uint32_t kv_head_idx, const uint32_t kv_stride_n, const uint32_t kv_stride_h, const uint32_t kv_idx_base, const uint32_t kv_len, const uint32_t warp_idx, const uint32_t lane_idx) { if constexpr (!is_fp4_type_v<typename KTraits::DTypeKV>) return; + if (sf_ptr == nullptr) { + FLASHINFER_RUNTIME_ASSERT("NVFP4 KV cache requires block scale tensors."); + }Also applies to: 535-576
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/prefill.cuh` around lines 461 - 510, The kernel page_produce_kv_sf can attempt pred_load_32b through a null sf_ptr when the FP4 SF cache pointers (maybe_k_cache_sf / maybe_v_cache_sf passed as sf_ptr) are nullptr; guard against this by checking sf_ptr (or the original maybe_k_cache_sf/maybe_v_cache_sf) before issuing the cp_async load. Concretely, inside page_produce_kv_sf (and the analogous block at lines 535-576) update the in_bounds predicate to also require sf_ptr != nullptr (or return/skip early when sf_ptr is null) so cp_async::pred_load_32b is only called when sf_ptr is valid.
🧹 Nitpick comments (2)
tests/test_helpers/utils_fp4.py (1)
103-128: Don't hardcode the NVFP4 global scale to1.0.All tests that build fixtures through this helper now bypass the new
k_scale/v_scaleplumbing. A bug in the global-scale path would still pass because both the kernel and the reference see the identity scale.♻️ Possible tweak
-def create_nvfp4_kv(shape, device): +def create_nvfp4_kv(shape, device, global_scale=1.0): @@ - return packed, sf, torch.tensor(1.0, device=device) + return packed, sf, torch.tensor(global_scale, device=device, dtype=torch.float32)tests/attention/test_batch_attention.py (1)
293-296: Use the shared skip helpers here instead of another raw CCxfail.If SM120/121 is still unsupported for this case, make it a helper-based skip; otherwise this marker hides failures on the exact architecture the new NVFP4 path needs to cover.
As per coding guidelines,
tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(),is_sm90a_supported(),is_sm100a_supported()) to skip tests on unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_attention.py` around lines 293 - 296, Replace the raw pytest.mark.xfail with a shared helper-based skip using the utilities in flashinfer.utils: import get_compute_capability (or the appropriate helper) from flashinfer.utils and change the marker on the test that currently uses pytest.mark.xfail(get_compute_capability(torch.device(device="cuda"))[0] == 12, ...) to pytest.mark.skipif(get_compute_capability(torch.device("cuda"))[0] == 12, reason="SM120/121 unsupported for this test") or, if a dedicated helper exists (e.g., is_sm120_supported()), use that helper instead to decide skipping; update the decorator on the test accordingly and remove the raw xfail usage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1147-1149: The v_scale parameter is accepted in the function
signature alongside kv_cache_sf and k_scale but never applied to the produced
output; either remove v_scale or apply it consistently like
BatchPrefillWithPagedKVCacheWrapper.run(). Locate the single-request prefill
function that declares kv_cache_sf, k_scale, v_scale (same signature shown) and
add the same post-output scaling logic used in
BatchPrefillWithPagedKVCacheWrapper.run(): after computing the output, if
v_scale is not None and not is_float_one then multiply the output tensor by
v_scale (or conversely remove v_scale from the signature and all callers if this
mode shouldn't support value scaling). Ensure references to k_scale/kv_cache_sf
behavior remain unchanged.
- Around line 471-473: The mutates_args lists incorrectly include read-only
scale-factor tensors; remove "maybe_k_cache_sf" and "maybe_v_cache_sf" from
ragged_run's mutates_args and remove "key_block_scales" and "value_block_scales"
from paged_run's mutates_args, after confirming kernels do not mutate them (they
only produce new transposed tensors). Update the mutates_args declarations in
the ragged_run and paged_run call sites to omit those symbols, keeping
get_trtllm_gen_prefill_module and run_single_prefill behavior as-is, and run the
existing tests or a quick smoke-run to ensure torch.compile optimizations no
longer treat these tensors as mutated.
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback _e2m1_and_ufp8sf_scale_to_float_cpu can
receive a flat 1-D ufp8_scale_tensor (length either K/sf_vec_size or
M*(K/sf_vec_size)) and currently treats it as already shaped [M, K/sf_vec_size],
causing wrong broadcasting; before decoding the UFP8 scales, detect and
reshape/expand ufp8_scale_tensor: if ufp8_scale_tensor.dim() == 1 and its
numel() == (m * (k // sf_vec_size)) then view it as (m, k // sf_vec_size); if
numel() == (k // sf_vec_size) then expand/unsqueeze to (m, k // sf_vec_size);
otherwise if dim()==2 ensure its shape matches (m, k // sf_vec_size) and raise a
clear error if not; then proceed with the existing decoding and
repeat_interleave logic using this normalized per-row scale tensor.
---
Outside diff comments:
In `@flashinfer/prefill.py`:
- Around line 599-629: The fake op _fake_ragged_run must match the real
ragged_run signature: add the FP8 scale parameters scale_q, scale_k, scale_v to
_fake_ragged_run with the same types and positions used in ragged_run so tracing
won't fail; update the function signature for
register_fake_op("flashinfer::{uri}_ragged_run") to include scale_q, scale_k,
scale_v (use the same Optional/torch.Tensor typing and default values as
ragged_run) and leave the body as pass.
- Around line 421-441: The fake op _fake_run_single_prefill has a signature
mismatch: add the missing scale parameters scale_q, scale_k, scale_v to its
parameter list so it exactly matches the real run_single_prefill signature;
ensure the new parameters use the same names, types/order and defaulting as in
run_single_prefill (place them before maybe_k_cache_sf/maybe_v_cache_sf like the
real function) so torch.compile/JIT tracing sees an identical call signature.
---
Duplicate comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 461-510: The kernel page_produce_kv_sf can attempt pred_load_32b
through a null sf_ptr when the FP4 SF cache pointers (maybe_k_cache_sf /
maybe_v_cache_sf passed as sf_ptr) are nullptr; guard against this by checking
sf_ptr (or the original maybe_k_cache_sf/maybe_v_cache_sf) before issuing the
cp_async load. Concretely, inside page_produce_kv_sf (and the analogous block at
lines 535-576) update the in_bounds predicate to also require sf_ptr != nullptr
(or return/skip early when sf_ptr is null) so cp_async::pred_load_32b is only
called when sf_ptr is valid.
---
Nitpick comments:
In `@tests/attention/test_batch_attention.py`:
- Around line 293-296: Replace the raw pytest.mark.xfail with a shared
helper-based skip using the utilities in flashinfer.utils: import
get_compute_capability (or the appropriate helper) from flashinfer.utils and
change the marker on the test that currently uses
pytest.mark.xfail(get_compute_capability(torch.device(device="cuda"))[0] == 12,
...) to pytest.mark.skipif(get_compute_capability(torch.device("cuda"))[0] ==
12, reason="SM120/121 unsupported for this test") or, if a dedicated helper
exists (e.g., is_sm120_supported()), use that helper instead to decide skipping;
update the decorator on the test accordingly and remove the raw xfail usage.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 980616bb-d5f7-4d8c-962b-1548d733a9a3
📒 Files selected for processing (16)
flashinfer/attention.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyflashinfer/utils.pyinclude/flashinfer/attention/persistent.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuhinclude/flashinfer/frag_layout_swizzle.cuhinclude/flashinfer/permuted_smem.cuhinclude/flashinfer/vec_dtypes.cuhtests/attention/test_batch_attention.pytests/attention/test_batch_prefill_kernels.pytests/attention/test_single_prefill.pytests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (2)
- flashinfer/utils.py
- flashinfer/jit/utils.py
🚧 Files skipped from review as they are similar to previous changes (8)
- flashinfer/attention.py
- include/flashinfer/permuted_smem.cuh
- tests/attention/test_single_prefill.py
- include/flashinfer/frag_layout_swizzle.cuh
- include/flashinfer/cp_async.cuh
- tests/attention/test_batch_prefill_kernels.py
- include/flashinfer/vec_dtypes.cuh
- flashinfer/jit/attention/modules.py
|
/bot run |
|
[FAILED] Pipeline #47420565: 5/20 passed |
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
jit: add torch.float4_e2m1fn_x2 to dtype maps
Add conditional entries for torch.float4_e2m1fn_x2 in
filename_safe_dtype_map ("fp4_e2m1") and dtype_map_kv
("__nv_fp4x2_e2m1") so that BatchPrefillWithPagedKVCacheWrapper
can select the FP4 kernel plan without KeyError when
kv_data_type=float4_e2m1fn_x2 is passed to begin_forward.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
fix batch decode UT
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
fix batch decode function and accuracy; add nvfp4 kv test
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
eb98c96 to
7be54dc
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)
1437-1452:⚠️ Potential issue | 🔴 CriticalRemove the prefill-only placeholders from decode's TRT-LLM
paged_runcall.
get_trtllm_gen_decode_module().paged_run()does not accept themax_q_len/batch_size/cum_seq_lens_*slots you added here. With the current list,self._max_kv_lenbinds tosinks,sinksbinds touses_shared_paged_kv_idx, and the remaining args overflow the wrapper, so this branch will raise as soon as TRT-LLM decode runs.Proposed fix
run_args += [ self._num_qo_heads, self._num_kv_heads, self._block_tables, self._kv_lens_buffer, page_size, - None, # max_q_len (not applicable for decode) self._max_kv_len, - None, # batch_size (not applicable for decode) - None, # cum_seq_lens_q (not applicable for decode) - None, # cum_seq_lens_kv (not applicable for decode) sinks, key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, True, # uses_shared_paged_kv_idx ]Also applies to: 2085-2128
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1437 - 1452, The decode branch is passing prefill-only placeholder slots into get_trtllm_gen_decode_module().paged_run(), causing argument misalignment: remove the None placeholders for max_q_len, batch_size, cum_seq_lens_q and cum_seq_lens_kv from the run_args list (the entries between page_size and self._max_kv_len) so the call to paged_run receives the correct parameters (ensure sinks, key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor and uses_shared_paged_kv_idx bind to the intended positions); apply the same removal in the other decode occurrence that mirrors this block.
♻️ Duplicate comments (1)
include/flashinfer/attention/prefill.cuh (1)
462-516:⚠️ Potential issue | 🟠 MajorRequire non-null SF pointers before issuing FP4 SF loads.
The FP4 SF helpers still form
sf_ptr + sf_gmem_offsetunconditionally, but all three call sites defaultmaybe_k_cache_sf/maybe_v_cache_sftonullptrwhen the params pack does not expose scale tensors. Any FP4 specialization reached without SF tensors will dereference a null base pointer beforepred_load_32bcan predicate the load.If missing SF tensors are invalid, fail earlier; if they are meant to be optional, the helpers need an explicit null-handling path.
Also applies to: 542-585, 1644-1651, 2064-2071, 2411-2418
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/attention/prefill.cuh` around lines 462 - 516, The FP4 SF load logic (e.g., in page_produce_kv_sf) unconditionally computes sf_ptr + sf_gmem_offset and passes it to cp_async::pred_load_32b, which will dereference a null when maybe_k_cache_sf/maybe_v_cache_sf are nullptr; add explicit null-handling: before forming sf_gmem_offset or calling pred_load_32b in page_produce_kv_sf (and the other FP4 SF helper call sites), test if sf_ptr==nullptr and either (a) fail fast with a clear error if SF tensors are required or (b) use a safe no-op path that supplies a dummy/valid pointer and sets in_bounds=false (so pred_load_32b is fully predicated) when SF is optional; update all related helpers/call sites (maybe_k_cache_sf, maybe_v_cache_sf, and other FP4 SF helper functions that call pred_load_32b) accordingly to avoid null pointer arithmetic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 110-117: Dispatcher SMEM accounting ignores the new FP4
scale-factor buffers in SharedStorageQKVO (k_sf_smem/v_sf_smem), causing
NUM_MMA_KV/DISPATCH_NUM_MMA_KV to pick configs that overflow on FP4 builds;
update the dispatcher math that computes num_ctas_per_sm and max_num_mma_kv_smem
to include the additional bytes of k_sf_smem and v_sf_smem when
is_fp4_type_v<DTypeKV> is true (i.e., add CTA_TILE_KV * HEAD_DIM_QK /
NVFP4_SF_VEC_SIZE and CTA_TILE_KV * HEAD_DIM_VO / NVFP4_SF_VEC_SIZE bytes
respectively, aligned as in SharedStorageQKVO) and apply the same change at the
other occurrences noted (around the other ranges) so the per-CTA dynamic SMEM
budget reflects the SF buffers before DISPATCH_NUM_MMA_KV selection.
In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-691: The test_batch_decode_with_paged_kv_cache_nvfp4 test
lacks a GPU architecture guard; add a skip decorator using is_sm100a_supported()
via `@pytest.mark.skipif`(not is_sm100a_supported(), reason="NVFP4 tests require
SM100+/Blackwell") placed before the existing `@pytest.mark.parametrize`
decorators so the test is skipped on unsupported hardware; locate the test
function name test_batch_decode_with_paged_kv_cache_nvfp4 and add the skipif
decorator consistent with other NVFP4 tests (e.g.,
tests/moe/test_trtllm_cutlass_fused_moe.py).
---
Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 1437-1452: The decode branch is passing prefill-only placeholder
slots into get_trtllm_gen_decode_module().paged_run(), causing argument
misalignment: remove the None placeholders for max_q_len, batch_size,
cum_seq_lens_q and cum_seq_lens_kv from the run_args list (the entries between
page_size and self._max_kv_len) so the call to paged_run receives the correct
parameters (ensure sinks, key_block_scales, value_block_scales,
skip_softmax_threshold_scale_factor and uses_shared_paged_kv_idx bind to the
intended positions); apply the same removal in the other decode occurrence that
mirrors this block.
---
Duplicate comments:
In `@include/flashinfer/attention/prefill.cuh`:
- Around line 462-516: The FP4 SF load logic (e.g., in page_produce_kv_sf)
unconditionally computes sf_ptr + sf_gmem_offset and passes it to
cp_async::pred_load_32b, which will dereference a null when
maybe_k_cache_sf/maybe_v_cache_sf are nullptr; add explicit null-handling:
before forming sf_gmem_offset or calling pred_load_32b in page_produce_kv_sf
(and the other FP4 SF helper call sites), test if sf_ptr==nullptr and either (a)
fail fast with a clear error if SF tensors are required or (b) use a safe no-op
path that supplies a dummy/valid pointer and sets in_bounds=false (so
pred_load_32b is fully predicated) when SF is optional; update all related
helpers/call sites (maybe_k_cache_sf, maybe_v_cache_sf, and other FP4 SF helper
functions that call pred_load_32b) accordingly to avoid null pointer arithmetic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f17c6f7f-060a-4ea6-8ddb-8c31da22da20
📒 Files selected for processing (4)
flashinfer/decode.pyflashinfer/jit/utils.pyinclude/flashinfer/attention/prefill.cuhtests/attention/test_batch_decode_kernels.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/jit/utils.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (6)
flashinfer/quantization/fp4_quantization.py (1)
101-145:⚠️ Potential issue | 🟠 MajorNormalize flat scale buffers before this fallback.
The public dequant path still accepts flattened
ufp8_scale_tensor, but this implementation treats it as already shaped[M, K // sf_vec_size]. On< SM90, a 1-D unswizzled scale buffer will now either broadcast incorrectly or fail oncerepeat_interleave(..., dim=-1)runs.Suggested fix
def _e2m1_and_ufp8sf_scale_to_float_cpu( e2m1_tensor: torch.Tensor, ufp8_scale_tensor: torch.Tensor, global_scale_tensor: Optional[torch.Tensor], sf_vec_size: int, ufp8_type: int, is_sf_swizzled_layout: bool, ) -> torch.Tensor: @@ device = e2m1_tensor.device m, k_half = e2m1_tensor.shape k = k_half * 2 + expected_sf_cols = k // sf_vec_size + + if ufp8_scale_tensor.dim() == 1: + if ufp8_scale_tensor.numel() == expected_sf_cols: + ufp8_scale_tensor = ufp8_scale_tensor.unsqueeze(0).expand(m, -1) + elif ufp8_scale_tensor.numel() == m * expected_sf_cols: + ufp8_scale_tensor = ufp8_scale_tensor.reshape(m, expected_sf_cols) + else: + raise ValueError( + f"Expected {expected_sf_cols} or {m * expected_sf_cols} scale values, " + f"got {ufp8_scale_tensor.numel()}" + ) + elif tuple(ufp8_scale_tensor.shape) != (m, expected_sf_cols): + raise ValueError( + f"Expected scale tensor shape {(m, expected_sf_cols)}, " + f"got {tuple(ufp8_scale_tensor.shape)}" + ) # Unpack two E2M1 nibbles per byte: low nibble = even indices, high nibble = odd fp4_vals = torch.empty(m, k, dtype=torch.uint8, device=device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/fp4_quantization.py` around lines 101 - 145, The CPU fallback _e2m1_and_ufp8sf_scale_to_float_cpu assumes ufp8_scale_tensor is shaped [M, K//sf_vec_size]; normalize flattened inputs first: inside _e2m1_and_ufp8sf_scale_to_float_cpu, detect if ufp8_scale_tensor.dim() == 1 and if so, if its length == sf_len (where sf_len = k // sf_vec_size) repeat it across the batch to shape [m, sf_len]; if its length == m * sf_len reshape it to [m, sf_len]; ensure the tensor is moved to the same device/dtype before later ops so the later sf_float, repeat_interleave, and broadcasting use the correct shape and device.tests/attention/test_batch_decode_kernels.py (1)
684-691:⚠️ Potential issue | 🟠 MajorAdd the missing architecture skip for the NVFP4 decode test.
This new test still has no capability guard, so unsupported GPUs will fail before the decode assertions are reached. Please add a
pytest.mark.skipif(...)using the NVFP4 helper/API capability check.As per coding guidelines "Skip test execution on unsupported GPU architectures using
flashinfer.utilscheck functions (is_sm90a_supported(),is_sm100a_supported(), etc.) or API methods likeapi_name.is_compute_capability_supported(cc)"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_batch_decode_kernels.py` around lines 684 - 691, Add a skip guard to the test_batch_decode_with_paged_kv_cache_nvfp4 test so it doesn't run on unsupported GPUs: import and use the NVFP4 capability check (e.g., flashinfer.utils.is_nvfp4_supported() or the equivalent API method) in a pytest.mark.skipif(...) decorator above the test function to skip when the check returns False; ensure the decorator message explains it's skipping due to missing NVFP4 support.flashinfer/prefill.py (4)
3322-3325:⚠️ Potential issue | 🟠 MajorGuard ragged
q_scale/k_scalefolding on the backend.This still multiplies
q_scaleandk_scaleintosm_scalebefore dispatch, but the cuDNN branch below also receives both scalars separately. Ragged cuDNN calls will be double-scaled.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3322 - 3325, The code multiplies q_scale and k_scale into sm_scale unconditionally, but the cuDNN branch below also receives q_scale and k_scale separately causing double-scaling for ragged cuDNN calls; change the logic around sm_scale, q_scale, and k_scale so you only fold (multiply) q_scale/k_scale into sm_scale when not using the cuDNN path (or alternatively stop passing q_scale/k_scale separately to cuDNN), i.e., gate the sm_scale *= q_scale and sm_scale *= k_scale operations behind the condition that selects the non-cuDNN backend (use the same branch/flag used for dispatch to cuDNN), and ensure the cuDNN branch receives either folded sm_scale or the separate q_scale/k_scale but not both.
1148-1150:⚠️ Potential issue | 🟠 Major
v_scaleis still ignored in single-request prefill.
k_scaleis folded intosm_scale, butv_scalenever affectsout. That makessingle_prefill_with_kv_cache()numerically inconsistent with the paged wrapper and with its own signature.Also applies to: 1354-1379
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1148 - 1150, single_prefill_with_kv_cache ignores v_scale (only folds k_scale into sm_scale), causing numerical inconsistency with the paged wrapper and the function signature; fix by applying v_scale to the value tensor before it's used to compute out (mirror how k_scale was folded into sm_scale) inside single_prefill_with_kv_cache: when kv_cache_sf (and v_scale) are present, scale the v component (from kv_cache_sf or the produced v) by v_scale (or incorporate it into existing sm_scale logic) so that out uses the scaled values; update any related paths where kv_cache_sf is unpacked and where out is computed so both k_scale and v_scale affect the final output consistently.
1337-1347:⚠️ Potential issue | 🔴 CriticalDon’t infer packed V/O width from
q.shape[-1].
kv_cache_sfonly tells you the V cache is packed; it does not guaranteehead_dim_vo == head_dim_qk. These branches still allocate the output with Q’s width and, for the single-request path, JIT the wronghead_dim_vospecialization for asymmetric QK/VO configs.Also applies to: 2334-2346, 3347-3361
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1337 - 1347, The code incorrectly infers V/O packed width from q.shape[-1]; instead compute out_head_dim from v.shape[-1] (not q.shape[-1]) and use that value when allocating out and when calling get_single_prefill_module so the JIT specialization gets the correct head_dim_vo for asymmetric QK/VO configs (update the out = torch.empty(...) allocation and the get_single_prefill_module(...) call that currently passes q.shape[-1] to use out_head_dim/v.shape[-1]); apply the same fix to the other similar branches where out_head_dim is derived from q.shape[-1].
467-476:⚠️ Potential issue | 🟡 MinorDrop the scale-factor tensors from
mutates_args.
maybe_k_cache_sf/maybe_v_cache_sfandkey_block_scales/value_block_scalesare forwarded as read-only inputs. Marking them as mutated pessimizestorch.compilealias analysis for no gain.Also applies to: 636-647
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 467 - 476, The decorator call to register_custom_op currently lists maybe_k_cache_sf and maybe_v_cache_sf (and likewise key_block_scales/value_block_scales in the other occurrence) inside mutates_args which falsely marks these tensors as mutated; remove maybe_k_cache_sf and maybe_v_cache_sf from the mutates_args tuple (and remove key_block_scales/value_block_scales from the duplicate occurrence) so they are treated as read-only inputs by torch.compile, leaving only genuinely mutated buffers (e.g., float_workspace_buffer, int_workspace_buffer, o, maybe_lse) in mutates_args.
🧹 Nitpick comments (1)
include/flashinfer/permuted_smem.cuh (1)
176-181: Document the partial-copy contract here.
load_64b_async()writes a 64-bit source into a 128-bit shared-memory slot viapred_load_128b_from_64b. A one-line note about which half is populated and why this path was chosen over a dedicated 64-bit SMEM layout would make the later FP4 loader much harder to misuse.As per coding guidelines "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/permuted_smem.cuh` around lines 176 - 181, Add a one-line comment in load_64b_async documenting the partial-copy contract: state that calling cp_async::pred_load_128b_from_64b with a 64-bit source writes the 64-bit value into one half of the 128-bit shared-memory slot (specify which half is populated and that the other half is left unchanged/undefined), and briefly justify why this 128-bit SMEM path is used (to reuse existing b128_t/128-bit alignment and avoid a separate 64-bit SMEM layout) and note the alternative considered (a dedicated 64-bit SMEM layout) so later users of load_64b_async, b128_t, base and cp_async::pred_load_128b_from_64b cannot misuse the partial-copy behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/attention.py`:
- Around line 149-151: The code currently accepts packed NVFP4/uint8 KV caches
while allowing kv_block_scales to be None, which yields incorrect attention
results; update the functions/methods that accept the parameter kv_block_scales
(the signature showing "kv_block_scales: Optional[Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor]]] = None" and the other similar occurrence
around lines 182-187) to validate inputs and raise a clear error when a
uint8/packed KV cache is provided but kv_block_scales is None; detect the dtype
(torch.uint8) or packed-NVFP4 indicator on the kv tensors and throw a ValueError
with an explanatory message requiring per-block scales instead of silently
proceeding.
In `@flashinfer/decode.py`:
- Around line 1458-1468: The run_args list passed to
get_trtllm_gen_decode_module().paged_run() in flashinfer/decode.py currently
includes four prefill-only None placeholders after page_size, which overflows
the trtllm-gen decode wrapper; update the run_args construction in the decode
path (the code that appends self._num_qo_heads, self._num_kv_heads,
self._block_tables, self._kv_lens_buffer, page_size, ...) to only include
page_size and then self._max_kv_len (remove the subsequent None entries for
max_q_len, batch_size, cum_seq_lens_q, cum_seq_lens_kv) so the argument list
matches the paged_run() decode signature.
In `@flashinfer/prefill.py`:
- Around line 1326-1333: The code currently unpacks kv_cache_sf into k_sf,v_sf
but does not reject packed NVFP4 KV when kv_cache_sf is missing; add a fail-fast
check in the prefill path: if kv_cache_sf is None and the KV cache tensors (the
K and V tensors used in this function) have dtype torch.uint8 (packed NVFP4),
raise a clear exception (ValueError) instead of proceeding with null
scale-factor pointers; apply the same guard in the other identical prefill spot
referenced (the second occurrence corresponding to the other block used by
BatchPrefillWithPagedKVCacheWrapper.run and trtllm_batch_context_with_kv_cache)
so both paths reject packed uint8 KV when kv_cache_sf is absent and avoid
incorrect dequantization.
In `@tests/attention/test_batch_attention.py`:
- Around line 293-306: The NVFP4-specific test test_batch_attention_nvfp4 must
be gated so it skips on unsupported GPU architectures; update the test to check
the appropriate capability via flashinfer.utils (e.g., call the relevant
is_smXX_supported() helper such as is_sm90a_supported()/is_sm100a_supported() or
use the API method is_compute_capability_supported(cc)) at the start of the test
and call pytest.skip with a clear message when the capability is absent so the
fixture setup (NVFP4 kernels) is not attempted on unsupported GPUs.
---
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 3322-3325: The code multiplies q_scale and k_scale into sm_scale
unconditionally, but the cuDNN branch below also receives q_scale and k_scale
separately causing double-scaling for ragged cuDNN calls; change the logic
around sm_scale, q_scale, and k_scale so you only fold (multiply)
q_scale/k_scale into sm_scale when not using the cuDNN path (or alternatively
stop passing q_scale/k_scale separately to cuDNN), i.e., gate the sm_scale *=
q_scale and sm_scale *= k_scale operations behind the condition that selects the
non-cuDNN backend (use the same branch/flag used for dispatch to cuDNN), and
ensure the cuDNN branch receives either folded sm_scale or the separate
q_scale/k_scale but not both.
- Around line 1148-1150: single_prefill_with_kv_cache ignores v_scale (only
folds k_scale into sm_scale), causing numerical inconsistency with the paged
wrapper and the function signature; fix by applying v_scale to the value tensor
before it's used to compute out (mirror how k_scale was folded into sm_scale)
inside single_prefill_with_kv_cache: when kv_cache_sf (and v_scale) are present,
scale the v component (from kv_cache_sf or the produced v) by v_scale (or
incorporate it into existing sm_scale logic) so that out uses the scaled values;
update any related paths where kv_cache_sf is unpacked and where out is computed
so both k_scale and v_scale affect the final output consistently.
- Around line 1337-1347: The code incorrectly infers V/O packed width from
q.shape[-1]; instead compute out_head_dim from v.shape[-1] (not q.shape[-1]) and
use that value when allocating out and when calling get_single_prefill_module so
the JIT specialization gets the correct head_dim_vo for asymmetric QK/VO configs
(update the out = torch.empty(...) allocation and the
get_single_prefill_module(...) call that currently passes q.shape[-1] to use
out_head_dim/v.shape[-1]); apply the same fix to the other similar branches
where out_head_dim is derived from q.shape[-1].
- Around line 467-476: The decorator call to register_custom_op currently lists
maybe_k_cache_sf and maybe_v_cache_sf (and likewise
key_block_scales/value_block_scales in the other occurrence) inside mutates_args
which falsely marks these tensors as mutated; remove maybe_k_cache_sf and
maybe_v_cache_sf from the mutates_args tuple (and remove
key_block_scales/value_block_scales from the duplicate occurrence) so they are
treated as read-only inputs by torch.compile, leaving only genuinely mutated
buffers (e.g., float_workspace_buffer, int_workspace_buffer, o, maybe_lse) in
mutates_args.
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 101-145: The CPU fallback _e2m1_and_ufp8sf_scale_to_float_cpu
assumes ufp8_scale_tensor is shaped [M, K//sf_vec_size]; normalize flattened
inputs first: inside _e2m1_and_ufp8sf_scale_to_float_cpu, detect if
ufp8_scale_tensor.dim() == 1 and if so, if its length == sf_len (where sf_len =
k // sf_vec_size) repeat it across the batch to shape [m, sf_len]; if its length
== m * sf_len reshape it to [m, sf_len]; ensure the tensor is moved to the same
device/dtype before later ops so the later sf_float, repeat_interleave, and
broadcasting use the correct shape and device.
In `@tests/attention/test_batch_decode_kernels.py`:
- Around line 684-691: Add a skip guard to the
test_batch_decode_with_paged_kv_cache_nvfp4 test so it doesn't run on
unsupported GPUs: import and use the NVFP4 capability check (e.g.,
flashinfer.utils.is_nvfp4_supported() or the equivalent API method) in a
pytest.mark.skipif(...) decorator above the test function to skip when the check
returns False; ensure the decorator message explains it's skipping due to
missing NVFP4 support.
---
Nitpick comments:
In `@include/flashinfer/permuted_smem.cuh`:
- Around line 176-181: Add a one-line comment in load_64b_async documenting the
partial-copy contract: state that calling cp_async::pred_load_128b_from_64b with
a 64-bit source writes the 64-bit value into one half of the 128-bit
shared-memory slot (specify which half is populated and that the other half is
left unchanged/undefined), and briefly justify why this 128-bit SMEM path is
used (to reuse existing b128_t/128-bit alignment and avoid a separate 64-bit
SMEM layout) and note the alternative considered (a dedicated 64-bit SMEM
layout) so later users of load_64b_async, b128_t, base and
cp_async::pred_load_128b_from_64b cannot misuse the partial-copy behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 67117767-fd5d-42a6-90fc-eac3b4d24d08
📒 Files selected for processing (18)
flashinfer/attention.pyflashinfer/decode.pyflashinfer/jit/attention/modules.pyflashinfer/jit/utils.pyflashinfer/prefill.pyflashinfer/quantization/fp4_quantization.pyflashinfer/utils.pyinclude/flashinfer/attention/persistent.cuhinclude/flashinfer/attention/prefill.cuhinclude/flashinfer/cp_async.cuhinclude/flashinfer/frag_layout_swizzle.cuhinclude/flashinfer/permuted_smem.cuhinclude/flashinfer/vec_dtypes.cuhtests/attention/test_batch_attention.pytests/attention/test_batch_decode_kernels.pytests/attention/test_batch_prefill_kernels.pytests/attention/test_single_prefill.pytests/test_helpers/utils_fp4.py
✅ Files skipped from review due to trivial changes (2)
- include/flashinfer/cp_async.cuh
- include/flashinfer/vec_dtypes.cuh
🚧 Files skipped from review as they are similar to previous changes (5)
- flashinfer/utils.py
- tests/test_helpers/utils_fp4.py
- flashinfer/jit/utils.py
- tests/attention/test_single_prefill.py
- flashinfer/jit/attention/modules.py
|
/bot run |
|
Transfer to #3097 |
📌 Description
This MR supports NVFP4 KV input for batch prefill and batch attention kernels. It widely supports all arch.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit