fix: add support check for gemm config for cutlass moe#2495
fix: add support check for gemm config for cutlass moe#2495yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the robustness of the CUTLASS Mixture of Experts (MoE) GEMM kernels by integrating critical validation checks. The primary goal is to prevent misconfigurations, particularly when utilizing the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdded runtime validation in MOE GEMM dispatch to enforce NO_SMEM epilogue constraints: require output N alignment based on OutputType bit-width and disallow FINALIZE epilogue fusion when NO_SMEM is selected; checks run before dispatch and do not change public APIs. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds necessary support checks for GEMM configurations in CUTLASS MoE kernels, specifically for NO_SMEM epilogue schedules. The checks ensure output alignment and prevent unsupported fusions, which is important for correctness and for the autotuner. While the changes are good, there's an opportunity to improve code clarity and reduce duplication. The alignment calculation is repeated, and I've suggested extracting it into a constant. Furthermore, since the alignment check logic is now duplicated in two functions, consider refactoring it into a shared helper function for better maintainability.
| if (inputs.gemm_config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM && | ||
| !isGatedActivation(inputs.activation_type)) { | ||
| TLLM_CHECK_WITH_INFO( | ||
| inputs.n % (256 / cutlass::sizeof_bits<OutputType>::value) == 0, | ||
| "Output N %ld does not meet minimum alignment requirements for NO_SMEM epilogue %d", | ||
| (long)inputs.n, (int)(256 / cutlass::sizeof_bits<OutputType>::value)); | ||
| } |
There was a problem hiding this comment.
The calculation for the minimum alignment is performed twice inside this check. To improve readability and avoid this repetition, you can store the result in a const auto variable. This makes the code cleaner and easier to understand.
if (inputs.gemm_config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM &&
!isGatedActivation(inputs.activation_type)) {
const auto min_alignment = 256 / cutlass::sizeof_bits<OutputType>::value;
TLLM_CHECK_WITH_INFO(
inputs.n % min_alignment == 0,
"Output N %ld does not meet minimum alignment requirements for NO_SMEM epilogue %d",
(long)inputs.n, (int)min_alignment);
}| if (inputs.gemm_config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM) { | ||
| TLLM_CHECK_WITH_INFO(inputs.gemm_config.epilogue_fusion_type != | ||
| cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE, | ||
| "NO_SMEM epilogue schedule is not supported with FINALIZE fusion"); | ||
| TLLM_CHECK_WITH_INFO( | ||
| inputs.n % (256 / cutlass::sizeof_bits<OutputType>::value) == 0, | ||
| "Output N %ld does not meet minimum alignment requirements for NO_SMEM epilogue %d", | ||
| (long)inputs.n, (int)(256 / cutlass::sizeof_bits<OutputType>::value)); | ||
| } |
There was a problem hiding this comment.
Similar to the check in moeGemmBiasAct, the alignment calculation is repeated here. Extracting it into a const auto variable will improve readability. Since this alignment check logic is now present in two places, you might also consider creating a private helper function to encapsulate this check and avoid code duplication.
if (inputs.gemm_config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM) {
TLLM_CHECK_WITH_INFO(inputs.gemm_config.epilogue_fusion_type !=
cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE,
"NO_SMEM epilogue schedule is not supported with FINALIZE fusion");
const auto min_alignment = 256 / cutlass::sizeof_bits<OutputType>::value;
TLLM_CHECK_WITH_INFO(
inputs.n % min_alignment == 0,
"Output N %ld does not meet minimum alignment requirements for NO_SMEM epilogue %d",
(long)inputs.n, (int)min_alignment);
}| // For NoSmem epilogue schedule, output N must be 256-bit aligned. | ||
| // For gated activation, this is automatic if the usual alignment requirement is met. | ||
| // This check is here so the autotuner can catch invalid tactics during profiling. | ||
| if (inputs.gemm_config.epilogue_schedule == cutlass_extensions::EpilogueScheduleType::NO_SMEM && |
There was a problem hiding this comment.
If we put this in runGemm/dispatchToArch we don't need to have two copies of this check.
Maybe here since this is only relevant for SM90+
There was a problem hiding this comment.
dispatchToArch doesn't work as we could not know if the activation is gated or not in the function.
runGemm works, but to align with the logic in the moe runner code, I think it's better to split the gemm1 and gemm2 logic separately to make it more clear and align with the original logic
There was a problem hiding this comment.
We dont need to check isGatedActivation here
size_t const fc1_out_size =
((!use_ampere_activation_fusion) && is_gated_activation) ? inter_size * 2 : inter_size;
This line sets the value of N correctly. The original check is only working with inter_size so needs to explicitly check we are in the non-gated case
There was a problem hiding this comment.
Thanks. Updated to dispatchToArch
|
/bot run |
|
[CANCELING] Pipeline #43308250: canceled |
This reverts commit 6ccb7f3.
|
/bot run |
|
[CANCELING] Pipeline #43311572: canceled |
|
/bot run |
|
[FAILED] Pipeline #43312738: 10/20 passed |
aleozlx
left a comment
There was a problem hiding this comment.
tests clean
approved again
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit