From 22706ea0d4ee6eeaedc8030fc870838c7153653a Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 20 Nov 2025 15:50:26 +0000 Subject: [PATCH 1/7] Add tests for gemm_bias_add_reduce --- .../profile_gemm_bias_add_reduce_impl.hpp | 18 ++++-- test/CMakeLists.txt | 1 + test/gemm_bias_add_reduce/CMakeLists.txt | 4 ++ .../test_gemm_bias_add_reduce_fp16.cpp | 52 ++++++++++++++++ .../gemm_bias_add_reduce/test_gemm_common.hpp | 61 +++++++++++++++++++ 5 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 test/gemm_bias_add_reduce/CMakeLists.txt create mode 100644 test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp create mode 100644 test/gemm_bias_add_reduce/test_gemm_common.hpp diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index 1930cf9eb6..a150f4e808 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -63,7 +63,7 @@ template -void profile_gemm_bias_add_reduce_impl(int do_verification, +bool profile_gemm_bias_add_reduce_impl(int do_verification, int init_method, bool do_log, bool time_kernel, @@ -75,6 +75,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, int StrideC, int StrideD0) { + bool pass = true; + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { return HostTensorDescriptor({len}, {stride}); }; @@ -343,9 +345,13 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); - ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); - ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); - ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + if(!pass) + { + std::cout << gemm_ptr->GetTypeString() << " failed" << std::endl; + } if(do_log) { @@ -372,12 +378,14 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + return pass; } } // namespace profiler diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..8998da62ee 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -248,6 +248,7 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) +add_subdirectory(gemm_bias_add_reduce) add_subdirectory(gemm_blockscale_wp) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) diff --git a/test/gemm_bias_add_reduce/CMakeLists.txt b/test/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 0000000000..22061c6708 --- /dev/null +++ b/test/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance) +endif() diff --git a/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp new file mode 100644 index 0000000000..9ff60725ae --- /dev/null +++ b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBiasAddReduce_FP16_MK_NK + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + std::tuple< F16, F16, F16, F16, F16, F32> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_bias_add_reduce/test_gemm_common.hpp b/test/gemm_bias_add_reduce/test_gemm_common.hpp new file mode 100644 index 0000000000..9a6fe5651d --- /dev/null +++ b/test/gemm_bias_add_reduce/test_gemm_common.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_bias_add_reduce_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmBiasAddReduceCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + using BiasDataType = std::tuple_element_t<5, Tuple>; + using D0DataType = std::tuple_element_t<6, Tuple>; + using ReduceDataType = std::tuple_element_t<7, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void Run(const int M, const int N, const int K) + { + bool all_success = true; + + int StrideA = std::is_same_v, Row> ? K : M; + int StrideB = std::is_same_v, Row> ? N : K; + int StrideD0 = std::is_same_v, Row> ? N : M; + int StrideC = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_bias_add_reduce_impl( + verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, StrideD0); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck From 92a42df62a92e0465f98936ca32fed548255973a Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 21 Nov 2025 14:17:59 +0000 Subject: [PATCH 2/7] Initial working implementation --- ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 682 ++++++++++++++++++ .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 438 ++++++++++- 2 files changed, 1105 insertions(+), 15 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..c64a1d504d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,682 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_bias_add_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops, + const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = EpilogueType( + p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = d0_element_op; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 + : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using CDEShuffleBlockTransferScalarPerVectors = Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using ReduceTrait = ReduceTrait_; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + StrideC1_{StrideC1}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + index_t StrideC1_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.d0_element_op_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Even) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + else if(TailNum == TailNumber::Odd) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index 942d4351b3..4e898a1424 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -10,6 +10,7 @@ namespace ck { template + static auto GetReduceBlockDescriptor(const CShuffleBlockDesc& c_shuffle_block_desc) + { + return transform_tensor_descriptor( + c_shuffle_block_desc, + make_tuple(make_freeze_transform(I0), + make_pass_through_transform(c_shuffle_block_desc.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform(c_shuffle_block_desc.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + } + + // Specialization of CShuffle + Reduce epilogue without D matrices template + typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename enable_if::type = false> __device__ void Run(CThreadBuf& c_thread_buf, DsGridPointer p_ds_grid, EDataType* p_e_grid, @@ -240,19 +259,21 @@ struct EpilogueReduceCShuffle Number{})); // LDS c_reduce_block_desc_mperblock_nperblock - constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple( - make_freeze_transform(I0), - make_pass_through_transform( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( - I1)), - make_freeze_transform(I0), - make_pass_through_transform( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( - I3))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + constexpr auto c_reduce_block_desc_mperblock_nperblock = GetReduceBlockDescriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat); + // transform_tensor_descriptor( + // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + // make_tuple( + // make_freeze_transform(I0), + // make_pass_through_transform( + // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + // I1)), + // make_freeze_transform(I0), + // make_pass_through_transform( + // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + // I3))), + // make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + // make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); static_assert( ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * @@ -460,10 +481,397 @@ struct EpilogueReduceCShuffle }); } + // Specialization of CShuffle + Bias + Add + Reduce epilogue + // The Bias and Add are applied on the Ds matrices + template ::type = false> + __device__ void Run(CThreadBuf& c_thread_buf, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + auto reduce_grid_desc_mblock_mperblock = + MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // LDS buffer + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + // Thread transfer Vgpr to LDS + auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); + + // Space Filling Curve Vgpr + constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; + + // Space Filling Curve Vmem + constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; + + // Block descriptor + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + GetCShuffleLDSDescriptor(); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = + // GetReduceBlockDescriptor(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat); + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple(make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + static_assert( + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + BlockSize, + "wrong!"); + + static_assert( + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) == + 0 && + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1); + + static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size(); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = + make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{}, + Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + CShuffleDataType, + typename ReduceTrait::ReduceAccDataType_, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_, + ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_m_id, // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = + make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + remove_cvref_t>, + typename ReduceTrait::ReduceAccDataType_, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + remove_cvref_t>, + typename ReduceTrait::ReduceAccDataType_, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + EDataType, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + EGlobalMemoryDataOperation, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + // CShuffle and Store + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + // d0 / d1 operations + c0_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], + ds_grid_buf[I0], + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + typename ReduceTrait::ReduceAccDataType_ out; + cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], + ds_grid_buf[I1], + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + d0_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + static_for<0, NumReduce, 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = + remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = ReduceOperation::template GetIdentityValue< + typename ReduceTrait::ReduceAccDataType_>(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Number<0>{}], cde_global_step); + + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Number<1>{}], cde_global_step); + + // move on E + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step); + } + }); + } + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid; typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops; typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops; index_t MRaw; + typename ReduceTrait::D0ElementwiseOperation_ d0_element_op; ReduceGridDesc_M reduce_grid_desc_m; }; From e5d3cf0232b4d8729b899cc0a94d5b3b8bb649c6 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 21 Nov 2025 15:57:45 +0000 Subject: [PATCH 3/7] Generalize implementation of reduce epilogue --- .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 8 +- .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 496 ++++-------------- 2 files changed, 95 insertions(+), 409 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 0240fcb619..b64b72f4d4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -49,8 +49,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = - EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M); + auto epilogue_args = EpilogueType(p_reduces_grid, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -188,6 +191,7 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera using ReduceTrait = ReduceTrait_ - static auto GetReduceBlockDescriptor(const CShuffleBlockDesc& c_shuffle_block_desc) - { - return transform_tensor_descriptor( - c_shuffle_block_desc, - make_tuple(make_freeze_transform(I0), - make_pass_through_transform(c_shuffle_block_desc.GetLength(I1)), - make_freeze_transform(I0), - make_pass_through_transform(c_shuffle_block_desc.GetLength(I3))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); - } - - // Specialization of CShuffle + Reduce epilogue without D matrices - template ::type = false> - __device__ void Run(CThreadBuf& c_thread_buf, - DsGridPointer p_ds_grid, - EDataType* p_e_grid, - void* p_shared, - const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - e_grid_desc_mblock_mperblock_nblock_nperblock, - CDEElementwiseOperation& cde_element_op, - const index_t& block_m_id, - const index_t& block_n_id) - { - auto reduce_grid_desc_mblock_mperblock = - MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - // C mapping in single thread. - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - BlockwiseGemmPipe:: - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // LDS buffer - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize()); - - // Thread transfer Vgpr to LDS - auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); - - // Space Filling Curve Vgpr - constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; - - // Space Filling Curve Vmem - constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; - - // Block descriptor - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - GetCShuffleLDSDescriptor(); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // Thread transfer LDS to Vmem - auto cde_shuffle_block_copy_lds_and_global = - Base::template GetLDSToVmemEpilogueDescriptor( - c_ds_desc_refs, - e_grid_desc_mblock_mperblock_nblock_nperblock, - cde_element_op, - block_m_id, - block_n_id); - - // tuple of reference to C/Ds tensor buffers - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // LDS c_reduce_block_desc_mperblock_nperblock - constexpr auto c_reduce_block_desc_mperblock_nperblock = GetReduceBlockDescriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat); - // transform_tensor_descriptor( - // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - // make_tuple( - // make_freeze_transform(I0), - // make_pass_through_transform( - // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( - // I1)), - // make_freeze_transform(I0), - // make_pass_through_transform( - // c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( - // I3))), - // make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - // make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); - - static_assert( - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == - BlockSize, - "wrong!"); - - static_assert( - (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) % - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) == - 0 && - (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) % - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == - 0, - "wrong!"); - - constexpr index_t mreduce_per_thread = - (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) / - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0); - - constexpr index_t nreduce_per_thread = - (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) / - ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1); - - static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size(); - - constexpr auto c_reduce_thread_lengths_mperblock_nperblock = - Sequence{}; - - // VGPR c_reduce_thread_desc_mperblock_nperblock - constexpr auto c_reduce_thread_desc_mperblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - // VGPR reduce_thread_desc_mperblock - constexpr auto reduce_thread_desc_mperblock = - make_naive_tensor_descriptor_packed(make_tuple(Number{})); - - // VGPR reduce_thread_desc_mblock_mperblock - constexpr auto reduce_thread_desc_mblock_mperblock = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); - - auto c_reduce_thread_buf = - make_static_buffer( - c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); - - // reduce: threadwise copy from LDS to VGPR - constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( - typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{}, - Sequence<1, 0>{}); - - const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - - const auto c_reduce_thread_data_idx_begin = - c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; - - auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< - CShuffleDataType, - typename ReduceTrait::ReduceAccDataType_, - decltype(c_reduce_block_desc_mperblock_nperblock), - decltype(c_reduce_thread_desc_mperblock_nperblock), - decltype(c_reduce_thread_lengths_mperblock_nperblock), - Sequence<0, 1>, - 1, - ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, - 1, - true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; - - auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( - [&](auto I) { - auto p_reduce_grid = p_reduces_grid[I]; - auto reduce_acc_element_op = reduce_out_element_ops[I]; - - return ThreadwiseTensorSliceTransfer_v1r3< - typename ReduceTrait::ReduceAccDataType_, - remove_pointer_t, - decltype(reduce_thread_desc_mblock_mperblock), - decltype(reduce_grid_desc_mblock_mperblock), - decltype(reduce_acc_element_op), - Sequence<1, mreduce_per_thread>, - Sequence<0, 1>, - 1, - ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_, - ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I), - 1, - false>{reduce_grid_desc_mblock_mperblock, - make_multi_index(block_m_id, // mblock - c_reduce_thread_data_idx_begin[I0]), // mperblock - reduce_acc_element_op}; - }, - Number{}); - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); - - // CShuffle and Store - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block loads its C data from LDS, D from global, applies elementwise - // operation and stores result E to global - cde_shuffle_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - { - c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, - c_shuffle_block_buf, - c_reduce_thread_desc_mperblock_nperblock, - make_tuple(I0, I0), - c_reduce_thread_buf); - - static_for<0, NumReduce, 1>{}([&](auto In) { - auto& p_reduce_grid = p_reduces_grid[In]; - - auto reduce_grid_buf = make_dynamic_buffer( - p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); - - auto reduce_thread_buf = - make_static_buffer( - reduce_thread_desc_mperblock.GetElementSpaceSize()); - - auto& reduce_in_element_op = reduce_in_element_ops[In]; - - auto& reduce_thread_copy_vgpr_to_global = - reduce_tuple_thread_copy_vgpr_to_global(In); - - using ReduceOperation = - remove_cvref_t; - using ThreadwiseReduce = - ThreadwiseReduction; - - // Global write Gemm shuffle + reduction - const auto reduce_identityVal = ReduceOperation::template GetIdentityValue< - typename ReduceTrait::ReduceAccDataType_>(); - - static_for<0, mreduce_per_thread, 1>{}( - [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); - - // reduce in VGPR - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; - - reduce_in_element_op(c_reduce_thread_buf(offset), - c_reduce_thread_buf(offset)); - }); - }); - - ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); - - // copy from VGPR to Global - reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - reduce_thread_buf, - reduce_grid_desc_mblock_mperblock, - reduce_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id); - reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( - reduce_grid_desc_mblock_mperblock, - make_tuple(c_global_step[I0], c_global_step[I1])); - } - }); - } - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_global_step); - }); - - // move on E - cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); - } - }); - } - - // Specialization of CShuffle + Bias + Add + Reduce epilogue - // The Bias and Add are applied on the Ds matrices template ::type = false> + typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock> __device__ void Run(CThreadBuf& c_thread_buf, DsGridPointer p_ds_grid, EDataType* p_e_grid, @@ -551,20 +228,20 @@ struct EpilogueReduceCShuffle GetCShuffleLDSDescriptor(); // LDS c_reduce_block_desc_mperblock_nperblock - constexpr auto c_reduce_block_desc_mperblock_nperblock = - // GetReduceBlockDescriptor(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat); - transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple(make_freeze_transform(I0), - make_pass_through_transform( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetLength(I1)), - make_freeze_transform(I0), - make_pass_through_transform( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetLength(I3))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + static_assert( ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == @@ -657,54 +334,49 @@ struct EpilogueReduceCShuffle }, Number{}); - // c0 and c1 - constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + // multiple Ds + constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, I1, Number{})); - constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = - c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple( + [&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; }, + Number{}); + + constexpr auto ds_thread_buf_size = + d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); auto c01_thread_buf = make_static_buffer( - c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + Number{}); - auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< - remove_cvref_t>, - typename ReduceTrait::ReduceAccDataType_, - decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]), - decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), - Sequence, - Sequence<0, 1, 2, 3>, - 3, - ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, - 1, - true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], - make_multi_index(I0, - m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], - I0, - n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); - - auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< - remove_cvref_t>, - typename ReduceTrait::ReduceAccDataType_, - decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]), - decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), - Sequence, - Sequence<0, 1, 2, 3>, - 3, - ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, - 1, - true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], - make_multi_index(I0, - m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], - I0, - n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + return ThreadwiseTensorSliceTransfer_v2< + remove_cvref_t>, + typename ReduceTrait::ReduceAccDataType_, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + remove_cvref_t< + decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>, + Sequence, + Sequence<0, 1, 2, 3>, + 3, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + }, + Number{}); constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, I1, Number{})); + // Write E from Vgpr to Vmem auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< typename ReduceTrait::ReduceAccDataType_, EDataType, @@ -722,7 +394,7 @@ struct EpilogueReduceCShuffle m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], I0, n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), - tensor_operation::element_wise::PassThrough{}}; + NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op}; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -750,36 +422,45 @@ struct EpilogueReduceCShuffle make_tuple(I0, I0), c_reduce_thread_buf); - // d0 / d1 operations - c0_thread_copy_global_to_vgpr.Run( - ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], - ds_grid_buf[I0], - c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, - make_tuple(I0, I0, I0, I0), - c01_thread_buf); - - // c = activation(c + bias) - static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( - [&](auto i) { - typename ReduceTrait::ReduceAccDataType_ out; - cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); - c_reduce_thread_buf(i) = out; - }); + // Note: currently multiple Ds supports only Bias + Add. + // It needs to be generalized for other operations (currently not needed) + if constexpr(NumDTensor > 0) + { + auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0); + // d0 / d1 operations + d0_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], + ds_grid_buf[I0], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + typename ReduceTrait::ReduceAccDataType_ out; + cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); - c1_thread_copy_global_to_vgpr.Run( - ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], - ds_grid_buf[I1], - c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, - make_tuple(I0, I0, I0, I0), - c01_thread_buf); + auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1); - // c = c + c1_functior(c1) - static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( - [&](auto i) { - d0_element_op(c01_thread_buf(i), c01_thread_buf(i)); - c_reduce_thread_buf(i) += c01_thread_buf(i); - }); + d1_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], + ds_grid_buf[I1], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + d0_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + } + + // Write E c_reduce_thread_copy_vgpr_to_global.Run( c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, make_tuple(I0, I0, I0, I0), @@ -787,6 +468,7 @@ struct EpilogueReduceCShuffle e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_buf); + // Reduction static_for<0, NumReduce, 1>{}([&](auto In) { auto& p_reduce_grid = p_reduces_grid[In]; @@ -854,11 +536,11 @@ struct EpilogueReduceCShuffle { constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); // move on Ds - c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( - ds_grid_desc_mblock_mperblock_nblock_nperblock[Number<0>{}], cde_global_step); - - c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( - ds_grid_desc_mblock_mperblock_nblock_nperblock[Number<1>{}], cde_global_step); + static_for<0, NumDTensor, 1>{}([&](auto I) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step); + }); // move on E c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( From cdf40e8aa53828a728e6fb2fa96f4dbe3fcdcb62 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 21 Nov 2025 17:29:09 +0000 Subject: [PATCH 4/7] Add tests for all layouts --- .../test_gemm_bias_add_reduce_fp16.cpp | 60 ++++++++++++++++++- .../gemm_bias_add_reduce/test_gemm_common.hpp | 2 +- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp index 9ff60725ae..c0206e9218 100644 --- a/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp +++ b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -33,13 +33,37 @@ class TestGemmBiasAddReduce_FP16_MK_NK { }; +template +class TestGemmBiasAddReduce_FP16_MK_KN + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmBiasAddReduce_FP16_KM_KN + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmBiasAddReduce_FP16_KM_NK + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + // clang-format off -using KernelTypes_MK_NK = ::testing::Types< +using KernelTypes = ::testing::Types< std::tuple< F16, F16, F16, F16, F16, F32> >; // clang-format on -TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_NK, KernelTypes); TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular) { @@ -50,3 +74,33 @@ TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K); } + +TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 1024; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_KN, Regular) +{ + std::vector Ms{256}; + constexpr int N = 512; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_NK, Regular) +{ + std::vector Ms{256}; + constexpr int N = 1024; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_bias_add_reduce/test_gemm_common.hpp b/test/gemm_bias_add_reduce/test_gemm_common.hpp index 9a6fe5651d..fd13deb878 100644 --- a/test/gemm_bias_add_reduce/test_gemm_common.hpp +++ b/test/gemm_bias_add_reduce/test_gemm_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" From 0436ae63ddce5ae0c81bf06b503d4d83644dfda5 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 21 Nov 2025 17:31:40 +0000 Subject: [PATCH 5/7] Add instances --- .../device_gemm_mean_squaremean_instance.hpp | 41 +++++++++ .../gpu/gemm_bias_add_reduce/CMakeLists.txt | 7 +- ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 85 +++++++++++++++++++ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 84 ++++++++++++++++++ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 84 ++++++++++++++++++ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 81 ++++++++++++++++++ .../profile_gemm_bias_add_reduce_impl.hpp | 44 ++++++++++ 7 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp index 6d23cd8745..c448a51cfc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -19,6 +19,7 @@ namespace instance { using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>; +#if defined(CK_USE_XDL) void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( @@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f std::vector&); void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); +#endif // CK_USE_WMMA template ::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + op_ptrs); +#endif } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index a82e95d8d1..8be1dc6b45 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,10 +1,15 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp + + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c736fae147 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..a702503e7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 8, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 8, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 8, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e27cb9d630 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..a2d0e0ba9c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8, S<32, 2>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index a150f4e808..362cb278a8 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple; using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>; +#if defined(CK_USE_XDL) void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -45,6 +46,21 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); +#endif + +#if defined(CK_USE_WMMA) +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); +#endif } // namespace instance } // namespace device @@ -243,33 +259,61 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( gemm_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( gemm_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( gemm_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( gemm_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); +#endif } } From 60884a35bd9e1761e9a8628314c8bcc0398b0877 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 26 Nov 2025 08:41:30 +0000 Subject: [PATCH 6/7] Fix test archs --- test/gemm_bias_add_reduce/CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/gemm_bias_add_reduce/CMakeLists.txt b/test/gemm_bias_add_reduce/CMakeLists.txt index 22061c6708..d713a3e255 100644 --- a/test/gemm_bias_add_reduce/CMakeLists.txt +++ b/test/gemm_bias_add_reduce/CMakeLists.txt @@ -1,4 +1,6 @@ -add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance) +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance) + endif() endif() From 7cd06964456dbcb12a78996f4876b4f3a273e18a Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 26 Nov 2025 15:52:10 +0000 Subject: [PATCH 7/7] Fix xdl bug --- .../gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 64f50d13df..c168ca9d18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { + block_sync_lds(); + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id),