Skip to content

Conversation

@sanchitintel
Copy link

@sanchitintel sanchitintel commented Nov 3, 2025

MoE GEMMs without using collectives. Can be called directly from another GPU kernel.
New copy, MMA atoms have been used. Some of the code has been adapted from the new API example.

  • Add back cutlass::DeviceAllocation in launch code
  • Add back accuracy checks (individual GEMMs have been tested, though). I had removed it when I removed cutlass::DeviceAllocation.
  • Decrease duplication among GEMM code by sharing common code except for scaling.
  • Add perf data

cc @pengzhao-intel @CaoZhongZ

@pengzhao-intel
Copy link

what datatype is covered by this PR and what the current performance?

}

reorder(tArA, tCrA);
reorder(tBrB, tCrB);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp16/bf16, what's the purpose of these two reorders? If we read the data of A and B with demand layout, we don't need to add reorder again.

Copy link
Author

@sanchitintel sanchitintel Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing!

If we read the data of A and B with demand layout, we don't need to add reorder again

In that case, it's a no-op. It's explicitly mentioned in the rearch documentation:

reorder acts as a "pipe" connecting copy and MMA operations (or any other subgroup-scope operations). With reorders, the kernel writer does not need to worry about perfectly matching layouts between copy and MMA atoms. In case the layouts do match perfectly (as make_block_2d_copy_{A,B,C} try to do), the compiler is able to remove the reorder entirely, making it a no-op.

}

// Assumes n-major SG layout
// TODO: This one is only for MXFP4. Add template specialization for dtype

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, I don't think we need to mix bf16 and mxfp4 together. The major part of mxfp4 is not really yet and we can target to merge bf16 first. The common part can be reused for mxfp4 later as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying! BTW, that TODO has been resolved. I forgot to remove that comment.

Copy link

@tdeng5 tdeng5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have folders for MOE, like: 09_bmg_grouped_gemm_fp8, 10_bmg_grouped_gemm_mixed_dtype; do you do similar things? If yes, can we follow the existing naming convention for examples' folder.

@sanchitintel
Copy link
Author

We have folders for MOE, like: 09_bmg_grouped_gemm_fp8, 10_bmg_grouped_gemm_mixed_dtype; do you do similar things? If yes, can we follow the existing naming convention for examples' folder.

Will do, thanks!

@sanchitintel
Copy link
Author

sanchitintel commented Nov 5, 2025

performance

@pengzhao-intel, as discussed offline, there are still some performance issues with the new store atom API, which are said to be exposing some IGC instruction scheduling issues. I'll revert to the mixed-API implementation. We'd still use the new copy atoms for loading A, B. Specifically, using the new B copy atoms allow better performance for ColumnMajor B case, as the new API supports more shapes for transposed loads (the legacy API only supports 16x16 loads for transposed B for BF16 dtype).

I'll request you again for review.

Thanks!

@Antonyvance Antonyvance requested a review from Copilot November 5, 2025 07:17
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a Mixture of Experts (MoE) GEMM implementation for Intel GPUs using SYCL and CuTe, enabling MoE computations without collectives that can be called directly from GPU kernels.

Key Changes

  • New persistent tile scheduler for MoE GEMM workloads with custom work distribution
  • Core MoE GEMM kernels supporting both standard 16-bit floating-point and quantized MXFP4 formats
  • Example implementation demonstrating multi-expert GEMM execution with real workload patterns

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
examples/cute/tutorial/moe/moe_tile_scheduler.hpp Implements persistent tile scheduler adapted for MoE workloads with per-expert work distribution
examples/cute/tutorial/moe/moe_grouped_gemm.hpp Main MoE GEMM orchestration handling expert batching and tensor pointer updates
examples/cute/tutorial/moe/moe_gemms.hpp Device-side GEMM kernels with support for bf16/fp16 and MXFP4 quantized operations
examples/cute/tutorial/moe/moe_example.cpp Host-side launcher and example with realistic multi-layer expert workload patterns
examples/cute/tutorial/CMakeLists.txt Adds build target for MoE GEMM example
Comments suppressed due to low confidence (1)

examples/cute/tutorial/moe/moe_tile_scheduler.hpp:1

  • Corrected spelling of 'Othwerwise' to 'Otherwise' in comment at line 301 of moe_gemms.hpp
/***************************************************************************************************

Comment on lines +106 to +108
// TODO
// The current MMA collective for grouped GEMM currently seems to be missing
// batching Ask Ji Yang
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment lacks context and actionable information. It should specify what batching functionality is missing and what needs to be asked or investigated.

Suggested change
// TODO
// The current MMA collective for grouped GEMM currently seems to be missing
// batching Ask Ji Yang
// TODO: Investigate and implement batching support in the MMA collective for grouped GEMM.
// Currently, the MMA collective does not support batched execution across multiple groups.
// Consult with Ji Yang to clarify the requirements for batching, including:
// - What batching semantics are needed for grouped GEMM in this context?
// - How should the MMA collective handle multiple groups in a single kernel launch?
// - Are there hardware or API constraints to consider for batching?

Copilot uses AI. Check for mistakes.
inline constexpr bool is_16_bit_fp_v =
is_16_bit_fp<std::remove_cv_t<std::remove_reference_t<T>>>::value;

// Making sure I got this right
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is too informal and vague for production code. It should be removed or replaced with a clear explanation of what the static_asserts verify.

Suggested change
// Making sure I got this right
// Verify that is_16_bit_fp_v correctly identifies cutlass::bfloat16_t and cutlass::half_t as 16-bit floating point types.

Copilot uses AI. Check for mistakes.
Comment on lines +279 to +284
// TODO: Remove magic numbers
// Assumes scales are [K/32, N], as the reads are coalesced in that
// case (although shuffles can't be prevented later). Since loaded
// scales would be read columnwise when scales would be shaped [N,
// K/32], use UniversalCopy instead.
scales_e8m0[i] = *(S_tile.data() + (64 * (sg_id % 4)) + k_tile * 256 +
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scaling calculation uses multiple hardcoded values (64, 4, 256, 8) without explanation. These should be defined as named constants with clear documentation of their relationship to the tensor layout.

Copilot uses AI. Check for mistakes.
Comment on lines +298 to +326
// hardcoded for work-item B subgroup fragment sized 128
// I created 4 loops just so that I can easily copy-paste some of them to
// the case of 32 or 64 elements per B thread fragment with compile-time
// expression evaluation. Othwerwise, the loops below can be combined.
// However, instead of hardcoding, figure out CuTe algebra based
// transformations that can lead to generic code.
CUTE_UNROLL
for (int i = 0, j = 0; i < 15; i += 2) {
tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i]));
tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 1]));
tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i + 64]));
tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 65]));
}
CUTE_UNROLL
for (int i = 16, j = 2; i < 31; i += 2) {
tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i]));
tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 1]));
tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i + 64]));
tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 65]));
}
CUTE_UNROLL
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment acknowledges that the hardcoded approach should be replaced with generic CuTe transformations. The four separate unrolled loops (lines 304-347) represent significant code duplication that reduces maintainability.

Suggested change
// hardcoded for work-item B subgroup fragment sized 128
// I created 4 loops just so that I can easily copy-paste some of them to
// the case of 32 or 64 elements per B thread fragment with compile-time
// expression evaluation. Othwerwise, the loops below can be combined.
// However, instead of hardcoding, figure out CuTe algebra based
// transformations that can lead to generic code.
CUTE_UNROLL
for (int i = 0, j = 0; i < 15; i += 2) {
tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i]));
tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 1]));
tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i + 64]));
tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 65]));
}
CUTE_UNROLL
for (int i = 16, j = 2; i < 31; i += 2) {
tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i]));
tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 1]));
tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] *
static_cast<float>(tCrB[i + 64]));
tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] *
static_cast<float>(tCrB[i + 65]));
}
CUTE_UNROLL
// Generic scaling using CuTe tensor algebra to avoid hardcoded loops
// This replaces the four unrolled loops with a single generic loop
// Assumes tCrB is of size 128 and scales_float_sg_tensor is of size 8
for (int i = 0; i < 128; ++i) {
// Compute the scale index for this element
// The original code uses j = (i % 32) / 8, mapping i to scale index
int scale_idx = (i % 32) / 8;
tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[scale_idx] *
static_cast<float>(tCrB[i]));
}
CUTE_UNROLL

Copilot uses AI. Check for mistakes.
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
cutlass::KernelHardwareInfo hw_info{0, sm_count};
auto dummy_problem_shape = cute::Shape<int, int, int>{1, gemm_k, gemm_n};
// I forgot why I used this hack
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment admits not understanding why this code exists, which is problematic for maintainability. The purpose of the dummy problem shape construction should be documented or the code should be refactored if it's unnecessary.

Suggested change
// I forgot why I used this hack
// The scheduler and kernel launch APIs require a GroupProblemShape, even for a single GEMM problem.
// We construct a dummy problem shape and group problem shape to satisfy these API requirements.
// This ensures correct setup of scheduler parameters and grid shape for the kernel launch.

Copilot uses AI. Check for mistakes.
Comment on lines +267 to +268
// TODO: Update with broadcast loads recently added. It seems it's okay to
// create a new TiledCopy object in each mainloop iteration.
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new TiledCopy object in each mainloop iteration (line 268 comment) could impact performance. This TODO should be addressed or a clearer justification provided for why it's acceptable.

Copilot uses AI. Check for mistakes.
template <typename T>
static constexpr bool is_complete_v = is_complete<T>::value;

// Some of this code has been authored by Peter caday
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected capitalization of name 'Peter caday' to 'Peter Caday'

Suggested change
// Some of this code has been authored by Peter caday
// Some of this code has been authored by Peter Caday

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants