Add Relu2 activation support in CUTLASS MoE backend and fix autotuner async CUDA error handling#2897
Conversation
…patch - Introduced `EpilogueOpDefaultRelu2` struct in `epilogue_helpers.h` for Relu2 activation. - Updated `moe_gemm_template_dispatch.h` to handle `ActivationType::Relu2`, enabling the use of the new Relu2 activation in GEMM operations. - Enhanced the autotuner to clear pending CUDA errors during profiling, improving robustness in error handling. Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the FlashInfer CUTLASS backend by introducing compatibility for models that employ Relu2 as their Mixture-of-Experts (MoE) gate activation, thereby expanding the range of supported neural network architectures. Concurrently, it improves the robustness of the autotuner by ensuring that transient CUDA errors are properly handled and cleared, preventing them from destabilizing later operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Relu2 epilogue support to CUTLASS and dispatch in MOE GEMM, tweaks autotuner tactic-failure handling to attempt CUDA synchronize, enables CUTLASS GDC compile flags for additional SM targets, and introduces SM121-specific kernel generation and heuristic selection with SM12.x normalization updates. Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new Relu2 activation function by extending the CUTLASS epilogue helpers and integrating it into the Mixture-of-Experts (MoE) GEMM kernel dispatch logic. Additionally, it enhances the flashinfer autotuner by adding a mechanism to clear pending asynchronous CUDA errors after failed profiling runs, which prevents these errors from affecting subsequent CUDA graph capture. I have no feedback to provide as there were no review comments.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
1019-1040: Add a fused-activation invariant guard before dispatch.If
inputs.use_fused_moeis ever true forRelu2(or other non-gated activations), execution can hit a non-executing path ingenericMoeGemmKernelLauncher::call. A fast-fail check here would prevent silent misconfiguration.Proposed guard
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType, IsMXFPX>::moeGemmBiasAct( GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { + TLLM_CHECK_WITH_INFO( + !inputs.use_fused_moe || + inputs.activation_type == ActivationType::Swiglu || + inputs.activation_type == ActivationType::Geglu, + "use_fused_moe is only valid for gated activations (Swiglu/Geglu)"); + switch (inputs.activation_type) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h` around lines 1019 - 1040, Add a fast-fail guard before the activation dispatch to ensure non-gated activations cannot be used with fused MOE: check inputs.use_fused_moe and if true for non-gated activations (e.g., ActivationType::Relu2, ActivationType::Relu, ActivationType::Gelu, ActivationType::Silu, ActivationType::Identity, ActivationType::Swiglu, ActivationType::Geglu) log/throw a clear error and return before calling runGemm; place this check in the same scope that switches on inputs.activation_type (the dispatch that calls runGemm) so misconfiguration never reaches genericMoeGemmKernelLauncher::call. Ensure the error message mentions use_fused_moe and the offending ActivationType to aid debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h`:
- Around line 1019-1040: Add a fast-fail guard before the activation dispatch to
ensure non-gated activations cannot be used with fused MOE: check
inputs.use_fused_moe and if true for non-gated activations (e.g.,
ActivationType::Relu2, ActivationType::Relu, ActivationType::Gelu,
ActivationType::Silu, ActivationType::Identity, ActivationType::Swiglu,
ActivationType::Geglu) log/throw a clear error and return before calling
runGemm; place this check in the same scope that switches on
inputs.activation_type (the dispatch that calls runGemm) so misconfiguration
never reaches genericMoeGemmKernelLauncher::call. Ensure the error message
mentions use_fused_moe and the offending ActivationType to aid debugging.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ef3b076f-ae24-496a-9133-6b2c7468e8fc
📒 Files selected for processing (3)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.hcsrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.hflashinfer/autotuner.py
- Added compilation flag `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to the `gen_cutlass_fused_moe_sm120`, `gen_cutlass_fused_moe_sm103`, and `gen_cutlass_fused_moe_sm100` functions to support GDC for SM100 architecture. Signed-off-by: Andrii Skliar <askliar@nvidia.com>
| "-DENABLE_FP8", | ||
| "-DENABLE_FP4", | ||
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", |
|
General question, why is the change in #1954 not enough for your case? It probably goes for a different code-path but I'd like to make sure we're not missing a different bug that causes this. |
| f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}" | ||
| ) | ||
|
|
||
| # Clear any pending async CUDA errors (e.g. |
There was a problem hiding this comment.
A bit of nitpicking but consider splitting this into a different PR
- Introduced `get_candidate_configs_sm121` function to handle GEMM configurations for the SM121 architecture, which has a reduced shared memory budget. - Updated `generate_sm120_grouped_gemm_operations` to accommodate the specific tile size constraints for SM121. - Enhanced `CompilationContext` to differentiate between SM120 and SM121 in the JIT cache. - Adjusted kernel generation logic to ensure compatibility with the new architecture. Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…re/add_relu2_for_default_backend
…2_for_default_backend
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 646-647: The FAST_BUILD branch in get_candidate_configs_sm121
unconditionally delegates non-GROUPED_GEMM cases to get_candidate_configs_sm120
which (in FAST_BUILD) returns CtaShape128x128x256B that can exceed SM121 SMEM
limits; modify the FAST_BUILD path in get_candidate_configs_sm121 to check the
FP4_ONLY flag before delegating — if config has FP4_ONLY then call
get_candidate_configs_sm120(config), otherwise select SM121-safe candidates
(avoid CtaShape128x128x256B) or fallback to an SM121-specific safe list; update
the logic around GROUPED_GEMM, FAST_BUILD, and FP4_ONLY in
get_candidate_configs_sm121 to mirror the non-FAST_BUILD guard behavior.
In `@flashinfer/jit/gemm/cutlass/generate_kernels.py`:
- Around line 1059-1062: The SM121-only branch is dead because
gen_cutlass_fused_moe_sm120_module() forces device_arch="120" so has_arch(121)
is never true; either remove the conditional and the generate_sm121_operations
call to eliminate dead code, or implement a parallel SM121 generator: add
gen_cutlass_fused_moe_sm121_module() mirroring
gen_cutlass_fused_moe_sm120_module() that passes device_arch="121" (and ensure
calling sites use it), then keep the if has_arch(121) and not has_arch(120)
branch to invoke generate_sm121_operations(True); update or remove references to
has_arch(121) accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 23eb8e22-b855-4109-9630-e4cca23e64b8
📒 Files selected for processing (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppflashinfer/autotuner.pyflashinfer/compilation_context.pyflashinfer/jit/gemm/cutlass/generate_kernels.py
| return get_candidate_configs_sm120(config); | ||
| } |
There was a problem hiding this comment.
FAST_BUILD non-GROUPED_GEMM path for SM121 may be incorrect.
In the FAST_BUILD branch (lines 645-647), when config does NOT have GROUPED_GEMM, it falls back to get_candidate_configs_sm120(config). However, get_candidate_configs_sm120 in FAST_BUILD mode (lines 589-599) returns CtaShape128x128x256B for non-grouped GEMM, which may exceed SM121's SMEM budget.
The non-FAST_BUILD path (lines 663-668) correctly delegates to SM120 for FP4_ONLY, but the FAST_BUILD path doesn't check FP4_ONLY before delegating.
Consider adding FP4_ONLY guard in FAST_BUILD path
`#ifdef` FAST_BUILD
if (config & CutlassGemmConfig::GROUPED_GEMM) {
return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B,
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1}};
} else {
+ if ((config & CutlassGemmConfig::FP4_ONLY) != 0) {
+ return get_candidate_configs_sm120(config);
+ } else {
+ TLLM_THROW("Not Implemented: SM121 non-group GEMM only supports nvfp4 in FAST_BUILD.");
+ }
- return get_candidate_configs_sm120(config);
}
`#else`📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return get_candidate_configs_sm120(config); | |
| } | |
| `#ifdef` FAST_BUILD | |
| if (config & CutlassGemmConfig::GROUPED_GEMM) { | |
| return {CutlassGemmConfig{CutlassTileConfigSM120::CtaShape128x128x64B, | |
| MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, | |
| ClusterShape::ClusterShape_1x1x1}}; | |
| } else { | |
| if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { | |
| return get_candidate_configs_sm120(config); | |
| } else { | |
| TLLM_THROW("Not Implemented: SM121 non-group GEMM only supports nvfp4 in FAST_BUILD."); | |
| } | |
| } | |
| `#else` |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`
around lines 646 - 647, The FAST_BUILD branch in get_candidate_configs_sm121
unconditionally delegates non-GROUPED_GEMM cases to get_candidate_configs_sm120
which (in FAST_BUILD) returns CtaShape128x128x256B that can exceed SM121 SMEM
limits; modify the FAST_BUILD path in get_candidate_configs_sm121 to check the
FP4_ONLY flag before delegating — if config has FP4_ONLY then call
get_candidate_configs_sm120(config), otherwise select SM121-safe candidates
(avoid CtaShape128x128x256B) or fallback to an SM121-specific safe list; update
the logic around GROUPED_GEMM, FAST_BUILD, and FP4_ONLY in
get_candidate_configs_sm121 to mirror the non-FAST_BUILD guard behavior.
- Implemented `gen_cutlass_fused_moe_sm121_module` to generate modules for the SM121 architecture, ensuring compatibility with its shared memory constraints. - Updated the `get_cutlass_fused_moe_module` function to handle the new SM121 backend. - Refactored `get_candidate_configs_sm121` to streamline GEMM configuration retrieval. This enhances the framework's capability to leverage the SM121 architecture effectively.
- Removed `gen_cutlass_fused_moe_sm121_module` and its references from the codebase, simplifying the architecture support. - Updated `get_cutlass_fused_moe_module` to handle only SM120 and SM103 backends. - Adjusted kernel generation logic to ensure compatibility with the remaining architectures. This change streamlines the code and focuses on maintaining support for the more widely used architectures.
- Introduced a filtering mechanism in `get_candidate_tiles` to exclude tile configurations where both M and N are greater than or equal to 128 for the SM121 architecture, addressing shared memory constraints. - Updated the return statements for various GEMM types to utilize the new filtering function, ensuring the autotuner does not consider invalid configurations. This change enhances the efficiency of the autotuner by preventing it from evaluating known-bad configurations for SM121.
|
/bot run |
|
@askliar - according to CUDA documentation, sm120 also has 99KB of shared memory, which makes sense, because sm120/sm121 are pretty much identical. Not sure where 228KB for sm120 comes from? Also, I haven't looked at more detail, but I assume it will build correctly for Spark if FLASHINFER_CUDA_ARCH_LIST=12.1a (not 12.1f), right? |
|
@eugr just FYI, I have split this PR into two separate ones: I will keep updating the latter with your comments!
Thank you for your contributions! |
|
[FAILED] Pipeline #47388152: 11/20 passed |
|
Thanks! |
Summary
InvalidTypeat runtime.cudaErrorIllegalInstruction), the sticky error was not cleared and would surface later during CUDA graph capture or inference. Addedtorch.cuda.synchronize()drain and demoted the "Skipping tactic" log from WARNING to DEBUG since these failures are expected and recoverable.121finstead of sharing120fwith SM12.0, ensuring the JIT cache is isolated per architecture.Depends on
This PR should be merged after #2913 (enable GDC for CUTLASS fused MoE PDL on SM12x), which adds the
-DCUTLASS_ENABLE_GDC_FOR_SM100=1compile flags that this PR's SM120 module generator relies on.Changes
Relu2 activation (
epilogue_helpers.h,moe_gemm_template_dispatch.h)Two pieces were missing:
epilogue_helpers.hhad noEpilogueOpDefaultRelu2tag struct orEpiloguepartial specialization, so there was no CUTLASS epilogue type for Relu2.moeGemmBiasAct()had nocase ActivationType::Relu2, causing it to fall through toInvalidTypeand throw.The
Relu2functor itself (relu(x)²) already existed infused_activations.h— this PR just wires it into the epilogue dispatch.SM121 tile filtering (
cutlass_heuristic.cpp,compilation_context.py)get_candidate_configs_sm121()that returns onlyCtaShape128x128x64B(~73 KB, fits in 99 KB). The other three SM120 tiles (128x128x128B, 256x128x64B, 128x256x64B) all exceed 99 KB.sm >= 120tosm == 120+ separatesm == 121branch.Ampere-style bf16/FP8 grouped GEMM (SM80 path):
filter_sm121lambda inget_candidate_tiles()that removes tiles where both M >= 128 and N >= 128, since these also exceed SM121's SMEM budget.JIT cache separation:
compilation_context.py: SM12.1 now returns(12, "1f")instead of(12, "0f"), giving GB10 cache path~/.cache/flashinfer/<version>/121f/separate from GB200's120f/.Autotuner robustness (
autotuner.py)WARNINGtoDEBUG— tile-incompatibility failures are expected on SM121 and should not spam logs.torch.cuda.synchronize()drain after failed tactic probes to clear sticky async CUDA errors (e.g.cudaErrorIllegalInstructionfrom failed TMA WS GEMM probes) before they surface during CUDA graph capture.Test plan
pytest tests/moe/)pytest tests/moe/test_trtllm_gen_fused_moe.py -k Relu2WARNING [Autotuner]: Skipping tacticlog spamCtaShape128x128x64BFP4 grouped GEMM tile selected~/.cache/flashinfer/0.6.7/121f/