Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ None

* New CMake flags:
* "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances
* "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types
* "DTYPES" -- Can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build an instance of the specified data types
* "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler
* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler
* Support for MI300A/MI300X
Expand Down
17 changes: 17 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
# definition will be added based on the GPU target in the following section
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
Expand All @@ -104,6 +108,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
Expand Down Expand Up @@ -280,6 +285,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
set(CK_GFX950_SUPPORT "ON")
endif()

if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()

option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
Expand Down Expand Up @@ -649,6 +663,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb

Additional cmake flags can be used to significantly speed-up the build:

* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
other data types.

Expand Down
12 changes: 12 additions & 0 deletions client_example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
Expand All @@ -41,6 +44,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
Expand All @@ -67,6 +71,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances for this target")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")
Expand Down
11 changes: 11 additions & 0 deletions include/ck/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#define CK_ENABLE_TF32 "ON"
#endif
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
Expand Down Expand Up @@ -85,6 +90,12 @@
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif

#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
#endif
#endif

#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
Expand All @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
op_ptrs);
Expand All @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
Expand Down Expand Up @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
Expand All @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
Expand All @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif

#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
Expand Down Expand Up @@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"ComputeTypeA and ComputeTypeB must be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta
PassThrough,
PassThrough,
Scale>>>& instances);
#endif

#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
Expand Down Expand Up @@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
" only support same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
Expand All @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
Expand All @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
Expand Down Expand Up @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
Expand All @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
Expand All @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_
PassThrough,
Bilinear,
PassThrough>>>& instances);
#endif

#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
Expand Down Expand Up @@ -151,24 +154,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
Expand Down
Loading