Add int4 paged KV support to main paths#3101
Add int4 paged KV support to main paths#3101lesj0610 wants to merge 10 commits intoflashinfer-ai:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for INT4 and INT8 paged KV caches, introducing the INT4Tensor wrapper and associated quantization utilities. The reviewer identified significant performance risks due to full-cache dequantization during inference and pointed out bugs in the handling of quantization scales, including potential FFI type mismatches and tensor broadcasting errors. A suggestion was also provided to use torch.float32 for better GPU efficiency.
| k_cache, v_cache = _dequantize_int4_paged_kv_cache( | ||
| paged_kv_cache, | ||
| self._kv_layout, | ||
| ) |
There was a problem hiding this comment.
Dequantizing the entire paged_kv_cache on every run call is extremely inefficient, especially for large caches. This operation allocates a new float16 tensor of the same size as the full cache (which is 4x larger than the int4 storage) and performs a full-pass dequantization kernel. This will likely lead to GPU OOM for large KV caches and significant latency overhead. Ideally, dequantization should be fused into the attention kernel or at least limited to the active pages indexed by the current request.
| sm_scale *= scale_q | ||
| if scale_k is not None: | ||
| sm_scale *= scale_k |
There was a problem hiding this comment.
If scale_q or scale_k are tensors (as indicated by the type hints and documentation), sm_scale will become a tensor. The fa2 kernel FFI expects a scalar float, so this will cause a TypeError when calling the kernel. Additionally, if is_float8(q) is true and backend != "fa3", these scales are currently ignored entirely in the fa2 path, leading to incorrect results.
| ): | ||
| out = (out.to(float) * scale_v).to(out.dtype) | ||
| else: | ||
| out *= scale_v |
There was a problem hiding this comment.
The in-place multiplication out *= scale_v will fail due to shape mismatch if scale_v is a tensor of shape [num_kv_heads] and out has shape [total_tokens, num_qo_heads, head_dim]. Even if num_qo_heads == num_kv_heads, broadcasting will attempt to match the last dimension (head_dim). scale_v should be reshaped to (1, -1, 1) and, in GQA cases, expanded to match num_qo_heads.
| paged_kv_cache, | ||
| self._kv_layout, | ||
| ) | ||
| else: |
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| ): | ||
| out = (out.to(float) * scale_v).to(out.dtype) |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds INT4 quantization and runtime support for paged KV caches (INT4Tensor, int4_quantize/int4_dequantize), INT4-aware append/prefill/decode flows and planner changes, int8 vector CUDA types, a new QKV dtype dispatch path, and extensive INT4/INT8/FP8 tests. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Client
participant Py as flashinfer (Python)
participant Q as Quantizer (packbits)
participant KV as PagedKVCache (INT4 views)
participant CU as CUDA Backend
Client->>Py: append_paged_kv_cache(append_key, append_value, paged_kv_cache, ...)
Py->>Py: is_int4_paged_kv_cache(...)?
alt INT4 path
Py->>Q: int4_quantize(append_key/value)
Q-->>Py: INT4Tensor(packed, scale, meta)
Py->>KV: _split_int4_paged_kv_cache_views(paged_kv_cache)
Py->>KV: write packed nibbles + scales into pages (layout-specific)
Py-->>Client: return
else FP/FP16 path
Py->>CU: _append_paged_kv_cache_kernel(...) (csrc dispatch)
CU-->>Py: GPU write complete
Py-->>Client: return
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
csrc/page.cu (1)
93-93: Document why this path uses the QKV-specific dtype dispatcher.This hot path now intentionally avoids the generic dispatcher in favor of
DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV, but there’s no local note explaining that the split exists to admit quantized KV dtypes without widening the generic tensor dispatcher. A brief rationale here would prevent an easy “cleanup” regression later.As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/page.cu` at line 93, Add a brief comment immediately above the call to DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(paged_k_cache.dtype(), ...) explaining that this hot path intentionally uses the QKV-specific dtype dispatcher to accept quantized KV dtypes without expanding/widening the generic tensor dispatcher, and that this decision is for performance (to avoid extra branching/widening on the critical Q/K/V cache path); mention that the alternative considered was using the generic dispatcher but it was rejected to avoid widening dtype support on the hot path and regressing performance.flashinfer/quantization/packbits.py (2)
148-217: Missing docstring content for new public APIs.Both
int4_quantizeandint4_dequantizeare exposed at the top level (flashinfer.int4_quantize,flashinfer.int4_dequantize) and decorated with@flashinfer_api, but the docstrings are one-liners. For public FlashInfer APIs users will grep for, please document at least:
- Shape contract (
x.shape[-1] % group_size == 0, packed last-dim isceil(hidden_dim/2), scale last-dim ishidden_dim // group_size).- Quantization scheme (symmetric,
scale = amax / 7, range[-8, 7], stored as unsigned nibble+8).dtypesupported forint4_dequantize(fp16 vs bf16 vs fp32 — currently any dtype the.to()cast accepts; clarify if bf16/fp32 are intended use-cases).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/packbits.py` around lines 148 - 217, Add detailed docstrings to both int4_quantize and int4_dequantize: describe the input/output shape contract (x.shape[-1] must be divisible by group_size, packed data last-dim equals ceil(hidden_dim/2), scale last-dim equals hidden_dim // group_size and corresponds to groups), explain the quantization scheme (symmetric quantization with scale = amax/7.0, quantized integer range [-8,7], stored as unsigned nibble by adding +8 and packed as two 4-bit values per byte), and clarify supported dequantize dtypes (which types are valid for the dtype parameter such as torch.float16, torch.bfloat16, torch.float32 and that the implementation uses .to(dtype) on tensors). Also include brief examples of input vs output shapes and note grouping behavior (group_size parameter and how scales are computed per-group) in each docstring for int4_quantize and int4_dequantize.
148-165: Enforcement ofhidden_dim % group_size == 0is stricter thanINT4Tensor.
INT4Tensor(inflashinfer/utils.py:795-851) allowsoriginal_shape[-1]to be non-multiple ofgroup_sizeby storingscale.shape[-1] = ceil(hidden_dim / group_size).int4_quantizehere refuses that case. This is fine as long as quantize is the only producer, but it's worth either:
- documenting the stricter contract in the docstring, or
- loosening to match the wrapper (pad to
num_groups * group_size, quantize, then truncate padding in scale bookkeeping) — same treatment you already do for oddhidden_dimat lines 175-181, which is otherwise unreachable under the current divisibility precondition.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/packbits.py` around lines 148 - 165, The int4_quantize function currently rejects inputs where hidden_dim % group_size != 0 but INT4Tensor supports this by using ceil(hidden_dim/group_size); change int4_quantize to compute num_groups = math.ceil(hidden_dim / group_size), pad the input's last dimension to num_groups * group_size before quantization, perform the existing packing logic, and ensure the resulting INT4Tensor (and its scale bookkeeping) records the original hidden_dim so downstream decoding/truncation works (this aligns with the existing padding logic around the unreachable block in int4_quantize and the INT4Tensor expectations).include/flashinfer/vec_dtypes.cuh (1)
1882-2114: LGTM —vec_t<int8_t, N>specializations mirror theuint8_tones correctly.
- The
static_cast<uint8_t>(val)cast before widening touint32_t/uint16_tfor the broadcast fill patterns is correct and avoids sign-extension pitfalls for negativeval.- Storage sizes (
int8_t,uint16_t,uint32_t,uint2,int4[N/16]) match theuint8_tanalogs.cast_from/cast_load/cast_storeare wired through the existing*_implhelpers, so int8↔other conversions go through the generic elementwise fallback invec_cast<dst, src>, which is the expected path for INT4 staged dequant chains landing on int8 storage.Minor follow-up (optional): the 1-/2-/4-/8-element and ≥16 specializations are near-byte-identical to the
uint8_tones apart from type names. Consider a future refactor that factors them via a shared template parameterized by the storage byte type to reduce duplication; not needed in this PR.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/vec_dtypes.cuh` around lines 1882 - 2114, No change required: the vec_t<int8_t, N> specializations are correct and mirror the uint8_t variants; keep vec_t<int8_t, 1/2/4/8 and template<size_t vec_size> implementations as-is (fill, load, store, memcpy, cast_from/cast_load/cast_store use cast_*_impl), but optionally consider refactoring duplicate code between vec_t<int8_t,...> and vec_t<uint8_t,...> into a shared template parameterized by the underlying storage type (refer to vec_t<int8_t, 1/2/4/8>, vec_t<int8_t, vec_size>, and the fill/load/store/memcpy methods for locations to factor).tests/attention/test_int4_paged_kv.py (1)
24-31: Nit: referenceINT4_GROUP_SIZEinstead of the literal32.
_allocate_int4_tensorhardcodes the scale grouping at 32. If the default ever changes inflashinfer.utils.INT4_GROUP_SIZE, these test fixtures silently desync from the wrapper validation.♻️ Proposed refactor
-import flashinfer -from flashinfer import prefill as flashinfer_prefill +import flashinfer +from flashinfer import prefill as flashinfer_prefill +from flashinfer.utils import INT4_GROUP_SIZE @@ def _allocate_int4_tensor(shape, device="cuda:0"): packed_dim = (shape[-1] + 1) // 2 - scale_dim = shape[-1] // 32 + scale_dim = shape[-1] // INT4_GROUP_SIZE🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_int4_paged_kv.py` around lines 24 - 31, The test helper _allocate_int4_tensor hardcodes the scale grouping size as 32, which can desync from the library constant; change the scale_dim computation to use flashinfer.utils.INT4_GROUP_SIZE (or import INT4_GROUP_SIZE) instead of the literal 32 so scale_dim = shape[-1] // INT4_GROUP_SIZE, leaving the rest of the call to flashinfer.INT4Tensor unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/page.py`:
- Around line 140-178: Add a head-count guard before the indexed writes: compute
the expected number of KV heads from the paged cache (e.g. expected_heads =
k_data.shape[2] if kv_layout == "NHD" else k_data.shape[1]) and validate that
packed_key.data.shape[1] and packed_value.data.shape[1] (and if applicable
packed_key.scale/packed_value.scale second dim) equal expected_heads; if not,
raise a ValueError mentioning num_kv_heads (or expected_heads) so a
[nnz,1,packed_dim] tensor cannot silently broadcast into all heads. Ensure this
check occurs just before the block that assigns into
k_data/v_data/k_scale/v_scale.
In `@flashinfer/prefill.py`:
- Around line 1879-1891: The code mutates wrapper state by assigning
self._backend = "fa2" when INT4 KV is detected, causing future plan() calls to
be pinned to fa2; instead, keep the decision local to the current invocation:
compute self._int4_kv_enabled and effective_kv_data_type as before but do not
assign to self._backend — create a local variable (e.g., backend =
self._backend) and if self._int4_kv_enabled set backend = "fa2", then run the
subsequent backend checks and logic against that local backend variable (and
raise the same NotImplementedError if backend != "fa2"), leaving self._backend
unchanged so future calls can re-evaluate auto selection.
- Around line 1309-1314: The else branch currently multiplies tensor-typed
scale_q/scale_k into sm_scale which turns sm_scale into a tensor and breaks the
non-FP8 API; update the validation in this branch to reject tensor inputs: check
scale_q and scale_k and if either is a tensor (e.g., torch.is_tensor(scale_q) or
torch.is_tensor(scale_k)) raise a clear TypeError/ValueError indicating tensor
scales are unsupported on the non-FP8 path so callers must pass Python scalars,
ensuring sm_scale remains a scalar before module.run() is called.
In `@flashinfer/quantization/packbits.py`:
- Around line 202-217: int4_dequantize can fail when hidden_dim isn't divisible
by group_size because num_groups is recomputed via ceil; instead derive
num_groups from the authoritative x.scale shape (num_groups = x.scale.shape[-1])
and compute expected_len = num_groups * x.group_size, then if the unpacked
last-dimension is shorter than expected_len pad it on the right with the neutral
nibble value (0x08) up to expected_len before converting to q and reshaping
(affecting variables hidden_dim, unpacked, q, num_groups, x.group_size, x.scale,
and the reshape call in int4_dequantize/INT4Tensor dequantization).
In `@tests/attention/test_int4_paged_kv.py`:
- Around line 17-22: Add a module-level GPU compute-capability skip so the tests
are skipped on unsupported architectures: import and call the appropriate
flashinfer utils check (e.g., flashinfer.utils.is_sm80_supported() or the API
method api_name.is_compute_capability_supported(cc)) at top-level in this module
and call pytest.skip(...) when it returns False so INT4 paged KV tests
(implemented via FA2/staged fp16) only run on SM80+ devices; reference the
module imports (flashinfer, flashinfer.prefill) to place the guard near them.
---
Nitpick comments:
In `@csrc/page.cu`:
- Line 93: Add a brief comment immediately above the call to
DISPATCH_DLPACK_DTYPE_TO_CTYPE_QKV(paged_k_cache.dtype(), ...) explaining that
this hot path intentionally uses the QKV-specific dtype dispatcher to accept
quantized KV dtypes without expanding/widening the generic tensor dispatcher,
and that this decision is for performance (to avoid extra branching/widening on
the critical Q/K/V cache path); mention that the alternative considered was
using the generic dispatcher but it was rejected to avoid widening dtype support
on the hot path and regressing performance.
In `@flashinfer/quantization/packbits.py`:
- Around line 148-217: Add detailed docstrings to both int4_quantize and
int4_dequantize: describe the input/output shape contract (x.shape[-1] must be
divisible by group_size, packed data last-dim equals ceil(hidden_dim/2), scale
last-dim equals hidden_dim // group_size and corresponds to groups), explain the
quantization scheme (symmetric quantization with scale = amax/7.0, quantized
integer range [-8,7], stored as unsigned nibble by adding +8 and packed as two
4-bit values per byte), and clarify supported dequantize dtypes (which types are
valid for the dtype parameter such as torch.float16, torch.bfloat16,
torch.float32 and that the implementation uses .to(dtype) on tensors). Also
include brief examples of input vs output shapes and note grouping behavior
(group_size parameter and how scales are computed per-group) in each docstring
for int4_quantize and int4_dequantize.
- Around line 148-165: The int4_quantize function currently rejects inputs where
hidden_dim % group_size != 0 but INT4Tensor supports this by using
ceil(hidden_dim/group_size); change int4_quantize to compute num_groups =
math.ceil(hidden_dim / group_size), pad the input's last dimension to num_groups
* group_size before quantization, perform the existing packing logic, and ensure
the resulting INT4Tensor (and its scale bookkeeping) records the original
hidden_dim so downstream decoding/truncation works (this aligns with the
existing padding logic around the unreachable block in int4_quantize and the
INT4Tensor expectations).
In `@include/flashinfer/vec_dtypes.cuh`:
- Around line 1882-2114: No change required: the vec_t<int8_t, N>
specializations are correct and mirror the uint8_t variants; keep vec_t<int8_t,
1/2/4/8 and template<size_t vec_size> implementations as-is (fill, load, store,
memcpy, cast_from/cast_load/cast_store use cast_*_impl), but optionally consider
refactoring duplicate code between vec_t<int8_t,...> and vec_t<uint8_t,...> into
a shared template parameterized by the underlying storage type (refer to
vec_t<int8_t, 1/2/4/8>, vec_t<int8_t, vec_size>, and the fill/load/store/memcpy
methods for locations to factor).
In `@tests/attention/test_int4_paged_kv.py`:
- Around line 24-31: The test helper _allocate_int4_tensor hardcodes the scale
grouping size as 32, which can desync from the library constant; change the
scale_dim computation to use flashinfer.utils.INT4_GROUP_SIZE (or import
INT4_GROUP_SIZE) instead of the literal 32 so scale_dim = shape[-1] //
INT4_GROUP_SIZE, leaving the rest of the call to flashinfer.INT4Tensor
unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: be5b1fdf-ae85-48ed-ae52-1fb6050adc99
📒 Files selected for processing (13)
csrc/page.cucsrc/tvm_ffi_utils.hflashinfer/__init__.pyflashinfer/decode.pyflashinfer/page.pyflashinfer/prefill.pyflashinfer/quantization/__init__.pyflashinfer/quantization/packbits.pyflashinfer/utils.pyinclude/flashinfer/vec_dtypes.cuhtests/attention/test_hopper_fp8_attention.pytests/attention/test_int4_paged_kv.pytests/attention/test_int8_paged_kv.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)
1930-1936:⚠️ Potential issue | 🟠 MajorDo not persist the INT4 auto→FA2 fallback in wrapper state.
plan()currently rewritesself._backendto"fa2"for INT4. That pins future non-INT4plan()calls on the same wrapper and bypasses normal auto selection.Suggested fix (keep configured backend immutable; use per-plan effective backend)
@@ - self._backend = backend + self._requested_backend = backend + self._backend = backend @@ def plan(...): - if self._int4_kv_enabled: - if self._backend == "auto": - self._backend = "fa2" - elif self._backend != "fa2": + backend = self._requested_backend + if self._int4_kv_enabled: + if backend == "auto": + backend = "fa2" + elif backend != "fa2": raise NotImplementedError( "INT4 paged KV cache only supports the fa2/common prefill path." ) @@ - if self._backend == "auto": - self._backend = determine_attention_backend( + if backend == "auto": + backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, self._custom_mask_buf is not None, # use_custom_mask q_data_type, effective_kv_data_type, ) @@ - self._cached_module = get_batch_prefill_module( - self._backend, *get_module_args - ) + self._cached_module = get_batch_prefill_module(backend, *get_module_args) + + # Persist only the effective backend for this plan/run pair. + self._backend = backendAlso applies to: 2085-2093
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1930 - 1936, The current plan() mutates wrapper state by assigning self._backend = "fa2" when self._int4_kv_enabled is true; instead leave self._backend immutable and compute a per-call effective_backend local variable (e.g., effective_backend = "fa2" if self._int4_kv_enabled and self._backend == "auto" else self._backend) and use effective_backend for the INT4-only checks and subsequent logic; do the same replacement for the similar block around the other occurrence (the block at the later lines referencing self._int4_kv_enabled/self._backend) so plan() never persists the auto→fa2 fallback on the wrapper instance.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 1930-1936: The current plan() mutates wrapper state by assigning
self._backend = "fa2" when self._int4_kv_enabled is true; instead leave
self._backend immutable and compute a per-call effective_backend local variable
(e.g., effective_backend = "fa2" if self._int4_kv_enabled and self._backend ==
"auto" else self._backend) and use effective_backend for the INT4-only checks
and subsequent logic; do the same replacement for the similar block around the
other occurrence (the block at the later lines referencing
self._int4_kv_enabled/self._backend) so plan() never persists the auto→fa2
fallback on the wrapper instance.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e09332f0-c248-4a84-9a77-b73d114e222c
📒 Files selected for processing (2)
flashinfer/prefill.pytests/attention/test_int8_paged_kv.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
flashinfer/prefill.py (1)
2399-2404:⚠️ Potential issue | 🟠 MajorKeep
sm_scalescalar on paged non-cudnn runs.
q_scaleandk_scaleare now allowed to be tensors, but these lines still multiply them straight intosm_scale. On the common paged path that turns a scalar kernel argument into a tensor, which is the same regression you already fixed insingle_prefill_with_kv_cache(). Reuse_normalize_non_fp8_scalar_scale()here or reject tensor inputs on this branch.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 2399 - 2404, The code multiplies q_scale and k_scale (which may be tensors on paged non-cudnn runs) directly into sm_scale, turning it into a tensor; update the branch in the prefill flow to keep sm_scale scalar by normalizing q_scale/k_scale before multiplication: call _normalize_non_fp8_scalar_scale(q_scale) and _normalize_non_fp8_scalar_scale(k_scale) (the same helper used in single_prefill_with_kv_cache()) or explicitly reject tensor inputs, then multiply the normalized scalar results into sm_scale so sm_scale remains a scalar on paged non-cudnn runs.flashinfer/quantization/packbits.py (1)
161-173:⚠️ Potential issue | 🟠 MajorAllow partial tail groups in
int4_quantize.
INT4Tensorandint4_dequantize()already support a final partial group viaceil(hidden_dim / group_size), but this guard still rejects any non-divisible last dimension. That makes_append_paged_kv_cache_int4()inflashinfer/page.pyfail on otherwise-valid INT4 KV head sizes with a partial tail. Pad only for scale/quant computation, then keep the stored packed payload truncated toceil(hidden_dim / 2)bytes.Suggested fix
- hidden_dim = x.shape[-1] - if hidden_dim % group_size != 0: - raise ValueError( - f"x.shape[-1] must be divisible by group_size, got {hidden_dim} and {group_size}" - ) - x_fp32 = x.to(torch.float32) - num_groups = hidden_dim // group_size - x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size) + hidden_dim = x.shape[-1] + num_groups = math.ceil(hidden_dim / group_size) + padded_hidden_dim = num_groups * group_size + if padded_hidden_dim != hidden_dim: + x_fp32 = torch.cat( + [ + x_fp32, + torch.zeros( + (*x.shape[:-1], padded_hidden_dim - hidden_dim), + dtype=x_fp32.dtype, + device=x_fp32.device, + ), + ], + dim=-1, + ) + x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size) @@ - q_unsigned = (q + 8).to(torch.uint8).reshape(*x.shape[:-1], hidden_dim) + q_unsigned = (q + 8).to(torch.uint8).reshape( + *x.shape[:-1], padded_hidden_dim + )[..., :hidden_dim]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/packbits.py` around lines 161 - 173, The code currently rejects x when hidden_dim % group_size != 0 inside int4_quantize; instead, allow a final partial group by computing num_groups = math.ceil(hidden_dim / group_size), pad x (e.g., with zeros) up to num_groups * group_size for the amax/scale and quantization steps in int4_quantize, perform grouping/scale/q on the padded tensor, then when producing the packed payload (q_unsigned or the stored INT4Tensor bytes) truncate the result to the original hidden_dim (i.e., ceil(hidden_dim / 2) bytes) so the stored payload matches the actual tail size; reference int4_quantize, INT4Tensor, int4_dequantize and _append_paged_kv_cache_int4 to ensure compatibility with existing dequantization and KV append logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1304-1313: When int4_input is detected in the prefill path (the
block using is_int4_tensor on k/v and int4_dequantize), explicitly reject any
FP8 q by checking q's dtype (or using the project helper for float8 tensors) and
raising a clear ValueError like "FP8 q is not supported with INT4 k/v"; place
this check before dequantization or before the later FP8 assertion that compares
q.dtype to k_tensor.dtype and v_tensor.dtype (referencing variables q, k_tensor,
v_tensor, int4_input, is_int4_tensor, int4_dequantize). Do the same for the
second INT4 branch that mirrors lines 1367-1373 so both INT4 single-prefill
paths produce a clear unsupported-combination error instead of hitting the
internal FP8 assert.
- Around line 1925-1938: When INT4 paged-KV is enabled (self._int4_kv_enabled
set in plan()), reject FP8 query dtypes up front instead of allowing planning to
succeed and failing at run-time; specifically, in the same planning logic where
backend/effective_kv_data_type are decided, add a guard that checks the query
dtype (the q_data_type or similar arg used by plan()) with
is_float8(q_data_type) and raise a clear NotImplementedError/ValueError if true
(message: "INT4 paged KV cache does not support FP8 query dtypes"). This mirrors
the FA2 runner assertion (assert not is_float8(q)) and prevents FP8 queries from
being planned when INT4 paged-KV (fa2) is selected.
---
Duplicate comments:
In `@flashinfer/prefill.py`:
- Around line 2399-2404: The code multiplies q_scale and k_scale (which may be
tensors on paged non-cudnn runs) directly into sm_scale, turning it into a
tensor; update the branch in the prefill flow to keep sm_scale scalar by
normalizing q_scale/k_scale before multiplication: call
_normalize_non_fp8_scalar_scale(q_scale) and
_normalize_non_fp8_scalar_scale(k_scale) (the same helper used in
single_prefill_with_kv_cache()) or explicitly reject tensor inputs, then
multiply the normalized scalar results into sm_scale so sm_scale remains a
scalar on paged non-cudnn runs.
In `@flashinfer/quantization/packbits.py`:
- Around line 161-173: The code currently rejects x when hidden_dim % group_size
!= 0 inside int4_quantize; instead, allow a final partial group by computing
num_groups = math.ceil(hidden_dim / group_size), pad x (e.g., with zeros) up to
num_groups * group_size for the amax/scale and quantization steps in
int4_quantize, perform grouping/scale/q on the padded tensor, then when
producing the packed payload (q_unsigned or the stored INT4Tensor bytes)
truncate the result to the original hidden_dim (i.e., ceil(hidden_dim / 2)
bytes) so the stored payload matches the actual tail size; reference
int4_quantize, INT4Tensor, int4_dequantize and _append_paged_kv_cache_int4 to
ensure compatibility with existing dequantization and KV append logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 61eb8ecc-c94f-4f31-b0d3-edd8b8083435
📒 Files selected for processing (5)
flashinfer/page.pyflashinfer/prefill.pyflashinfer/quantization/packbits.pytests/attention/test_int4_paged_kv.pytests/utils/test_quantization.py
✅ Files skipped from review due to trivial changes (1)
- tests/attention/test_int4_paged_kv.py
There was a problem hiding this comment.
🧹 Nitpick comments (3)
flashinfer/utils.py (2)
853-865:INT4Tensor.unbindsilently breaks when called on the packed last dim.
self.original_shape[:dim] + self.original_shape[dim + 1:]anddata.select(dim, i)both run for anydim, including the final (packed) axis. The current call sites only usedim=1, so this is fine today, but unbinding along the innermost axis would produce a child whosedata.shape[-1]no longer equalsceil(original_shape[-1]/2)and would blow up inside__init__with a confusing “data last dimension must be ceil(...)” error. Consider guarding explicitly (also handling negativedim) to fail fast with a clearer message.Proposed guard
def unbind(self, dim: int = 0) -> Tuple["INT4Tensor", ...]: + ndim = len(self.original_shape) + normalized_dim = dim + ndim if dim < 0 else dim + if normalized_dim < 0 or normalized_dim >= ndim - 1: + raise ValueError( + f"INT4Tensor.unbind does not support the packed last dim; got dim={dim}" + ) size = self.data.shape[dim]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` around lines 853 - 865, INT4Tensor.unbind currently allows unbinding along the packed (last) axis which breaks invariants in INT4Tensor.__init__ because data.select(dim, i) will yield tensors whose last dimension no longer equals ceil(original_shape[-1]/2); modify INT4Tensor.unbind to validate and reject attempts to unbind the packed axis (and normalize negative dim to positive) by checking if the target dim corresponds to the packed innermost dimension (compare dim against len(self.original_shape)-1) and raise a clear ValueError mentioning INT4Tensor.unbind, the packed last axis, and original_shape; keep existing behavior for all other dims and still construct children using data.select and scale.select when allowed.
953-953: Nit (Ruff RUF005): use unpacking instead of tuple concatenation.Proposed fix
- original_shape = (unique_page_indices.numel(),) + x.original_shape[1:] + original_shape = (unique_page_indices.numel(), *x.original_shape[1:])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` at line 953, Replace the tuple concatenation used to build original_shape with tuple unpacking: construct original_shape by putting unique_page_indices.numel() as the first element and then unpacking the remainder of x.original_shape starting at index 1 (i.e., use unpacking of x.original_shape[1:] rather than (unique_page_indices.numel(),) + x.original_shape[1:]); update the assignment to original_shape in the function/section that references unique_page_indices and x.original_shape accordingly.flashinfer/decode.py (1)
994-1018: Plan-time INT4 handling looks good; minor suggestion on theNotImplementedErrorwording.The INT4 gating (reject int4 for Q/O, force fa2 when auto, reject cuda-graph) is coherent and correctly ordered before
canonicalize_torch_dtypeso"int4"never hitsgetattr(torch, ...). One tiny clarification: the error at line 1009 says “only supports the fa2/common decode path,” but this wrapper only allowsfa2(nottrtllm-gen/fa3). Consider a more direct message, e.g."INT4 paged KV cache requires backend='fa2' or 'auto'; got '{self._backend}'.".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 994 - 1018, Update the NotImplementedError message in the INT4 gating block so it clearly indicates that only backend 'fa2' (or 'auto' which becomes 'fa2') is allowed; specifically change the exception raised in the branch that checks self._backend != "fa2" to something like "INT4 paged KV cache requires backend='fa2' or 'auto'; got '{self._backend}'." Reference the INT4-related symbols in this block: is_int4_dtype, canonicalize_torch_dtype, self._int4_kv_enabled, and the conditional that sets/validates self._backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 994-1018: Update the NotImplementedError message in the INT4
gating block so it clearly indicates that only backend 'fa2' (or 'auto' which
becomes 'fa2') is allowed; specifically change the exception raised in the
branch that checks self._backend != "fa2" to something like "INT4 paged KV cache
requires backend='fa2' or 'auto'; got '{self._backend}'." Reference the
INT4-related symbols in this block: is_int4_dtype, canonicalize_torch_dtype,
self._int4_kv_enabled, and the conditional that sets/validates self._backend.
In `@flashinfer/utils.py`:
- Around line 853-865: INT4Tensor.unbind currently allows unbinding along the
packed (last) axis which breaks invariants in INT4Tensor.__init__ because
data.select(dim, i) will yield tensors whose last dimension no longer equals
ceil(original_shape[-1]/2); modify INT4Tensor.unbind to validate and reject
attempts to unbind the packed axis (and normalize negative dim to positive) by
checking if the target dim corresponds to the packed innermost dimension
(compare dim against len(self.original_shape)-1) and raise a clear ValueError
mentioning INT4Tensor.unbind, the packed last axis, and original_shape; keep
existing behavior for all other dims and still construct children using
data.select and scale.select when allowed.
- Line 953: Replace the tuple concatenation used to build original_shape with
tuple unpacking: construct original_shape by putting unique_page_indices.numel()
as the first element and then unpacking the remainder of x.original_shape
starting at index 1 (i.e., use unpacking of x.original_shape[1:] rather than
(unique_page_indices.numel(),) + x.original_shape[1:]); update the assignment to
original_shape in the function/section that references unique_page_indices and
x.original_shape accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 367576cd-5d18-4b58-acb9-133655960159
📒 Files selected for processing (4)
flashinfer/decode.pyflashinfer/prefill.pyflashinfer/utils.pytests/attention/test_int4_paged_kv.py
✅ Files skipped from review due to trivial changes (1)
- tests/attention/test_int4_paged_kv.py
c3af1b0 to
e8da95d
Compare
|
addressed in latest commits. |
📌 Description
Builds on the int8 paged-KV work in #3100 to add int4 support.
Incremental review against the int8 branch:
lesj0610/flashinfer@codex/int8-paged-kv-main-v068...codex/int4-paged-kv-main-v068
torch.uint8is already used in some paths as an FP4 container, so a plain uint8 input creates a semantic conflict. An explicitINT4Tensorwrapper is used to keep the contract unambiguous. Storage is packed uint8 with grouped fp16 scales (group_size=32).The implementation goes through staged dequantization to fp16 before calling existing kernels. On Hopper, auto backend selection falls back to FA2 the same way as in #3100. The following are not included in this PR:
Because the upstream PR base must stay on
main, GitHub still shows the shared int8 commits in the main PR diff until #3100 lands. The compare link above shows the int4-only incremental delta.Tested on Ampere (A100) and Hopper (H100):
51 tests passed on both architectures.
🔍 Related Issues
🚀 Pull Request Checklist
✅ Pre-commit Checks
pip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
Depends on #3100.
Retargeted against
mainafter thev0.6.8release branch cut.