Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperat

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// c[g, m, n] = a[g, m, k] * b[g, n, k]
// c[g, m, n] = a[g, k, m] * b[g, k, n]
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances =
std::tuple<
// clang-format off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperat

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// c[g, m, n] = a[g, m, k] * b[g, n, k]
// c[g, m, n] = a[g, k, m] * b[g, n, k]
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances =
std::tuple<
// clang-format off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperat

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// c[g, m, n] = a[g, m, k] * b[g, n, k]
// c[g, m, n] = a[g, m, k] * b[g, k, n]
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances =
std::tuple<
// clang-format off
Expand Down
104 changes: 101 additions & 3 deletions profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,44 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"

template <typename DataType>
inline constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else
{
return 1e-3;
}
}

template <typename DataType>
inline constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else
{
return 1e-3;
}
}

namespace ck {
namespace tensor_operation {
namespace device {
Expand All @@ -33,6 +71,8 @@ using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;

#ifdef CK_ENABLE_FP16
#ifdef CK_USE_XDL
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

Expand All @@ -44,6 +84,22 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn

void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif // CK_USE_XDL

#ifdef CK_USE_WMMA
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif // CK_USE_WMMA
#endif // CK_ENABLE_FP16

} // namespace instance
} // namespace device
Expand Down Expand Up @@ -210,42 +266,72 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
// add device GEMM instances
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;

#ifdef CK_ENABLE_FP16
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs);
#endif
}
}
#endif // CK_ENABLE_FP16

if(gemm_ptrs.size() <= 0)
{
Expand Down Expand Up @@ -318,9 +404,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());

bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
bool c_error = ck::utils::check_err(c_g_m_n_device_result,
c_g_m_n_host_result,
"Error: Device and Host results do not match!",
get_rtol<CDataType>(),
get_atol<CDataType>());
bool d0_error = ck::utils::check_err(d0_g_m_device_result,
d0_g_m_host_result,
"Error: Device and Host results do not match!",
get_rtol<ReduceDataType>(),
get_atol<ReduceDataType>());
bool d1_error = ck::utils::check_err(d1_g_m_device_result,
d1_g_m_host_result,
"Error: Device and Host results do not match!",
get_rtol<ReduceDataType>(),
get_atol<ReduceDataType>());

pass = pass && (c_error == true);
pass = pass && (d0_error == true);
Expand Down
7 changes: 6 additions & 1 deletion test/batched_gemm_reduce/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
endif()
endif()

add_test_executable(test_batched_gemm_reduce_fp16_wmma batched_gemm_reduce_fp16_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16_wmma PRIVATE utility device_batched_gemm_reduce_instance)
endif()
67 changes: 67 additions & 0 deletions test/batched_gemm_reduce/batched_gemm_reduce_fp16_wmma.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this use a GTest suite like the other tests?

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <iostream>

#include "profiler/profile_batched_gemm_reduce_impl.hpp"

int main()
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

int M = 512;
int N = 256;
int K = 128;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about what typical use cases are for these instances, but is it worth checking irregular/small sizes to make sure all the edgecases work properly?


int BatchCount = 3;

bool pass = true;

pass = pass && ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
Row,
Row,
Row>(
true, 1, false, false, M, N, K, K, N, N, BatchCount);

pass = pass && ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
Row,
Col,
Row>(
true, 1, false, false, M, N, K, K, K, N, BatchCount);

pass = pass && ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
Col,
Row,
Row>(
true, 1, false, false, M, N, K, M, N, N, BatchCount);

pass = pass && ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
Col,
Col,
Row>(
true, 1, false, false, M, N, K, M, K, N, BatchCount);

if(pass)
{
std::cout << "test BatchedGEMM+Reduce fp16 WMMA: Pass" << std::endl;
return 0;
}
else
{
std::cout << "test BatchedGEMM+Reduce fp16 WMMA: Fail" << std::endl;
return -1;
}
}