Skip to content

Add PER_TOKEN_HEAD FP8 quantization and P-scale for mha_batch_prefill#3418

Open
msaffari-amd wants to merge 4 commits into
mainfrom
AITERKER-112
Open

Add PER_TOKEN_HEAD FP8 quantization and P-scale for mha_batch_prefill#3418
msaffari-amd wants to merge 4 commits into
mainfrom
AITERKER-112

Conversation

@msaffari-amd
Copy link
Copy Markdown

Adds a FP8 quantization path (PER_TOKEN_HEAD) and optional per-query-head P-scale to mha_batch_prefill, with corresponding CK kernel support.
ROCm/rocm-libraries#7883

This PR should be merged after CK PR is merged and CK submodule is updated.

Motivation

Existing FP8 quant modes (PERTENSOR, KV_BLOCKSCALE) applies descaling that doesn't capture per-token or per-head variance in activation magnitudes. PER_TOKEN_HEAD enables descaling for Q and K at per-token-per-head granularity.

Technical Details

Quantization scheme
Tensor Descale granularity Shape
Q per-token, per-head [total_q, nhead_q]
K per-token, per-head (paged) [num_total_pages, page_block_size, nhead_k]
V per-head [nhead_k]

P-scale
Optional per-q-head softmax P-scale [num_head_q] fp32. The kernel folds log2(p_scale) into the exp2 row-max shift — the factor appears in both P and rowsum l, cancelling in O = sum(P·V) / l without needing a separate output fixup.

Test Plan

Test and benchmarking scripts have been added.

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3418 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for a new FP8 quantization mode (PER_TOKEN_HEAD) and an optional per-query-head softmax P-scale for mha_batch_prefill, spanning the Python API, pybind interface, CK argument plumbing, and accompanying tests/bench tooling.

Changes:

  • Extend mha_batch_prefill APIs (Python + pybind + C++ CK interface) with PER_TOKEN_HEAD descale tensors and optional p_scale / p_scale_inv.
  • Add PER_TOKEN_HEAD quantization helpers and a new PER_TOKEN_HEAD test path in op_tests/test_batch_prefill.py.
  • Add a standalone benchmark script for PER_TOKEN_HEAD (op_tests/bench_per_token_head.py).

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_batch_prefill.py Adds PER_TOKEN_HEAD quantization utilities and test entrypoints; adds skip_reference to enable benchmark-style runs.
op_tests/bench_per_token_head.py New benchmark driver for PER_TOKEN_HEAD batch prefill with optional verification.
csrc/py_itfs_ck/mha_batch_prefill_kernels.cu Plumbs PER_TOKEN_HEAD descales + optional P-scale into CK arg setup and quant-mode validation.
csrc/include/torch/mha_batch_prefill.h Extends the C++ torch interface signature to include PER_TOKEN_HEAD args and P-scale.
csrc/include/rocm_ops.hpp Updates pybind argument list to expose new PER_TOKEN_HEAD and P-scale parameters to Python.
aiter/ops/mha.py Extends the Python wrapper surface area for PER_TOKEN_HEAD descales and optional P-scale forwarding.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/mha.py
Comment on lines +2790 to 2794
# PER_TOKEN_HEAD mode descales
q_descale_per_token: Optional[torch.Tensor] = None,
k_descale_per_token: Optional[torch.Tensor] = None,
v_descale_per_head: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
Comment on lines 557 to +560
quant_scale_enum qscale_type;
if(kv_block_descale.has_value())
if(q_descale_per_token.has_value() || k_descale_per_token.has_value() ||
v_descale_per_head.has_value())
{
Comment on lines +2621 to +2625
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("num_qo_heads,num_kv_heads", [(32, 8), (16, 16)])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("qo_len,kv_len", [(128, 1024), (512, 2048), (1024, 4096)])
@pytest.mark.parametrize("causal", [False, True])
msaffari-amd and others added 2 commits May 29, 2026 15:33
Plumbs the new VEC_K_COL_V_LAYOUT through aiter so vLLM/SGLang can call
mha_batch_prefill with the same KV cache layout produced by aiter's
reshape_and_cache_kernel and consumed by the decode paged-attention
kernel: K is 5D vectorized
[NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] and V
is 4D ColumnMajor [NumBlocks, NumHeads, HeadDim, PageSize]. This avoids
an intermediate reshape between decode and prefill.

Changes:
- csrc/include/mha_fwd.h: mha_batch_prefill_traits now accepts an
  optional is_v_rowmajor (default true) and forwards it to the
  underlying fmha_batch_prefill_traits, replacing the previously
  hard-coded true.
- csrc/cpp_itfs/mha_fwd_batch_prefill.cu: forward args.is_v_rowmajor
  into the traits.
- csrc/py_itfs_ck/mha_batch_prefill_kernels.cu: add the
  K.dim()==5 && V.dim()==4 branch that infers VEC_K_COL_V_LAYOUT,
  validates V shape [Pages, NumHeads, HeadDim, PageSize] and strides
  (innermost PageSize contiguous, HeadDim stride = page_block_size),
  sets args.is_v_rowmajor=false, args.stride_v=1 (per the V offset
  transform convention), and routes nhead_stride_v through v.stride(1)
  alongside the existing vectorized branch. Adds a symmetric
  CHECK_SHAPE block.
- aiter/ops/mha.py: relax the Python pre-flight to accept 5D K + 4D V
  (matches the kernel-side shape/stride checks).
- op_tests/test_batch_prefill.py: add vec_k_col_v_kv_cache() helper,
  extend apply_kv_layout(), wire run_batch_prefill_per_token_head() to
  build V in the new shape for kvcache_layout="vec_k_col_v", and add a
  new pytest test_batch_prefill_per_token_head_vec_k_col_v_pytest
  parametrized across batch/heads/qo_len/kv_len/causal/soft_cap/page.
- op_tests/bench_per_token_head.py: document BENCH_KV_LAYOUT=vec_k_col_v
  (env-var routing already worked).
- 3rdparty/composable_kernel: bump submodule to pick up
  VEC_K_COL_V_LAYOUT support in the CK FMHA batch_prefill kernel and
  codegen.

Co-authored-by: Cursor <cursoragent@cursor.com>
ref_masked_attention materialized the full [H, Q, K] fp32 attention matrix
in one shot. For the bench_per_token_head config (nhq=8, hd=128) that peak
allocation is H * seqlen_q * seqlen_k * 4 bytes, i.e. ~137 GB at Q=K=65536
and ~550 GB at Q=K=131072. Verification of those long-seq configurations
OOMed inside the reference even though the kernel itself ran fine:

    File "op_tests/test_batch_prefill.py", line 322, in ref_masked_attention
        attn_weights = scale * torch.einsum("qhd,khd->hqk",
                                            query.float(), key.float())
    torch.OutOfMemoryError: HIP out of memory. Tried to allocate 94.88 GiB.

This is purely a test-harness issue (kernel + wrapper untouched).

Changes:
- op_tests/test_batch_prefill.py:
  - Rewrite ref_masked_attention to chunk along the Q dimension. fp32 K/V
    are materialized once and reused; each iteration only allocates an
    [H, Q_chunk, K] fp32 scratch. At H=8, K=131072, Q_chunk=1024 that's
    ~4 GB peak, well within the 192 GB MI308X budget.
  - Chunk size is parameterizable via BENCH_REF_Q_CHUNK env var (default
    1024) and an internal q_chunk_size kwarg used by the equivalence
    test. q_chunk_size<=0 disables chunking (single-shot fallback).
  - Preserve every existing semantic: scaling, soft_cap, causal/window
    masks, fully-masked-row zeroing, and return_lse all match the
    pre-chunking path bit-for-bit in fp32. Mask construction uses global
    Q indices (via the new _local_mask_rows helper) so chunk boundaries
    do not shift which tokens see which mask state.
  - Keep the canonical pre-chunking implementation as a private helper
    (_ref_masked_attention_unchunked) so the equivalence test has a
    semantic anchor that does not silently track future changes to the
    chunked path.
  - Add test_ref_masked_attention_chunking_equivalence: at qlen=klen=4096,
    nhq=8, hd=128, parametrized over fp32 and bf16 inputs, causal on/off,
    soft_cap on/off, return_lse on/off, and three q_chunk sizes
    (including 1000, which does not evenly divide seqlen_q so the tail
    partial chunk is exercised). Tolerances: 1e-5 for fp32 inputs and
    fp32 LSE (the cuBLAS shape-dependent reduction order shifts last-bit
    rounding), 1e-3 for bf16 output (1 bf16 ULP at output magnitude near
    0.1, set by the final out.to(query) cast).

Validation:
- BENCH_VERIFY=1 BENCH_KV_LAYOUT=vec_k_col_v BENCH_PAGE_SIZE=1024
  python op_tests/bench_per_token_head.py 1024 / 16384 / 65536 / 131072
  all VERIFY PASSED on MI308X. 65536 was the previously OOMing config;
  131072 is twice that and now also verifies cleanly.
- pytest op_tests/test_batch_prefill.py::test_ref_masked_attention_chunking_equivalence
  passes all 48 parametrized cases.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants