Skip to content

Conversation

@EnricoDeg
Copy link
Contributor

Proposed changes

Summary:

  • Change EpilogueReduceCShuffle to support bias + add operations before reduction (multiple Ds)
  • Add wmma device struct for gemm_bias_add_reduce
  • Add instances (xdl parity)
  • Add tests

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered


void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just use the instance factory get_device_gemm_add_add_mean_squaremean_instances() here instead of manually using the add_device_xxx_instances() functions?

@krithalith
Copy link
Contributor

In the profiler impl we have:

std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
                       sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
                       sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
                       sizeof(ReduceDataType) * M;

But I thought the Bias was a simple 1D vector of size N?

make_tuple(I0, I0, I0, I0),
c01_thread_buf);

// c = c + c1_functior(c1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo


public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is int value initialization

@EnricoDeg EnricoDeg force-pushed the streamhpc/gemm_bias_add_reduce_wmma branch from 90fa9e7 to 7cd0696 Compare December 3, 2025 10:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants