Skip to content

Commit 583fafc

Browse files
arai713ThomasNing
andauthored
[CK_TILE] Fix for Moving DataTypeTraits into a Common File (#3335)
This PR fixes a mismatch caused when PR #3146 was merged out of sync with develop, which made its intended changes ineffective. This PR reapplies those changes to move DataTypeTraits into a common file to mitigate code duplication. Co-authored-by: Thomas Ning <[email protected]>
1 parent ffc3120 commit 583fafc

File tree

3 files changed

+11
-68
lines changed

3 files changed

+11
-68
lines changed

tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
8080
{
8181
// Use DataTypeTraits to get the actual type names from the generated header
8282
// The generated header defines ADataType, BDataType, AccDataType, CDataType
83-
std::string dtype_a = DataTypeTraits<ADataType>::name;
84-
std::string dtype_b = DataTypeTraits<BDataType>::name;
85-
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
86-
std::string dtype_c = DataTypeTraits<CDataType>::name;
87-
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
88-
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
83+
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
84+
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
85+
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
86+
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
87+
std::string dtype_d0 = ck_tile::DataTypeTraits<D0DataType>::name;
88+
std::string dtype_d1 = ck_tile::DataTypeTraits<D1DataType>::name;
8989

9090
// Layout names from the layout types
9191
std::string layout_a = ALayout::name;

tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
8383
{
8484
// Use DataTypeTraits to get the actual type names from the generated header
8585
// The generated header defines ADataType, BDataType, AccDataType, CDataType
86-
std::string dtype_a = DataTypeTraits<ADataType>::name;
87-
std::string dtype_b = DataTypeTraits<BDataType>::name;
88-
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
89-
std::string dtype_c = DataTypeTraits<CDataType>::name;
86+
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
87+
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
88+
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
89+
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
9090

9191
// Layout names from the layout types
9292
std::string layout_a = ALayout::name;

tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,67 +6,10 @@
66
#include <string>
77
#include "ck_tile/core.hpp"
88
#include "ck_tile/host.hpp"
9+
#include "ck_tile/ops/common/utils.hpp"
910
#include "ck_tile/core/numeric/integer.hpp"
1011
#include "ck_tile/core/numeric/pk_int4.hpp"
1112

12-
// DataTypeTraits for all supported types
13-
template <typename T>
14-
struct DataTypeTraits;
15-
16-
template <>
17-
struct DataTypeTraits<float>
18-
{
19-
static constexpr const char* name = "fp32";
20-
};
21-
22-
template <>
23-
struct DataTypeTraits<double>
24-
{
25-
static constexpr const char* name = "fp64";
26-
};
27-
28-
template <>
29-
struct DataTypeTraits<ck_tile::half_t>
30-
{
31-
static constexpr const char* name = "fp16";
32-
};
33-
34-
template <>
35-
struct DataTypeTraits<ck_tile::bf16_t>
36-
{
37-
static constexpr const char* name = "bf16";
38-
};
39-
40-
template <>
41-
struct DataTypeTraits<ck_tile::fp8_t>
42-
{
43-
static constexpr const char* name = "fp8";
44-
};
45-
46-
template <>
47-
struct DataTypeTraits<ck_tile::bf8_t>
48-
{
49-
static constexpr const char* name = "bf8";
50-
};
51-
52-
template <>
53-
struct DataTypeTraits<ck_tile::int8_t>
54-
{
55-
static constexpr const char* name = "int8";
56-
};
57-
58-
template <>
59-
struct DataTypeTraits<ck_tile::int32_t>
60-
{
61-
static constexpr const char* name = "int32";
62-
};
63-
64-
template <>
65-
struct DataTypeTraits<ck_tile::pk_int4_t>
66-
{
67-
static constexpr const char* name = "pk_int4_t";
68-
};
69-
7013
// Helper function to determine if a layout is row-major
7114
template <typename Layout>
7215
constexpr auto is_row_major(Layout)

0 commit comments

Comments
 (0)