@@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
6969
7070 using Base::m_preload;
7171
72- static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
72+ static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
73+ static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
7374 static constexpr index_t KPerBlockBQ =
7475 integer_divide_ceil (BlockGemmShape::kK , QuantGroupSize::kK );
7576 static constexpr index_t QScalesPerBlockRow =
@@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
9596 // clang-format on
9697 }
9798
99+ template <index_t nloop>
100+ CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler ()
101+ {
102+ constexpr index_t Aload_inst =
103+ (kMPerBlock * kKPerBlock * sizeof (ADataType)) / BlockSize / VectorLoadSize;
104+ constexpr index_t Bload_inst =
105+ (kKPerBlock * kNPerBlock * sizeof (BDataType)) / BlockSize / VectorLoadSize;
106+ constexpr index_t BQload_inst = ck_tile::integer_divide_ceil (
107+ ck_tile::integer_divide_ceil (kKPerBlock * kNPerBlock * sizeof (BQDataType),
108+ QuantGroupSize::kK * QuantGroupSize::kK ),
109+ VectorLoadSize);
110+ constexpr index_t kLdsVec = 8 ;
111+ constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
112+ constexpr index_t ds_read_inst = kMPerBlock / kLdsVec ;
113+ constexpr index_t ds_write_inst = Aload_inst;
114+ constexpr index_t mfma_inst = (kMPerBlock / WG::kM ) * (kNPerBlock / WG::kN );
115+ constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
116+ constexpr index_t buffer_load_rep =
117+ min (mfma_inst / buffer_load_inst, 4 ); // 1 buffer_load cover 4 mfma
118+
119+ static_for<0 , nloop, 1 >{}([&](auto j_inst) {
120+ ignore = j_inst;
121+ static_for<0 , mfma_inst, 1 >{}([&](auto i_inst) {
122+ __builtin_amdgcn_sched_group_barrier (LLVMSchedGroupMask::MFMA, 1 , 0 ); // MFMA
123+
124+ if constexpr (ds_rep > 0 && i_inst % ds_rep == 0 )
125+ {
126+ __builtin_amdgcn_sched_group_barrier (
127+ LLVMSchedGroupMask::DS_READ, 1 , 0 ); // DS read
128+ }
129+ if constexpr (ds_rep > 0 && i_inst % ds_rep == 1 )
130+ {
131+ __builtin_amdgcn_sched_group_barrier (
132+ LLVMSchedGroupMask::DS_WRITE, 1 , 0 ); // DS write
133+ }
134+
135+ if constexpr (buffer_load_rep > 0 && i_inst % buffer_load_rep == 0 )
136+ {
137+ if constexpr (ds_write_inst > 0 )
138+ {
139+ __builtin_amdgcn_sched_group_barrier (
140+ LLVMSchedGroupMask::VMEM_READ, 1 , 0 ); // VMEM read
141+ }
142+ }
143+ __builtin_amdgcn_sched_group_barrier (LLVMSchedGroupMask::VALU, 2 , 0 ); // VALU
144+ });
145+ });
146+ __builtin_amdgcn_sched_barrier (0 );
147+ }
148+
98149 static constexpr bool PreshuffleB = Problem::PreshuffleB;
99150 static constexpr auto TailNum = Problem::TailNum;
100151
@@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
130181 static_assert (!is_b_row_major, " B must be col major (row major not supported yet)" );
131182
132183 const index_t iMWarp = get_warp_id () / NWarp;
184+ // Double-Buffering (loop_count=2) for full load/compute overlap.
185+ const index_t loop_count = 2 ;
133186
134187 __builtin_amdgcn_sched_barrier (0 );
135188
@@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
313366 __builtin_amdgcn_sched_barrier (0 );
314367
315368 // MAIN LOOP
316- index_t iCounter = (num_loop - 1 ) / 2 ;
369+ index_t iCounter = (num_loop - 1 ) / loop_count;
370+
317371 while (iCounter > 0 )
318372 {
373+ __builtin_amdgcn_sched_barrier (0 );
374+ // Prefill A(2i+1)
375+ a_block_tile_tmp = tile_elementwise_in (a_element_func, a_block_tile);
376+ store_tile (a_copy_lds_window_pong, a_block_tile_tmp);
377+
378+ // Prefetch A(2i+2)
379+ a_block_tile = load_tile (a_copy_dram_window);
380+ // move A window to next k
381+ move_tile_window (a_copy_dram_window, {0 , kKPerBlock });
382+
383+ // GEMM 2i
384+ block_weight_preshuffle (c_block_tile,
385+ a_warp_tensor,
386+ b_warp_tensor_ping,
387+ bq_block_tile,
388+ a_warp_windows_ping);
319389 // prefetch B(2i+1)
320390 static_for<0 , KIterPerWarp, 1 >{}([&](auto kIter ) {
321391 static_for<0 , NIterPerWarp, 1 >{}([&](auto nIter) {
@@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
342412 move_tile_window (bq_copy_dram_window, {0 , KPerBlockBQ});
343413 }
344414
345- // Prefill A(2i+1)
346- a_block_tile_tmp = tile_elementwise_in (a_element_func, a_block_tile);
347- store_tile (a_copy_lds_window_pong, a_block_tile_tmp);
348-
349- // Prefetch A(2i+2)
350- a_block_tile = load_tile (a_copy_dram_window);
351- // move A window to next k
352- move_tile_window (a_copy_dram_window, {0 , kKPerBlock });
353-
354- // GEMM 2i
355- block_weight_preshuffle (c_block_tile,
356- a_warp_tensor,
357- b_warp_tensor_ping,
358- bq_block_tile,
359- a_warp_windows_ping);
360-
361415 static_for<0 , m_preload, 1 >{}([&](auto loadIter) {
362416 constexpr auto mIter = loadIter % MIterPerWarp;
363417 constexpr auto kIter = loadIter / MIterPerWarp;
364418 a_warp_tensor (loadIter) =
365419 load_tile (a_warp_windows_pong (number<mIter >{})(number<kIter >{}));
366420 });
367- Base::HotLoopScheduler ();
368421
369422 // Next K
370423
@@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
416469 a_warp_tensor (loadIter) =
417470 load_tile (a_warp_windows_ping (number<mIter >{})(number<kIter >{}));
418471 });
419- Base::HotLoopScheduler ();
420-
421472 iCounter--;
473+ HotLoopScheduler<loop_count>();
422474 }
423475
424476 // tail
@@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
456508 load_tile (a_warp_windows_pong (number<mIter >{})(number<kIter >{}));
457509 });
458510
459- Base::Last2ndHotLoopScheduler ();
460-
461511 // GEMM loopK
462512 block_weight_preshuffle (c_block_tile,
463513 a_warp_tensor,
464514 b_warp_tensor_pong,
465515 bq_block_tile_2,
466516 a_warp_windows_pong);
467- Base::LastHotLoopScheduler ();
517+ HotLoopScheduler<loop_count> ();
468518 }
469519 else if constexpr (TailNum == TailNumber::Odd)
470520 {
0 commit comments