Skip to content

Commit b27380c

Browse files
committed
[ck_tile] move get_k_warp to gemm_shape
1 parent 8bb4da0 commit b27380c

File tree

8 files changed

+59
-209
lines changed

8 files changed

+59
-209
lines changed

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,6 @@
1212
#include "ck_tile/ops/gemm.hpp"
1313
#include "ck_tile/utility/json_dump.hpp"
1414

15-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
16-
constexpr ck_tile::index_t get_k_warp_tile()
17-
{
18-
#if defined(CK_GFX950_SUPPORT)
19-
constexpr bool is_8bit_float =
20-
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
21-
if constexpr(M_Warp_Tile == 32)
22-
return is_8bit_float ? 64 : 16;
23-
else
24-
return is_8bit_float ? 128 : 32;
25-
#else
26-
if constexpr(M_Warp_Tile == 32)
27-
return 16;
28-
else
29-
return 32;
30-
#endif
31-
}
32-
33-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
34-
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
35-
{
36-
#if defined(CK_GFX950_SUPPORT)
37-
if constexpr(M_Warp_Tile == 32)
38-
return sizeof(PrecType) == 2 ? 16 : 64;
39-
else
40-
return sizeof(PrecType) == 2 ? 32 : 128;
41-
#else
42-
if constexpr(M_Warp_Tile == 32)
43-
return sizeof(PrecType) == 2 ? 16 : 32;
44-
else
45-
return sizeof(PrecType) == 2 ? 32 : 64;
46-
#endif
47-
}
48-
4915
struct GemmConfigBase
5016
{
5117
static constexpr bool kPadM = false;
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
12288

12389
static constexpr ck_tile::index_t M_Warp_Tile = 16;
12490
static constexpr ck_tile::index_t N_Warp_Tile = 16;
125-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
91+
static constexpr ck_tile::index_t K_Warp_Tile =
92+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
12693

12794
static constexpr bool DoubleSmemBuffer = false;
12895
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
141108

142109
static constexpr ck_tile::index_t M_Warp_Tile = 32;
143110
static constexpr ck_tile::index_t N_Warp_Tile = 32;
144-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
111+
static constexpr ck_tile::index_t K_Warp_Tile =
112+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
145113

146114
static constexpr bool DoubleSmemBuffer = false;
147115
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
160128

161129
static constexpr ck_tile::index_t M_Warp_Tile = 16;
162130
static constexpr ck_tile::index_t N_Warp_Tile = 16;
163-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
131+
static constexpr ck_tile::index_t K_Warp_Tile =
132+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
164133

165134
static constexpr bool DoubleSmemBuffer = false;
166135
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
204173

205174
static constexpr ck_tile::index_t M_Warp_Tile = 32;
206175
static constexpr ck_tile::index_t N_Warp_Tile = 32;
207-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
176+
static constexpr ck_tile::index_t K_Warp_Tile =
177+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
208178

209179
static constexpr bool DoubleSmemBuffer = true;
210180
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
223193

224194
static constexpr ck_tile::index_t M_Warp_Tile = 32;
225195
static constexpr ck_tile::index_t N_Warp_Tile = 32;
226-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
196+
static constexpr ck_tile::index_t K_Warp_Tile =
197+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
227198

228199
static constexpr bool DoubleSmemBuffer = true;
229200
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
242213

243214
static constexpr ck_tile::index_t M_Warp_Tile = 32;
244215
static constexpr ck_tile::index_t N_Warp_Tile = 32;
245-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
216+
static constexpr ck_tile::index_t K_Warp_Tile =
217+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
246218

247219
static constexpr bool DoubleSmemBuffer = false;
248220
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
282254

283255
static constexpr ck_tile::index_t M_Warp_Tile = 16;
284256
static constexpr ck_tile::index_t N_Warp_Tile = 16;
285-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
257+
static constexpr ck_tile::index_t K_Warp_Tile =
258+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
286259

287260
static constexpr int kBlockPerCu = 1;
288261
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
306279

307280
static constexpr ck_tile::index_t M_Warp_Tile = 16;
308281
static constexpr ck_tile::index_t N_Warp_Tile = 16;
309-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
282+
static constexpr ck_tile::index_t K_Warp_Tile =
283+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
310284

311285
static constexpr int kBlockPerCu = 2;
312286
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

example/ck_tile/17_grouped_gemm/grouped_gemm.hpp

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,6 @@
1111
#include "ck_tile/ops/gemm.hpp"
1212
#include "ck_tile/utility/json_dump.hpp"
1313

14-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
15-
constexpr ck_tile::index_t get_k_warp_tile()
16-
{
17-
#if defined(CK_GFX950_SUPPORT)
18-
constexpr bool is_8bit_float =
19-
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
20-
if constexpr(M_Warp_Tile == 32)
21-
return is_8bit_float ? 64 : 16;
22-
else
23-
return is_8bit_float ? 128 : 32;
24-
#else
25-
if constexpr(M_Warp_Tile == 32)
26-
return 16;
27-
else
28-
return 32;
29-
#endif
30-
}
31-
32-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
33-
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
34-
{
35-
#if defined(CK_GFX950_SUPPORT)
36-
if constexpr(M_Warp_Tile == 32)
37-
return sizeof(PrecType) == 2 ? 16 : 64;
38-
else
39-
return sizeof(PrecType) == 2 ? 32 : 128;
40-
#else
41-
if constexpr(M_Warp_Tile == 32)
42-
return sizeof(PrecType) == 2 ? 16 : 32;
43-
else
44-
return sizeof(PrecType) == 2 ? 32 : 64;
45-
#endif
46-
}
47-
4814
template <typename DataType>
4915
struct GemmTypeConfig;
5016

@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
11177

11278
static constexpr ck_tile::index_t M_Warp_Tile = 32;
11379
static constexpr ck_tile::index_t N_Warp_Tile = 32;
114-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
80+
static constexpr ck_tile::index_t K_Warp_Tile =
81+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
11582

11683
static constexpr bool DoubleSmemBuffer = false;
11784
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
134101

135102
static constexpr ck_tile::index_t M_Warp_Tile = 32;
136103
static constexpr ck_tile::index_t N_Warp_Tile = 32;
137-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
104+
static constexpr ck_tile::index_t K_Warp_Tile =
105+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
138106

139107
static constexpr bool DoubleSmemBuffer = true;
140108
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
157125

158126
static constexpr ck_tile::index_t M_Warp_Tile = 16;
159127
static constexpr ck_tile::index_t N_Warp_Tile = 16;
160-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
128+
static constexpr ck_tile::index_t K_Warp_Tile =
129+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
161130

162131
static constexpr bool DoubleSmemBuffer = true;
163132
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
178147

179148
static constexpr ck_tile::index_t M_Warp_Tile = 16;
180149
static constexpr ck_tile::index_t N_Warp_Tile = 16;
181-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
150+
static constexpr ck_tile::index_t K_Warp_Tile =
151+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
182152

183153
static constexpr bool kPadK = true;
184154

@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
203173

204174
static constexpr ck_tile::index_t M_Warp_Tile = 16;
205175
static constexpr ck_tile::index_t N_Warp_Tile = 16;
206-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
176+
static constexpr ck_tile::index_t K_Warp_Tile =
177+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
207178

208179
static constexpr int kBlockPerCu = 2;
209180
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,6 @@
1111
#include "ck_tile/ops/gemm.hpp"
1212
#include "ck_tile/utility/json_dump.hpp"
1313

14-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
15-
constexpr ck_tile::index_t get_k_warp_tile()
16-
{
17-
#if defined(CK_GFX950_SUPPORT)
18-
constexpr bool is_8bit_float =
19-
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
20-
if constexpr(M_Warp_Tile == 32)
21-
return is_8bit_float ? 64 : 16;
22-
else
23-
return is_8bit_float ? 128 : 32;
24-
#else
25-
if constexpr(M_Warp_Tile == 32)
26-
return 16;
27-
else
28-
return 32;
29-
#endif
30-
}
31-
3214
struct GemmConfigBase
3315
{
3416
static constexpr bool kPadM = false;

example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,6 @@
1010
#include "ck_tile/ops/gemm.hpp"
1111
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
1212

13-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
14-
constexpr ck_tile::index_t get_k_warp_tile()
15-
{
16-
#if defined(CK_GFX950_SUPPORT)
17-
constexpr bool is_8bit_float =
18-
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
19-
if constexpr(M_Warp_Tile == 32)
20-
return is_8bit_float ? 64 : 16;
21-
else
22-
return is_8bit_float ? 128 : 32;
23-
#else
24-
if constexpr(M_Warp_Tile == 32)
25-
return 16;
26-
else
27-
return 32;
28-
#endif
29-
}
30-
31-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
32-
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
33-
{
34-
#if defined(CK_GFX950_SUPPORT)
35-
if constexpr(M_Warp_Tile == 32)
36-
return sizeof(PrecType) == 2 ? 16 : 64;
37-
else
38-
return sizeof(PrecType) == 2 ? 32 : 128;
39-
#else
40-
if constexpr(M_Warp_Tile == 32)
41-
return sizeof(PrecType) == 2 ? 16 : 32;
42-
else
43-
return sizeof(PrecType) == 2 ? 32 : 64;
44-
#endif
45-
}
46-
4713
template <typename DataType>
4814
struct GemmTypeConfig;
4915

@@ -98,7 +64,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
9864

9965
static constexpr ck_tile::index_t M_Warp_Tile = 32;
10066
static constexpr ck_tile::index_t N_Warp_Tile = 32;
101-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
67+
static constexpr ck_tile::index_t K_Warp_Tile =
68+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
10269
};
10370

10471
template <typename PrecType>
@@ -115,7 +82,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
11582
static constexpr ck_tile::index_t M_Warp_Tile = 16;
11683
static constexpr ck_tile::index_t N_Warp_Tile = 16;
11784
static constexpr ck_tile::index_t K_Warp_Tile =
118-
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
85+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
11986

12087
static constexpr bool PreshuffleB = true;
12188
static constexpr bool DoubleSmemBuffer = true;

example/ck_tile/38_block_scale_gemm/gemm_utils.hpp

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
2424
return combined_hash;
2525
}
2626

27-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
28-
constexpr ck_tile::index_t get_k_warp_tile()
29-
{
30-
#if defined(CK_GFX950_SUPPORT)
31-
constexpr bool is_8bit_float =
32-
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
33-
if constexpr(M_Warp_Tile == 32)
34-
return is_8bit_float ? 64 : 16;
35-
else
36-
return is_8bit_float ? 128 : 32;
37-
#else
38-
if constexpr(M_Warp_Tile == 32)
39-
return 16;
40-
else
41-
return 32;
42-
#endif
43-
}
44-
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
45-
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
46-
{
47-
#if defined(CK_GFX950_SUPPORT)
48-
if constexpr(M_Warp_Tile == 32)
49-
return sizeof(PrecType) == 2 ? 16 : 64;
50-
else
51-
return sizeof(PrecType) == 2 ? 32 : 128;
52-
#else
53-
if constexpr(M_Warp_Tile == 32)
54-
return sizeof(PrecType) == 2 ? 16 : 32;
55-
else
56-
return sizeof(PrecType) == 2 ? 32 : 64;
57-
#endif
58-
}
59-
6027
template <typename Layout>
6128
static constexpr inline auto is_row_major(Layout layout_)
6229
{
@@ -122,7 +89,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
12289

12390
static constexpr ck_tile::index_t M_Warp_Tile = 16;
12491
static constexpr ck_tile::index_t N_Warp_Tile = 16;
125-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
92+
static constexpr ck_tile::index_t K_Warp_Tile =
93+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
12694
};
12795

12896
template <typename PrecType>
@@ -138,7 +106,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
138106

139107
static constexpr ck_tile::index_t M_Warp_Tile = 16;
140108
static constexpr ck_tile::index_t N_Warp_Tile = 16;
141-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
109+
static constexpr ck_tile::index_t K_Warp_Tile =
110+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
142111
};
143112

144113
template <typename PrecType>
@@ -155,7 +124,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
155124
static constexpr ck_tile::index_t M_Warp_Tile = 16;
156125
static constexpr ck_tile::index_t N_Warp_Tile = 16;
157126
static constexpr ck_tile::index_t K_Warp_Tile =
158-
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
127+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
159128

160129
static constexpr bool PreshuffleQuant = true;
161130
};
@@ -174,7 +143,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
174143
static constexpr ck_tile::index_t M_Warp_Tile = 16;
175144
static constexpr ck_tile::index_t N_Warp_Tile = 16;
176145
static constexpr ck_tile::index_t K_Warp_Tile =
177-
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
146+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
178147

179148
static constexpr bool PreshuffleB = true;
180149
static constexpr bool DoubleSmemBuffer = true;
@@ -204,7 +173,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
204173
static constexpr ck_tile::index_t M_Warp_Tile = 16;
205174
static constexpr ck_tile::index_t N_Warp_Tile = 16;
206175
static constexpr ck_tile::index_t K_Warp_Tile =
207-
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
176+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
208177

209178
static constexpr bool PreshuffleB = true;
210179
static constexpr bool DoubleSmemBuffer = true;
@@ -233,7 +202,8 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
233202

234203
static constexpr ck_tile::index_t M_Warp_Tile = 16;
235204
static constexpr ck_tile::index_t N_Warp_Tile = 16;
236-
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
205+
static constexpr ck_tile::index_t K_Warp_Tile =
206+
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
237207
};
238208

239209
template <typename PrecType>

0 commit comments

Comments
 (0)