diff --git a/CHANGELOG.md b/CHANGELOG.md index b07e322fe1..7bca1af2af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -156,7 +156,7 @@ None * New CMake flags: * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances - * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + * "DTYPES" -- Can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build an instance of the specified data types * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler * New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler * Support for MI300A/MI300X diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c4d79f9..acae1f5ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + # definition will be added based on the GPU target in the following section + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -106,6 +110,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_FP8 "ON") @@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") set(CK_GFX950_SUPPORT "ON") endif() +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") +else() + message(STATUS "Disabling TF32 instances") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() diff --git a/README.md b/README.md index 01d523c2ab..8a5258bab6 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb Additional cmake flags can be used to significantly speed-up the build: -* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 2ed338d08a..cab84f5c6c 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -27,6 +27,9 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -41,6 +44,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") if (GPU_TARGETS MATCHES "gfx94") @@ -67,6 +71,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() + if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") + else() + message(STATUS "Disabling TF32 instances for this target") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 306a6c2ff1..113bf99243 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,6 +55,11 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#define CK_ENABLE_TF32 "ON" +#endif +#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -85,6 +90,12 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ +#endif +#endif + #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 03e3ae88a3..89009c6d0b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index cd65a2285a..84a715b70a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in PassThrough, PassThrough, Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, "ComputeTypeA and ComputeTypeB must be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index 36980e5935..c898dbf781 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta PassThrough, PassThrough, Scale>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, " only support same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e677f6f848..3fe8fa9c5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 448a6b5d51..a0e8e46570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index acf9c9e150..64bbdf6ec5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ba2f6b921a..5089ea2c1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same!"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); @@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( @@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( @@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 46bc0d2320..d4729f4d13 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory && @@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory && @@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 90852d2945..090c99819f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory && @@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 0) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index eeaf269394..ef037526ca 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_f16") elseif(type MATCHES "fp32") set(type1 "_f32") + elseif(type MATCHES "tf32") + set(type1 "_tf32") elseif(type MATCHES "fp8") set(type1 "_f8") elseif(type MATCHES "bf16") @@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME) #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR - source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") - message(DEBUG "removing tf32 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + if(source_name MATCHES "_tf32_") + if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) + message(DEBUG "removing tf32 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() endforeach() @@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "fp32 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + message(DEBUG "tf32 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message(DEBUG "fp64 instance found!") set(add_inst 1) @@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list}) "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_bf16" OR @@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list}) list(APPEND CK_DEVICE_OTHER_INSTANCES $) endif() message(DEBUG "add_instance_directory ${subdir_path}") - endif() + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62d6e860f9..cbf763fc13 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } @@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index a18aab41a5..c4f154e180 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index c94b77dd4f..4319d849c8 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; -#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NHWGC_GKYXC_NHWGK @@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NGCDHW_GKCZYX_NGKDHW @@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 4eb12e6e19..79b9beb8c7 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 7df9fd6167..f497ee8da5 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..c221f11f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() @@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif()