Skip to content

Add int4 paged KV support to main paths#3049

Draft
lesj0610 wants to merge 6 commits intoflashinfer-ai:release-v0.6.7from
lesj0610:codex/int4-paged-kv-main-path
Draft

Add int4 paged KV support to main paths#3049
lesj0610 wants to merge 6 commits intoflashinfer-ai:release-v0.6.7from
lesj0610:codex/int4-paged-kv-main-path

Conversation

@lesj0610
Copy link
Copy Markdown

📌 Description

Builds on the int8 paged-KV work in #3048 to add int4 support.

torch.uint8 is already used in some paths as an FP4 container, so a plain uint8 input creates a semantic conflict. An explicit INT4Tensor wrapper is used to keep the contract unambiguous. Storage is packed uint8 with grouped fp16 scales (group_size=32).

The implementation goes through staged dequantization to fp16 before calling existing kernels. On Hopper, auto backend selection falls back to FA2 the same way as in #3048. The following are not included in this PR:

  • CUDA graph: explicitly blocked, as the staging step requires temporary allocation
  • Native FA3, XQA, and TRTLLM-gen int4 paths

Until #3048 is merged, GitHub will also show the int8 commits in this diff because this branch is stacked on top of that work.

Tested on Ampere (A100) and Hopper (H100):

python -m pytest tests/attention/test_int4_paged_kv.py -v

51 tests passed on both architectures.

🔍 Related Issues

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Depends on #3048.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f3aa2506-2043-4e35-8fd3-86f75858d713

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for INT4 and INT8 paged KV caches across the decode and prefill modules. Key changes include the implementation of int4_quantize and int4_dequantize functions, the addition of an INT4Tensor wrapper class, and the extension of paged KV cache management to handle these quantized formats. The PR also adds int8_t vector type support in the CUDA backend and ensures that quantized KV caches correctly fall back to the fa2 implementation. Review feedback suggests using itemsize for more robust 8-bit type detection when applying output scales and adjusting the INT4 quantization scaling factor to better utilize the available bit range.

Comment thread flashinfer/prefill.py
Comment on lines +1374 to +1381
if out.dtype in (
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
):
out = (out.to(float) * scale_v).to(out.dtype)
else:
out *= scale_v
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for applying scale_v is slightly different from the one in flashinfer/decode.py:598, which uses if out.itemsize == 1:. The itemsize check is more robust as it covers all 8-bit types (like uint8) without needing to list them explicitly. Using itemsize would also make the implementation consistent across the codebase.

        if out.itemsize == 1:
            out = (out.to(float) * scale_v).to(out.dtype)
        else:
            out *= scale_v

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already updated in the latest branch.

num_groups = hidden_dim // group_size
x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size)
amax = x_grouped.abs().amax(dim=-1, keepdim=True)
scale = torch.where(amax > 0, amax / 7.0, torch.ones_like(amax))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scaling factor amax / 7.0 does not seem to fully utilize the int4 quantization range of [-8, 7]. With this scaling, the values of x_grouped / scale will be within [-7.0, 7.0]. After torch.round, the quantized values will be in [-7, 7], meaning the value -8 is never used.

To better utilize the available quantization range, consider using 8.0 as the divisor. This would map the input range [-amax, amax] to [-8, 8], and after clamping to [-8, 7], it would make use of the full available range.

Suggested change
scale = torch.where(amax > 0, amax / 7.0, torch.ones_like(amax))
scale = torch.where(amax > 0, amax / 8.0, torch.ones_like(amax))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not changing this. amax / 7.0 is intentional for symmetric ±amax reconstruction.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !542 has been created, and the CI pipeline #48432957 is currently running. I'll report back once the pipeline job completes.

@lesj0610 lesj0610 force-pushed the codex/int4-paged-kv-main-path branch from 3ee92f4 to 502f158 Compare April 14, 2026 04:18
@lesj0610 lesj0610 marked this pull request as draft April 17, 2026 07:16
@lesj0610
Copy link
Copy Markdown
Author

Keeping this PR as the release-v0.6.7 snapshot and moving active review to #3101 against main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants