Skip to content

Fix MXFP4/MXFP8 failures in SM120 FAST_BUILD and expand all_tiles[] #2994

Merged
aleozlx merged 32 commits intoflashinfer-ai:mainfrom
askliar:feature/sm121-tile-filtering
Apr 13, 2026
Merged

Fix MXFP4/MXFP8 failures in SM120 FAST_BUILD and expand all_tiles[] #2994
aleozlx merged 32 commits intoflashinfer-ai:mainfrom
askliar:feature/sm121-tile-filtering

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented Apr 6, 2026

Problem

MXFP4 and MXFP8 GEMM operations were failing on SM120 because:

  • The FAST_BUILD path returned a single hardcoded CtaShape128x128x64B tile regardless of GROUPED_GEMM, and that tile is not valid for all MXFP4/MXFP8 configurations
  • The full-build all_tiles[] table was missing tiles needed by those dtypes (128x128x128B, 128x128x64B, 256x128x64B),
    leaving the autotuner with no viable candidate in some cases

Fix

  • FAST_BUILD: differentiate grouped vs. non-grouped paths with tiles known to work for MXFP4/MXFP8:
  • Grouped: 128x128x128B + 128x128x64B
  • Non-grouped: 128x128x256B + 128x128x64B
  • Full-build all_tiles[]: add the three missing tiles so the autotuner has a complete candidate set for MXFP4/MXFP8
    workloads

Summary by CodeRabbit

  • Performance & Optimizations

    • More predictable kernel candidate selection and expanded tile/configuration options for SM120-class GPUs to improve tuning and performance.
    • Broadened handling of grouped computation patterns to enable additional configuration choices.
  • Build/Compatibility

    • Refined CUDA 12.9+ architecture suffixing for more accurate build targeting.
  • Chores

    • Added type annotations and minor signature clarifications (no runtime behavior changes).
  • Bug Fixes

    • MoE fusion path now forwards additional tensors/parameters to improve fused operation correctness.

Co-authored-by: samuellees lsam@nvidia.com

Andrii Skliar and others added 19 commits March 31, 2026 20:21
Added torch.cuda.synchronize() drain after failed tactic probes to clear
sticky async CUDA errors (e.g. cudaErrorIllegalInstruction from failed TMA
WS GEMM probes) before they surface during CUDA graph capture.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
- Introduced `get_candidate_configs_sm121` function to handle GEMM configurations for the SM121 architecture, which has a reduced shared memory budget.
- Updated `generate_sm120_grouped_gemm_operations` to accommodate the specific tile size constraints for SM121.
- Enhanced `CompilationContext` to differentiate between SM120 and SM121 in the JIT cache.
- Adjusted kernel generation logic to ensure compatibility with the new architecture.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
- Implemented `gen_cutlass_fused_moe_sm121_module` to generate modules for the SM121 architecture, ensuring compatibility with its shared memory constraints.
- Updated the `get_cutlass_fused_moe_module` function to handle the new SM121 backend.
- Refactored `get_candidate_configs_sm121` to streamline GEMM configuration retrieval.

This enhances the framework's capability to leverage the SM121 architecture effectively.
- Removed `gen_cutlass_fused_moe_sm121_module` and its references from the codebase, simplifying the architecture support.
- Updated `get_cutlass_fused_moe_module` to handle only SM120 and SM103 backends.
- Adjusted kernel generation logic to ensure compatibility with the remaining architectures.

This change streamlines the code and focuses on maintaining support for the more widely used architectures.
- Introduced a filtering mechanism in `get_candidate_tiles` to exclude tile configurations where both M and N are greater than or equal to 128 for the SM121 architecture, addressing shared memory constraints.
- Updated the return statements for various GEMM types to utilize the new filtering function, ensuring the autotuner does not consider invalid configurations.

This change enhances the efficiency of the autotuner by preventing it from evaluating known-bad configurations for SM121.
- Introduced a `skipped_count` variable to track the number of unsupported tactics during profiling in the `AutoTuner` class.
- Added logging to inform users when tactics are skipped, enhancing visibility into the autotuning process.

This change improves the debugging experience by providing insights into the profiling process and potential issues with unsupported tactics.
…ility

- Updated the `TileShape` structure to include a third dimension `k` for various tile configurations, ensuring accurate representation of tile shapes in the `get_cta_shape_for_config` function.
- Removed the filtering mechanism for large tile configurations specific to SM121, simplifying the candidate tile retrieval process across different GEMM types.
- Adjusted return statements in `get_candidate_tiles` to directly return valid configurations without filtering, enhancing the efficiency of the autotuner.

This change improves the flexibility and accuracy of tile shape configurations in CUTLASS, facilitating better performance across supported architectures.
…imits

- Introduced a new function `tile_fits_smem` to evaluate if tile configurations fit within the shared memory limits for different architectures, improving memory management.
- Updated `get_candidate_configs_sm120` to include the new shared memory fitting logic, ensuring only valid configurations are considered for SM120.
- Adjusted the `get_candidate_configs` function to utilize the shared memory fitting check, enhancing the robustness of GEMM configuration retrieval.

This change optimizes the autotuning process by preventing the selection of configurations that exceed shared memory constraints, leading to better performance across supported architectures.
- Updated the shared memory limit for SM 8.6, 8.9, 12.x from 102400 bytes to 101376 bytes to reflect accurate constraints.
- Cleaned up logging statements in `get_candidate_configs_sm120` and `get_candidate_configs` for better readability and consistency.

This change ensures that the shared memory calculations are precise, enhancing the reliability of GEMM configuration evaluations.
…heuristic.cpp

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit removes the `tile_fits_smem` function from the `cutlass_heuristic.cpp` file, which was responsible for checking if a given tile and stage pair fit within the device shared memory limit. The logic for this check has been deemed unnecessary for the current candidate configuration functions. Additionally, the `compilation_context.py` file has been updated to clarify the suffix handling for SM 12.x architectures, ensuring that each variant is correctly represented. The `autotuner.py` file has also been modified to include error handling for CUDA operations, improving robustness during profiling. Overall, these changes streamline the code and enhance error management.
This commit modifies the tile configuration in the `cutlass_heuristic.cpp` file by changing the structure of the `all_tiles` array. The K dimension has been removed from the configuration, simplifying the tile representation to only include the tile enumeration and dimensions M and N. This change streamlines the candidate configuration process for SM120 GEMM operations.
This commit updates the `cutlass_heuristic.cpp` file by removing the K dimension from the TileShape structure, streamlining the tile configuration for various CutlassTileConfig cases. The changes enhance the clarity and efficiency of the candidate configuration process for GEMM operations, particularly for SM120, by focusing solely on the M and N dimensions.
This commit further simplifies the tile shape configuration in `cutlass_heuristic.cpp` by removing unnecessary return statements and streamlining the candidate configuration logic. The changes enhance code clarity and maintain the focus on M and N dimensions, aligning with previous refactoring efforts to optimize GEMM operations for SM120.
…heuristic.cpp

This commit updates the candidate configuration logic in `cutlass_heuristic.cpp` for SM120 by introducing additional tile shapes based on the `GROUPED_GEMM` configuration. The changes provide a more comprehensive set of configurations, improving flexibility and performance for GEMM operations. The logic now distinguishes between grouped and non-grouped configurations, ensuring appropriate tile shapes are returned based on the input parameters.
…lass_heuristic.cpp

This commit eliminates outdated candidate configurations related to the `FP4_ONLY` and `GROUPED_GEMM` settings in `cutlass_heuristic.cpp`. The removal streamlines the candidate configuration logic, focusing on relevant tile shapes and enhancing code clarity for GEMM operations on SM120.
@samuellees samuellees added run-ci and removed run-ci labels Apr 9, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)

618-622: Pre-size result for the fixed tile table.

Tiny cleanup: reserve capacity before the push loop to avoid reallocations.

♻️ Suggested tweak
   std::vector<CutlassGemmConfig> result;
+  result.reserve(std::size(all_tiles));
   for (auto tile : all_tiles) {
     result.push_back(CutlassGemmConfig{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
                                        ClusterShape::ClusterShape_1x1x1});
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`
around lines 618 - 622, The vector 'result' should pre-size or reserve capacity
to avoid repeated reallocations when populating from 'all_tiles'; before the
loop that pushes CutlassGemmConfig entries, call reserve(all_tiles.size()) (or
resize and assign) on 'result' so the push_back loop that creates
CutlassGemmConfig{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1} does not trigger reallocations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Around line 618-622: The vector 'result' should pre-size or reserve capacity
to avoid repeated reallocations when populating from 'all_tiles'; before the
loop that pushes CutlassGemmConfig entries, call reserve(all_tiles.size()) (or
resize and assign) on 'result' so the push_back loop that creates
CutlassGemmConfig{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x1x1} does not trigger reallocations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bd12e213-32c3-4fe8-a3f5-5ff4655a7279

📥 Commits

Reviewing files that changed from the base of the PR and between 8ff5d38 and b259686.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
  • flashinfer/compilation_context.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/compilation_context.py

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !525 has been updated with latest changes, and the CI pipeline #48123477 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

tests look good

@aleozlx aleozlx enabled auto-merge (squash) April 13, 2026 04:28
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

public CI seemed cancelled for some reason. restarted and waiting for auto merge

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

wait, seems the pre-commit check has failed. pls address that by re-running pre-commit

auto-merge was automatically disabled April 13, 2026 09:29

Head branch was pushed to by a user without write access

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/comm/allreduce.py (1)

729-729: Consider exposing routed_scaling_factor as a function parameter.

This is hardcoded to None with no way for callers to pass a different value, unlike other optional MOE Finalize parameters (expert_scale_factor, shared_expert_output) which are exposed in the function signature. If this is intentional (feature not yet ready), a brief comment would help clarify.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/allreduce.py` at line 729, The call currently hardcodes
routed_scaling_factor=None; expose routed_scaling_factor as an optional
parameter (default None) on the same function that already accepts
expert_scale_factor and shared_expert_output, add it to the function signature
and docstring, and forward that parameter into the call (replacing
routed_scaling_factor=None with routed_scaling_factor=routed_scaling_factor) so
callers can override it; if leaving it fixed was intentional instead, add a
short comment next to the call explaining why it must remain None.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/comm/allreduce.py`:
- Line 729: The call currently hardcodes routed_scaling_factor=None; expose
routed_scaling_factor as an optional parameter (default None) on the same
function that already accepts expert_scale_factor and shared_expert_output, add
it to the function signature and docstring, and forward that parameter into the
call (replacing routed_scaling_factor=None with
routed_scaling_factor=routed_scaling_factor) so callers can override it; if
leaving it fixed was intentional instead, add a short comment next to the call
explaining why it must remain None.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 92c59e3f-ce80-4fbe-a432-65aa3cc6a0b9

📥 Commits

Reviewing files that changed from the base of the PR and between b259686 and 9a78337.

📒 Files selected for processing (4)
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/comm/allreduce.py
  • flashinfer/jit/core.py
✅ Files skipped from review due to trivial changes (3)
  • flashinfer/jit/core.py
  • flashinfer/aot.py
  • flashinfer/autotuner.py

@aleozlx aleozlx enabled auto-merge (squash) April 13, 2026 09:37
auto-merge was automatically disabled April 13, 2026 09:40

Head branch was pushed to by a user without write access

@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Apr 13, 2026

@aleozlx I have looked more into the pre-commit changes - those are also on main. I will do a separate PR.

@askliar askliar force-pushed the feature/sm121-tile-filtering branch from 862593c to 216802d Compare April 13, 2026 09:46
@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Apr 13, 2026

@aleozlx here it is: #3043

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !525 has been updated with latest changes, and the CI pipeline #48411992 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx merged commit 04f4c0c into flashinfer-ai:main Apr 13, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: comm run-ci v0.6.8 release blocker label for 0.6.8

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants