-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
debug the test case as below the evtdag on our platform
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_f32_tma_tma_warpspecialized_EVTDAG, 128x128x64) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_128,_128,_64>;
using ClusterShape_MNK = Shape<_1,_1,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
//C,D
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape_MNK, EpilogueTileType, float, float, EpilogueSchedule>;
//bias
using AuxLoadDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t>;
//ElementOutput(C),ElementCompute(ACC), ElementScalar = ElementCompute(ACC)
using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombEVTDAG<
EpilogueDescriptor, AuxLoadDescriptor, float, float, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
EpilogueTileType,
float, float,
float, LayoutC, 4,
float, LayoutC, 4,
EpilogueSchedule,
FusionCallbacks
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelTmaMmaWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
bool passed = test::gemm::device::TestAllEVT<Gemm, test::gemm::device::HostEVTDAG>();
EXPECT_TRUE(passed);
}
but the result is not correct
Steps/Code to reproduce bug
-
we test for tensor (8, 8, 8) use cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp file
modify the function param for debugbool TestAllEVT(bool check_relative_equality = false) {
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
#if !STANDARD_TEST
std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment};
std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment};if constexpr (cute::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
problem_size_m.push_back(768);
problem_size_n.push_back(768);
}#else
std::vector problem_size_m = {8};//{max_alignment};
std::vector problem_size_n = {8};//{max_alignment};#endif
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});#if !STANDARD_TEST
std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
#else
std::vector problem_size_k = {8};//{max_alignment};
#endif -
the step check for the every op result
2.1 the result corrct
ACC - gpu side
-8.00 2.00 17.00 6.00
7.00 -20.00 2.00 20.00
8.00 6.00 3.00 15.00
-11.00 18.00 -6.00 -19.00
-14.00 15.00 -8.00 4.00
4.00 -27.00 -4.00 11.00
12.00 12.00 9.00 -2.00
18.00 -25.00 -16.00 22.00
16.00 3.00 11.00 16.00
38.00 1.00 29.00 19.00
10.00 6.00 13.00 28.00
-12.00 2.00 -18.00 -10.00
2.00 9.00 -10.00 14.00
31.00 -11.00 28.00 14.00
-29.00 5.00 -16.00 -14.00
0.00 5.00 9.00 -5.00ACC - cpu side First 8x8 ACC values at (m=0,n=0,l=0): -8.00 2.00 17.00 6.00 16.00 3.00 11.00 16.00 7.00 -20.00 2.00 20.00 38.00 1.00 29.00 19.00 8.00 6.00 3.00 15.00 10.00 6.00 13.00 28.00 -11.00 18.00 -6.00 -19.00 -12.00 2.00 -18.00 -10.00 -14.00 15.00 -8.00 4.00 2.00 9.00 -10.00 14.00 4.00 -27.00 -4.00 11.00 31.00 -11.00 28.00 14.00 12.00 12.00 9.00 -2.00 -29.00 5.00 -16.00 -14.00 18.00 -25.00 -16.00 22.00 0.00 5.00 9.00 -5.00
2.2 c mat
C mat - gpu side
0.00 8.00 -4.51 340.00
8.66 -1417.00 -0.00 2.46
-2.23 -0.00 0.00 -0.00
-2186.00 -0.00 -550.50 -0.00
9.69 nan -0.00 -68.31
0.00 -0.00 -34976.00 -0.00
-2048.00 -0.06 0.00 -0.00
38.22 -0.00 -19.05 -19.19
-19.19 -19.19 -9288.00 -0.00
8.91 -0.00 -8872.00 -0.00
-35232.00 -0.00 -143.62 5.54
9.66 nan 0.00 -0.00
-8744.00 -0.00 0.00 -0.00
0.00 0.59 -2.23 -0.00
-8.98 -2048.00 -0.06 0.00
0.00 0.00 0.00 0.00
C mat - cpu side
C matrix shape: (8, 8, 1)
First 8x8 values:
0 1 2 3 4 5 6 7
0: 1.00 3.00 4.00 -2.00 4.00 -4.00 4.00 3.00
1: -1.00 -1.00 -1.00 2.00 -1.00 1.00 2.00 1.00
2: 2.00 3.00 1.00 0.00 2.00 2.00 2.00 0.00
3: -3.00 3.00 -3.00 -2.00 -4.00 -3.00 -1.00 -1.00
4: -4.00 -2.00 1.00 -4.00 -1.00 1.00 3.00 2.00
5: 4.00 -2.00 -1.00 -1.00 3.00 -3.00 -4.00 1.00
6: -4.00 1.00 -3.00 2.00 -1.00 4.00 -3.00 0.00
7: 3.00 -2.00 2.00 3.00 0.00 -3.00 -2.00 0.00
the address is from
75d7ab9fe014 to 0x612605879d00
- there is a memcpy form host to device as below
class Testbed3xEVT {
//run function
EVTModule host_reference(problem_size, check_relative_equality, 2024);
/// Initialize the epilogue arguments
arguments = typename Gemm::Arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{
impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a,
impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b
},
{ // Epilogue arguments
{}, // thread
static_cast<ElementC*>(host_reference.get_tensor_C_ptr()),
impl_.collective_epilogue.stride_c,
static_cast<ElementD*>(host_reference.get_tensor_D_ptr()),
impl_.collective_epilogue.stride_d
}, // Epilogue arguments end
hw_info,
scheduler_args
};
std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg));
....
}
but I find the size epilogue_args.arg is not equal to arguments.epilogue.thread , the epilogue_args.arg is 96B,and the arguments.epilogue.thread= 104B , which reduce the copy will override the C mat address
our test is depend on the cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp
struct Arguments {
typename FusionCallbacks::Arguments thread{};
ElementC const* ptr_C;
StrideC dC;
ElementD const* ptr_D;
StrideD dD;
};
in this struct the after the thread{} param is the c mat address, but the host param has 106 override the address
-
our modification plan
4.1 device side
struct Arguments {
typename FusionCallbacks::Arguments thread{};
using PaddingType = std::array<uint8_t, 32>;
PaddingType padding{}; //padding
ElementC const* ptr_C;
StrideC dC;
ElementD const* ptr_D;
StrideD dD;
};4.2 host side
cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpparguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b }, { // Epilogue arguments {}, // thread #if defined (__AICA__) || defined (__AICA_FOR_MODIFY__) {}, #endif static_cast<ElementC*>(host_reference.get_tensor_C_ptr()), impl_.collective_epilogue.stride_c, static_cast<ElementD*>(host_reference.get_tensor_D_ptr()), impl_.collective_epilogue.stride_d }, // Epilogue arguments end hw_info, scheduler_args };
Expected behavior
A clear and concise description of what you expected to happen.
Environment details (please complete the following information):
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
Additional context
Add any other context about the problem here