diff --git a/test/unit/gemm/device/gemm_universal_lincomb_per_rowbias_eltact.cpp b/test/unit/gemm/device/gemm_universal_lincomb_per_rowbias_eltact.cpp index 559f734b53..1168217226 100644 --- a/test/unit/gemm/device/gemm_universal_lincomb_per_rowbias_eltact.cpp +++ b/test/unit/gemm/device/gemm_universal_lincomb_per_rowbias_eltact.cpp @@ -53,21 +53,24 @@ struct MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_GemmConfig { using ElementOutput = float; using ElementBias = float; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; using TileShape = Shape<_256, _256, _32>; + // Use XE_DPAS_TT atom for new API using TiledMma = - typename TiledMMAHelper, + typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr static int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; - using EpilogueOp = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + using EpilogueOp = cutlass::epilogue::fusion::XeLinCombPerRowBiasEltAct< cutlass::epilogue::thread::ReLu, ElementOutput, ElementComputeEpilogue, ElementBias, ElementAccumulator, ElementAccumulator, @@ -81,15 +84,14 @@ struct MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_GemmConfig { using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, // Epilogue tile (void = automatic) ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + void, // Load copy atom (void = automatic) + void>; // Store copy atom (void = automatic) using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< GEMMDispatchPolicy, @@ -177,3 +179,4 @@ TEST(MainloopIntelXeXMX16_LinCombPerRowBiasEltAct_NonParam, LargeKSmallMN) { EXPECT_TRUE((test::gemm::device::TestXe(32, 32, 8192, 1, 1.0f, 0.0f))); EXPECT_TRUE((test::gemm::device::TestXe(64, 64, 16384, 1, 1.0f, 0.0f))); } + diff --git a/test/unit/gemm/device/gemm_universal_mainloopintelxexmx16group.cpp b/test/unit/gemm/device/gemm_universal_mainloopintelxexmx16group.cpp index e68ade6dd1..1d66f155ef 100644 --- a/test/unit/gemm/device/gemm_universal_mainloopintelxexmx16group.cpp +++ b/test/unit/gemm/device/gemm_universal_mainloopintelxexmx16group.cpp @@ -57,21 +57,21 @@ struct MainloopIntelXeXMX16Group_GemmConfig { using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; //XE_2D_U16x32x32_LD_N + using GmemTiledCopyB = void; //XE_2D_U16x32x32_LD_V // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; using TiledMma = - TiledMMA, + TiledMMA>, //XE_8x16x16_F32BF16BF16F32_TT Layout, Stride<_4, _1, _0>>, Tile, Stride<_1, _32, _8>>, Layout, Stride<_1, _64, _16>>, _32>>; constexpr static int PipelineStages = 2; // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination(problem_sizes, 1.0f, 0.0f)); -} +} \ No newline at end of file