1212#include " ck_tile/ops/gemm.hpp"
1313#include " ck_tile/utility/json_dump.hpp"
1414
15- template <typename PrecType, ck_tile::index_t M_Warp_Tile>
16- constexpr ck_tile::index_t get_k_warp_tile ()
17- {
18- #if defined(CK_GFX950_SUPPORT)
19- constexpr bool is_8bit_float =
20- std::is_same_v<PrecType, ck_tile::fp8_t > || std::is_same_v<PrecType, ck_tile::bf8_t >;
21- if constexpr (M_Warp_Tile == 32 )
22- return is_8bit_float ? 64 : 16 ;
23- else
24- return is_8bit_float ? 128 : 32 ;
25- #else
26- if constexpr (M_Warp_Tile == 32 )
27- return 16 ;
28- else
29- return 32 ;
30- #endif
31- }
32-
33- template <typename PrecType, ck_tile::index_t M_Warp_Tile>
34- constexpr ck_tile::index_t get_k_warp_tile_flatmm ()
35- {
36- #if defined(CK_GFX950_SUPPORT)
37- if constexpr (M_Warp_Tile == 32 )
38- return sizeof (PrecType) == 2 ? 16 : 64 ;
39- else
40- return sizeof (PrecType) == 2 ? 32 : 128 ;
41- #else
42- if constexpr (M_Warp_Tile == 32 )
43- return sizeof (PrecType) == 2 ? 16 : 32 ;
44- else
45- return sizeof (PrecType) == 2 ? 32 : 64 ;
46- #endif
47- }
48-
4915struct GemmConfigBase
5016{
5117 static constexpr bool kPadM = false ;
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
12288
12389 static constexpr ck_tile::index_t M_Warp_Tile = 16 ;
12490 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
125- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
91+ static constexpr ck_tile::index_t K_Warp_Tile =
92+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
12693
12794 static constexpr bool DoubleSmemBuffer = false ;
12895 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
141108
142109 static constexpr ck_tile::index_t M_Warp_Tile = 32 ;
143110 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
144- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
111+ static constexpr ck_tile::index_t K_Warp_Tile =
112+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
145113
146114 static constexpr bool DoubleSmemBuffer = false ;
147115 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
160128
161129 static constexpr ck_tile::index_t M_Warp_Tile = 16 ;
162130 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
163- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
131+ static constexpr ck_tile::index_t K_Warp_Tile =
132+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
164133
165134 static constexpr bool DoubleSmemBuffer = false ;
166135 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
204173
205174 static constexpr ck_tile::index_t M_Warp_Tile = 32 ;
206175 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
207- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
176+ static constexpr ck_tile::index_t K_Warp_Tile =
177+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
208178
209179 static constexpr bool DoubleSmemBuffer = true ;
210180 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
223193
224194 static constexpr ck_tile::index_t M_Warp_Tile = 32 ;
225195 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
226- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
196+ static constexpr ck_tile::index_t K_Warp_Tile =
197+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
227198
228199 static constexpr bool DoubleSmemBuffer = true ;
229200 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
242213
243214 static constexpr ck_tile::index_t M_Warp_Tile = 32 ;
244215 static constexpr ck_tile::index_t N_Warp_Tile = 32 ;
245- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
216+ static constexpr ck_tile::index_t K_Warp_Tile =
217+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
246218
247219 static constexpr bool DoubleSmemBuffer = false ;
248220 static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
282254
283255 static constexpr ck_tile::index_t M_Warp_Tile = 16 ;
284256 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
285- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
257+ static constexpr ck_tile::index_t K_Warp_Tile =
258+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true >();
286259
287260 static constexpr int kBlockPerCu = 1 ;
288261 static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
306279
307280 static constexpr ck_tile::index_t M_Warp_Tile = 16 ;
308281 static constexpr ck_tile::index_t N_Warp_Tile = 16 ;
309- static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
282+ static constexpr ck_tile::index_t K_Warp_Tile =
283+ ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true >();
310284
311285 static constexpr int kBlockPerCu = 2 ;
312286 static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
0 commit comments