Skip to content

Integrate CUTLASS Small Tile N Blockscaled GEMMs/Grouped GEMMs for SM120 and SM121#3152

Open
depaulmillz wants to merge 7 commits intoflashinfer-ai:mainfrom
depaulmillz:small_blockscaled_n
Open

Integrate CUTLASS Small Tile N Blockscaled GEMMs/Grouped GEMMs for SM120 and SM121#3152
depaulmillz wants to merge 7 commits intoflashinfer-ai:mainfrom
depaulmillz:small_blockscaled_n

Conversation

@depaulmillz
Copy link
Copy Markdown
Contributor

@depaulmillz depaulmillz commented Apr 23, 2026

📌 Description

This MR bumps the CUTLASS commit and adds support for the new small tile N Blockscaled GEMMs in a variety of interfaces including mm_fp4 and the TRT-LLM CUTLASS MoE plugin.

Compared to fb3bb44 when running nvidia/Qwen3-30B-A3B-NVFP4 on VLLM (commit 2463f00) with this MR, I can observe decent performance improvements for ISL=1024 OSL=1024 on DGX Spark:

Concurrency Throughput Speedup Inter-token Latency Speedup
1 1.21x 1.21x
2 1.13x 1.14x
4 1.27x 1.26x

For reference I run the following:

vllm serve nvidia/Qwen3-30B-A3B-NVFP4
vllm bench serve --dataset-name random --num-prompts 10 --num-warmups 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 1
vllm bench serve --dataset-name random --num-prompts 10 --num-warmups 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 2
vllm bench serve --dataset-name random --num-prompts 10 --num-warmups 1 --random-input-len 1024 --random-output-len 1024 --max-concurrency 4

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • SM120/121 GEMM adds new CTA shapes and supports tile N = 32, 64, 128.
    • Added operand-swap (swap_ab) option across FP4/MXFP8/NVFP4 GEMM paths; kernels now include both swap variants.
  • Improvements

    • Kernel generation broadened to produce more mixed-precision FP4/MXFP8 variants.
    • Dispatch and validation extended to accept and route new tile and swap combinations.
  • Tests

    • GEMM tests expanded to cover new tile sizes and both swap_ab settings.
  • Chores

    • Updated third-party submodule pointer.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

📝 Walkthrough

Walkthrough

Updates CUTLASS submodule pointer and extends SM120/121 GEMM support by adding a compile-time/runtime swap_ab option, broadening CTA tile N from 128 to {32,64,128}, and plumbing these changes through dispatch, templates, heuristics, bindings, JIT generation, and tests.

Changes

Cohort / File(s) Summary
Submodule Update
3rdparty/cutlass
Updated CUTLASS submodule pointer to a different upstream commit.
FP4 GEMM Core (SM120)
csrc/fp4_gemm_cutlass_sm120.cu, csrc/fp4_gemm_cutlass_sm120.jinja
Added explicit swap_ab=false to a fallback config; template now instantiates FP4 GEMM launcher for both swap_ab=true/false.
Group GEMM Launchers (NVFP4/MXFP4)
csrc/group_gemm_nvfp4_groupwise_sm120.cu, csrc/group_gemm_mxfp4_groupwise_sm120.cu, csrc/group_gemm_sm120_binding.cu
Added runtime swap_ab arg to public launchers, introduced DISPATCH_SWAP_AB to convert to compile-time SwapAB, extended tile_n dispatch to {32,64,128}, and updated template signatures/instantiations to include SwapAB.
Group GEMM Instantiation Templates
csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja, csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
Expanded tile_n loop to {32,64,128} and added swap_ab loop (true/false), forwarding swap_ab into instantiation macros.
MXFP8 GEMM Templates & Jinja
csrc/mxfp8_gemm_cutlass_sm120.jinja, include/flashinfer/gemm/mxfp8_gemm_template_sm120.h, include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h
Added SwapAB template param, conditional A/B swapping and LayoutC selection, expanded CTA shapes to include N={32,64,128}, and generate both swap variants per CTA.
FP4 GEMM Templates (SM120)
include/flashinfer/gemm/fp4_gemm_template_sm120.h, include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
Added SwapAB template parameter to launchers, switched to packed CUTE strides, added conditional A/B pointer/dimension swapping and adjusted dispatch to include swap_ab variants.
Group GEMM Headers (device args prep)
include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh, include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
Added SwapAB template param to args-prep kernels and host templates; implemented compile-time if constexpr (SwapAB) paths that swap problem shape, strides, pointers, scale-factor layouts, and instantiation names.
CUTLASS Config & Heuristics
include/flashinfer/gemm/cutlass_gemm_configs.h, csrc/.../cutlass_extensions/.../gemm_configs.h, csrc/.../cutlass_kernels/cutlass_heuristic.cpp
Added SM120 CTA enum members for new tile shapes (128x{32,64}x{64,128}B); added swap_ab field to CutlassGemmConfig and SM120 constructor; expanded FP4 candidate generation to include new shapes and both swap_ab variants.
Template Dispatch for MoE & TMA
csrc/.../moe_gemm/moe_gemm_template_dispatch_tma_ws.h
Extended SM120 tile-shape validation and dispatch switch to accept additional 128x{32,64}x{128,256} combinations.
Python API & JIT generation
flashinfer/gemm/gemm_base.py, flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/cutlass/generate_kernels.py
Expanded tile_n support to {32,64,128}, added/validated swap_ab in NVFP4 group GEMM APIs (default True), and widened JIT CTA candidate lists and emission filters to include the new shapes and swap variants.
Tests
tests/gemm/test_group_gemm_fp4.py, tests/gemm/test_groupwise_scaled_gemm_mxfp4.py, tests/gemm/test_mm_mxfp8_sm120.py
Expanded test matrix to iterate tile_n ∈ {32,64,128} and swap_ab ∈ {True,False}; updated expected tactic counts to reflect added CTA and swap variants.

Sequence Diagram(s)

sequenceDiagram
    participant Python as "Python API"
    participant FFI as "FFI (host bindings)"
    participant Host as "Host launcher"
    participant ArgsPrep as "Args-prep kernel (CUDA)"
    participant CUTLASS as "CUTLASS GEMM kernel"
    Python->>FFI: call group_gemm(..., tile_m,tile_n,tile_k, swap_ab)
    FFI->>Host: marshal args, pass swap_ab
    Host->>Host: DISPATCH_TILE_N & DISPATCH_SWAP_AB -> select template instantiation
    Host->>ArgsPrep: launch args-prep kernel (SwapAB compile-time)
    ArgsPrep->>CUTLASS: prepare device args, then invoke CUTLASS kernel
    CUTLASS-->>Host: return results (device completion)
    Host-->>Python: return/collect outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • yzh119
  • aleozlx
  • yongwww
  • dhiraj113
  • sricketts
  • cyx-6
  • bkryu
  • samuellees
  • jimmyzho
  • nv-yunzheq

Poem

🐰
I hopped through tiles of thirty-two and more,
Swapped A with B and bounded every core;
Kernels spawned in pairs, both false and true,
CUTLASS hummed, the compile lights grew;
A rabbit’s cheer — new paths for me and you! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.93% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: integrating CUTLASS blockscaled GEMMs with small tile N support for SM120/121 architectures.
Description check ✅ Passed The PR description includes the key sections from the template: a detailed description of changes, related issues link (though empty), completed pre-commit checks, and test confirmations; all critical sections are present and properly filled out.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the swap_ab parameter and new tile configurations (N=32, N=64) for SM120/121 GEMM kernels across FP4, MXFP4, and MXFP8 formats. Key changes include updates to the CUTLASS subproject, dispatch logic, JIT module, and Python API. Feedback highlights a critical compilation error from a missing comma in a template macro, an outdated error message in Python validation logic, and several missing tile configurations in dispatch switches and candidate lists that should be included for completeness.

Comment thread include/flashinfer/gemm/mxfp8_gemm_template_sm120.h Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
Comment thread include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h Outdated
Comment thread flashinfer/jit/gemm/core.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)

501-511: ⚠️ Potential issue | 🟡 Minor

SM120 MoE dispatch switch missing handlers for two tile configurations.

The SM120 CutlassTileConfigSM120 enum includes CtaShape128x128x256B and CtaShape256x128x128B, and in non-FAST_BUILD mode, both are included in the all_tiles[] candidate list for MoE in get_candidate_configs_sm120() (cutlass_heuristic.cpp:629). However, the dispatch switch in moe_gemm_template_dispatch_tma_ws.h (lines 501–511) lacks corresponding SHAPE_CASE handlers for these two shapes. If either config passes workspace validation and is selected by the heuristic, the dispatcher will hit DEFAULT_CASE(120) and raise "Unsupported tile shape config".

The heuristic comment (line 622–623) states that invalid tiles "are skipped gracefully by the try-catch in calcMaxWorkspaceSize", but this relies on workspace validation to fail for these shapes. To avoid fragility and silent filtering, ensure these two shapes are either:

  • pruned from the heuristic candidate pool, or
  • added as corresponding SHAPE_CASE entries in the dispatcher.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`
around lines 501 - 511, The dispatch switch in
moe_gemm_template_dispatch_tma_ws.h is missing handlers for the
CutlassTileConfigSM120 enum values CtaShape128x128x256B and
CtaShape256x128x128B; update the dispatcher to add corresponding SHAPE_CASE
entries for these two shapes so they won't fall through to DEFAULT_CASE(120).
Locate the switch on gemm_config.tile_config_sm120 and add SHAPE_CASE(...) lines
matching the missing enum variants (same pattern as the existing SHAPE_CASE
entries), or alternatively remove those two enums from the candidate set
produced by get_candidate_configs_sm120() in cutlass_heuristic.cpp so they are
never considered by calcMaxWorkspaceSize; prefer adding the SHAPE_CASE entries
to keep the heuristic intact.
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h (1)

121-146: ⚠️ Potential issue | 🟡 Minor

Make the dispatcher fail-fast for unhandled tile configurations instead of silently defaulting.

The CutlassTileConfigSM120 enum declares CtaShape128x128x256B, CtaShape256x128x128B, and CtaShape128x256x128B, but the switch statement in dispatch_gemm_cta_shape has no case for them. Currently, the default: branch silently dispatches a 128×128×128 kernel—inconsistent with the explicit throws for Undefined and ChooseWithHeuristic. While getConfigs() currently only emits the 8 supported tiles, making the dispatcher throw for unhandled enums provides better defensive design and consistency (the MXFP8 variant already handles all three).

Proposed fix
     default:
-      DISPATCH_WITH_SCHEDULER(128, 128, 128);  // Fallback
+      throw std::runtime_error(
+          "[Error][FP4][dispatch_gemm_cta_shape] Unsupported SM120 tile config.");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h` around lines 121 -
146, The switch in dispatch_gemm_cta_shape over CutlassTileConfigSM120 silently
falls back in default; update it to fail-fast for unhandled enum values by
removing the silent DISPATCH_WITH_SCHEDULER fallback and instead throw a
runtime_error for unknown/unsupported tile configs; ensure you add explicit
cases or a clear error for the missing enums (CtaShape128x128x256B,
CtaShape256x128x128B, CtaShape128x256x128B) or let the default throw with a
message like "[Error][FP4][dispatch_gemm_cta_shape] Unsupported Gemm tile
config: <enum>" so the behavior matches the existing throws for Undefined and
ChooseWithHeuristic.
flashinfer/gemm/gemm_base.py (1)

6549-6561: ⚠️ Potential issue | 🟠 Major

Preserve positional-call compatibility for out and out_dtype.

Adding swap_ab before out changes the public Python call ABI. Existing positional callers that passed out or out_dtype after tile_k will now bind those values to swap_ab / out, which can raise confusing type errors or silently stop using the caller-provided output buffer. Please make swap_ab keyword-only or append it after the existing optional arguments.

🔧 Compatible signature adjustment
 def group_gemm_nvfp4_nt_groupwise(
     a: torch.Tensor,  # (cum_m, k)
     b: torch.Tensor,  # (batch_size, n, k // 2)
     a_scale: torch.Tensor,  # (cum_m_padded, k // 16)
     b_scale: torch.Tensor,  # (batch_size, n_padded, k // 16)
     m_indptr: torch.Tensor,  # (batch_size + 1, )
     alpha: Optional[torch.Tensor] = None,  # (batch_size, )
     tile_m: int = 128,
     tile_n: int = 128,
     tile_k: int = 128,
-    swap_ab: bool = True,
     out: Optional[torch.Tensor] = None,  # (cum_m, n)
     out_dtype: Optional[torch.dtype] = None,
+    *,
+    swap_ab: bool = True,
 ) -> torch.Tensor:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 6549 - 6561, The function
signature for group_gemm_nvfp4_nt_groupwise breaks positional-call compatibility
by inserting swap_ab before out/out_dtype; make swap_ab keyword-only (e.g.,
introduce a positional-only separator so swap_ab must be passed by name) or move
swap_ab to after out_dtype so existing positional callers still bind out and
out_dtype correctly; update the function definition and any callers referenced
by group_gemm_nvfp4_nt_groupwise to use the new keyword-only or reordered
parameter accordingly.
🧹 Nitpick comments (4)
include/flashinfer/gemm/mxfp8_gemm_template_sm120.h (1)

218-229: Consider documenting the AB-swap rationale in this hot path.

The if constexpr (SWAP_AB_) branch swaps A/B, SFA/SFB, and (m, n) — a standard CUTLASS technique to improve warp utilization when the original M is small and N is large. A short comment explaining this choice (and why LayoutC flips to ColumnMajor at line 95-96) would help future maintainers. As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".

💡 Suggested comment
     Mxfp8GemmOperator gemm;                                                                       \
     auto args = [&]() {                                                                           \
+      /* AB-swap: when the logical M is small relative to N, swapping A/B and              */ \
+      /* (m,n) improves warp utilization by mapping the large dimension onto CTA_M. The    */ \
+      /* matching LayoutC flip to ColumnMajor keeps writes to D in the correct logical layout. */ \
       if constexpr (SWAP_AB_) {                                                                   \
         return prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_<               \
             Mxfp8GemmOperator>(D, B, A, weight_sf, input_sf, n, m, k, batch_count);               \
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/gemm/mxfp8_gemm_template_sm120.h` around lines 218 - 229,
Add a concise comment above the hot-path branch that uses SWAP_AB_ (the block
creating Mxfp8GemmOperator gemm and calling
prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_) explaining
that the if constexpr (SWAP_AB_) path intentionally swaps A and B, their scale
factors (input_sf/weight_sf), and the m/n dimensions to improve warp utilization
when M is small and N is large (a common CUTLASS optimization), and note that
this is why LayoutC is flipped to ColumnMajor in the corresponding code path;
include a short note about alternative approaches considered (e.g., tiling or
different CTA shapes) and that this is a deliberate performance trade-off for
this kernel.
include/flashinfer/gemm/cutlass_gemm_configs.h (2)

391-401: Add a default value for swap_ab to avoid a silent API break.

swap_ab is inserted before use_stream_k without a default, while the member on line 356 defaults to false. Any external/downstream caller that previously constructed an SM120 CutlassGemmConfig with the 4‑argument form (or 5‑argument with use_stream_k) will now fail to compile, and callers that omit swap_ab intentionally have no way to express the old behavior. Matching the member default also keeps the class ergonomic and consistent with use_stream_k.

♻️ Proposed fix
   CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120,
                     MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule,
-                    ClusterShape cluster_shape, bool swap_ab, bool use_stream_k = false)
+                    ClusterShape cluster_shape, bool swap_ab = false, bool use_stream_k = false)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/gemm/cutlass_gemm_configs.h` around lines 391 - 401, The
CutlassGemmConfig constructor that sets sm_version=120 added the parameter
swap_ab without a default, breaking callers; update the
CutlassGemmConfig(CutlassTileConfigSM120..., bool swap_ab, bool use_stream_k =
false) constructor signature to give swap_ab a default of false (matching the
member default) so existing 4- or 5-argument call sites still compile and
semantics remain consistent with the swap_ab member and use_stream_k parameter.

413-464: Surface swap_ab in toString() / operator<< for debuggability.

The SM120/SM121 branch in toString() reports use_stream_k but not swap_ab, and operator<< reports neither. Since swap_ab now changes which kernel instantiation is dispatched, including it in the serialized tactic string makes autotuner logs and crash diagnostics unambiguous.

♻️ Suggested addition (illustrative)
       if (sm_version == 120 || sm_version == 121) {
         tactic << "\n\tscheduler: " << (use_stream_k ? "StreamK (auto heuristic)" : "DP (default)");
+        tactic << "\n\tswap_ab: " << (swap_ab ? "true" : "false");
       }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/gemm/cutlass_gemm_configs.h` around lines 413 - 464, The
toString() and operator<< output for CutlassGemmConfig must include the swap_ab
field so logs reflect the kernel instantiation choice; update
CutlassGemmConfig::toString() (in the is_tma_warp_specialized branch and the
non-TMA/enableCudaKernel branches) to append a line reporting "swap_ab:
true/false" (similar to how use_stream_k and enableCudaKernel are printed), and
update the free operator<< overload to include ", swap_ab: " << (config.swap_ab
? "true" : "false") in both the is_tma_warp_specialized and else branches (use
getTileConfigAsInt(), is_tma_warp_specialized, use_stream_k, and
enableCudaKernel to locate the relevant code).
csrc/group_gemm_mxfp4_groupwise_sm120.cu (1)

58-69: Optional: DISPATCH_SWAP_AB has a redundant branch and an unreachable check.

Since swap_ab is a bool, the else if (swap_ab == false) branch is always entered when the first isn't, which makes the trailing TVM_FFI_ICHECK(false) << "Unsupported SWAP AB" and return false dead code. A plain if/else is sufficient and avoids the misleading "unsupported" error path. The other DISPATCH_* macros in this file handle open-ended value sets, so this simplification only applies here.

♻️ Proposed simplification
-#define DISPATCH_SWAP_AB(swap_ab, SWAP_AB, ...)     \
-  [&]() -> bool {                                   \
-    if (swap_ab == true) {                          \
-      constexpr bool SWAP_AB = true;                \
-      return __VA_ARGS__();                         \
-    } else if (swap_ab == false) {                  \
-      constexpr bool SWAP_AB = false;               \
-      return __VA_ARGS__();                         \
-    }                                               \
-    TVM_FFI_ICHECK(false) << "Unsupported SWAP AB"; \
-    return false;                                   \
-  }()
+#define DISPATCH_SWAP_AB(swap_ab, SWAP_AB, ...) \
+  [&]() -> bool {                               \
+    if (swap_ab) {                              \
+      constexpr bool SWAP_AB = true;            \
+      return __VA_ARGS__();                     \
+    } else {                                    \
+      constexpr bool SWAP_AB = false;           \
+      return __VA_ARGS__();                     \
+    }                                           \
+  }()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu` around lines 58 - 69,
DISPATCH_SWAP_AB contains an unnecessary else-if and an unreachable
TVM_FFI_ICHECK(false) path because swap_ab is a bool; change the branching to a
simple if/else so the macro only tests if (swap_ab) { constexpr bool SWAP_AB =
true; return __VA_ARGS__(); } else { constexpr bool SWAP_AB = false; return
__VA_ARGS__(); } and remove the trailing TVM_FFI_ICHECK(false) and return false;
lines to eliminate dead code while keeping the same behavior for
DISPATCH_SWAP_AB, swap_ab, and the SWAP_AB constant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 75-77: The assertion message is inconsistent with the expected
value: update the f-string on the assert that checks num_tactics so the message
matches the expected 10 tactics (e.g., change "Expected 5 tactics" to "Expected
10 tactics") where the assertion referencing num_tactics is located in the test
(lines containing the comment about SM120 tile configs and the assert).

---

Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 501-511: The dispatch switch in
moe_gemm_template_dispatch_tma_ws.h is missing handlers for the
CutlassTileConfigSM120 enum values CtaShape128x128x256B and
CtaShape256x128x128B; update the dispatcher to add corresponding SHAPE_CASE
entries for these two shapes so they won't fall through to DEFAULT_CASE(120).
Locate the switch on gemm_config.tile_config_sm120 and add SHAPE_CASE(...) lines
matching the missing enum variants (same pattern as the existing SHAPE_CASE
entries), or alternatively remove those two enums from the candidate set
produced by get_candidate_configs_sm120() in cutlass_heuristic.cpp so they are
never considered by calcMaxWorkspaceSize; prefer adding the SHAPE_CASE entries
to keep the heuristic intact.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 6549-6561: The function signature for
group_gemm_nvfp4_nt_groupwise breaks positional-call compatibility by inserting
swap_ab before out/out_dtype; make swap_ab keyword-only (e.g., introduce a
positional-only separator so swap_ab must be passed by name) or move swap_ab to
after out_dtype so existing positional callers still bind out and out_dtype
correctly; update the function definition and any callers referenced by
group_gemm_nvfp4_nt_groupwise to use the new keyword-only or reordered parameter
accordingly.

In `@include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h`:
- Around line 121-146: The switch in dispatch_gemm_cta_shape over
CutlassTileConfigSM120 silently falls back in default; update it to fail-fast
for unhandled enum values by removing the silent DISPATCH_WITH_SCHEDULER
fallback and instead throw a runtime_error for unknown/unsupported tile configs;
ensure you add explicit cases or a clear error for the missing enums
(CtaShape128x128x256B, CtaShape256x128x128B, CtaShape128x256x128B) or let the
default throw with a message like "[Error][FP4][dispatch_gemm_cta_shape]
Unsupported Gemm tile config: <enum>" so the behavior matches the existing
throws for Undefined and ChooseWithHeuristic.

---

Nitpick comments:
In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu`:
- Around line 58-69: DISPATCH_SWAP_AB contains an unnecessary else-if and an
unreachable TVM_FFI_ICHECK(false) path because swap_ab is a bool; change the
branching to a simple if/else so the macro only tests if (swap_ab) { constexpr
bool SWAP_AB = true; return __VA_ARGS__(); } else { constexpr bool SWAP_AB =
false; return __VA_ARGS__(); } and remove the trailing TVM_FFI_ICHECK(false) and
return false; lines to eliminate dead code while keeping the same behavior for
DISPATCH_SWAP_AB, swap_ab, and the SWAP_AB constant.

In `@include/flashinfer/gemm/cutlass_gemm_configs.h`:
- Around line 391-401: The CutlassGemmConfig constructor that sets
sm_version=120 added the parameter swap_ab without a default, breaking callers;
update the CutlassGemmConfig(CutlassTileConfigSM120..., bool swap_ab, bool
use_stream_k = false) constructor signature to give swap_ab a default of false
(matching the member default) so existing 4- or 5-argument call sites still
compile and semantics remain consistent with the swap_ab member and use_stream_k
parameter.
- Around line 413-464: The toString() and operator<< output for
CutlassGemmConfig must include the swap_ab field so logs reflect the kernel
instantiation choice; update CutlassGemmConfig::toString() (in the
is_tma_warp_specialized branch and the non-TMA/enableCudaKernel branches) to
append a line reporting "swap_ab: true/false" (similar to how use_stream_k and
enableCudaKernel are printed), and update the free operator<< overload to
include ", swap_ab: " << (config.swap_ab ? "true" : "false") in both the
is_tma_warp_specialized and else branches (use getTileConfigAsInt(),
is_tma_warp_specialized, use_stream_k, and enableCudaKernel to locate the
relevant code).

In `@include/flashinfer/gemm/mxfp8_gemm_template_sm120.h`:
- Around line 218-229: Add a concise comment above the hot-path branch that uses
SWAP_AB_ (the block creating Mxfp8GemmOperator gemm and calling
prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_) explaining
that the if constexpr (SWAP_AB_) path intentionally swaps A and B, their scale
factors (input_sf/weight_sf), and the m/n dimensions to improve warp utilization
when M is small and N is large (a common CUTLASS optimization), and note that
this is why LayoutC is flipped to ColumnMajor in the corresponding code path;
include a short note about alternative approaches considered (e.g., tiling or
different CTA shapes) and that this is a deliberate performance trade-off for
this kernel.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 23477909-64f7-4bda-b028-6d8a0fb75da8

📥 Commits

Reviewing files that changed from the base of the PR and between 498e837 and 73f3910.

📒 Files selected for processing (25)
  • 3rdparty/cutlass
  • csrc/fp4_gemm_cutlass_sm120.cu
  • csrc/fp4_gemm_cutlass_sm120.jinja
  • csrc/group_gemm_mxfp4_groupwise_sm120.cu
  • csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_sm120_binding.cu
  • csrc/mxfp8_gemm_cutlass_sm120.jinja
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
  • flashinfer/jit/gemm/cutlass/generate_kernels.py
  • include/flashinfer/gemm/cutlass_gemm_configs.h
  • include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h
  • include/flashinfer/gemm/fp4_gemm_template_sm120.h
  • include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh
  • include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
  • include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h
  • include/flashinfer/gemm/mxfp8_gemm_template_sm120.h
  • tests/gemm/test_group_gemm_fp4.py
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
  • tests/gemm/test_mm_mxfp8_sm120.py

Comment thread tests/gemm/test_mm_mxfp8_sm120.py Outdated
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 24, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !598 has been created, and the CI pipeline #49347545 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #49347545: 10/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

6243-6244: ⚠️ Potential issue | 🟡 Minor

Use strict bool checks for swap_ab validation.

swap_ab not in [True, False] accepts 0/1 because bool is a subclass of int. Since the function signature declares swap_ab: bool, use isinstance(swap_ab, bool) for strict validation.

🔧 Proposed fix
-    if swap_ab not in [True, False]:
+    if not isinstance(swap_ab, bool):
         raise ValueError(f"swap_ab must be a boolean value, but got {swap_ab}")

Also applies to: 6491-6492

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 6243 - 6244, Replace the current
loose boolean check for swap_ab (which uses "swap_ab not in [True, False]") with
a strict type check using isinstance(swap_ab, bool) in the validation logic;
update the same pattern at the other occurrence as well (the second instance
around the block that references swap_ab later in the file) so that
functions/methods performing swap_ab validation enforce actual bool types rather
than accepting ints like 0/1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6243-6244: Replace the current loose boolean check for swap_ab
(which uses "swap_ab not in [True, False]") with a strict type check using
isinstance(swap_ab, bool) in the validation logic; update the same pattern at
the other occurrence as well (the second instance around the block that
references swap_ab later in the file) so that functions/methods performing
swap_ab validation enforce actual bool types rather than accepting ints like
0/1.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d9c92c71-6514-4c51-b38a-d4a17fd500ff

📥 Commits

Reviewing files that changed from the base of the PR and between e44f175 and 010a611.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 28, 2026

Changes and internal CI LGTM. Will wait for another pair of eyes to review

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx aleozlx added the v0.6.11 release blocker label for 0.6.11 label Apr 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: gemm run-ci v0.6.11 release blocker label for 0.6.11

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants