Skip to content

Commit 20e4392

Browse files
Add PER_TOKEN_HEAD FP8 quant and P-scale to batch_prefill
Add a new FP8 quantization scheme (PER_TOKEN_HEAD, enum value 5) for the batch_prefill FMHA kernel. Unlike PERTENSOR (single scale for all of Q/K/V) or KV_BLOCKSCALE (per-page K/V scales), PER_TOKEN_HEAD applies fine-grained descales: - Q descale: per-token, per-head [total_q, nhead_q] - K descale: per-token, per-head [num_total_pages, page_block_size, nhead_k] - V descale: per-head [nhead_k] The dequantization of the QK dot product is staged through LDS to avoid inflating the inner-loop instruction footprint. Cross-page tiles (page_block_size < kN0) are supported via per-column physical page lookup, unlike KV_BLOCKSCALE which requires page_block_size >= kN0. Additionally, an optional per-q-head P-scale [num_head_q] is supported. The kernel folds log2(p_scale) into the exp2 row-max shift, so the scale factor appears in both P and the rowsum l, cancelling in O = sum(P*V) / l with no separate V-descale fixup needed. Also adds page_size=64 to the codegen page size list, and includes SRD same-page-skip optimizations for K/V window rebasing. Changes: - block_attention_quant_scale_enum.hpp: PER_TOKEN_HEAD = 5 - quant.hpp: enum, serialize ("pth"), decode - cpp_symbol_map.py: codegen symbol mappings - fmha_batch_prefill.py: page_size=64, per_token_head qscale, filter update - fmha_fwd.hpp: args struct (stride fields, p_scale_ptr), kargs forwarding - fmha_batch_prefill_kernel.hpp: kargs struct, MakeKargs, get_scale_s, pipeline dispatch - block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: LDS-staged dequant, p_scale_log2 exp2-shift fold, cross-page support, SRD same-page skip, PER_TOKEN_HEAD convenience overload Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent ca4930e commit 20e4392

7 files changed

Lines changed: 460 additions & 30 deletions

File tree

projects/composablekernel/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
8181
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
8282
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
8383
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
84+
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
8485
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
8586
}
8687

@@ -89,6 +90,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
8990
"pertensor": "quant_scale_enum::pertensor",
9091
"blockscale": "quant_scale_enum::blockscale",
9192
"kv_blockscale": "quant_scale_enum::kv_blockscale",
93+
"per_token_head": "quant_scale_enum::per_token_head",
9294
"mx": "quant_scale_enum::mx",
9395
}
9496

projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
5050

51-
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
51+
SUPPORTED_PAGE_SIZE = [1, 16, 64, 1024]
5252
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
5353
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
5454
KV_MEMORY_LAYOUT_ENUM_MAP = {
@@ -733,7 +733,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
733733
kv_lookup_table,
734734
) in itertools.product(
735735
["t", "f"],
736-
["pertensor", "kv_blockscale"],
736+
["pertensor", "kv_blockscale", "per_token_head"],
737737
get_mask_map(mask_impl).keys(),
738738
["no"],
739739
["t", "f"],
@@ -819,9 +819,11 @@ def get_fwd_blobs(
819819
for page_size in SUPPORTED_PAGE_SIZE:
820820
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
821821
continue
822-
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
823-
# This ensures all tokens in a main loop iteration belong to the same page
824-
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
822+
# kv_blockscale only supports page_size >= kN0; per_token_head can cross pages.
823+
if (
824+
pipeline.F_qscale == "kv_blockscale"
825+
and page_size < tile.F_bn0
826+
):
825827
continue
826828
k = FmhaFwdKernel(
827829
F_idx=0,

projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,17 @@ struct fmha_batch_prefill_args
673673
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
674674
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
675675
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
676+
677+
// PER_TOKEN_HEAD: q/k use per-token-per-head descales; v uses per-head descales.
678+
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
679+
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
680+
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
681+
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
682+
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
683+
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
684+
685+
// PER_TOKEN_HEAD optional per-q-head P scale [num_head_q] fp32.
686+
const void* p_scale_ptr = nullptr;
676687
};
677688

678689
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
@@ -1342,7 +1353,14 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
13421353
args.drop_seed_offset,
13431354
args.sink_ptr,
13441355
args.nblock_stride_kv_block_descale,
1345-
args.nhead_stride_kv_block_descale);
1356+
args.nhead_stride_kv_block_descale,
1357+
args.stride_q_descale_token,
1358+
args.nhead_stride_q_descale,
1359+
args.nblock_stride_k_descale_page,
1360+
args.stride_k_descale_token,
1361+
args.nhead_stride_k_descale,
1362+
args.nhead_stride_v_descale,
1363+
args.p_scale_ptr);
13461364
}
13471365
else
13481366
{ // create batch mode kernel arguments
@@ -1397,7 +1415,14 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
13971415
args.drop_seed_offset,
13981416
args.sink_ptr,
13991417
args.nblock_stride_kv_block_descale,
1400-
args.nhead_stride_kv_block_descale);
1418+
args.nhead_stride_kv_block_descale,
1419+
args.stride_q_descale_token,
1420+
args.nhead_stride_q_descale,
1421+
args.nblock_stride_k_descale_page,
1422+
args.stride_k_descale_token,
1423+
args.nhead_stride_k_descale,
1424+
args.nhead_stride_v_descale,
1425+
args.p_scale_ptr);
14011426
}
14021427
}();
14031428

projects/composablekernel/example/ck_tile/01_fmha/quant.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// keep sync with BlockAttentionQuantScaleEnum
1616
enum class quant_scale_enum
1717
{
18-
no_scale = 0,
19-
pertensor = 1,
20-
blockscale = 2,
21-
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
22-
mx = 4, // Microscaling (MX)
18+
no_scale = 0,
19+
pertensor = 1,
20+
blockscale = 2,
21+
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
22+
mx = 4, // Microscaling (MX)
23+
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
2324
};
2425

2526
struct quant_scale_info
@@ -38,6 +39,8 @@ struct quant_scale_info
3839
os << "kvbs";
3940
else if(type == quant_scale_enum::mx)
4041
os << "mx";
42+
else if(type == quant_scale_enum::per_token_head)
43+
os << "pth";
4144
}
4245

4346
static quant_scale_info decode(std::string str)
@@ -63,6 +66,10 @@ struct quant_scale_info
6366
{
6467
info.type = quant_scale_enum::mx;
6568
}
69+
else if(str == "pth" || str == "5")
70+
{
71+
info.type = quant_scale_enum::per_token_head;
72+
}
6673
else
6774
{
6875
throw std::invalid_argument("invalid quant scale value: " + str);

projects/composablekernel/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ namespace ck_tile {
1010
// This class is used for codegen pattern matching
1111
enum class BlockAttentionQuantScaleEnum
1212
{
13-
NO_SCALE = 0,
14-
PERTENSOR = 1,
15-
BLOCKSCALE = 2,
16-
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
17-
MX = 4, // Microscaling
13+
NO_SCALE = 0,
14+
PERTENSOR = 1,
15+
BLOCKSCALE = 2,
16+
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
17+
MX = 4, // Microscaling
18+
PER_TOKEN_HEAD = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
1819
};
1920

2021
template <BlockAttentionQuantScaleEnum>
@@ -45,5 +46,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::MX>
4546
{
4647
static constexpr const char* name = "mx";
4748
};
49+
template <>
50+
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD>
51+
{
52+
static constexpr const char* name = "per_token_head";
53+
};
4854

4955
} // namespace ck_tile

projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,25 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
205205
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
206206
};
207207

208+
// PER_TOKEN_HEAD: Q per-token-per-head, K per-token-per-head (paged-aligned), V per-head
209+
// q_descale: [total_q, nhead_q]
210+
// k_descale: [num_total_pages, page_block_size, nhead_k]
211+
// v_descale: [nhead_k]
212+
struct FmhaFwdPerTokenHeadKargs
213+
{
214+
const void* q_descale_ptr = nullptr;
215+
const void* k_descale_ptr = nullptr;
216+
const void* v_descale_ptr = nullptr;
217+
ck_tile::index_t stride_q_descale_token = 0;
218+
ck_tile::index_t nhead_stride_q_descale = 0;
219+
ck_tile::index_t nblock_stride_k_descale_page = 0;
220+
ck_tile::index_t stride_k_descale_token = 0;
221+
ck_tile::index_t nhead_stride_k_descale = 0;
222+
ck_tile::index_t nhead_stride_v_descale = 0;
223+
// Optional per-q-head P scale [num_head_q] fp32.
224+
const void* p_scale_ptr = nullptr;
225+
};
226+
208227
// Helper template to select QScale Kargs type based on QScaleEnum
209228
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
210229
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
@@ -225,6 +244,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
225244
using type = FmhaFwdKVBlockScaleKargs;
226245
};
227246

247+
template <typename EmptyType>
248+
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD, EmptyType>
249+
{
250+
using type = FmhaFwdPerTokenHeadKargs;
251+
};
252+
228253
struct FmhaFwdDropoutSeedOffset
229254
{
230255
template <typename T>
@@ -379,7 +404,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
379404
drop_seed_offset,
380405
const void* sink_ptr = nullptr,
381406
ck_tile::index_t nblock_stride_kv_block_descale = 0,
382-
ck_tile::index_t nhead_stride_kv_block_descale = 0)
407+
ck_tile::index_t nhead_stride_kv_block_descale = 0,
408+
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
409+
ck_tile::index_t stride_q_descale_token = 0,
410+
ck_tile::index_t nhead_stride_q_descale = 0,
411+
ck_tile::index_t nblock_stride_k_descale_page = 0,
412+
ck_tile::index_t stride_k_descale_token = 0,
413+
ck_tile::index_t nhead_stride_k_descale = 0,
414+
ck_tile::index_t nhead_stride_v_descale = 0,
415+
const void* p_scale_ptr = nullptr)
383416
{
384417
Kargs kargs{{q_ptr,
385418
k_ptr,
@@ -458,6 +491,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
458491
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
459492
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
460493
}
494+
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
495+
{
496+
kargs.q_descale_ptr = q_descale_ptr;
497+
kargs.k_descale_ptr = k_descale_ptr;
498+
kargs.v_descale_ptr = v_descale_ptr;
499+
kargs.stride_q_descale_token = stride_q_descale_token;
500+
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
501+
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
502+
kargs.stride_k_descale_token = stride_k_descale_token;
503+
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
504+
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
505+
kargs.p_scale_ptr = p_scale_ptr;
506+
}
461507
if constexpr(kHasDropout)
462508
{
463509
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -536,7 +582,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
536582
drop_seed_offset,
537583
const void* sink_ptr = nullptr,
538584
ck_tile::index_t nblock_stride_kv_block_descale = 0,
539-
ck_tile::index_t nhead_stride_kv_block_descale = 0)
585+
ck_tile::index_t nhead_stride_kv_block_descale = 0,
586+
// PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD)
587+
ck_tile::index_t stride_q_descale_token = 0,
588+
ck_tile::index_t nhead_stride_q_descale = 0,
589+
ck_tile::index_t nblock_stride_k_descale_page = 0,
590+
ck_tile::index_t stride_k_descale_token = 0,
591+
ck_tile::index_t nhead_stride_k_descale = 0,
592+
ck_tile::index_t nhead_stride_v_descale = 0,
593+
const void* p_scale_ptr = nullptr)
540594
{
541595
Kargs kargs{{q_ptr,
542596
k_ptr,
@@ -612,6 +666,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
612666
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
613667
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
614668
}
669+
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
670+
{
671+
kargs.q_descale_ptr = q_descale_ptr;
672+
kargs.k_descale_ptr = k_descale_ptr;
673+
kargs.v_descale_ptr = v_descale_ptr;
674+
kargs.stride_q_descale_token = stride_q_descale_token;
675+
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
676+
kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page;
677+
kargs.stride_k_descale_token = stride_k_descale_token;
678+
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
679+
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
680+
kargs.p_scale_ptr = p_scale_ptr;
681+
}
615682
if constexpr(kHasDropout)
616683
{
617684
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1222,6 +1289,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
12221289
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
12231290
return kargs.scale_s * q_descale;
12241291
}
1292+
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
1293+
{
1294+
// Q/K descales are per-token-per-head, applied as outer product in pipeline.
1295+
// Here we only forward the softmax scale (1/sqrt(d)).
1296+
return kargs.scale_s;
1297+
}
12251298
else
12261299
{
12271300
return kargs.scale_s;
@@ -1339,6 +1412,50 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
13391412
kargs.nblock_stride_kv_block_descale,
13401413
kargs.nhead_stride_kv_block_descale);
13411414
}
1415+
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
1416+
{
1417+
// PER_TOKEN_HEAD: Q/K descales are per-token-per-head, V is per-head.
1418+
assert(kargs.q_descale_ptr != nullptr);
1419+
assert(kargs.k_descale_ptr != nullptr);
1420+
assert(kargs.v_descale_ptr != nullptr);
1421+
const float* q_descale_ptr = reinterpret_cast<const float*>(kargs.q_descale_ptr);
1422+
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
1423+
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
1424+
1425+
const float* p_scale_ptr = reinterpret_cast<const float*>(kargs.p_scale_ptr);
1426+
1427+
return FmhaPipeline{}(q_dram_window,
1428+
k_dram_window,
1429+
v_dram_window,
1430+
bias_dram_window,
1431+
randval_dram_window,
1432+
lse_dram_window,
1433+
mask,
1434+
position_encoding,
1435+
variant_params.sm_scale,
1436+
variant,
1437+
variant_params,
1438+
block_indices,
1439+
smem_ptr,
1440+
page_idx,
1441+
stride_k_for_pipeline,
1442+
stride_v_for_pipeline,
1443+
kargs.batch_stride_k,
1444+
kargs.batch_stride_v,
1445+
dropout,
1446+
sink_value,
1447+
max_page_table_idx,
1448+
q_descale_ptr,
1449+
k_descale_ptr,
1450+
v_descale_ptr,
1451+
kargs.stride_q_descale_token,
1452+
kargs.nhead_stride_q_descale,
1453+
kargs.nblock_stride_k_descale_page,
1454+
kargs.stride_k_descale_token,
1455+
kargs.nhead_stride_k_descale,
1456+
kargs.nhead_stride_v_descale,
1457+
p_scale_ptr);
1458+
}
13421459
else
13431460
{
13441461
return FmhaPipeline{}(q_dram_window,

0 commit comments

Comments
 (0)