INT4 per-token-head KV cache + kv_dequant dispatch scaffold#39668
INT4 per-token-head KV cache + kv_dequant dispatch scaffold#39668lesj0610 wants to merge 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: lesj0610 <lesj0610@gmail.com>
|
Documentation preview: https://vllm--39668.org.readthedocs.build/en/39668/ |
There was a problem hiding this comment.
Code Review
This pull request implements int4_per_token_head KV cache quantization for the Triton attention backend. Key additions include a quantization kernel utilizing a Gaussian-friendly codebook, Hadamard transform integration for improved accuracy, and support for per-group scales in the unified attention kernel. The PR also updates the KV cache management logic to support padded attention pages, enabling the unification of layers with heterogeneous head sizes. Review feedback highlights critical improvement opportunities regarding memory alignment for inline scales and the need for explicit stride handling in kernels to ensure robustness against cache layout changes.
| packed_head_size = kv_quant_mode.packed_head_size(head_size) | ||
| scale_pad = ( | ||
| get_per_token_head_scale_count(head_size, kv_quant_mode) | ||
| * get_dtype_size(scale_dtype) | ||
| // get_dtype_size(cache_dtype) | ||
| ) |
There was a problem hiding this comment.
The packed_head_size should be padded to ensure that the subsequent scales are aligned to their data type size (e.g., 2 bytes for float16). If packed_head_size is odd and scale_dtype is float16, the scales will start at an unaligned byte offset, which can cause issues when creating typed views or during GPU memory access. This is particularly important for head sizes that are not multiples of 4 (since packed_head_size = head_size // 2).
packed_head_size = kv_quant_mode.packed_head_size(head_size)
# Ensure scales are aligned to scale_dtype size (e.g. 2 bytes for fp16)
alignment = get_dtype_size(scale_dtype) // get_dtype_size(cache_dtype)
packed_head_size = (packed_head_size + alignment - 1) // alignment * alignment
scale_pad = (
get_per_token_head_scale_count(head_size, kv_quant_mode)
* get_dtype_size(scale_dtype)
// get_dtype_size(cache_dtype)
)There was a problem hiding this comment.
Packed INT4 payload is now aligned before the inline fp16 scales start. This is padded in the cache shape path, so the scale region begins on the right boundary.
| v_base_scale = kv_half_stride * dtype_sz // scale_dtype_sz | ||
| slot_scale = slot_stride * dtype_sz // scale_dtype_sz | ||
| head_scale = head_stride * dtype_sz // scale_dtype_sz | ||
| scale_off = hs * hs_stride * dtype_sz // scale_dtype_sz |
There was a problem hiding this comment.
As noted in the PR description, there is a potential misalignment issue here if hs * hs_stride * dtype_sz is not a multiple of scale_dtype_sz. This will happen if the packed head size is odd and the scale type is 2 bytes (e.g., float16). You should add an assertion here to catch this, or ensure that hs is always padded to the required alignment in get_kv_cache_shape.
There was a problem hiding this comment.
This was the same issue on the scale carve-out side. After padding the packed payload, this offset now stays divisible by the scale dtype size. An explicit check was added here too.
| key_cache_ptr | ||
| + blk * stride_kc_blk | ||
| + slot_in_blk * stride_kc_slot | ||
| + head * stride_kc_head | ||
| + pair_offs, | ||
| k_packed, | ||
| mask=k_pair_mask, | ||
| ) |
There was a problem hiding this comment.
The kernel assumes that the last dimension of the key cache (the head dimension) is contiguous (stride 1) by adding pair_offs directly to the pointer. While this is a common layout in vLLM, it should be explicitly asserted or the stride should be passed to the kernel to avoid silent data corruption if the cache layout changes (e.g., due to a different permutation in attn_utils.py).
There was a problem hiding this comment.
The current layout keeps the last head dimension contiguous, but that assumption was implicit. There is an explicit invariant check here now so this does not silently break if the cache layout changes later.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3d270db3d8
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| ) | ||
| logical_elems_per_page = logical_numel // num_blocks | ||
| assert logical_elems_per_page <= elems_per_page | ||
| contiguous_strides = torch.empty(kv_cache_shape).stride() |
There was a problem hiding this comment.
Avoid materializing huge tensors just to read strides
In the padded-page branch, torch.empty(kv_cache_shape).stride() allocates a full dense tensor only to obtain contiguous strides. When num_blocks is large (realistic for production KV caches), this temporary can consume gigabytes and fail initialization with OOM even though no real data tensor is needed. Please compute the contiguous stride tuple arithmetically from kv_cache_shape instead of allocating a tensor (the same pattern also appears in vllm/v1/worker/gpu/attn_utils.py).
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
torch.empty(...).stride() was too heavy for large KV cache shapes. This now uses arithmetic contiguous stride computation instead, and the matching padded-page path was cleaned up the same way.
| v_base_scale = kv_half_stride * dtype_sz // scale_dtype_sz | ||
| slot_scale = slot_stride * dtype_sz // scale_dtype_sz | ||
| head_scale = head_stride * dtype_sz // scale_dtype_sz | ||
| scale_off = hs * hs_stride * dtype_sz // scale_dtype_sz |
There was a problem hiding this comment.
Enforce INT4 head-size alignment for fp16 scale views
INT4 mode currently accepts any even head_size, but when head_size % 4 == 2 (for example 66), packed data uses an odd byte count per head, and converting byte strides/offsets to fp16 element units via // scale_dtype_sz truncates half-byte offsets. That makes _k_scale_cache/_v_scale_cache point to misaligned bytes and can corrupt KV data/scales during cache updates. Please reject these head sizes (or add extra padding so offsets are divisible by the scale dtype size) before creating scale views.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
This was the same issue on the scale carve-out side. After padding the packed payload, this offset now stays divisible by the scale dtype size. An explicit check was added here too.
Signed-off-by: lesj0610 <lesj0610@gmail.com>
Signed-off-by: lesj0610 <lesj0610@gmail.com>
|
I submitted the related FlashInfer work upstream.
Scope of this PR is unchanged — still the kv_dequant/int4 scaffold on vLLM side only. FlashInfer backend hookup will follow separately after those PRs settle. |
|
Moving this to draft for now so the Triton/ROCm low-bit base can land first; I plan to bring FlashInfer back as follow-up once the seam is in place. |
|
This pull request has merge conflicts that must be resolved before it can be |
I was working on INT4 KV cache independently in my fork when JartX submitted #39074. After I read mgoin's comment on that PR about LOC growth and structural separation, I decided to fix the structure first before submitting. So this is not just my old INT4 code submitted again. Rebuilt from latest
main,kv_dequant/scaffold first, then INT4 on top of that.What this PR does
This PR has two main changes.
It adds
kv_dequant/and moves per-token-head KV dequant dispatch there. INT8 / FP8 behavior should stay the same, but the logic is not mixed into the Triton attention core anymore. On this branchtriton_unified_attention.pyis1254lines, compared to1268onorigin/main.It adds
int4_per_token_headon top of that scaffold. The INT4 path uses Hadamard rotation, a symmetric Gaussian codebook, packed uint8 storage with 2 channels per byte, and fp16 grouped scales with 32 channels per group. For CUDA,hadacore_transformis used as the fast path. For non-CUDA, a simpler dense Hadamard fallback is used.The packed-block approach from #36893 and #38378 is also followed here. So the scale bytes are included in the page / block sizing path now, not handled as a separate side buffer.
Difference from #39074
I think reviewers will ask this first, so I want to explain it here.
This was built independently. The math is different. #39074 uses an asymmetric
[min, max] -> [0..15]mapping with zero-point packed into the scale mantissa. Here a symmetric Gaussian codebook is used, because after Hadamard I think it gives better reconstruction quality for the distribution I expect in that case. For structure, this branch moves the mode-specific dispatch intokv_dequant/, while #39074 keeps more of that logic inside the Triton attention core. For Hadamard fallback, #39074 uses 3-tier (hadacore -> Triton MMA -> butterfly), and this branch is 2-tier (hadacore -> torch.matmul) because I did not want to add another new Triton kernel in the same PR.Validation
What I ran on my side:
tests/quantization/test_per_token_kv_cache.py->68 passed, 23 skippedtests/models/quantization/test_per_token_kv_cache.py->2 passed, 1 skippedtests/v1/attention/test_kv_dequant_hadamard.py->6 passedtests/kernels/attention/test_attention_selector.py-> passedtests/v1/core/test_kv_cache_utils.py-> passedtests/v1/worker/test_gpu_model_runner.py-> passedThe broader Triton attention matrix was also checked enough to separate branch-specific behavior from the pre-existing SM86 fp8-output failures on
origin/main. For the branch-relevant matrix I got384 passed.Benchmarks were run again after the helper fixes. Setup: 2x RTX 3090, TP=2, local OpenAI-compatible server, input
1024, output128,10prompts.Qwen3.5-27B
Qwen3.5-35B-A3B
Known limits
flashinferINT4 support is scaffold only in this PR. The actual backend connection will be done later._ensure_scale_caches()divisibility assert out for now because there is no real misalignment reproducer yet.origin/main; they are not introduced by this branch.calculate_kv_scalesand falls back to scale1.0.