Skip to content

Commit 85a34fd

Browse files
committed
replace with ck defined __gfx9__
1 parent 80197ca commit 85a34fd

File tree

5 files changed

+66
-0
lines changed

5 files changed

+66
-0
lines changed

profiler/src/profile_grouped_conv_bwd_data.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
8484
using F32 = float;
8585
using F16 = ck::half_t;
8686
using BF16 = ck::bhalf_t;
87+
#if defined(__gfx9__)
8788
using TF32 = ck::tf32_t;
89+
#endif
8890

8991
using namespace ck::tensor_layout::convolution;
9092

@@ -141,7 +143,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
141143
}
142144
else if(data_type == ConvDataType::F32_F32_F32_TF32)
143145
{
146+
#if defined(__gfx9__)
144147
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
148+
#endif
145149
}
146150
}
147151
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -160,7 +164,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
160164
}
161165
else if(data_type == ConvDataType::F32_F32_F32_TF32)
162166
{
167+
#if defined(__gfx9__)
163168
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
169+
#endif
164170
}
165171
}
166172
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -179,7 +185,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
179185
}
180186
else if(data_type == ConvDataType::F32_F32_F32_TF32)
181187
{
188+
#if defined(__gfx9__)
182189
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
190+
#endif
183191
}
184192
}
185193
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -198,7 +206,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
198206
}
199207
else if(data_type == ConvDataType::F32_F32_F32_TF32)
200208
{
209+
#if defined(__gfx9__)
201210
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
211+
#endif
202212
}
203213
}
204214
}
@@ -220,7 +230,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
220230
}
221231
else if(data_type == ConvDataType::F32_F32_F32_TF32)
222232
{
233+
#if defined(__gfx9__)
223234
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
235+
#endif
224236
}
225237
}
226238
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -239,7 +251,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
239251
}
240252
else if(data_type == ConvDataType::F32_F32_F32_TF32)
241253
{
254+
#if defined(__gfx9__)
242255
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
256+
#endif
243257
}
244258
}
245259
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -258,7 +272,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
258272
}
259273
else if(data_type == ConvDataType::F32_F32_F32_TF32)
260274
{
275+
#if defined(__gfx9__)
261276
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
277+
#endif
262278
}
263279
}
264280
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -277,7 +293,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
277293
}
278294
else if(data_type == ConvDataType::F32_F32_F32_TF32)
279295
{
296+
#if defined(__gfx9__)
280297
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
298+
#endif
281299
}
282300
}
283301
}

profiler/src/profile_grouped_conv_bwd_weight.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
9999
using BF16 = ck::bhalf_t;
100100
using F8 = ck::f8_t;
101101
using BF8 = ck::bf8_t;
102+
#if defined(__gfx9__)
102103
using TF32 = ck::tf32_t;
104+
#endif
103105

104106
using namespace ck::tensor_layout::convolution;
105107

@@ -160,7 +162,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
160162
}
161163
else if(data_type == ConvDataType::F32_F32_F32_TF32)
162164
{
165+
#if defined(__gfx9__)
163166
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
167+
#endif
164168
}
165169
}
166170
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -180,7 +184,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
180184
}
181185
else if(data_type == ConvDataType::F32_F32_F32_TF32)
182186
{
187+
#if defined(__gfx9__)
183188
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
189+
#endif
184190
}
185191
}
186192
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -204,7 +210,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
204210
}
205211
else if(data_type == ConvDataType::F32_F32_F32_TF32)
206212
{
213+
#if defined(__gfx9__)
207214
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
215+
#endif
208216
}
209217
}
210218
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -235,7 +243,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
235243
}
236244
else if(data_type == ConvDataType::F32_F32_F32_TF32)
237245
{
246+
#if defined(__gfx9__)
238247
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
248+
#endif
239249
}
240250
}
241251
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -260,7 +270,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
260270
}
261271
else if(data_type == ConvDataType::F32_F32_F32_TF32)
262272
{
273+
#if defined(__gfx9__)
263274
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
275+
#endif
264276
}
265277
}
266278
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -294,7 +306,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
294306
}
295307
else if(data_type == ConvDataType::F32_F32_F32_TF32)
296308
{
309+
#if defined(__gfx9__)
297310
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
311+
#endif
298312
}
299313
}
300314
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -326,7 +340,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
326340
}
327341
else if(data_type == ConvDataType::F32_F32_F32_TF32)
328342
{
343+
#if defined(__gfx9__)
329344
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
345+
#endif
330346
}
331347
}
332348

profiler/src/profile_grouped_conv_fwd.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
105105
using INT8 = int8_t;
106106
using F8 = ck::f8_t;
107107
using BF8 = ck::bf8_t;
108+
#if defined(__gfx9__)
108109
using TF32 = ck::tf32_t;
110+
#endif
109111

110112
//
111113
using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -226,7 +228,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
226228
}
227229
else if(data_type == ConvDataType::F32_F32_F32_TF32)
228230
{
231+
#if defined(__gfx9__)
229232
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
233+
#endif
230234
}
231235
}
232236
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -249,7 +253,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
249253
}
250254
else if(data_type == ConvDataType::F32_F32_F32_TF32)
251255
{
256+
#if defined(__gfx9__)
252257
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
258+
#endif
253259
}
254260
}
255261
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -274,7 +280,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
274280
}
275281
else if(data_type == ConvDataType::F32_F32_F32_TF32)
276282
{
283+
#if defined(__gfx9__)
277284
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
285+
#endif
278286
}
279287
}
280288
// NHWGC_GKYXC_NHWGK
@@ -298,7 +306,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
298306
}
299307
else if(data_type == ConvDataType::F32_F32_F32_TF32)
300308
{
309+
#if defined(__gfx9__)
301310
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
311+
#endif
302312
}
303313
}
304314
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -321,7 +331,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
321331
}
322332
else if(data_type == ConvDataType::F32_F32_F32_TF32)
323333
{
334+
#if defined(__gfx9__)
324335
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
336+
#endif
325337
}
326338
}
327339
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -340,7 +352,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
340352
}
341353
else if(data_type == ConvDataType::F32_F32_F32_TF32)
342354
{
355+
#if defined(__gfx9__)
343356
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
357+
#endif
344358
}
345359
}
346360
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -359,7 +373,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
359373
}
360374
else if(data_type == ConvDataType::F32_F32_F32_TF32)
361375
{
376+
#if defined(__gfx9__)
362377
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
378+
#endif
363379
}
364380
}
365381
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -400,7 +416,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
400416
}
401417
else if(data_type == ConvDataType::F32_F32_F32_TF32)
402418
{
419+
#if defined(__gfx9__)
403420
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
421+
#endif
404422
}
405423
}
406424
// NGCDHW_GKCZYX_NGKDHW
@@ -421,7 +439,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
421439
}
422440
else if(data_type == ConvDataType::F32_F32_F32_TF32)
423441
{
442+
#if defined(__gfx9__)
424443
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
444+
#endif
425445
}
426446
}
427447

profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
105105
using F32 = float;
106106
using BF16 = ck::bhalf_t;
107107
using F16 = ck::half_t;
108+
#if defined(__gfx9__)
108109
using TF32 = ck::tf32_t;
110+
#endif
109111

110112
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
111113
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -170,7 +172,9 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
170172
}
171173
else if(data_type == ConvDataType::F32_F32_F32_TF32)
172174
{
175+
#if defined(__gfx9__)
173176
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
177+
#endif
174178
}
175179
}
176180
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -190,7 +194,9 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
190194
}
191195
else if(data_type == ConvDataType::F32_F32_F32_TF32)
192196
{
197+
#if defined(__gfx9__)
193198
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
199+
#endif
194200
}
195201
}
196202

profiler/src/profile_grouped_conv_fwd_clamp.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
105105
using F32 = float;
106106
using BF16 = ck::bhalf_t;
107107
using F16 = ck::half_t;
108+
#if defined(__gfx9__)
108109
using TF32 = ck::tf32_t;
110+
#endif
109111

110112
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
111113
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -173,7 +175,9 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
173175
}
174176
else if(data_type == ConvDataType::F32_F32_F32_TF32)
175177
{
178+
#if defined(__gfx9__)
176179
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
180+
#endif
177181
}
178182
}
179183
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -193,7 +197,9 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
193197
}
194198
else if(data_type == ConvDataType::F32_F32_F32_TF32)
195199
{
200+
#if defined(__gfx9__)
196201
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
202+
#endif
197203
}
198204
}
199205

0 commit comments

Comments
 (0)