Skip to content

[BUG] HostRowReduce not init #2861

@muyudy

Description

@muyudy

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
the result is not correct for our testcase
eg:
Row Reduce Reference =
8.00 -24.00 0.00 62.00 52.00 26.00 41.00 55.00

Row Reduce Computed = 错误的
8.00 8.00 32.00 29.81 51.99 -6.06 43.00 52.99

Steps/Code to reproduce bug

The process of debugging

  1. use the cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp, modify the the tensor size for debug

bool TestAllEVT(bool check_relative_equality = false) {
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
#if !STANDARD_TEST
std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment};
std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment};

if constexpr (cute::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
problem_size_m.push_back(768);
problem_size_n.push_back(768);
}

#else
std::vector problem_size_m = {8};//{max_alignment};
std::vector problem_size_n = {8};//{max_alignment};

#endif

constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});

#if !STANDARD_TEST
std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
#else
std::vector problem_size_k = {8};//{max_alignment};
#endif

  1. add the debug information for sm90_visitor_store_tma_warpspecialized.hpp

in the cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp file,
debug the struct Sm90RowReduction function reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results)

  //
  // 2. Atomic reduction
  //
  if constexpr (IsAtomic) {
    // Filter so we don't issue redunant copies over stride-0 modes
    Tensor tCrRow_flt = filter_zeros(tCrRow);
    Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCrRow_flt.stride()));


    auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN);

    Tensor tCgRow = sm90_partition_for_epilogue<ReferenceSrc>(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx);
    Tensor tCgRow_flt = filter_zeros(tCgRow);

......

   // NOTE: atomic reduction is performed in the output type
    using ConvertOutput = NumericConverter<ElementOutput, ElementCompute, RoundStyle>;
    using ReduceOutput = GmemReduceFn<ElementOutput>;
    ConvertOutput convert_output{};
    ReduceOutput reduce_output{};

    if constexpr (SwapShuffle) {
      CUTLASS_PRAGMA_UNROLL
      //有4个FltFrgSizePerLaneM
      for (int i = 0; i < FltFrgSizePerLaneM; ++i) {
        int idx = lane_m * FltFrgSizePerLaneM + i;
        // Only care about OOB for N mode
       // if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) 
       if (get<1>(tCcRow(idx)) <  N) 
        {
            // before atom op printf
            if (threadIdx.x >= 256 && threadIdx.x < 260) {
                printf("PERFORMING ATOMIC REDUCTION:\n");
                 printf("  Address: %p, idx= %d\n", &tCgRow_flt(idx), idx);
                printf("  Old value: ");
                if constexpr (cute::is_same_v<ElementOutput, float>) {
                    printf("%.2f, idx= %d\n", float(tCgRow_flt(idx)), idx);
                }
                printf("  Adding value: ");
                ElementOutput converted = convert_output(tCrRow_flt(i));
                if constexpr (cute::is_same_v<ElementOutput, float>) {
                    printf("%.2f\n", float(converted));
                }
            }

          reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i)));
            // after atom op printf
            if (threadIdx.x >= 256 && threadIdx.x < 260) {
                printf("  New value: ");
                if constexpr (cute::is_same_v<ElementOutput, float>) {
                    printf("%.2f, idx= %d\n", float(tCgRow_flt(idx)), idx);
                }
            }

        }
      }
    }

the log is as below:

PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
Address: 0x70b6ed60c000, idx= 0
Address: 0x70b6ed60c004, idx= 0
Address: 0x70b6ed60c008, idx= 0
Address: 0x70b6ed60c00c, idx= 0

Old value: 0.00, idx= 0
Old value: 32.00, idx= 0
Old value: 32.00, idx= 0
Old value: -32.19, idx= 0

Adding value: 8.00
Adding value: -24.00
Adding value: 0.00
Adding value: 62.00

New value: 8.00, idx= 0
New value: 8.00, idx= 0
New value: 32.00, idx= 0
New value: 29.81, idx= 0

PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
PERFORMING ATOMIC REDUCTION:
Address: 0x70b6ed60c010, idx= 1
Address: 0x70b6ed60c014, idx= 1
Address: 0x70b6ed60c018, idx= 1
Address: 0x70b6ed60c01c, idx= 1

Old value: -0.01, idx = 1
Old value: -32.06, idx = 1
Old value: 2.00, idx = 1
Old value: -2.01, idx = 1

Adding value: 52.00
Adding value: 26.00
Adding value: 41.00
Adding value: 55.00

New value: 51.99, idx= 1
New value: -6.06, idx= 1
New value: 43.00, idx= 1
New value: 52.99, idx= 1

from the log we can find that for the first time add for the atom,the old value is not 0 but for the random. the Adding value is correct, that is the root cause of the error result.

  1. check the host evt code
    from cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp,
    only resize not initialized for tensor_row_reduce_ .

HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024):
Base(check_relative_equality) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
N_ = cute::get<1>(problem_shape_MNKL);
if constexpr (FinalReduction) {
tensor_row_reduce_.resize(cutlass::Coord<1>(N_));
reference_row_reduce_.resize(cutlass::Coord<1>(N_));
reduce_buffer_.resize(cutlass::Coord<1>(N_));
}
else {
auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{}));
extent_m_ = cute::get<0>(NumTile);
extent_n_ = cute::get<1>(NumTile) * TileN;
extent_l_ = cute::get<2>(NumTile);
auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_);
tensor_row_reduce_.resize(shape);
reference_row_reduce_.resize(shape);
reduce_buffer_.resize(shape);
cutlass::reference::host::TensorFill(tensor_row_reduce_.host_view(), ElementDst(0));
}

add code

  // init tensor_row_reduce_
  tensor_row_reduce_.sync_host();
  for (int i = 0; i < tensor_row_reduce_.size(); ++i) {
      tensor_row_reduce_.host_data()[i] = ElementDst(0);
  }
  tensor_row_reduce_.sync_device();
  // init reduce_buffer_ 
  reduce_buffer_.sync_host();
  for (int i = 0; i < reduce_buffer_.size(); ++i) {
      reduce_buffer_.host_data()[i] = ElementCompute(0);
  }
  reduce_buffer_.sync_device();

Expected behavior
A clear and concise description of what you expected to happen.

Environment details (please complete the following information):

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions