Add PER_TOKEN_HEAD FP8 quantization and P-scale for mha_batch_prefill#3418
Open
msaffari-amd wants to merge 4 commits into
Open
Add PER_TOKEN_HEAD FP8 quantization and P-scale for mha_batch_prefill#3418msaffari-amd wants to merge 4 commits into
msaffari-amd wants to merge 4 commits into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds support for a new FP8 quantization mode (PER_TOKEN_HEAD) and an optional per-query-head softmax P-scale for mha_batch_prefill, spanning the Python API, pybind interface, CK argument plumbing, and accompanying tests/bench tooling.
Changes:
- Extend
mha_batch_prefillAPIs (Python + pybind + C++ CK interface) with PER_TOKEN_HEAD descale tensors and optionalp_scale/p_scale_inv. - Add PER_TOKEN_HEAD quantization helpers and a new PER_TOKEN_HEAD test path in
op_tests/test_batch_prefill.py. - Add a standalone benchmark script for PER_TOKEN_HEAD (
op_tests/bench_per_token_head.py).
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/test_batch_prefill.py |
Adds PER_TOKEN_HEAD quantization utilities and test entrypoints; adds skip_reference to enable benchmark-style runs. |
op_tests/bench_per_token_head.py |
New benchmark driver for PER_TOKEN_HEAD batch prefill with optional verification. |
csrc/py_itfs_ck/mha_batch_prefill_kernels.cu |
Plumbs PER_TOKEN_HEAD descales + optional P-scale into CK arg setup and quant-mode validation. |
csrc/include/torch/mha_batch_prefill.h |
Extends the C++ torch interface signature to include PER_TOKEN_HEAD args and P-scale. |
csrc/include/rocm_ops.hpp |
Updates pybind argument list to expose new PER_TOKEN_HEAD and P-scale parameters to Python. |
aiter/ops/mha.py |
Extends the Python wrapper surface area for PER_TOKEN_HEAD descales and optional P-scale forwarding. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+2790
to
2794
| # PER_TOKEN_HEAD mode descales | ||
| q_descale_per_token: Optional[torch.Tensor] = None, | ||
| k_descale_per_token: Optional[torch.Tensor] = None, | ||
| v_descale_per_head: Optional[torch.Tensor] = None, | ||
| sink_ptr: Optional[Tensor] = None, |
Comment on lines
557
to
+560
| quant_scale_enum qscale_type; | ||
| if(kv_block_descale.has_value()) | ||
| if(q_descale_per_token.has_value() || k_descale_per_token.has_value() || | ||
| v_descale_per_head.has_value()) | ||
| { |
Comment on lines
+2621
to
+2625
| @pytest.mark.parametrize("batch_size", [1, 4]) | ||
| @pytest.mark.parametrize("num_qo_heads,num_kv_heads", [(32, 8), (16, 16)]) | ||
| @pytest.mark.parametrize("head_dim", [128]) | ||
| @pytest.mark.parametrize("qo_len,kv_len", [(128, 1024), (512, 2048), (1024, 4096)]) | ||
| @pytest.mark.parametrize("causal", [False, True]) |
Plumbs the new VEC_K_COL_V_LAYOUT through aiter so vLLM/SGLang can call mha_batch_prefill with the same KV cache layout produced by aiter's reshape_and_cache_kernel and consumed by the decode paged-attention kernel: K is 5D vectorized [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] and V is 4D ColumnMajor [NumBlocks, NumHeads, HeadDim, PageSize]. This avoids an intermediate reshape between decode and prefill. Changes: - csrc/include/mha_fwd.h: mha_batch_prefill_traits now accepts an optional is_v_rowmajor (default true) and forwards it to the underlying fmha_batch_prefill_traits, replacing the previously hard-coded true. - csrc/cpp_itfs/mha_fwd_batch_prefill.cu: forward args.is_v_rowmajor into the traits. - csrc/py_itfs_ck/mha_batch_prefill_kernels.cu: add the K.dim()==5 && V.dim()==4 branch that infers VEC_K_COL_V_LAYOUT, validates V shape [Pages, NumHeads, HeadDim, PageSize] and strides (innermost PageSize contiguous, HeadDim stride = page_block_size), sets args.is_v_rowmajor=false, args.stride_v=1 (per the V offset transform convention), and routes nhead_stride_v through v.stride(1) alongside the existing vectorized branch. Adds a symmetric CHECK_SHAPE block. - aiter/ops/mha.py: relax the Python pre-flight to accept 5D K + 4D V (matches the kernel-side shape/stride checks). - op_tests/test_batch_prefill.py: add vec_k_col_v_kv_cache() helper, extend apply_kv_layout(), wire run_batch_prefill_per_token_head() to build V in the new shape for kvcache_layout="vec_k_col_v", and add a new pytest test_batch_prefill_per_token_head_vec_k_col_v_pytest parametrized across batch/heads/qo_len/kv_len/causal/soft_cap/page. - op_tests/bench_per_token_head.py: document BENCH_KV_LAYOUT=vec_k_col_v (env-var routing already worked). - 3rdparty/composable_kernel: bump submodule to pick up VEC_K_COL_V_LAYOUT support in the CK FMHA batch_prefill kernel and codegen. Co-authored-by: Cursor <cursoragent@cursor.com>
ref_masked_attention materialized the full [H, Q, K] fp32 attention matrix
in one shot. For the bench_per_token_head config (nhq=8, hd=128) that peak
allocation is H * seqlen_q * seqlen_k * 4 bytes, i.e. ~137 GB at Q=K=65536
and ~550 GB at Q=K=131072. Verification of those long-seq configurations
OOMed inside the reference even though the kernel itself ran fine:
File "op_tests/test_batch_prefill.py", line 322, in ref_masked_attention
attn_weights = scale * torch.einsum("qhd,khd->hqk",
query.float(), key.float())
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 94.88 GiB.
This is purely a test-harness issue (kernel + wrapper untouched).
Changes:
- op_tests/test_batch_prefill.py:
- Rewrite ref_masked_attention to chunk along the Q dimension. fp32 K/V
are materialized once and reused; each iteration only allocates an
[H, Q_chunk, K] fp32 scratch. At H=8, K=131072, Q_chunk=1024 that's
~4 GB peak, well within the 192 GB MI308X budget.
- Chunk size is parameterizable via BENCH_REF_Q_CHUNK env var (default
1024) and an internal q_chunk_size kwarg used by the equivalence
test. q_chunk_size<=0 disables chunking (single-shot fallback).
- Preserve every existing semantic: scaling, soft_cap, causal/window
masks, fully-masked-row zeroing, and return_lse all match the
pre-chunking path bit-for-bit in fp32. Mask construction uses global
Q indices (via the new _local_mask_rows helper) so chunk boundaries
do not shift which tokens see which mask state.
- Keep the canonical pre-chunking implementation as a private helper
(_ref_masked_attention_unchunked) so the equivalence test has a
semantic anchor that does not silently track future changes to the
chunked path.
- Add test_ref_masked_attention_chunking_equivalence: at qlen=klen=4096,
nhq=8, hd=128, parametrized over fp32 and bf16 inputs, causal on/off,
soft_cap on/off, return_lse on/off, and three q_chunk sizes
(including 1000, which does not evenly divide seqlen_q so the tail
partial chunk is exercised). Tolerances: 1e-5 for fp32 inputs and
fp32 LSE (the cuBLAS shape-dependent reduction order shifts last-bit
rounding), 1e-3 for bf16 output (1 bf16 ULP at output magnitude near
0.1, set by the final out.to(query) cast).
Validation:
- BENCH_VERIFY=1 BENCH_KV_LAYOUT=vec_k_col_v BENCH_PAGE_SIZE=1024
python op_tests/bench_per_token_head.py 1024 / 16384 / 65536 / 131072
all VERIFY PASSED on MI308X. 65536 was the previously OOMing config;
131072 is twice that and now also verifies cleanly.
- pytest op_tests/test_batch_prefill.py::test_ref_masked_attention_chunking_equivalence
passes all 48 parametrized cases.
Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a FP8 quantization path (PER_TOKEN_HEAD) and optional per-query-head P-scale to mha_batch_prefill, with corresponding CK kernel support.
ROCm/rocm-libraries#7883
This PR should be merged after CK PR is merged and CK submodule is updated.
Motivation
Existing FP8 quant modes (PERTENSOR, KV_BLOCKSCALE) applies descaling that doesn't capture per-token or per-head variance in activation magnitudes. PER_TOKEN_HEAD enables descaling for Q and K at per-token-per-head granularity.
Technical Details
Quantization scheme
Tensor Descale granularity Shape
Q per-token, per-head [total_q, nhead_q]
K per-token, per-head (paged) [num_total_pages, page_block_size, nhead_k]
V per-head [nhead_k]
P-scale
Optional per-q-head softmax P-scale [num_head_q] fp32. The kernel folds log2(p_scale) into the exp2 row-max shift — the factor appears in both P and rowsum l, cancelling in O = sum(P·V) / l without needing a separate output fixup.
Test Plan
Test and benchmarking scripts have been added.