Add int4 paged KV support to main paths#3049
Add int4 paged KV support to main paths#3049lesj0610 wants to merge 6 commits intoflashinfer-ai:release-v0.6.7from
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for INT4 and INT8 paged KV caches across the decode and prefill modules. Key changes include the implementation of int4_quantize and int4_dequantize functions, the addition of an INT4Tensor wrapper class, and the extension of paged KV cache management to handle these quantized formats. The PR also adds int8_t vector type support in the CUDA backend and ensures that quantized KV caches correctly fall back to the fa2 implementation. Review feedback suggests using itemsize for more robust 8-bit type detection when applying output scales and adjusting the INT4 quantization scaling factor to better utilize the available bit range.
| if out.dtype in ( | ||
| torch.int8, | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| ): | ||
| out = (out.to(float) * scale_v).to(out.dtype) | ||
| else: | ||
| out *= scale_v |
There was a problem hiding this comment.
This logic for applying scale_v is slightly different from the one in flashinfer/decode.py:598, which uses if out.itemsize == 1:. The itemsize check is more robust as it covers all 8-bit types (like uint8) without needing to list them explicitly. Using itemsize would also make the implementation consistent across the codebase.
if out.itemsize == 1:
out = (out.to(float) * scale_v).to(out.dtype)
else:
out *= scale_vThere was a problem hiding this comment.
already updated in the latest branch.
| num_groups = hidden_dim // group_size | ||
| x_grouped = x_fp32.reshape(*x.shape[:-1], num_groups, group_size) | ||
| amax = x_grouped.abs().amax(dim=-1, keepdim=True) | ||
| scale = torch.where(amax > 0, amax / 7.0, torch.ones_like(amax)) |
There was a problem hiding this comment.
The scaling factor amax / 7.0 does not seem to fully utilize the int4 quantization range of [-8, 7]. With this scaling, the values of x_grouped / scale will be within [-7.0, 7.0]. After torch.round, the quantized values will be in [-7, 7], meaning the value -8 is never used.
To better utilize the available quantization range, consider using 8.0 as the divisor. This would map the input range [-amax, amax] to [-8, 8], and after clamping to [-8, 7], it would make use of the full available range.
| scale = torch.where(amax > 0, amax / 7.0, torch.ones_like(amax)) | |
| scale = torch.where(amax > 0, amax / 8.0, torch.ones_like(amax)) |
There was a problem hiding this comment.
not changing this. amax / 7.0 is intentional for symmetric ±amax reconstruction.
|
/bot run |
3ee92f4 to
502f158
Compare
|
Keeping this PR as the |
📌 Description
Builds on the int8 paged-KV work in #3048 to add int4 support.
torch.uint8is already used in some paths as an FP4 container, so a plain uint8 input creates a semantic conflict. An explicitINT4Tensorwrapper is used to keep the contract unambiguous. Storage is packed uint8 with grouped fp16 scales (group_size=32).The implementation goes through staged dequantization to fp16 before calling existing kernels. On Hopper, auto backend selection falls back to FA2 the same way as in #3048. The following are not included in this PR:
Until #3048 is merged, GitHub will also show the int8 commits in this diff because this branch is stacked on top of that work.
Tested on Ampere (A100) and Hopper (H100):
51 tests passed on both architectures.
🔍 Related Issues
🚀 Pull Request Checklist
✅ Pre-commit Checks
pip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
Depends on #3048.