Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)

auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);

auto epilogue_args =
EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M);
auto epilogue_args = EpilogueType(p_reduces_grid,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});

GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
Expand Down Expand Up @@ -188,6 +191,7 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera

using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,
tensor_operation::element_wise::PassThrough,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace ck {

template <typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
Expand All @@ -21,6 +22,7 @@ struct ReduceTrait_
{
using ReduceAccDataType_ = ReduceAccDataType;
using ReducePtrsGlobal_ = ReducePtrsGlobal;
using D0ElementwiseOperation_ = D0ElementwiseOperation;
using ReduceOperations_ = ReduceOperations;
using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
Expand Down Expand Up @@ -148,11 +150,13 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
const index_t MRaw_)
const index_t MRaw_,
const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
: p_reduces_grid(p_reduces_grid_),
reduce_in_element_ops(reduce_in_element_ops_),
reduce_out_element_ops(reduce_out_element_ops_),
MRaw(MRaw_),
d0_element_op{d0_element_op_},
reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)}
{
}
Expand All @@ -174,6 +178,13 @@ struct EpilogueReduceCShuffle
const index_t& block_m_id,
const index_t& block_n_id)
{
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);

const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);

auto reduce_grid_desc_mblock_mperblock =
MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m);

Expand Down Expand Up @@ -216,29 +227,6 @@ struct EpilogueReduceCShuffle
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();

// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));

// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);

// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));

// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
Expand Down Expand Up @@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle
},
Number<NumReduce>{});

// multiple Ds
constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));

constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple(
[&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
Number<NumDTensor>{});

constexpr auto ds_thread_buf_size =
d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();

auto c01_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
Number<ds_thread_buf_size>{});

auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
return ThreadwiseTensorSliceTransfer_v2<
remove_cvref_t<tuple_element_t<I.value, DsDataType>>,
typename ReduceTrait::ReduceAccDataType_,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
remove_cvref_t<
decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
1,
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});

constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));

// Write E from Vgpr to Vmem
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
typename ReduceTrait::ReduceAccDataType_,
EDataType,
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
EGlobalMemoryDataOperation,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op};

constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();

static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
Expand All @@ -365,22 +415,60 @@ struct EpilogueReduceCShuffle

// make sure it's safe to read from LDS
block_sync_lds();

// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));

{
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);

// Note: currently multiple Ds supports only Bias + Add.
// It needs to be generalized for other operations (currently not needed)
if constexpr(NumDTensor > 0)
{
auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0);
// d0 / d1 operations
d0_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);

// c = activation(c + bias)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
typename ReduceTrait::ReduceAccDataType_ out;
cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
c_reduce_thread_buf(i) = out;
});

auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1);

d1_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);

// c = c + c1_functior(c1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo

static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
d0_element_op(c01_thread_buf(i), c01_thread_buf(i));
c_reduce_thread_buf(i) += c01_thread_buf(i);
});
}

// Write E
c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c_reduce_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);

// Reduction
static_for<0, NumReduce, 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];

Expand Down Expand Up @@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
static_for<0, NumDTensor, 1>{}([&](auto I) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
});

// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
}
});
}
Expand All @@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
index_t MRaw;
typename ReduceTrait::D0ElementwiseOperation_ d0_element_op;
ReduceGridDesc_M reduce_grid_desc_m;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");

static_for<0, num_access, 1>{}([&](auto access_id) {
block_sync_lds();

// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace instance {

using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>;

#if defined(CK_USE_XDL)
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
Expand All @@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
#endif // CK_USE_XDL

#if defined(CK_USE_WMMA)
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
#endif // CK_USE_WMMA

template <typename ADataType,
typename BDataType,
Expand All @@ -45,33 +58,61 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
#endif
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_bias_add_reduce_instance
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp

device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
Loading