Skip to content

Commit 18fe146

Browse files
authored
Merge branch 'develop' into 2dQuantPreshuffleWeight
2 parents 2441260 + ffc3120 commit 18fe146

File tree

4 files changed

+93
-31
lines changed

4 files changed

+93
-31
lines changed

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@
2929
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
3030

3131

32-
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
32+
DTYPE_BITS = {
33+
"fp32": 32,
34+
"fp16": 16,
35+
"bf16": 16,
36+
"fp8": 8,
37+
"fp8bf16": 8,
38+
"fp8fp32": 8,
39+
"bf8": 8,
40+
}
3341

3442
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
3543

@@ -678,6 +686,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
678686
return {
679687
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
680688
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
689+
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
681690
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
682691
} # fmt: skip
683692
elif dtype in ["fp8fp32"]:
@@ -742,8 +751,8 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli
742751
get_mask_map(mask_impl).keys(),
743752
["no"],
744753
):
745-
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
746-
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
754+
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
755+
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
747756
elif dtype in ["fp8", "fp8fp16", "bf8"]:
748757
# TODO
749758
None
@@ -958,7 +967,7 @@ def get_fwd_blobs(
958967
cond &= mode == "batch"
959968
cond &= pipeline.F_vlayout == "row"
960969
if dtype == "fp8bf16":
961-
cond &= hdim == 128 or hdim == 256
970+
cond &= hdim == 128 or hdim == 192
962971
if not cond:
963972
continue
964973
# Aiter(mha_varlen_fwd) integration
@@ -967,21 +976,21 @@ def get_fwd_blobs(
967976
cond &= mode == "group"
968977
cond &= pipeline.F_vlayout == "row"
969978
if dtype == "fp8bf16":
970-
cond &= hdim == 128 or hdim == 256
979+
cond &= hdim == 128 or hdim == 192
971980
if not cond:
972981
continue
973982
# aiter::mha_fwd C++ api integration
974983
elif receipt == 600:
975984
cond = dtype in ["fp16", "bf16", "fp8bf16"]
976985
cond &= pipeline.F_vlayout == "row"
977986
if dtype == "fp8bf16":
978-
cond &= hdim == 128 or hdim == 256
987+
cond &= hdim == 128 or hdim == 192
979988
if not cond:
980989
continue
981990
elif receipt == 888:
982991
cond = dtype in ["fp8bf16", "fp8fp32"]
983992
cond &= pipeline.F_vlayout == "row"
984-
cond &= hdim == 128 or hdim == 256
993+
cond &= hdim == 128 or hdim == 192
985994
if not cond:
986995
continue
987996

example/ck_tile/38_block_scale_gemm/gemm_utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
211211

212212
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
213213
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
214+
static constexpr int kBlockPerCu = 2;
214215
};
215216

216217
template <typename PrecType>

include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
8787
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
8888
static constexpr index_t kAlignmentBias =
8989
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
90+
static constexpr index_t kAlignmentRandVal =
91+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
9092

9193
#if CK_TILE_FMHA_FWD_FAST_EXP2
9294
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;

include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
6969

7070
using Base::m_preload;
7171

72-
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
72+
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
73+
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
7374
static constexpr index_t KPerBlockBQ =
7475
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
7576
static constexpr index_t QScalesPerBlockRow =
@@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
9596
// clang-format on
9697
}
9798

99+
template <index_t nloop>
100+
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
101+
{
102+
constexpr index_t Aload_inst =
103+
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
104+
constexpr index_t Bload_inst =
105+
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
106+
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
107+
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
108+
QuantGroupSize::kK * QuantGroupSize::kK),
109+
VectorLoadSize);
110+
constexpr index_t kLdsVec = 8;
111+
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
112+
constexpr index_t ds_read_inst = kMPerBlock / kLdsVec;
113+
constexpr index_t ds_write_inst = Aload_inst;
114+
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
115+
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
116+
constexpr index_t buffer_load_rep =
117+
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
118+
119+
static_for<0, nloop, 1>{}([&](auto j_inst) {
120+
ignore = j_inst;
121+
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
122+
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
123+
124+
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
125+
{
126+
__builtin_amdgcn_sched_group_barrier(
127+
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
128+
}
129+
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
130+
{
131+
__builtin_amdgcn_sched_group_barrier(
132+
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
133+
}
134+
135+
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
136+
{
137+
if constexpr(ds_write_inst > 0)
138+
{
139+
__builtin_amdgcn_sched_group_barrier(
140+
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
141+
}
142+
}
143+
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
144+
});
145+
});
146+
__builtin_amdgcn_sched_barrier(0);
147+
}
148+
98149
static constexpr bool PreshuffleB = Problem::PreshuffleB;
99150
static constexpr auto TailNum = Problem::TailNum;
100151

@@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
130181
static_assert(!is_b_row_major, "B must be col major (row major not supported yet)");
131182

132183
const index_t iMWarp = get_warp_id() / NWarp;
184+
// Double-Buffering (loop_count=2) for full load/compute overlap.
185+
const index_t loop_count = 2;
133186

134187
__builtin_amdgcn_sched_barrier(0);
135188

@@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
313366
__builtin_amdgcn_sched_barrier(0);
314367

315368
// MAIN LOOP
316-
index_t iCounter = (num_loop - 1) / 2;
369+
index_t iCounter = (num_loop - 1) / loop_count;
370+
317371
while(iCounter > 0)
318372
{
373+
__builtin_amdgcn_sched_barrier(0);
374+
// Prefill A(2i+1)
375+
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
376+
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
377+
378+
// Prefetch A(2i+2)
379+
a_block_tile = load_tile(a_copy_dram_window);
380+
// move A window to next k
381+
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
382+
383+
// GEMM 2i
384+
block_weight_preshuffle(c_block_tile,
385+
a_warp_tensor,
386+
b_warp_tensor_ping,
387+
bq_block_tile,
388+
a_warp_windows_ping);
319389
// prefetch B(2i+1)
320390
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
321391
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
342412
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
343413
}
344414

345-
// Prefill A(2i+1)
346-
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
347-
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
348-
349-
// Prefetch A(2i+2)
350-
a_block_tile = load_tile(a_copy_dram_window);
351-
// move A window to next k
352-
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
353-
354-
// GEMM 2i
355-
block_weight_preshuffle(c_block_tile,
356-
a_warp_tensor,
357-
b_warp_tensor_ping,
358-
bq_block_tile,
359-
a_warp_windows_ping);
360-
361415
static_for<0, m_preload, 1>{}([&](auto loadIter) {
362416
constexpr auto mIter = loadIter % MIterPerWarp;
363417
constexpr auto kIter = loadIter / MIterPerWarp;
364418
a_warp_tensor(loadIter) =
365419
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
366420
});
367-
Base::HotLoopScheduler();
368421

369422
// Next K
370423

@@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
416469
a_warp_tensor(loadIter) =
417470
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
418471
});
419-
Base::HotLoopScheduler();
420-
421472
iCounter--;
473+
HotLoopScheduler<loop_count>();
422474
}
423475

424476
// tail
@@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
456508
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
457509
});
458510

459-
Base::Last2ndHotLoopScheduler();
460-
461511
// GEMM loopK
462512
block_weight_preshuffle(c_block_tile,
463513
a_warp_tensor,
464514
b_warp_tensor_pong,
465515
bq_block_tile_2,
466516
a_warp_windows_pong);
467-
Base::LastHotLoopScheduler();
517+
HotLoopScheduler<loop_count>();
468518
}
469519
else if constexpr(TailNum == TailNumber::Odd)
470520
{

0 commit comments

Comments
 (0)