Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}

Expand All @@ -89,6 +90,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"per_token_head": "quant_scale_enum::per_token_head",
"mx": "quant_scale_enum::mx",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@

K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}

SUPPORTED_PAGE_SIZE = [1, 16, 1024]
SUPPORTED_PAGE_SIZE = [1, 16, 64, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
# vec_k_col_v: K is 5D vectorized (same as "vectorized") and V is 4D ColumnMajor
# [NumBlocks, NumHeads, HeadDim, PageSize] (decode-aligned). Generated as an
# additional gated variant for fp8bf16 PER_TOKEN_HEAD only; see get_pipelines().
SUPPORTED_KV_MEMORY_LAYOUT_FP8_PTH_EXTRA = ["vec_k_col_v"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
"vec_k_col_v": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VEC_K_COL_V_LAYOUT",
}
KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
Expand Down Expand Up @@ -733,7 +738,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor", "kv_blockscale"],
["pertensor", "kv_blockscale", "per_token_head"],
get_mask_map(mask_impl).keys(),
["no"],
["t", "f"],
Expand All @@ -746,6 +751,31 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
if sink == "t" and mask in ("no", "s_no"):
continue
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip

# Decode-aligned VEC_K_COL_V variants: 5D vectorized K + 4D ColumnMajor V
# [NumBlocks, NumHeads, HeadDim, PageSize]. Gated to PER_TOKEN_HEAD only.
# Both lookup tables are emitted because the wrapper picks VLLM_BLOCK_TABLE_2D
# whenever a block_table is supplied (decode/prefill production path) and falls
# back to SGLANG_PAGE_TABLE_1D otherwise. Pipeline runs with vlayout="col";
# existing else-branches already skip the RowMajor V shuffle.
for (
logits,
mask,
bias,
sink,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
["no"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT_FP8_PTH_EXTRA,
SUPPORTED_KV_LOOKUP_TABLE,
):
if sink == "t" and mask in ("no", "s_no"):
continue
pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, "f", "f", "per_token_head", mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
Expand Down Expand Up @@ -819,9 +849,11 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
# kv_blockscale only supports page_size >= kN0; per_token_head can cross pages.
if (
pipeline.F_qscale == "kv_blockscale"
and page_size < tile.F_bn0
):
continue
k = FmhaFwdKernel(
F_idx=0,
Expand Down Expand Up @@ -867,7 +899,13 @@ def get_fwd_blobs(
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
# vlayout="row" everywhere except the decode-aligned
# vec_k_col_v variant (fp8bf16 PER_TOKEN_HEAD only) which
# uses vlayout="col"; see get_pipelines() above.
cond &= (
pipeline.F_vlayout == "row"
or pipeline.F_kv_memory_layout == "vec_k_col_v"
)
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
Expand Down
33 changes: 31 additions & 2 deletions projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,10 @@ struct fmha_batch_prefill_args
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
kv_memory_layout; // KV memory layout (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
// V tensor logical layout selector. true = RowMajor V (HeadDim contiguous), false =
// ColumnMajor V (PageSize contiguous, decode-aligned). Used by the auto-generated
// dispatcher to pick a kernel variant whose VLayout matches V's physical layout.
bool is_v_rowmajor = true;
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
Expand Down Expand Up @@ -673,6 +677,17 @@ struct fmha_batch_prefill_args
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension

// PER_TOKEN_HEAD: q/k use per-token-per-head descales; v uses per-head descales.
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)

// PER_TOKEN_HEAD optional per-q-head P scale [num_head_q] fp32.
const void* p_scale_ptr = nullptr;
};

// Selects the KV-cache load mode for a batch-prefill dispatch arm.
Expand Down Expand Up @@ -1342,7 +1357,14 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.p_scale_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1397,7 +1419,14 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.p_scale_ptr);
}
}();

Expand Down
17 changes: 12 additions & 5 deletions projects/composablekernel/example/ck_tile/01_fmha/quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
};

struct quant_scale_info
Expand All @@ -38,6 +39,8 @@ struct quant_scale_info
os << "kvbs";
else if(type == quant_scale_enum::mx)
os << "mx";
else if(type == quant_scale_enum::per_token_head)
os << "pth";
}

static quant_scale_info decode(std::string str)
Expand All @@ -63,6 +66,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::mx;
}
else if(str == "pth" || str == "5")
{
info.type = quant_scale_enum::per_token_head;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ namespace ck_tile {
// - LINEAR_LAYOUT:
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
// - VEC_K_COL_V_LAYOUT (decode-aligned, hybrid):
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] (same as VECTORIZED)
// V: [NumBlocks, NumHeads, HeadDim, PageSize] (4D, ColumnMajor)
// This matches the layout produced by aiter's reshape_and_cache_kernel and consumed
// by the decode paged-attention kernel, so prefill can ingest the live KV cache
// without an intermediate reshape.
enum class BlockAttentionKVCacheMemoryLayoutEnum
{
VECTORIZED_LAYOUT = 0,
LINEAR_LAYOUT = 1,
VECTORIZED_LAYOUT = 0,
LINEAR_LAYOUT = 1,
VEC_K_COL_V_LAYOUT = 2,
};

// KV cache lookup table layout selector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
MX = 4, // Microscaling
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
MX = 4, // Microscaling
PER_TOKEN_HEAD = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
};

template <BlockAttentionQuantScaleEnum>
Expand Down Expand Up @@ -45,5 +46,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::MX>
{
static constexpr const char* name = "mx";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD>
{
static constexpr const char* name = "per_token_head";
};

} // namespace ck_tile
Loading
Loading