Add PER_TOKEN_HEAD FP8 quant and P-scale to batch_prefill#7883
Open
msaffari-amd wants to merge 2 commits into
Open
Add PER_TOKEN_HEAD FP8 quant and P-scale to batch_prefill#7883msaffari-amd wants to merge 2 commits into
msaffari-amd wants to merge 2 commits into
Conversation
Add a new FP8 quantization scheme (PER_TOKEN_HEAD, enum value 5) for the
batch_prefill FMHA kernel. Unlike PERTENSOR (single scale for all of Q/K/V)
or KV_BLOCKSCALE (per-page K/V scales), PER_TOKEN_HEAD applies fine-grained
descales:
- Q descale: per-token, per-head [total_q, nhead_q]
- K descale: per-token, per-head [num_total_pages, page_block_size, nhead_k]
- V descale: per-head [nhead_k]
The dequantization of the QK dot product is staged through LDS to avoid
inflating the inner-loop instruction footprint. Cross-page tiles
(page_block_size < kN0) are supported via per-column physical page lookup,
unlike KV_BLOCKSCALE which requires page_block_size >= kN0.
Additionally, an optional per-q-head P-scale [num_head_q] is supported.
The kernel folds log2(p_scale) into the exp2 row-max shift, so the scale
factor appears in both P and the rowsum l, cancelling in O = sum(P*V) / l
with no separate V-descale fixup needed.
Also adds page_size=64 to the codegen page size list, and includes SRD
same-page-skip optimizations for K/V window rebasing.
Changes:
- block_attention_quant_scale_enum.hpp: PER_TOKEN_HEAD = 5
- quant.hpp: enum, serialize ("pth"), decode
- cpp_symbol_map.py: codegen symbol mappings
- fmha_batch_prefill.py: page_size=64, per_token_head qscale, filter update
- fmha_fwd.hpp: args struct (stride fields, p_scale_ptr), kargs forwarding
- fmha_batch_prefill_kernel.hpp: kargs struct, MakeKargs, get_scale_s,
pipeline dispatch
- block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: LDS-staged dequant,
p_scale_log2 exp2-shift fold, cross-page support, SRD same-page skip,
PER_TOKEN_HEAD convenience overload
1 task
New decode-aligned KV cache layout for FP8 PER_TOKEN_HEAD batch_prefill: 5D vectorized K + 4D ColumnMajor V [NumBlocks, NumHeads, HeadDim, PageSize]. Matches the layout produced by reshape_and_cache and consumed by the decode paged-attention kernel, so prefill can ingest the live KV cache without an intermediate reshape. - block_attention_kvcache_layout_enum.hpp: add VEC_K_COL_V_LAYOUT (= 2). - fmha_batch_prefill_kernel.hpp: route K dram through the vectorized branch for VEC_K_COL_V; add a new V dram branch building (Pages, HeadDim, PageSize) with stride (batch_stride_v, page_block_size, 1) and merging to logical (D, TotalSeqK). - block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: keep kAlignmentV = kMaxVecLoad for VEC_K_COL_V despite kPadSeqLenK=true (full-page invariant keeps vec loads safe along the contiguous PageSize dim). - block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp: new kUseVectorizedVPolicy<Problem>() predicate routes VEC_K_COL_V through the same V tile dist / LDS layout / SmemKPack / alignment as VECTORIZED_LAYOUT. - block_fmha_pipeline_problem.hpp: relax IsVLayoutRowMajor static_assert to accept ColumnMajor V for VEC_K_COL_V; introduce kIsKVectorized predicate so the page_size=1 + K-vectorized rejection covers the new layout. - tile_fmha_traits.hpp: extend the supported-layouts static_assert. - fmha_fwd.hpp: add `bool is_v_rowmajor = true` to fmha_batch_prefill_args so the wrapper can flip it for VEC_K_COL_V. - codegen/ops/fmha_batch_prefill.py: add SUPPORTED_KV_MEMORY_LAYOUT_FP8_PTH_EXTRA map entry and a gated emission loop for fp8bf16 PER_TOKEN_HEAD with vlayout="col" + kv_memory_layout="vec_k_col_v" across both lookup tables. Relax receipt 200 to allow vlayout="col" only when kv_memory_layout == "vec_k_col_v".
20e4392 to
475d8ae
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a new FP8 quantization scheme (PER_TOKEN_HEAD) to the CK batch_prefill FMHA kernel, along with optional per-query-head P-scale support.
Motivation
Existing FP8 quant modes (PERTENSOR, KV_BLOCKSCALE) applies descaling that doesn't capture per-token or per-head variance in activation magnitudes. PER_TOKEN_HEAD enables descaling for Q and K at per-token-per-head granularity.
Technical Details
Quantization scheme
Tensor Descale granularity Shape
Q per-token, per-head [total_q, nhead_q]
K per-token, per-head (paged) [num_total_pages, page_block_size, nhead_k]
V per-head [nhead_k]
The QK dequantization (s_acc[i,j] *= q_descale[i] * k_descale[j]) is staged through LDS to minimize inner-loop register pressure matching the approach used in the fmha_fwd pipeline.
P-scale
An optional per-q-head P-scale [num_head_q] is supported. log2(p_scale) is folded into the exp2 row-max shift, so the scale factor appears in both P and the rowsum l, cancelling in O = sum(P·V) / l without needing a separate fixup.
Cross-page support
Unlike KV_BLOCKSCALE (which requires page_block_size >= kN0), PER_TOKEN_HEAD supports cross-page tiles by precomputing per-column physical page indices. This enables page_size=64 (newly added to the codegen list).