@@ -105,7 +105,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
105105 using INT8 = int8_t ;
106106 using F8 = ck::f8_t ;
107107 using BF8 = ck::bf8_t ;
108+ #if defined(__gfx9__)
108109 using TF32 = ck::tf32_t ;
110+ #endif
109111
110112 //
111113 using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -226,7 +228,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
226228 }
227229 else if (data_type == ConvDataType::F32_F32_F32_TF32)
228230 {
231+ #if defined(__gfx9__)
229232 return profile (I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
233+ #endif
230234 }
231235 }
232236 else if (num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -249,7 +253,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
249253 }
250254 else if (data_type == ConvDataType::F32_F32_F32_TF32)
251255 {
256+ #if defined(__gfx9__)
252257 return profile (I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
258+ #endif
253259 }
254260 }
255261 else if (num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -274,7 +280,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
274280 }
275281 else if (data_type == ConvDataType::F32_F32_F32_TF32)
276282 {
283+ #if defined(__gfx9__)
277284 return profile (I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
285+ #endif
278286 }
279287 }
280288 // NHWGC_GKYXC_NHWGK
@@ -298,7 +306,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
298306 }
299307 else if (data_type == ConvDataType::F32_F32_F32_TF32)
300308 {
309+ #if defined(__gfx9__)
301310 return profile (I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
311+ #endif
302312 }
303313 }
304314 else if (num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -321,7 +331,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
321331 }
322332 else if (data_type == ConvDataType::F32_F32_F32_TF32)
323333 {
334+ #if defined(__gfx9__)
324335 return profile (I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
336+ #endif
325337 }
326338 }
327339 else if (num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -340,7 +352,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
340352 }
341353 else if (data_type == ConvDataType::F32_F32_F32_TF32)
342354 {
355+ #if defined(__gfx9__)
343356 return profile (I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
357+ #endif
344358 }
345359 }
346360 else if (num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -359,7 +373,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
359373 }
360374 else if (data_type == ConvDataType::F32_F32_F32_TF32)
361375 {
376+ #if defined(__gfx9__)
362377 return profile (I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
378+ #endif
363379 }
364380 }
365381 else if (num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -400,7 +416,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
400416 }
401417 else if (data_type == ConvDataType::F32_F32_F32_TF32)
402418 {
419+ #if defined(__gfx9__)
403420 return profile (I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
421+ #endif
404422 }
405423 }
406424 // NGCDHW_GKCZYX_NGKDHW
@@ -421,7 +439,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
421439 }
422440 else if (data_type == ConvDataType::F32_F32_F32_TF32)
423441 {
442+ #if defined(__gfx9__)
424443 return profile (I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
444+ #endif
425445 }
426446 }
427447
0 commit comments