Skip to content
23 changes: 13 additions & 10 deletions test/unit/gemm/device/gemm_universal_lincomb_per_rowbias_eltact.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,24 @@ struct MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_GemmConfig {
using ElementOutput = float;
using ElementBias = float;

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// When left unspecified (void), MainloopXeL1Staged automatically selects
// appropriate 2D block copy operations
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;

using TileShape = Shape<_256, _256, _32>;

// Use XE_DPAS_TT atom for new API
using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>,
Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

constexpr static int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;

using EpilogueOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct<
using EpilogueOp = cutlass::epilogue::fusion::XeLinCombPerRowBiasEltAct<
cutlass::epilogue::thread::ReLu,
ElementOutput, ElementComputeEpilogue, ElementBias,
ElementAccumulator, ElementAccumulator,
Expand All @@ -81,15 +84,14 @@ struct MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_GemmConfig {
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
void, // Epilogue tile (void = automatic)
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U32x8x16_ST_N,
void, void>;
void, // Load copy atom (void = automatic)
void>; // Store copy atom (void = automatic)

using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
Expand Down Expand Up @@ -177,3 +179,4 @@ TEST(MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_NonParam, LargeKSmallMN) {
EXPECT_TRUE((test::gemm::device::TestXe<Gemm, cutlass::epilogue::thread::ReLu>(32, 32, 8192, 1, 1.0f, 0.0f)));
EXPECT_TRUE((test::gemm::device::TestXe<Gemm, cutlass::epilogue::thread::ReLu>(64, 64, 16384, 1, 1.0f, 0.0f)));
}

Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ struct MainloopIntelXeXMX16Group_GemmConfig {
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
using GmemTiledCopyA = void; //XE_2D_U16x32x32_LD_N
using GmemTiledCopyB = void; //XE_2D_U16x32x32_LD_V

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

using TiledMma =
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
TiledMMA<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, //XE_8x16x16_F32BF16BF16F32_TT
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;

constexpr static int PipelineStages = 2;
// Dispatch to grouped gemm algorithm
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group<PipelineStages>;
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
Expand Down Expand Up @@ -200,4 +200,4 @@ TEST(MainloopIntelXeXMX16Group_NonParam, LargeKSmallMN) {
{64, 64, 16384}
};
EXPECT_TRUE(test::gemm::device::TestXeGrouped<Gemm>(problem_sizes, 1.0f, 0.0f));
}
}