Skip to content

Add int4 paged KV support to main paths#3101

Open
lesj0610 wants to merge 10 commits intoflashinfer-ai:mainfrom
lesj0610:codex/int4-paged-kv-main-v068
Open

Add int4 paged KV support to main paths#3101
lesj0610 wants to merge 10 commits intoflashinfer-ai:mainfrom
lesj0610:codex/int4-paged-kv-main-v068

Conversation

@lesj0610
Copy link
Copy Markdown

@lesj0610 lesj0610 commented Apr 17, 2026

📌 Description

Builds on the int8 paged-KV work in #3100 to add int4 support.

Incremental review against the int8 branch:
lesj0610/flashinfer@codex/int8-paged-kv-main-v068...codex/int4-paged-kv-main-v068

torch.uint8 is already used in some paths as an FP4 container, so a plain uint8 input creates a semantic conflict. An explicit INT4Tensor wrapper is used to keep the contract unambiguous. Storage is packed uint8 with grouped fp16 scales (group_size=32).

The implementation goes through staged dequantization to fp16 before calling existing kernels. On Hopper, auto backend selection falls back to FA2 the same way as in #3100. The following are not included in this PR:

  • CUDA graph: explicitly blocked, as the staging step requires temporary allocation
  • Native FA3, XQA, and TRTLLM-gen int4 paths

Because the upstream PR base must stay on main, GitHub still shows the shared int8 commits in the main PR diff until #3100 lands. The compare link above shows the int4-only incremental delta.

Tested on Ampere (A100) and Hopper (H100):

python -m pytest tests/attention/test_int4_paged_kv.py -v

51 tests passed on both architectures.

🔍 Related Issues

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Depends on #3100.

Retargeted against main after the v0.6.8 release branch cut.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for INT4 and INT8 paged KV caches, introducing the INT4Tensor wrapper and associated quantization utilities. The reviewer identified significant performance risks due to full-cache dequantization during inference and pointed out bugs in the handling of quantization scales, including potential FFI type mismatches and tensor broadcasting errors. A suggestion was also provided to use torch.float32 for better GPU efficiency.

Comment thread flashinfer/decode.py
Comment on lines +1330 to +1333
k_cache, v_cache = _dequantize_int4_paged_kv_cache(
paged_kv_cache,
self._kv_layout,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dequantizing the entire paged_kv_cache on every run call is extremely inefficient, especially for large caches. This operation allocates a new float16 tensor of the same size as the full cache (which is 4x larger than the int4 storage) and performs a full-pass dequantization kernel. This will likely lead to GPU OOM for large KV caches and significant latency overhead. Ideally, dequantization should be fused into the attention kernel or at least limited to the active pages indexed by the current request.

Comment thread flashinfer/prefill.py Outdated
Comment on lines +1311 to +1313
sm_scale *= scale_q
if scale_k is not None:
sm_scale *= scale_k
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If scale_q or scale_k are tensors (as indicated by the type hints and documentation), sm_scale will become a tensor. The fa2 kernel FFI expects a scalar float, so this will cause a TypeError when calling the kernel. Additionally, if is_float8(q) is true and backend != "fa3", these scales are currently ignored entirely in the fa2 path, leading to incorrect results.

Comment thread flashinfer/prefill.py Outdated
):
out = (out.to(float) * scale_v).to(out.dtype)
else:
out *= scale_v
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The in-place multiplication out *= scale_v will fail due to shape mismatch if scale_v is a tensor of shape [num_kv_heads] and out has shape [total_tokens, num_qo_heads, head_dim]. Even if num_qo_heads == num_kv_heads, broadcasting will attempt to match the last dimension (head_dim). scale_v should be reshaped to (1, -1, 1) and, in GQA cases, expanded to match num_qo_heads.

Comment thread flashinfer/prefill.py
Comment on lines +2286 to +2289
paged_kv_cache,
self._kv_layout,
)
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dequantizing the entire paged_kv_cache in the run method is inefficient for large caches and can lead to GPU OOM, as it allocates a full float16 copy of the cache pool. This should be optimized to only dequantize active pages or fused into the kernel.

Comment thread flashinfer/prefill.py Outdated
torch.float8_e4m3fn,
torch.float8_e5m2,
):
out = (out.to(float) * scale_v).to(out.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using out.to(float) will cast the tensor to float64 in PyTorch. It is recommended to use torch.float32 for better performance on GPU.

Suggested change
out = (out.to(float) * scale_v).to(out.dtype)
out = (out.to(torch.float32) * scale_v).to(out.dtype)

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds INT4 quantization and runtime support for paged KV caches (INT4Tensor, int4_quantize/int4_dequantize), INT4-aware append/prefill/decode flows and planner changes, int8 vector CUDA types, a new QKV dtype dispatch path, and extensive INT4/INT8/FP8 tests.

Changes

Cohort / File(s) Summary
C++ dispatch & page kernel
csrc/tvm_ffi_utils.h, csrc/page.cu
Introduce int8_code, _DISPATCH_CASE_I8, and new DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV macro; switch QKV dispatch in append_paged_kv_cache to the QKV-aware dispatcher.
CUDA vec types
include/flashinfer/vec_dtypes.cuh
Add vec_t<int8_t, N> specializations (1,2,4,8, >=16) plus load/store/cast/memcpy and global/volatile variants.
Python INT4 quant primitives & exports
flashinfer/quantization/packbits.py, flashinfer/quantization/__init__.py, flashinfer/__init__.py
Add int4_quantize/int4_dequantize, export them and include INT4Tensor in module/top-level exports.
Python utils & INT4 KV helpers
flashinfer/utils.py
Add INT4_GROUP_SIZE, INT4_DTYPE_NAME, is_int4_dtype, INT4Tensor type and helpers (is_int4_tensor, is_int4_paged_kv_cache, split/dequantize helpers); block FA3 for int8 KV.
Paged KV append (INT4 path)
flashinfer/page.py
Widen append_paged_kv_cache to accept INT4 cache forms; add INT4 branch that quantizes inputs, splits cache views, validates shapes/groups, computes page indices/offsets, and writes packed nibbles and scales into pages for both layouts.
Prefill & decode runtime/planner changes
flashinfer/prefill.py, flashinfer/decode.py
Widen single APIs to accept INT4Tensor, add INT4 detection/dequantization paths, force/select fa2 backend for INT4, reject incompatible scale args and CUDA-graph for INT4, and swap unpack→dequantize helpers in wrappers.
Tests: INT4/INT8/FP8 coverage
tests/attention/test_int4_paged_kv.py, tests/attention/test_int8_paged_kv.py, tests/attention/test_hopper_fp8_attention.py, tests/utils/test_quantization.py
Add CUDA-gated tests validating INT4 paged-KV append/dequantize/compute, int8 paged-KV flows, an FP8 scaling regression, and int4 group-tail dequantization handling.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Client
    participant Py as flashinfer (Python)
    participant Q as Quantizer (packbits)
    participant KV as PagedKVCache (INT4 views)
    participant CU as CUDA Backend

    Client->>Py: append_paged_kv_cache(append_key, append_value, paged_kv_cache, ...)
    Py->>Py: is_int4_paged_kv_cache(...)?
    alt INT4 path
        Py->>Q: int4_quantize(append_key/value)
        Q-->>Py: INT4Tensor(packed, scale, meta)
        Py->>KV: _split_int4_paged_kv_cache_views(paged_kv_cache)
        Py->>KV: write packed nibbles + scales into pages (layout-specific)
        Py-->>Client: return
    else FP/FP16 path
        Py->>CU: _append_paged_kv_cache_kernel(...) (csrc dispatch)
        CU-->>Py: GPU write complete
        Py-->>Client: return
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • yzh119
  • cyx-6
  • samuellees
  • saltyminty
  • bkryu
  • yyihuang
  • kahyunnam
  • jimmyzho
  • nv-yunzheq

Poem

I'm a rabbit with nibble-packed cheer, 🐇
I hop through pages and stash them near.
Four-bit whispers snug in a byte,
I pack, I scale, I keep them tight.
Tiny nibbles, speedy flight! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.87% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add int4 paged KV support to main paths' clearly and directly describes the main objective of this PR—adding int4 paged KV support to the primary code paths.
Description check ✅ Passed The PR description comprehensively covers what changes are made, why they're needed, dependencies, tested configurations, and checklist items are marked complete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
csrc/page.cu (1)

93-93: Document why this path uses the QKV-specific dtype dispatcher.

This hot path now intentionally avoids the generic dispatcher in favor of DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV, but there’s no local note explaining that the split exists to admit quantized KV dtypes without widening the generic tensor dispatcher. A brief rationale here would prevent an easy “cleanup” regression later.

As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/page.cu` at line 93, Add a brief comment immediately above the call to
DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(paged_k_cache.dtype(), ...) explaining that
this hot path intentionally uses the QKV-specific dtype dispatcher to accept
quantized KV dtypes without expanding/widening the generic tensor dispatcher,
and that this decision is for performance (to avoid extra branching/widening on
the critical Q/K/V cache path); mention that the alternative considered was
using the generic dispatcher but it was rejected to avoid widening dtype support
on the hot path and regressing performance.
flashinfer/quantization/packbits.py (2)

148-217: Missing docstring content for new public APIs.

Both int4_quantize and int4_dequantize are exposed at the top level (flashinfer.int4_quantize, flashinfer.int4_dequantize) and decorated with @flashinfer_api, but the docstrings are one-liners. For public FlashInfer APIs users will grep for, please document at least:

  • Shape contract (x.shape[-1] % group_size == 0, packed last-dim is ceil(hidden_dim/2), scale last-dim is hidden_dim // group_size).
  • Quantization scheme (symmetric, scale = amax / 7, range [-8, 7], stored as unsigned nibble +8).
  • dtype supported for int4_dequantize (fp16 vs bf16 vs fp32 — currently any dtype the .to() cast accepts; clarify if bf16/fp32 are intended use-cases).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/packbits.py` around lines 148 - 217, Add detailed
docstrings to both int4_quantize and int4_dequantize: describe the input/output
shape contract (x.shape[-1] must be divisible by group_size, packed data
last-dim equals ceil(hidden_dim/2), scale last-dim equals hidden_dim //
group_size and corresponds to groups), explain the quantization scheme
(symmetric quantization with scale = amax/7.0, quantized integer range [-8,7],
stored as unsigned nibble by adding +8 and packed as two 4-bit values per byte),
and clarify supported dequantize dtypes (which types are valid for the dtype
parameter such as torch.float16, torch.bfloat16, torch.float32 and that the
implementation uses .to(dtype) on tensors). Also include brief examples of input
vs output shapes and note grouping behavior (group_size parameter and how scales
are computed per-group) in each docstring for int4_quantize and int4_dequantize.

148-165: Enforcement of hidden_dim % group_size == 0 is stricter than INT4Tensor.

INT4Tensor (in flashinfer/utils.py:795-851) allows original_shape[-1] to be non-multiple of group_size by storing scale.shape[-1] = ceil(hidden_dim / group_size). int4_quantize here refuses that case. This is fine as long as quantize is the only producer, but it's worth either:

  • documenting the stricter contract in the docstring, or
  • loosening to match the wrapper (pad to num_groups * group_size, quantize, then truncate padding in scale bookkeeping) — same treatment you already do for odd hidden_dim at lines 175-181, which is otherwise unreachable under the current divisibility precondition.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/packbits.py` around lines 148 - 165, The
int4_quantize function currently rejects inputs where hidden_dim % group_size !=
0 but INT4Tensor supports this by using ceil(hidden_dim/group_size); change
int4_quantize to compute num_groups = math.ceil(hidden_dim / group_size), pad
the input's last dimension to num_groups * group_size before quantization,
perform the existing packing logic, and ensure the resulting INT4Tensor (and its
scale bookkeeping) records the original hidden_dim so downstream
decoding/truncation works (this aligns with the existing padding logic around
the unreachable block in int4_quantize and the INT4Tensor expectations).
include/flashinfer/vec_dtypes.cuh (1)

1882-2114: LGTM — vec_t<int8_t, N> specializations mirror the uint8_t ones correctly.

  • The static_cast<uint8_t>(val) cast before widening to uint32_t/uint16_t for the broadcast fill patterns is correct and avoids sign-extension pitfalls for negative val.
  • Storage sizes (int8_t, uint16_t, uint32_t, uint2, int4[N/16]) match the uint8_t analogs.
  • cast_from/cast_load/cast_store are wired through the existing *_impl helpers, so int8↔other conversions go through the generic elementwise fallback in vec_cast<dst, src>, which is the expected path for INT4 staged dequant chains landing on int8 storage.

Minor follow-up (optional): the 1-/2-/4-/8-element and ≥16 specializations are near-byte-identical to the uint8_t ones apart from type names. Consider a future refactor that factors them via a shared template parameterized by the storage byte type to reduce duplication; not needed in this PR.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/vec_dtypes.cuh` around lines 1882 - 2114, No change
required: the vec_t<int8_t, N> specializations are correct and mirror the
uint8_t variants; keep vec_t<int8_t, 1/2/4/8 and template<size_t vec_size>
implementations as-is (fill, load, store, memcpy, cast_from/cast_load/cast_store
use cast_*_impl), but optionally consider refactoring duplicate code between
vec_t<int8_t,...> and vec_t<uint8_t,...> into a shared template parameterized by
the underlying storage type (refer to vec_t<int8_t, 1/2/4/8>, vec_t<int8_t,
vec_size>, and the fill/load/store/memcpy methods for locations to factor).
tests/attention/test_int4_paged_kv.py (1)

24-31: Nit: reference INT4_GROUP_SIZE instead of the literal 32.

_allocate_int4_tensor hardcodes the scale grouping at 32. If the default ever changes in flashinfer.utils.INT4_GROUP_SIZE, these test fixtures silently desync from the wrapper validation.

♻️ Proposed refactor
-import flashinfer
-from flashinfer import prefill as flashinfer_prefill
+import flashinfer
+from flashinfer import prefill as flashinfer_prefill
+from flashinfer.utils import INT4_GROUP_SIZE
@@
 def _allocate_int4_tensor(shape, device="cuda:0"):
     packed_dim = (shape[-1] + 1) // 2
-    scale_dim = shape[-1] // 32
+    scale_dim = shape[-1] // INT4_GROUP_SIZE
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_int4_paged_kv.py` around lines 24 - 31, The test helper
_allocate_int4_tensor hardcodes the scale grouping size as 32, which can desync
from the library constant; change the scale_dim computation to use
flashinfer.utils.INT4_GROUP_SIZE (or import INT4_GROUP_SIZE) instead of the
literal 32 so scale_dim = shape[-1] // INT4_GROUP_SIZE, leaving the rest of the
call to flashinfer.INT4Tensor unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/page.py`:
- Around line 140-178: Add a head-count guard before the indexed writes: compute
the expected number of KV heads from the paged cache (e.g. expected_heads =
k_data.shape[2] if kv_layout == "NHD" else k_data.shape[1]) and validate that
packed_key.data.shape[1] and packed_value.data.shape[1] (and if applicable
packed_key.scale/packed_value.scale second dim) equal expected_heads; if not,
raise a ValueError mentioning num_kv_heads (or expected_heads) so a
[nnz,1,packed_dim] tensor cannot silently broadcast into all heads. Ensure this
check occurs just before the block that assigns into
k_data/v_data/k_scale/v_scale.

In `@flashinfer/prefill.py`:
- Around line 1879-1891: The code mutates wrapper state by assigning
self._backend = "fa2" when INT4 KV is detected, causing future plan() calls to
be pinned to fa2; instead, keep the decision local to the current invocation:
compute self._int4_kv_enabled and effective_kv_data_type as before but do not
assign to self._backend — create a local variable (e.g., backend =
self._backend) and if self._int4_kv_enabled set backend = "fa2", then run the
subsequent backend checks and logic against that local backend variable (and
raise the same NotImplementedError if backend != "fa2"), leaving self._backend
unchanged so future calls can re-evaluate auto selection.
- Around line 1309-1314: The else branch currently multiplies tensor-typed
scale_q/scale_k into sm_scale which turns sm_scale into a tensor and breaks the
non-FP8 API; update the validation in this branch to reject tensor inputs: check
scale_q and scale_k and if either is a tensor (e.g., torch.is_tensor(scale_q) or
torch.is_tensor(scale_k)) raise a clear TypeError/ValueError indicating tensor
scales are unsupported on the non-FP8 path so callers must pass Python scalars,
ensuring sm_scale remains a scalar before module.run() is called.

In `@flashinfer/quantization/packbits.py`:
- Around line 202-217: int4_dequantize can fail when hidden_dim isn't divisible
by group_size because num_groups is recomputed via ceil; instead derive
num_groups from the authoritative x.scale shape (num_groups = x.scale.shape[-1])
and compute expected_len = num_groups * x.group_size, then if the unpacked
last-dimension is shorter than expected_len pad it on the right with the neutral
nibble value (0x08) up to expected_len before converting to q and reshaping
(affecting variables hidden_dim, unpacked, q, num_groups, x.group_size, x.scale,
and the reshape call in int4_dequantize/INT4Tensor dequantization).

In `@tests/attention/test_int4_paged_kv.py`:
- Around line 17-22: Add a module-level GPU compute-capability skip so the tests
are skipped on unsupported architectures: import and call the appropriate
flashinfer utils check (e.g., flashinfer.utils.is_sm80_supported() or the API
method api_name.is_compute_capability_supported(cc)) at top-level in this module
and call pytest.skip(...) when it returns False so INT4 paged KV tests
(implemented via FA2/staged fp16) only run on SM80+ devices; reference the
module imports (flashinfer, flashinfer.prefill) to place the guard near them.

---

Nitpick comments:
In `@csrc/page.cu`:
- Line 93: Add a brief comment immediately above the call to
DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(paged_k_cache.dtype(), ...) explaining that
this hot path intentionally uses the QKV-specific dtype dispatcher to accept
quantized KV dtypes without expanding/widening the generic tensor dispatcher,
and that this decision is for performance (to avoid extra branching/widening on
the critical Q/K/V cache path); mention that the alternative considered was
using the generic dispatcher but it was rejected to avoid widening dtype support
on the hot path and regressing performance.

In `@flashinfer/quantization/packbits.py`:
- Around line 148-217: Add detailed docstrings to both int4_quantize and
int4_dequantize: describe the input/output shape contract (x.shape[-1] must be
divisible by group_size, packed data last-dim equals ceil(hidden_dim/2), scale
last-dim equals hidden_dim // group_size and corresponds to groups), explain the
quantization scheme (symmetric quantization with scale = amax/7.0, quantized
integer range [-8,7], stored as unsigned nibble by adding +8 and packed as two
4-bit values per byte), and clarify supported dequantize dtypes (which types are
valid for the dtype parameter such as torch.float16, torch.bfloat16,
torch.float32 and that the implementation uses .to(dtype) on tensors). Also
include brief examples of input vs output shapes and note grouping behavior
(group_size parameter and how scales are computed per-group) in each docstring
for int4_quantize and int4_dequantize.
- Around line 148-165: The int4_quantize function currently rejects inputs where
hidden_dim % group_size != 0 but INT4Tensor supports this by using
ceil(hidden_dim/group_size); change int4_quantize to compute num_groups =
math.ceil(hidden_dim / group_size), pad the input's last dimension to num_groups
* group_size before quantization, perform the existing packing logic, and ensure
the resulting INT4Tensor (and its scale bookkeeping) records the original
hidden_dim so downstream decoding/truncation works (this aligns with the
existing padding logic around the unreachable block in int4_quantize and the
INT4Tensor expectations).

In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 1882-2114: No change required: the vec_t<int8_t, N>
specializations are correct and mirror the uint8_t variants; keep vec_t<int8_t,
1/2/4/8 and template<size_t vec_size> implementations as-is (fill, load, store,
memcpy, cast_from/cast_load/cast_store use cast_*_impl), but optionally consider
refactoring duplicate code between vec_t<int8_t,...> and vec_t<uint8_t,...> into
a shared template parameterized by the underlying storage type (refer to
vec_t<int8_t, 1/2/4/8>, vec_t<int8_t, vec_size>, and the fill/load/store/memcpy
methods for locations to factor).

In `@tests/attention/test_int4_paged_kv.py`:
- Around line 24-31: The test helper _allocate_int4_tensor hardcodes the scale
grouping size as 32, which can desync from the library constant; change the
scale_dim computation to use flashinfer.utils.INT4_GROUP_SIZE (or import
INT4_GROUP_SIZE) instead of the literal 32 so scale_dim = shape[-1] //
INT4_GROUP_SIZE, leaving the rest of the call to flashinfer.INT4Tensor
unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: be5b1fdf-ae85-48ed-ae52-1fb6050adc99

📥 Commits

Reviewing files that changed from the base of the PR and between 0e18a1c and dfbdea3.

📒 Files selected for processing (13)
  • csrc/page.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/__init__.py
  • flashinfer/decode.py
  • flashinfer/page.py
  • flashinfer/prefill.py
  • flashinfer/quantization/__init__.py
  • flashinfer/quantization/packbits.py
  • flashinfer/utils.py
  • include/flashinfer/vec_dtypes.cuh
  • tests/attention/test_hopper_fp8_attention.py
  • tests/attention/test_int4_paged_kv.py
  • tests/attention/test_int8_paged_kv.py

Comment thread flashinfer/page.py
Comment thread flashinfer/prefill.py
Comment thread flashinfer/quantization/packbits.py Outdated
Comment thread tests/attention/test_int4_paged_kv.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/prefill.py (1)

1930-1936: ⚠️ Potential issue | 🟠 Major

Do not persist the INT4 auto→FA2 fallback in wrapper state.

plan() currently rewrites self._backend to "fa2" for INT4. That pins future non-INT4 plan() calls on the same wrapper and bypasses normal auto selection.

Suggested fix (keep configured backend immutable; use per-plan effective backend)
@@
-        self._backend = backend
+        self._requested_backend = backend
+        self._backend = backend
@@ def plan(...):
-        if self._int4_kv_enabled:
-            if self._backend == "auto":
-                self._backend = "fa2"
-            elif self._backend != "fa2":
+        backend = self._requested_backend
+        if self._int4_kv_enabled:
+            if backend == "auto":
+                backend = "fa2"
+            elif backend != "fa2":
                 raise NotImplementedError(
                     "INT4 paged KV cache only supports the fa2/common prefill path."
                 )
@@
-            if self._backend == "auto":
-                self._backend = determine_attention_backend(
+            if backend == "auto":
+                backend = determine_attention_backend(
                     self.device,
                     PosEncodingMode[pos_encoding_mode].value,
                     use_fp16_qk_reduction,
                     self._custom_mask_buf is not None,  # use_custom_mask
                     q_data_type,
                     effective_kv_data_type,
                 )
@@
-                self._cached_module = get_batch_prefill_module(
-                    self._backend, *get_module_args
-                )
+                self._cached_module = get_batch_prefill_module(backend, *get_module_args)
+
+        # Persist only the effective backend for this plan/run pair.
+        self._backend = backend

Also applies to: 2085-2093

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1930 - 1936, The current plan() mutates
wrapper state by assigning self._backend = "fa2" when self._int4_kv_enabled is
true; instead leave self._backend immutable and compute a per-call
effective_backend local variable (e.g., effective_backend = "fa2" if
self._int4_kv_enabled and self._backend == "auto" else self._backend) and use
effective_backend for the INT4-only checks and subsequent logic; do the same
replacement for the similar block around the other occurrence (the block at the
later lines referencing self._int4_kv_enabled/self._backend) so plan() never
persists the auto→fa2 fallback on the wrapper instance.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 1930-1936: The current plan() mutates wrapper state by assigning
self._backend = "fa2" when self._int4_kv_enabled is true; instead leave
self._backend immutable and compute a per-call effective_backend local variable
(e.g., effective_backend = "fa2" if self._int4_kv_enabled and self._backend ==
"auto" else self._backend) and use effective_backend for the INT4-only checks
and subsequent logic; do the same replacement for the similar block around the
other occurrence (the block at the later lines referencing
self._int4_kv_enabled/self._backend) so plan() never persists the auto→fa2
fallback on the wrapper instance.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e09332f0-c248-4a84-9a77-b73d114e222c

📥 Commits

Reviewing files that changed from the base of the PR and between dfbdea3 and 346ad9e.

📒 Files selected for processing (2)
  • flashinfer/prefill.py
  • tests/attention/test_int8_paged_kv.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
flashinfer/prefill.py (1)

2399-2404: ⚠️ Potential issue | 🟠 Major

Keep sm_scale scalar on paged non-cudnn runs.

q_scale and k_scale are now allowed to be tensors, but these lines still multiply them straight into sm_scale. On the common paged path that turns a scalar kernel argument into a tensor, which is the same regression you already fixed in single_prefill_with_kv_cache(). Reuse _normalize_non_fp8_scalar_scale() here or reject tensor inputs on this branch.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 2399 - 2404, The code multiplies q_scale
and k_scale (which may be tensors on paged non-cudnn runs) directly into
sm_scale, turning it into a tensor; update the branch in the prefill flow to
keep sm_scale scalar by normalizing q_scale/k_scale before multiplication: call
_normalize_non_fp8_scalar_scale(q_scale) and
_normalize_non_fp8_scalar_scale(k_scale) (the same helper used in
single_prefill_with_kv_cache()) or explicitly reject tensor inputs, then
multiply the normalized scalar results into sm_scale so sm_scale remains a
scalar on paged non-cudnn runs.
flashinfer/quantization/packbits.py (1)

161-173: ⚠️ Potential issue | 🟠 Major

Allow partial tail groups in int4_quantize.

INT4Tensor and int4_dequantize() already support a final partial group via ceil(hidden_dim / group_size), but this guard still rejects any non-divisible last dimension. That makes _append_paged_kv_cache_int4() in flashinfer/page.py fail on otherwise-valid INT4 KV head sizes with a partial tail. Pad only for scale/quant computation, then keep the stored packed payload truncated to ceil(hidden_dim / 2) bytes.

Suggested fix
-    hidden_dim = x.shape[-1]
-    if hidden_dim % group_size != 0:
-        raise ValueError(
-            f"x.shape[-1] must be divisible by group_size, got {hidden_dim} and {group_size}"
-        )
-
     x_fp32 = x.to(torch.float32)
-    num_groups = hidden_dim // group_size
-    x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size)
+    hidden_dim = x.shape[-1]
+    num_groups = math.ceil(hidden_dim / group_size)
+    padded_hidden_dim = num_groups * group_size
+    if padded_hidden_dim != hidden_dim:
+        x_fp32 = torch.cat(
+            [
+                x_fp32,
+                torch.zeros(
+                    (*x.shape[:-1], padded_hidden_dim - hidden_dim),
+                    dtype=x_fp32.dtype,
+                    device=x_fp32.device,
+                ),
+            ],
+            dim=-1,
+        )
+    x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size)
@@
-    q_unsigned = (q + 8).to(torch.uint8).reshape(*x.shape[:-1], hidden_dim)
+    q_unsigned = (q + 8).to(torch.uint8).reshape(
+        *x.shape[:-1], padded_hidden_dim
+    )[..., :hidden_dim]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/packbits.py` around lines 161 - 173, The code
currently rejects x when hidden_dim % group_size != 0 inside int4_quantize;
instead, allow a final partial group by computing num_groups =
math.ceil(hidden_dim / group_size), pad x (e.g., with zeros) up to num_groups *
group_size for the amax/scale and quantization steps in int4_quantize, perform
grouping/scale/q on the padded tensor, then when producing the packed payload
(q_unsigned or the stored INT4Tensor bytes) truncate the result to the original
hidden_dim (i.e., ceil(hidden_dim / 2) bytes) so the stored payload matches the
actual tail size; reference int4_quantize, INT4Tensor, int4_dequantize and
_append_paged_kv_cache_int4 to ensure compatibility with existing dequantization
and KV append logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1304-1313: When int4_input is detected in the prefill path (the
block using is_int4_tensor on k/v and int4_dequantize), explicitly reject any
FP8 q by checking q's dtype (or using the project helper for float8 tensors) and
raising a clear ValueError like "FP8 q is not supported with INT4 k/v"; place
this check before dequantization or before the later FP8 assertion that compares
q.dtype to k_tensor.dtype and v_tensor.dtype (referencing variables q, k_tensor,
v_tensor, int4_input, is_int4_tensor, int4_dequantize). Do the same for the
second INT4 branch that mirrors lines 1367-1373 so both INT4 single-prefill
paths produce a clear unsupported-combination error instead of hitting the
internal FP8 assert.
- Around line 1925-1938: When INT4 paged-KV is enabled (self._int4_kv_enabled
set in plan()), reject FP8 query dtypes up front instead of allowing planning to
succeed and failing at run-time; specifically, in the same planning logic where
backend/effective_kv_data_type are decided, add a guard that checks the query
dtype (the q_data_type or similar arg used by plan()) with
is_float8(q_data_type) and raise a clear NotImplementedError/ValueError if true
(message: "INT4 paged KV cache does not support FP8 query dtypes"). This mirrors
the FA2 runner assertion (assert not is_float8(q)) and prevents FP8 queries from
being planned when INT4 paged-KV (fa2) is selected.

---

Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 2399-2404: The code multiplies q_scale and k_scale (which may be
tensors on paged non-cudnn runs) directly into sm_scale, turning it into a
tensor; update the branch in the prefill flow to keep sm_scale scalar by
normalizing q_scale/k_scale before multiplication: call
_normalize_non_fp8_scalar_scale(q_scale) and
_normalize_non_fp8_scalar_scale(k_scale) (the same helper used in
single_prefill_with_kv_cache()) or explicitly reject tensor inputs, then
multiply the normalized scalar results into sm_scale so sm_scale remains a
scalar on paged non-cudnn runs.

In `@flashinfer/quantization/packbits.py`:
- Around line 161-173: The code currently rejects x when hidden_dim % group_size
!= 0 inside int4_quantize; instead, allow a final partial group by computing
num_groups = math.ceil(hidden_dim / group_size), pad x (e.g., with zeros) up to
num_groups * group_size for the amax/scale and quantization steps in
int4_quantize, perform grouping/scale/q on the padded tensor, then when
producing the packed payload (q_unsigned or the stored INT4Tensor bytes)
truncate the result to the original hidden_dim (i.e., ceil(hidden_dim / 2)
bytes) so the stored payload matches the actual tail size; reference
int4_quantize, INT4Tensor, int4_dequantize and _append_paged_kv_cache_int4 to
ensure compatibility with existing dequantization and KV append logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 61eb8ecc-c94f-4f31-b0d3-edd8b8083435

📥 Commits

Reviewing files that changed from the base of the PR and between 346ad9e and 6ccb91e.

📒 Files selected for processing (5)
  • flashinfer/page.py
  • flashinfer/prefill.py
  • flashinfer/quantization/packbits.py
  • tests/attention/test_int4_paged_kv.py
  • tests/utils/test_quantization.py
✅ Files skipped from review due to trivial changes (1)
  • tests/attention/test_int4_paged_kv.py

Comment thread flashinfer/prefill.py
Comment thread flashinfer/prefill.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
flashinfer/utils.py (2)

853-865: INT4Tensor.unbind silently breaks when called on the packed last dim.

self.original_shape[:dim] + self.original_shape[dim + 1:] and data.select(dim, i) both run for any dim, including the final (packed) axis. The current call sites only use dim=1, so this is fine today, but unbinding along the innermost axis would produce a child whose data.shape[-1] no longer equals ceil(original_shape[-1]/2) and would blow up inside __init__ with a confusing “data last dimension must be ceil(...)” error. Consider guarding explicitly (also handling negative dim) to fail fast with a clearer message.

Proposed guard
     def unbind(self, dim: int = 0) -> Tuple["INT4Tensor", ...]:
+        ndim = len(self.original_shape)
+        normalized_dim = dim + ndim if dim < 0 else dim
+        if normalized_dim < 0 or normalized_dim >= ndim - 1:
+            raise ValueError(
+                f"INT4Tensor.unbind does not support the packed last dim; got dim={dim}"
+            )
         size = self.data.shape[dim]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 853 - 865, INT4Tensor.unbind currently
allows unbinding along the packed (last) axis which breaks invariants in
INT4Tensor.__init__ because data.select(dim, i) will yield tensors whose last
dimension no longer equals ceil(original_shape[-1]/2); modify INT4Tensor.unbind
to validate and reject attempts to unbind the packed axis (and normalize
negative dim to positive) by checking if the target dim corresponds to the
packed innermost dimension (compare dim against len(self.original_shape)-1) and
raise a clear ValueError mentioning INT4Tensor.unbind, the packed last axis, and
original_shape; keep existing behavior for all other dims and still construct
children using data.select and scale.select when allowed.

953-953: Nit (Ruff RUF005): use unpacking instead of tuple concatenation.

Proposed fix
-        original_shape = (unique_page_indices.numel(),) + x.original_shape[1:]
+        original_shape = (unique_page_indices.numel(), *x.original_shape[1:])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` at line 953, Replace the tuple concatenation used to
build original_shape with tuple unpacking: construct original_shape by putting
unique_page_indices.numel() as the first element and then unpacking the
remainder of x.original_shape starting at index 1 (i.e., use unpacking of
x.original_shape[1:] rather than (unique_page_indices.numel(),) +
x.original_shape[1:]); update the assignment to original_shape in the
function/section that references unique_page_indices and x.original_shape
accordingly.
flashinfer/decode.py (1)

994-1018: Plan-time INT4 handling looks good; minor suggestion on the NotImplementedError wording.

The INT4 gating (reject int4 for Q/O, force fa2 when auto, reject cuda-graph) is coherent and correctly ordered before canonicalize_torch_dtype so "int4" never hits getattr(torch, ...). One tiny clarification: the error at line 1009 says “only supports the fa2/common decode path,” but this wrapper only allows fa2 (not trtllm-gen/fa3). Consider a more direct message, e.g. "INT4 paged KV cache requires backend='fa2' or 'auto'; got '{self._backend}'.".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 994 - 1018, Update the NotImplementedError
message in the INT4 gating block so it clearly indicates that only backend 'fa2'
(or 'auto' which becomes 'fa2') is allowed; specifically change the exception
raised in the branch that checks self._backend != "fa2" to something like "INT4
paged KV cache requires backend='fa2' or 'auto'; got '{self._backend}'."
Reference the INT4-related symbols in this block: is_int4_dtype,
canonicalize_torch_dtype, self._int4_kv_enabled, and the conditional that
sets/validates self._backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 994-1018: Update the NotImplementedError message in the INT4
gating block so it clearly indicates that only backend 'fa2' (or 'auto' which
becomes 'fa2') is allowed; specifically change the exception raised in the
branch that checks self._backend != "fa2" to something like "INT4 paged KV cache
requires backend='fa2' or 'auto'; got '{self._backend}'." Reference the
INT4-related symbols in this block: is_int4_dtype, canonicalize_torch_dtype,
self._int4_kv_enabled, and the conditional that sets/validates self._backend.

In `@flashinfer/utils.py`:
- Around line 853-865: INT4Tensor.unbind currently allows unbinding along the
packed (last) axis which breaks invariants in INT4Tensor.__init__ because
data.select(dim, i) will yield tensors whose last dimension no longer equals
ceil(original_shape[-1]/2); modify INT4Tensor.unbind to validate and reject
attempts to unbind the packed axis (and normalize negative dim to positive) by
checking if the target dim corresponds to the packed innermost dimension
(compare dim against len(self.original_shape)-1) and raise a clear ValueError
mentioning INT4Tensor.unbind, the packed last axis, and original_shape; keep
existing behavior for all other dims and still construct children using
data.select and scale.select when allowed.
- Line 953: Replace the tuple concatenation used to build original_shape with
tuple unpacking: construct original_shape by putting unique_page_indices.numel()
as the first element and then unpacking the remainder of x.original_shape
starting at index 1 (i.e., use unpacking of x.original_shape[1:] rather than
(unique_page_indices.numel(),) + x.original_shape[1:]); update the assignment to
original_shape in the function/section that references unique_page_indices and
x.original_shape accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 367576cd-5d18-4b58-acb9-133655960159

📥 Commits

Reviewing files that changed from the base of the PR and between 6ccb91e and d54c35f.

📒 Files selected for processing (4)
  • flashinfer/decode.py
  • flashinfer/prefill.py
  • flashinfer/utils.py
  • tests/attention/test_int4_paged_kv.py
✅ Files skipped from review due to trivial changes (1)
  • tests/attention/test_int4_paged_kv.py

@lesj0610 lesj0610 force-pushed the codex/int4-paged-kv-main-v068 branch from c3af1b0 to e8da95d Compare April 17, 2026 11:21
@lesj0610
Copy link
Copy Markdown
Author

addressed in latest commits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants