Skip to content

INT4 per-token-head KV cache + kv_dequant dispatch scaffold#39668

Draft
lesj0610 wants to merge 3 commits intovllm-project:mainfrom
lesj0610:lesj/kv-dequant-dispatch-rfc
Draft

INT4 per-token-head KV cache + kv_dequant dispatch scaffold#39668
lesj0610 wants to merge 3 commits intovllm-project:mainfrom
lesj0610:lesj/kv-dequant-dispatch-rfc

Conversation

@lesj0610
Copy link
Copy Markdown
Contributor

@lesj0610 lesj0610 commented Apr 13, 2026

I was working on INT4 KV cache independently in my fork when JartX submitted #39074. After I read mgoin's comment on that PR about LOC growth and structural separation, I decided to fix the structure first before submitting. So this is not just my old INT4 code submitted again. Rebuilt from latest main, kv_dequant/ scaffold first, then INT4 on top of that.

What this PR does

This PR has two main changes.

  1. It adds kv_dequant/ and moves per-token-head KV dequant dispatch there. INT8 / FP8 behavior should stay the same, but the logic is not mixed into the Triton attention core anymore. On this branch triton_unified_attention.py is 1254 lines, compared to 1268 on origin/main.

  2. It adds int4_per_token_head on top of that scaffold. The INT4 path uses Hadamard rotation, a symmetric Gaussian codebook, packed uint8 storage with 2 channels per byte, and fp16 grouped scales with 32 channels per group. For CUDA, hadacore_transform is used as the fast path. For non-CUDA, a simpler dense Hadamard fallback is used.

The packed-block approach from #36893 and #38378 is also followed here. So the scale bytes are included in the page / block sizing path now, not handled as a separate side buffer.

Difference from #39074

I think reviewers will ask this first, so I want to explain it here.

This was built independently. The math is different. #39074 uses an asymmetric [min, max] -> [0..15] mapping with zero-point packed into the scale mantissa. Here a symmetric Gaussian codebook is used, because after Hadamard I think it gives better reconstruction quality for the distribution I expect in that case. For structure, this branch moves the mode-specific dispatch into kv_dequant/, while #39074 keeps more of that logic inside the Triton attention core. For Hadamard fallback, #39074 uses 3-tier (hadacore -> Triton MMA -> butterfly), and this branch is 2-tier (hadacore -> torch.matmul) because I did not want to add another new Triton kernel in the same PR.

Validation

What I ran on my side:

  • tests/quantization/test_per_token_kv_cache.py -> 68 passed, 23 skipped
  • tests/models/quantization/test_per_token_kv_cache.py -> 2 passed, 1 skipped
  • tests/v1/attention/test_kv_dequant_hadamard.py -> 6 passed
  • tests/kernels/attention/test_attention_selector.py -> passed
  • tests/v1/core/test_kv_cache_utils.py -> passed
  • tests/v1/worker/test_gpu_model_runner.py -> passed

The broader Triton attention matrix was also checked enough to separate branch-specific behavior from the pre-existing SM86 fp8-output failures on origin/main. For the branch-relevant matrix I got 384 passed.

Benchmarks were run again after the helper fixes. Setup: 2x RTX 3090, TP=2, local OpenAI-compatible server, input 1024, output 128, 10 prompts.

Qwen3.5-27B

Metric INT8 INT4 Delta
req/s 0.33 0.31 -6.1%
tok/s 42.16 39.98 -5.2%
TTFT 6777.22 ms 6985.91 ms +3.1%
TPOT 183.83 ms 195.11 ms +6.1%

Qwen3.5-35B-A3B

Metric INT8 INT4 Delta
req/s 0.43 0.42 -2.3%
tok/s 55.42 53.28 -3.9%
TTFT 2368.72 ms 2429.92 ms +2.6%
TPOT 161.46 ms 168.22 ms +4.2%

Known limits

  • flashinfer INT4 support is scaffold only in this PR. The actual backend connection will be done later.
  • I left _ensure_scale_caches() divisibility assert out for now because there is no real misalignment reproducer yet.
  • The SM86 fp8-output failures in the broader Triton attention matrix already exist on origin/main; they are not introduced by this branch.
  • For the Qwen hybrid benchmark path, vLLM still logs the current upstream behavior that disables calculate_kv_scales and falls back to scale 1.0.
  • Also planning to try contributing INT4 support to FlashInfer directly. Not sure if they will accept it, but I think it's worth trying.

Signed-off-by: lesj0610 <lesj0610@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 13, 2026

Documentation preview: https://vllm--39668.org.readthedocs.build/en/39668/

@mergify mergify Bot added documentation Improvements or additions to documentation nvidia v1 labels Apr 13, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements int4_per_token_head KV cache quantization for the Triton attention backend. Key additions include a quantization kernel utilizing a Gaussian-friendly codebook, Hadamard transform integration for improved accuracy, and support for per-group scales in the unified attention kernel. The PR also updates the KV cache management logic to support padded attention pages, enabling the unification of layers with heterogeneous head sizes. Review feedback highlights critical improvement opportunities regarding memory alignment for inline scales and the need for explicit stride handling in kernels to ensure robustness against cache layout changes.

Comment on lines +332 to +337
packed_head_size = kv_quant_mode.packed_head_size(head_size)
scale_pad = (
get_per_token_head_scale_count(head_size, kv_quant_mode)
* get_dtype_size(scale_dtype)
// get_dtype_size(cache_dtype)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The packed_head_size should be padded to ensure that the subsequent scales are aligned to their data type size (e.g., 2 bytes for float16). If packed_head_size is odd and scale_dtype is float16, the scales will start at an unaligned byte offset, which can cause issues when creating typed views or during GPU memory access. This is particularly important for head sizes that are not multiples of 4 (since packed_head_size = head_size // 2).

            packed_head_size = kv_quant_mode.packed_head_size(head_size)
            # Ensure scales are aligned to scale_dtype size (e.g. 2 bytes for fp16)
            alignment = get_dtype_size(scale_dtype) // get_dtype_size(cache_dtype)
            packed_head_size = (packed_head_size + alignment - 1) // alignment * alignment
            scale_pad = (
                get_per_token_head_scale_count(head_size, kv_quant_mode)
                * get_dtype_size(scale_dtype)
                // get_dtype_size(cache_dtype)
            )

Copy link
Copy Markdown
Contributor Author

@lesj0610 lesj0610 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Packed INT4 payload is now aligned before the inline fp16 scales start. This is padded in the cache shape path, so the scale region begins on the right boundary.

v_base_scale = kv_half_stride * dtype_sz // scale_dtype_sz
slot_scale = slot_stride * dtype_sz // scale_dtype_sz
head_scale = head_stride * dtype_sz // scale_dtype_sz
scale_off = hs * hs_stride * dtype_sz // scale_dtype_sz
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As noted in the PR description, there is a potential misalignment issue here if hs * hs_stride * dtype_sz is not a multiple of scale_dtype_sz. This will happen if the packed head size is odd and the scale type is 2 bytes (e.g., float16). You should add an assertion here to catch this, or ensure that hs is always padded to the required alignment in get_kv_cache_shape.

Copy link
Copy Markdown
Contributor Author

@lesj0610 lesj0610 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the same issue on the scale carve-out side. After padding the packed payload, this offset now stays divisible by the scale dtype size. An explicit check was added here too.

Comment on lines +404 to +411
key_cache_ptr
+ blk * stride_kc_blk
+ slot_in_blk * stride_kc_slot
+ head * stride_kc_head
+ pair_offs,
k_packed,
mask=k_pair_mask,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel assumes that the last dimension of the key cache (the head dimension) is contiguous (stride 1) by adding pair_offs directly to the pointer. While this is a common layout in vLLM, it should be explicitly asserted or the stride should be passed to the kernel to avoid silent data corruption if the cache layout changes (e.g., due to a different permutation in attn_utils.py).

Copy link
Copy Markdown
Contributor Author

@lesj0610 lesj0610 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current layout keeps the last head dimension contiguous, but that assumption was implicit. There is an explicit invariant check here now so this does not silently break if the cache layout changes later.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3d270db3d8

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
)
logical_elems_per_page = logical_numel // num_blocks
assert logical_elems_per_page <= elems_per_page
contiguous_strides = torch.empty(kv_cache_shape).stride()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid materializing huge tensors just to read strides

In the padded-page branch, torch.empty(kv_cache_shape).stride() allocates a full dense tensor only to obtain contiguous strides. When num_blocks is large (realistic for production KV caches), this temporary can consume gigabytes and fail initialization with OOM even though no real data tensor is needed. Please compute the contiguous stride tuple arithmetically from kv_cache_shape instead of allocating a tensor (the same pattern also appears in vllm/v1/worker/gpu/attn_utils.py).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

@lesj0610 lesj0610 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.empty(...).stride() was too heavy for large KV cache shapes. This now uses arithmetic contiguous stride computation instead, and the matching padded-page path was cleaned up the same way.

v_base_scale = kv_half_stride * dtype_sz // scale_dtype_sz
slot_scale = slot_stride * dtype_sz // scale_dtype_sz
head_scale = head_stride * dtype_sz // scale_dtype_sz
scale_off = hs * hs_stride * dtype_sz // scale_dtype_sz
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Enforce INT4 head-size alignment for fp16 scale views

INT4 mode currently accepts any even head_size, but when head_size % 4 == 2 (for example 66), packed data uses an odd byte count per head, and converting byte strides/offsets to fp16 element units via // scale_dtype_sz truncates half-byte offsets. That makes _k_scale_cache/_v_scale_cache point to misaligned bytes and can corrupt KV data/scales during cache updates. Please reject these head sizes (or add extra padding so offsets are divisible by the scale dtype size) before creating scale views.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

@lesj0610 lesj0610 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the same issue on the scale carve-out side. After padding the packed payload, this offset now stays divisible by the scale dtype size. An explicit check was added here too.

Signed-off-by: lesj0610 <lesj0610@gmail.com>
Signed-off-by: lesj0610 <lesj0610@gmail.com>
@lesj0610
Copy link
Copy Markdown
Contributor Author

I submitted the related FlashInfer work upstream.

Scope of this PR is unchanged — still the kv_dequant/int4 scaffold on vLLM side only. FlashInfer backend hookup will follow separately after those PRs settle.

Copy link
Copy Markdown
Contributor Author

Moving this to draft for now so the Triton/ROCm low-bit base can land first; I plan to bring FlashInfer back as follow-up once the seam is in place.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 15, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lesj0610.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant