-
Notifications
You must be signed in to change notification settings - Fork 67
MoE GEMM example #600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
MoE GEMM example #600
Conversation
c0b52e6 to
6c0cee7
Compare
6c0cee7 to
a57a540
Compare
|
what datatype is covered by this PR and what the current performance? |
| } | ||
|
|
||
| reorder(tArA, tCrA); | ||
| reorder(tBrB, tCrB); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp16/bf16, what's the purpose of these two reorders? If we read the data of A and B with demand layout, we don't need to add reorder again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing!
If we read the data of A and B with demand layout, we don't need to add reorder again
In that case, it's a no-op. It's explicitly mentioned in the rearch documentation:
reorder acts as a "pipe" connecting copy and MMA operations (or any other subgroup-scope operations). With reorders, the kernel writer does not need to worry about perfectly matching layouts between copy and MMA atoms. In case the layouts do match perfectly (as make_block_2d_copy_{A,B,C} try to do), the compiler is able to remove the reorder entirely, making it a no-op.
| } | ||
|
|
||
| // Assumes n-major SG layout | ||
| // TODO: This one is only for MXFP4. Add template specialization for dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this PR, I don't think we need to mix bf16 and mxfp4 together. The major part of mxfp4 is not really yet and we can target to merge bf16 first. The common part can be reused for mxfp4 later as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying! BTW, that TODO has been resolved. I forgot to remove that comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have folders for MOE, like: 09_bmg_grouped_gemm_fp8, 10_bmg_grouped_gemm_mixed_dtype; do you do similar things? If yes, can we follow the existing naming convention for examples' folder.
Will do, thanks! |
@pengzhao-intel, as discussed offline, there are still some performance issues with the new store atom API, which are said to be exposing some IGC instruction scheduling issues. I'll revert to the mixed-API implementation. We'd still use the new copy atoms for loading I'll request you again for review. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a Mixture of Experts (MoE) GEMM implementation for Intel GPUs using SYCL and CuTe, enabling MoE computations without collectives that can be called directly from GPU kernels.
Key Changes
- New persistent tile scheduler for MoE GEMM workloads with custom work distribution
- Core MoE GEMM kernels supporting both standard 16-bit floating-point and quantized MXFP4 formats
- Example implementation demonstrating multi-expert GEMM execution with real workload patterns
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/cute/tutorial/moe/moe_tile_scheduler.hpp | Implements persistent tile scheduler adapted for MoE workloads with per-expert work distribution |
| examples/cute/tutorial/moe/moe_grouped_gemm.hpp | Main MoE GEMM orchestration handling expert batching and tensor pointer updates |
| examples/cute/tutorial/moe/moe_gemms.hpp | Device-side GEMM kernels with support for bf16/fp16 and MXFP4 quantized operations |
| examples/cute/tutorial/moe/moe_example.cpp | Host-side launcher and example with realistic multi-layer expert workload patterns |
| examples/cute/tutorial/CMakeLists.txt | Adds build target for MoE GEMM example |
Comments suppressed due to low confidence (1)
examples/cute/tutorial/moe/moe_tile_scheduler.hpp:1
- Corrected spelling of 'Othwerwise' to 'Otherwise' in comment at line 301 of moe_gemms.hpp
/***************************************************************************************************
| // TODO | ||
| // The current MMA collective for grouped GEMM currently seems to be missing | ||
| // batching Ask Ji Yang |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment lacks context and actionable information. It should specify what batching functionality is missing and what needs to be asked or investigated.
| // TODO | |
| // The current MMA collective for grouped GEMM currently seems to be missing | |
| // batching Ask Ji Yang | |
| // TODO: Investigate and implement batching support in the MMA collective for grouped GEMM. | |
| // Currently, the MMA collective does not support batched execution across multiple groups. | |
| // Consult with Ji Yang to clarify the requirements for batching, including: | |
| // - What batching semantics are needed for grouped GEMM in this context? | |
| // - How should the MMA collective handle multiple groups in a single kernel launch? | |
| // - Are there hardware or API constraints to consider for batching? |
| inline constexpr bool is_16_bit_fp_v = | ||
| is_16_bit_fp<std::remove_cv_t<std::remove_reference_t<T>>>::value; | ||
|
|
||
| // Making sure I got this right |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is too informal and vague for production code. It should be removed or replaced with a clear explanation of what the static_asserts verify.
| // Making sure I got this right | |
| // Verify that is_16_bit_fp_v correctly identifies cutlass::bfloat16_t and cutlass::half_t as 16-bit floating point types. |
| // TODO: Remove magic numbers | ||
| // Assumes scales are [K/32, N], as the reads are coalesced in that | ||
| // case (although shuffles can't be prevented later). Since loaded | ||
| // scales would be read columnwise when scales would be shaped [N, | ||
| // K/32], use UniversalCopy instead. | ||
| scales_e8m0[i] = *(S_tile.data() + (64 * (sg_id % 4)) + k_tile * 256 + |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scaling calculation uses multiple hardcoded values (64, 4, 256, 8) without explanation. These should be defined as named constants with clear documentation of their relationship to the tensor layout.
| // hardcoded for work-item B subgroup fragment sized 128 | ||
| // I created 4 loops just so that I can easily copy-paste some of them to | ||
| // the case of 32 or 64 elements per B thread fragment with compile-time | ||
| // expression evaluation. Othwerwise, the loops below can be combined. | ||
| // However, instead of hardcoding, figure out CuTe algebra based | ||
| // transformations that can lead to generic code. | ||
| CUTE_UNROLL | ||
| for (int i = 0, j = 0; i < 15; i += 2) { | ||
| tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | ||
| static_cast<float>(tCrB[i])); | ||
| tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | ||
| static_cast<float>(tCrB[i + 1])); | ||
| tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | ||
| static_cast<float>(tCrB[i + 64])); | ||
| tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | ||
| static_cast<float>(tCrB[i + 65])); | ||
| } | ||
| CUTE_UNROLL | ||
| for (int i = 16, j = 2; i < 31; i += 2) { | ||
| tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | ||
| static_cast<float>(tCrB[i])); | ||
| tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | ||
| static_cast<float>(tCrB[i + 1])); | ||
| tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | ||
| static_cast<float>(tCrB[i + 64])); | ||
| tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | ||
| static_cast<float>(tCrB[i + 65])); | ||
| } | ||
| CUTE_UNROLL |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment acknowledges that the hardcoded approach should be replaced with generic CuTe transformations. The four separate unrolled loops (lines 304-347) represent significant code duplication that reduces maintainability.
| // hardcoded for work-item B subgroup fragment sized 128 | |
| // I created 4 loops just so that I can easily copy-paste some of them to | |
| // the case of 32 or 64 elements per B thread fragment with compile-time | |
| // expression evaluation. Othwerwise, the loops below can be combined. | |
| // However, instead of hardcoding, figure out CuTe algebra based | |
| // transformations that can lead to generic code. | |
| CUTE_UNROLL | |
| for (int i = 0, j = 0; i < 15; i += 2) { | |
| tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | |
| static_cast<float>(tCrB[i])); | |
| tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | |
| static_cast<float>(tCrB[i + 1])); | |
| tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | |
| static_cast<float>(tCrB[i + 64])); | |
| tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | |
| static_cast<float>(tCrB[i + 65])); | |
| } | |
| CUTE_UNROLL | |
| for (int i = 16, j = 2; i < 31; i += 2) { | |
| tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | |
| static_cast<float>(tCrB[i])); | |
| tCrB[i + 1] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | |
| static_cast<float>(tCrB[i + 1])); | |
| tCrB[i + 64] = static_cast<bfloat16_t>(scales_float_sg_tensor[j] * | |
| static_cast<float>(tCrB[i + 64])); | |
| tCrB[i + 65] = static_cast<bfloat16_t>(scales_float_sg_tensor[j + 1] * | |
| static_cast<float>(tCrB[i + 65])); | |
| } | |
| CUTE_UNROLL | |
| // Generic scaling using CuTe tensor algebra to avoid hardcoded loops | |
| // This replaces the four unrolled loops with a single generic loop | |
| // Assumes tCrB is of size 128 and scales_float_sg_tensor is of size 8 | |
| for (int i = 0; i < 128; ++i) { | |
| // Compute the scale index for this element | |
| // The original code uses j = (i % 32) / 8, mapping i to scale index | |
| int scale_idx = (i % 32) / 8; | |
| tCrB[i] = static_cast<bfloat16_t>(scales_float_sg_tensor[scale_idx] * | |
| static_cast<float>(tCrB[i])); | |
| } | |
| CUTE_UNROLL |
| cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); | ||
| cutlass::KernelHardwareInfo hw_info{0, sm_count}; | ||
| auto dummy_problem_shape = cute::Shape<int, int, int>{1, gemm_k, gemm_n}; | ||
| // I forgot why I used this hack |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment admits not understanding why this code exists, which is problematic for maintainability. The purpose of the dummy problem shape construction should be documented or the code should be refactored if it's unnecessary.
| // I forgot why I used this hack | |
| // The scheduler and kernel launch APIs require a GroupProblemShape, even for a single GEMM problem. | |
| // We construct a dummy problem shape and group problem shape to satisfy these API requirements. | |
| // This ensures correct setup of scheduler parameters and grid shape for the kernel launch. |
| // TODO: Update with broadcast loads recently added. It seems it's okay to | ||
| // create a new TiledCopy object in each mainloop iteration. |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a new TiledCopy object in each mainloop iteration (line 268 comment) could impact performance. This TODO should be addressed or a clearer justification provided for why it's acceptable.
| template <typename T> | ||
| static constexpr bool is_complete_v = is_complete<T>::value; | ||
|
|
||
| // Some of this code has been authored by Peter caday |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected capitalization of name 'Peter caday' to 'Peter Caday'
| // Some of this code has been authored by Peter caday | |
| // Some of this code has been authored by Peter Caday |
MoE GEMMs without using collectives. Can be called directly from another GPU kernel.
New copy, MMA atoms have been used. Some of the code has been adapted from the new API example.
cutlass::DeviceAllocationin launch codecutlass::DeviceAllocation.cc @pengzhao-intel @CaoZhongZ