diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..2d31b6476b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,803 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + auto reduces_batch = p_reduces_grid; + compute_ptr_offset_of_batch.OffsetReducePtrs(g_idx, reduces_batch); + + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = + EpilogueType(reduces_batch, reduce_in_element_ops, reduce_out_element_ops, karg.M); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 + : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceBatchedGemmReduce_Wmma_CShuffleV3; + + static_assert(PermuteA == false, + "Permute A functionality not supported by DeviceBatchedGemm operations.\n"); + static_assert(PermuteB == false, + "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); + + using CDEShuffleBlockTransferScalarPerVectors = + Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using ReduceTrait = ReduceTrait_; + + static constexpr index_t NumReduce = ReduceOperations::Size(); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + long_index_t BatchStrideC, + std::array BatchStrideReduce) + : BatchStrideA_{BatchStrideA}, + BatchStrideB_{BatchStrideB}, + BatchStrideC_{BatchStrideC}, + BatchStrideReduce_{BatchStrideReduce} + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideA_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideB_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * BatchStrideC_; + } + + template + __host__ __device__ void OffsetReducePtrs(index_t g_idx, ReducePtrs& ptrs) const + { + static_for<0, NumReduce, 1>{}( + [&](auto I) { ptrs(I) = ptrs(I) + g_idx * BatchStrideReduce_[I.value]; }); + } + + private: + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + long_index_t BatchStrideC_; + std::array BatchStrideReduce_{}; + }; + + private: + static long_index_t ComputeABatchStride(index_t MRaw, index_t KRaw, index_t StrideA) + { + if constexpr(is_same_v) + { + return static_cast(MRaw) * StrideA; + } + else + { + return static_cast(KRaw) * StrideA; + } + } + + static long_index_t ComputeBBatchStride(index_t KRaw, index_t NRaw, index_t StrideB) + { + if constexpr(is_same_v) + { + return static_cast(KRaw) * StrideB; + } + else + { + return static_cast(NRaw) * StrideB; + } + } + + static long_index_t ComputeCBatchStride(index_t MRaw, index_t NRaw, index_t StrideC) + { + if constexpr(is_same_v) + { + return static_cast(MRaw) * StrideC; + } + else + { + return static_cast(NRaw) * StrideC; + } + } + + public: + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops, + std::array batch_stride_reduce) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + Batch_{Batch}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops}, + batch_stride_reduce_{batch_stride_reduce}, + compute_ptr_offset_of_batch_( + ComputePtrOffsetOfStridedBatch{ComputeABatchStride(MRaw, KRaw, StrideA), + ComputeBBatchStride(KRaw, NRaw, StrideB), + ComputeCBatchStride(MRaw, NRaw, StrideC), + batch_stride_reduce}) + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + index_t Batch_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + std::array batch_stride_reduce_{}; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideC + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + gdy *= arg.Batch_; + + float ave_time = 0; + + const index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.compute_ptr_offset_of_batch_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(tail_num == TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(tail_num == TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(tail_num == TailNumber::Even) + { + const auto kernel = kernel_batched_gemm_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + else if(tail_num == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() { return true; } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideC + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_e, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t Batch) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I.value]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I.value])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I.value])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + std::array batch_stride_reduce{}; + static_for<0, NumReduce, 1>{}( + [&](auto I) { batch_stride_reduce[I.value] = static_cast(M); }); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + Batch, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + batch_stride_reduce}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_e, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t Batch = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I.value]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I.value])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I.value])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + std::array batch_stride_reduce{}; + static_for<0, NumReduce, 1>{}( + [&](auto I) { batch_stride_reduce[I.value] = static_cast(M); }); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + Batch, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + batch_stride_reduce); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + str << "DeviceBatchedGemmReduce_Wmma_CShuffleV3" << "<" << BlockSize << ", " << MPerBlock + << ", " << NPerBlock << ", " << KPerBlock << ", " << AK1 << ", " << BK1 << ", " + << MPerWmma << ", " << NPerWmma << ", " << MRepeat << ", " << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " << CShuffleNRepeatPerShuffle << ">"; + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index a098a0a7e5..89626f1afa 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,10 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..e604c358cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, k, m] * b[g, k, n] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..6bd538750b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, k, m] * b[g, n, k] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..7a3d3d98e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, m, k] * b[g, k, n] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..7386ab3bf7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + // v1 Interwave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + // v3 Intrawave + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index c8564c120c..689657a505 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -39,7 +39,7 @@ using ReduceMemOp = ck::InMemoryDataOperationEnumSequence +inline constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else + { + return 1e-3; + } +} + +template +inline constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else + { + return 1e-3; + } +} + namespace ck { namespace tensor_operation { namespace device { @@ -33,6 +71,8 @@ using ReduceOutElementOps = ck::Tuple; using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; +#ifdef CK_ENABLE_FP16 +#ifdef CK_USE_XDL void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( std::vector&); @@ -44,6 +84,22 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( std::vector&); +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); +#endif // CK_USE_WMMA +#endif // CK_ENABLE_FP16 } // namespace instance } // namespace device @@ -210,6 +266,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, // add device GEMM instances std::vector gemm_ptrs; +#ifdef CK_ENABLE_FP16 if constexpr(is_same::value && is_same::value && is_same::value) { @@ -217,35 +274,64 @@ bool profile_batched_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); +#endif } } +#endif // CK_ENABLE_FP16 if(gemm_ptrs.size() <= 0) { @@ -318,9 +404,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); - bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); - bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result); - bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result); + bool c_error = ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + bool d0_error = ck::utils::check_err(d0_g_m_device_result, + d0_g_m_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + bool d1_error = ck::utils::check_err(d1_g_m_device_result, + d1_g_m_host_result, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); pass = pass && (c_error == true); pass = pass && (d0_error == true); diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 4348c4b536..f1af78f05d 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -4,4 +4,9 @@ add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - endif() +endif() + +add_test_executable(test_batched_gemm_reduce_fp16_wmma batched_gemm_reduce_fp16_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16_wmma PRIVATE utility device_batched_gemm_reduce_instance) +endif() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16_wmma.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_wmma.cpp new file mode 100644 index 0000000000..ef6ade65f7 --- /dev/null +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_wmma.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "profiler/profile_batched_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + int BatchCount = 3; + + bool pass = true; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N, BatchCount); + + if(pass) + { + std::cout << "test BatchedGEMM+Reduce fp16 WMMA: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test BatchedGEMM+Reduce fp16 WMMA: Fail" << std::endl; + return -1; + } +}