Integrate CUTLASS Small Tile N Blockscaled GEMMs/Grouped GEMMs for SM120 and SM121#3152
Integrate CUTLASS Small Tile N Blockscaled GEMMs/Grouped GEMMs for SM120 and SM121#3152depaulmillz wants to merge 7 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughUpdates CUTLASS submodule pointer and extends SM120/121 GEMM support by adding a compile-time/runtime Changes
Sequence Diagram(s)sequenceDiagram
participant Python as "Python API"
participant FFI as "FFI (host bindings)"
participant Host as "Host launcher"
participant ArgsPrep as "Args-prep kernel (CUDA)"
participant CUTLASS as "CUTLASS GEMM kernel"
Python->>FFI: call group_gemm(..., tile_m,tile_n,tile_k, swap_ab)
FFI->>Host: marshal args, pass swap_ab
Host->>Host: DISPATCH_TILE_N & DISPATCH_SWAP_AB -> select template instantiation
Host->>ArgsPrep: launch args-prep kernel (SwapAB compile-time)
ArgsPrep->>CUTLASS: prepare device args, then invoke CUTLASS kernel
CUTLASS-->>Host: return results (device completion)
Host-->>Python: return/collect outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the swap_ab parameter and new tile configurations (N=32, N=64) for SM120/121 GEMM kernels across FP4, MXFP4, and MXFP8 formats. Key changes include updates to the CUTLASS subproject, dispatch logic, JIT module, and Python API. Feedback highlights a critical compilation error from a missing comma in a template macro, an outdated error message in Python validation logic, and several missing tile configurations in dispatch switches and candidate lists that should be included for completeness.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
501-511:⚠️ Potential issue | 🟡 MinorSM120 MoE dispatch switch missing handlers for two tile configurations.
The SM120
CutlassTileConfigSM120enum includesCtaShape128x128x256BandCtaShape256x128x128B, and in non-FAST_BUILD mode, both are included in theall_tiles[]candidate list for MoE inget_candidate_configs_sm120()(cutlass_heuristic.cpp:629). However, the dispatch switch in moe_gemm_template_dispatch_tma_ws.h (lines 501–511) lacks correspondingSHAPE_CASEhandlers for these two shapes. If either config passes workspace validation and is selected by the heuristic, the dispatcher will hitDEFAULT_CASE(120)and raise "Unsupported tile shape config".The heuristic comment (line 622–623) states that invalid tiles "are skipped gracefully by the try-catch in calcMaxWorkspaceSize", but this relies on workspace validation to fail for these shapes. To avoid fragility and silent filtering, ensure these two shapes are either:
- pruned from the heuristic candidate pool, or
- added as corresponding
SHAPE_CASEentries in the dispatcher.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h` around lines 501 - 511, The dispatch switch in moe_gemm_template_dispatch_tma_ws.h is missing handlers for the CutlassTileConfigSM120 enum values CtaShape128x128x256B and CtaShape256x128x128B; update the dispatcher to add corresponding SHAPE_CASE entries for these two shapes so they won't fall through to DEFAULT_CASE(120). Locate the switch on gemm_config.tile_config_sm120 and add SHAPE_CASE(...) lines matching the missing enum variants (same pattern as the existing SHAPE_CASE entries), or alternatively remove those two enums from the candidate set produced by get_candidate_configs_sm120() in cutlass_heuristic.cpp so they are never considered by calcMaxWorkspaceSize; prefer adding the SHAPE_CASE entries to keep the heuristic intact.include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h (1)
121-146:⚠️ Potential issue | 🟡 MinorMake the dispatcher fail-fast for unhandled tile configurations instead of silently defaulting.
The
CutlassTileConfigSM120enum declaresCtaShape128x128x256B,CtaShape256x128x128B, andCtaShape128x256x128B, but the switch statement indispatch_gemm_cta_shapehas no case for them. Currently, thedefault:branch silently dispatches a 128×128×128 kernel—inconsistent with the explicit throws forUndefinedandChooseWithHeuristic. WhilegetConfigs()currently only emits the 8 supported tiles, making the dispatcher throw for unhandled enums provides better defensive design and consistency (the MXFP8 variant already handles all three).Proposed fix
default: - DISPATCH_WITH_SCHEDULER(128, 128, 128); // Fallback + throw std::runtime_error( + "[Error][FP4][dispatch_gemm_cta_shape] Unsupported SM120 tile config.");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h` around lines 121 - 146, The switch in dispatch_gemm_cta_shape over CutlassTileConfigSM120 silently falls back in default; update it to fail-fast for unhandled enum values by removing the silent DISPATCH_WITH_SCHEDULER fallback and instead throw a runtime_error for unknown/unsupported tile configs; ensure you add explicit cases or a clear error for the missing enums (CtaShape128x128x256B, CtaShape256x128x128B, CtaShape128x256x128B) or let the default throw with a message like "[Error][FP4][dispatch_gemm_cta_shape] Unsupported Gemm tile config: <enum>" so the behavior matches the existing throws for Undefined and ChooseWithHeuristic.flashinfer/gemm/gemm_base.py (1)
6549-6561:⚠️ Potential issue | 🟠 MajorPreserve positional-call compatibility for
outandout_dtype.Adding
swap_abbeforeoutchanges the public Python call ABI. Existing positional callers that passedoutorout_dtypeaftertile_kwill now bind those values toswap_ab/out, which can raise confusing type errors or silently stop using the caller-provided output buffer. Please makeswap_abkeyword-only or append it after the existing optional arguments.🔧 Compatible signature adjustment
def group_gemm_nvfp4_nt_groupwise( a: torch.Tensor, # (cum_m, k) b: torch.Tensor, # (batch_size, n, k // 2) a_scale: torch.Tensor, # (cum_m_padded, k // 16) b_scale: torch.Tensor, # (batch_size, n_padded, k // 16) m_indptr: torch.Tensor, # (batch_size + 1, ) alpha: Optional[torch.Tensor] = None, # (batch_size, ) tile_m: int = 128, tile_n: int = 128, tile_k: int = 128, - swap_ab: bool = True, out: Optional[torch.Tensor] = None, # (cum_m, n) out_dtype: Optional[torch.dtype] = None, + *, + swap_ab: bool = True, ) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 6549 - 6561, The function signature for group_gemm_nvfp4_nt_groupwise breaks positional-call compatibility by inserting swap_ab before out/out_dtype; make swap_ab keyword-only (e.g., introduce a positional-only separator so swap_ab must be passed by name) or move swap_ab to after out_dtype so existing positional callers still bind out and out_dtype correctly; update the function definition and any callers referenced by group_gemm_nvfp4_nt_groupwise to use the new keyword-only or reordered parameter accordingly.
🧹 Nitpick comments (4)
include/flashinfer/gemm/mxfp8_gemm_template_sm120.h (1)
218-229: Consider documenting the AB-swap rationale in this hot path.The
if constexpr (SWAP_AB_)branch swaps A/B, SFA/SFB, and (m, n) — a standard CUTLASS technique to improve warp utilization when the original M is small and N is large. A short comment explaining this choice (and why LayoutC flips toColumnMajorat line 95-96) would help future maintainers. As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".💡 Suggested comment
Mxfp8GemmOperator gemm; \ auto args = [&]() { \ + /* AB-swap: when the logical M is small relative to N, swapping A/B and */ \ + /* (m,n) improves warp utilization by mapping the large dimension onto CTA_M. The */ \ + /* matching LayoutC flip to ColumnMajor keeps writes to D in the correct logical layout. */ \ if constexpr (SWAP_AB_) { \ return prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_< \ Mxfp8GemmOperator>(D, B, A, weight_sf, input_sf, n, m, k, batch_count); \🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/mxfp8_gemm_template_sm120.h` around lines 218 - 229, Add a concise comment above the hot-path branch that uses SWAP_AB_ (the block creating Mxfp8GemmOperator gemm and calling prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_) explaining that the if constexpr (SWAP_AB_) path intentionally swaps A and B, their scale factors (input_sf/weight_sf), and the m/n dimensions to improve warp utilization when M is small and N is large (a common CUTLASS optimization), and note that this is why LayoutC is flipped to ColumnMajor in the corresponding code path; include a short note about alternative approaches considered (e.g., tiling or different CTA shapes) and that this is a deliberate performance trade-off for this kernel.include/flashinfer/gemm/cutlass_gemm_configs.h (2)
391-401: Add a default value forswap_abto avoid a silent API break.
swap_abis inserted beforeuse_stream_kwithout a default, while the member on line 356 defaults tofalse. Any external/downstream caller that previously constructed an SM120CutlassGemmConfigwith the 4‑argument form (or 5‑argument withuse_stream_k) will now fail to compile, and callers that omitswap_abintentionally have no way to express the old behavior. Matching the member default also keeps the class ergonomic and consistent withuse_stream_k.♻️ Proposed fix
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule, - ClusterShape cluster_shape, bool swap_ab, bool use_stream_k = false) + ClusterShape cluster_shape, bool swap_ab = false, bool use_stream_k = false)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/cutlass_gemm_configs.h` around lines 391 - 401, The CutlassGemmConfig constructor that sets sm_version=120 added the parameter swap_ab without a default, breaking callers; update the CutlassGemmConfig(CutlassTileConfigSM120..., bool swap_ab, bool use_stream_k = false) constructor signature to give swap_ab a default of false (matching the member default) so existing 4- or 5-argument call sites still compile and semantics remain consistent with the swap_ab member and use_stream_k parameter.
413-464: Surfaceswap_abintoString()/operator<<for debuggability.The SM120/SM121 branch in
toString()reportsuse_stream_kbut notswap_ab, andoperator<<reports neither. Sinceswap_abnow changes which kernel instantiation is dispatched, including it in the serialized tactic string makes autotuner logs and crash diagnostics unambiguous.♻️ Suggested addition (illustrative)
if (sm_version == 120 || sm_version == 121) { tactic << "\n\tscheduler: " << (use_stream_k ? "StreamK (auto heuristic)" : "DP (default)"); + tactic << "\n\tswap_ab: " << (swap_ab ? "true" : "false"); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/cutlass_gemm_configs.h` around lines 413 - 464, The toString() and operator<< output for CutlassGemmConfig must include the swap_ab field so logs reflect the kernel instantiation choice; update CutlassGemmConfig::toString() (in the is_tma_warp_specialized branch and the non-TMA/enableCudaKernel branches) to append a line reporting "swap_ab: true/false" (similar to how use_stream_k and enableCudaKernel are printed), and update the free operator<< overload to include ", swap_ab: " << (config.swap_ab ? "true" : "false") in both the is_tma_warp_specialized and else branches (use getTileConfigAsInt(), is_tma_warp_specialized, use_stream_k, and enableCudaKernel to locate the relevant code).csrc/group_gemm_mxfp4_groupwise_sm120.cu (1)
58-69: Optional:DISPATCH_SWAP_ABhas a redundant branch and an unreachable check.Since
swap_abis abool, theelse if (swap_ab == false)branch is always entered when the first isn't, which makes the trailingTVM_FFI_ICHECK(false) << "Unsupported SWAP AB"andreturn falsedead code. A plainif/elseis sufficient and avoids the misleading "unsupported" error path. The otherDISPATCH_*macros in this file handle open-ended value sets, so this simplification only applies here.♻️ Proposed simplification
-#define DISPATCH_SWAP_AB(swap_ab, SWAP_AB, ...) \ - [&]() -> bool { \ - if (swap_ab == true) { \ - constexpr bool SWAP_AB = true; \ - return __VA_ARGS__(); \ - } else if (swap_ab == false) { \ - constexpr bool SWAP_AB = false; \ - return __VA_ARGS__(); \ - } \ - TVM_FFI_ICHECK(false) << "Unsupported SWAP AB"; \ - return false; \ - }() +#define DISPATCH_SWAP_AB(swap_ab, SWAP_AB, ...) \ + [&]() -> bool { \ + if (swap_ab) { \ + constexpr bool SWAP_AB = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool SWAP_AB = false; \ + return __VA_ARGS__(); \ + } \ + }()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu` around lines 58 - 69, DISPATCH_SWAP_AB contains an unnecessary else-if and an unreachable TVM_FFI_ICHECK(false) path because swap_ab is a bool; change the branching to a simple if/else so the macro only tests if (swap_ab) { constexpr bool SWAP_AB = true; return __VA_ARGS__(); } else { constexpr bool SWAP_AB = false; return __VA_ARGS__(); } and remove the trailing TVM_FFI_ICHECK(false) and return false; lines to eliminate dead code while keeping the same behavior for DISPATCH_SWAP_AB, swap_ab, and the SWAP_AB constant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 75-77: The assertion message is inconsistent with the expected
value: update the f-string on the assert that checks num_tactics so the message
matches the expected 10 tactics (e.g., change "Expected 5 tactics" to "Expected
10 tactics") where the assertion referencing num_tactics is located in the test
(lines containing the comment about SM120 tile configs and the assert).
---
Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 501-511: The dispatch switch in
moe_gemm_template_dispatch_tma_ws.h is missing handlers for the
CutlassTileConfigSM120 enum values CtaShape128x128x256B and
CtaShape256x128x128B; update the dispatcher to add corresponding SHAPE_CASE
entries for these two shapes so they won't fall through to DEFAULT_CASE(120).
Locate the switch on gemm_config.tile_config_sm120 and add SHAPE_CASE(...) lines
matching the missing enum variants (same pattern as the existing SHAPE_CASE
entries), or alternatively remove those two enums from the candidate set
produced by get_candidate_configs_sm120() in cutlass_heuristic.cpp so they are
never considered by calcMaxWorkspaceSize; prefer adding the SHAPE_CASE entries
to keep the heuristic intact.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6549-6561: The function signature for
group_gemm_nvfp4_nt_groupwise breaks positional-call compatibility by inserting
swap_ab before out/out_dtype; make swap_ab keyword-only (e.g., introduce a
positional-only separator so swap_ab must be passed by name) or move swap_ab to
after out_dtype so existing positional callers still bind out and out_dtype
correctly; update the function definition and any callers referenced by
group_gemm_nvfp4_nt_groupwise to use the new keyword-only or reordered parameter
accordingly.
In `@include/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.h`:
- Around line 121-146: The switch in dispatch_gemm_cta_shape over
CutlassTileConfigSM120 silently falls back in default; update it to fail-fast
for unhandled enum values by removing the silent DISPATCH_WITH_SCHEDULER
fallback and instead throw a runtime_error for unknown/unsupported tile configs;
ensure you add explicit cases or a clear error for the missing enums
(CtaShape128x128x256B, CtaShape256x128x128B, CtaShape128x256x128B) or let the
default throw with a message like "[Error][FP4][dispatch_gemm_cta_shape]
Unsupported Gemm tile config: <enum>" so the behavior matches the existing
throws for Undefined and ChooseWithHeuristic.
---
Nitpick comments:
In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu`:
- Around line 58-69: DISPATCH_SWAP_AB contains an unnecessary else-if and an
unreachable TVM_FFI_ICHECK(false) path because swap_ab is a bool; change the
branching to a simple if/else so the macro only tests if (swap_ab) { constexpr
bool SWAP_AB = true; return __VA_ARGS__(); } else { constexpr bool SWAP_AB =
false; return __VA_ARGS__(); } and remove the trailing TVM_FFI_ICHECK(false) and
return false; lines to eliminate dead code while keeping the same behavior for
DISPATCH_SWAP_AB, swap_ab, and the SWAP_AB constant.
In `@include/flashinfer/gemm/cutlass_gemm_configs.h`:
- Around line 391-401: The CutlassGemmConfig constructor that sets
sm_version=120 added the parameter swap_ab without a default, breaking callers;
update the CutlassGemmConfig(CutlassTileConfigSM120..., bool swap_ab, bool
use_stream_k = false) constructor signature to give swap_ab a default of false
(matching the member default) so existing 4- or 5-argument call sites still
compile and semantics remain consistent with the swap_ab member and use_stream_k
parameter.
- Around line 413-464: The toString() and operator<< output for
CutlassGemmConfig must include the swap_ab field so logs reflect the kernel
instantiation choice; update CutlassGemmConfig::toString() (in the
is_tma_warp_specialized branch and the non-TMA/enableCudaKernel branches) to
append a line reporting "swap_ab: true/false" (similar to how use_stream_k and
enableCudaKernel are printed), and update the free operator<< overload to
include ", swap_ab: " << (config.swap_ab ? "true" : "false") in both the
is_tma_warp_specialized and else branches (use getTileConfigAsInt(),
is_tma_warp_specialized, use_stream_k, and enableCudaKernel to locate the
relevant code).
In `@include/flashinfer/gemm/mxfp8_gemm_template_sm120.h`:
- Around line 218-229: Add a concise comment above the hot-path branch that uses
SWAP_AB_ (the block creating Mxfp8GemmOperator gemm and calling
prepareGemmArgsSm120_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##SWAP_AB_) explaining
that the if constexpr (SWAP_AB_) path intentionally swaps A and B, their scale
factors (input_sf/weight_sf), and the m/n dimensions to improve warp utilization
when M is small and N is large (a common CUTLASS optimization), and note that
this is why LayoutC is flipped to ColumnMajor in the corresponding code path;
include a short note about alternative approaches considered (e.g., tiling or
different CTA shapes) and that this is a deliberate performance trade-off for
this kernel.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 23477909-64f7-4bda-b028-6d8a0fb75da8
📒 Files selected for processing (25)
3rdparty/cutlasscsrc/fp4_gemm_cutlass_sm120.cucsrc/fp4_gemm_cutlass_sm120.jinjacsrc/group_gemm_mxfp4_groupwise_sm120.cucsrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_sm120_binding.cucsrc/mxfp8_gemm_cutlass_sm120.jinjacsrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.hcsrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppcsrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.hflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyflashinfer/jit/gemm/cutlass/generate_kernels.pyinclude/flashinfer/gemm/cutlass_gemm_configs.hinclude/flashinfer/gemm/fp4_gemm_cutlass_template_sm120.hinclude/flashinfer/gemm/fp4_gemm_template_sm120.hinclude/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuhinclude/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.hinclude/flashinfer/gemm/mxfp8_gemm_template_sm120.htests/gemm/test_group_gemm_fp4.pytests/gemm/test_groupwise_scaled_gemm_mxfp4.pytests/gemm/test_mm_mxfp8_sm120.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
|
/bot run |
|
[FAILED] Pipeline #49347545: 10/20 passed |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
6243-6244:⚠️ Potential issue | 🟡 MinorUse strict bool checks for
swap_abvalidation.
swap_ab not in [True, False]accepts0/1becauseboolis a subclass ofint. Since the function signature declaresswap_ab: bool, useisinstance(swap_ab, bool)for strict validation.🔧 Proposed fix
- if swap_ab not in [True, False]: + if not isinstance(swap_ab, bool): raise ValueError(f"swap_ab must be a boolean value, but got {swap_ab}")Also applies to: 6491-6492
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 6243 - 6244, Replace the current loose boolean check for swap_ab (which uses "swap_ab not in [True, False]") with a strict type check using isinstance(swap_ab, bool) in the validation logic; update the same pattern at the other occurrence as well (the second instance around the block that references swap_ab later in the file) so that functions/methods performing swap_ab validation enforce actual bool types rather than accepting ints like 0/1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6243-6244: Replace the current loose boolean check for swap_ab
(which uses "swap_ab not in [True, False]") with a strict type check using
isinstance(swap_ab, bool) in the validation logic; update the same pattern at
the other occurrence as well (the second instance around the block that
references swap_ab later in the file) so that functions/methods performing
swap_ab validation enforce actual bool types rather than accepting ints like
0/1.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d9c92c71-6514-4c51-b38a-d4a17fd500ff
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
|
Changes and internal CI LGTM. Will wait for another pair of eyes to review |
📌 Description
This MR bumps the CUTLASS commit and adds support for the new small tile N Blockscaled GEMMs in a variety of interfaces including
mm_fp4and the TRT-LLM CUTLASS MoE plugin.Compared to fb3bb44 when running nvidia/Qwen3-30B-A3B-NVFP4 on VLLM (commit 2463f00) with this MR, I can observe decent performance improvements for ISL=1024 OSL=1024 on DGX Spark:
For reference I run the following:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
Chores