feat: Add backend="b12x" for mm_fp4 on SM120#3051
Conversation
Port the b12x block-scaled NVFP4 dense GEMM kernel to FlashInfer as a backend='cute-dsl' option for mm_fp4 on SM120 and SM121 GPUs. Key features: - Warp-level MMA (MmaMXF4NVF4Op m16n8k64) with 8 MMA warps + 1 DMA warp - Underfill tile selection (64x64, 64x128, 128x64) for small-M shapes - SF tile decoupling: scale factor SMEM rounded up to 128-element blocks - Epilogue sync elimination for single-store tiles - Pipeline stage cap at 4 for bounded register pressure - PersistentTileSchedulerParams workaround for CUTLASS DSL 4.4.1 Changes: - New: flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py - Modified: flashinfer/cute_dsl/utils.py (sm120_make_smem_layout_sfa/sfb) - Modified: flashinfer/gemm/gemm_base.py (SM120 dispatch, tactics, heuristic) - Modified: flashinfer/gemm/__init__.py (export Sm120BlockScaledDenseGemmKernel) - Modified: tests/gemm/test_mm_fp4.py (enable SM120/121 for cute-dsl) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…o SM120 only - Add 'b12x' as a new backend name for mm_fp4 (SM120 only) - Revert 'cute-dsl' requirement back to SM100/SM103 only - Remove SM121 support (pending CuTe-DSL 4.5 wheel fix) - Add 'b12x' to benchmark CLI choices - Update tests to parametrize 'b12x' backend separately Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdds SM120 (sm_120) block‑scaled GEMM support: new Changes
Sequence DiagramsequenceDiagram
actor User
participant mmfp4 as mm_fp4()
participant Selector as Heuristic Selector
participant Runner as b12x Runner
participant Compiler as cute.compile/cache
participant Kernel as Sm120BlockScaledDenseGemmKernel
participant GPU as GPU Device
User->>mmfp4: call mm_fp4(..., backend="auto" or "b12x")
mmfp4->>Selector: resolve backend (heuristic: prefer b12x on SM120+NVFP4)
Selector-->>mmfp4: backend="b12x"
mmfp4->>Runner: invoke _b12x_gemm_fp4_runner(...)
Runner->>Compiler: compile/select tactic & layouts (cached)
Compiler-->>Runner: compiled kernel callable
Runner->>Kernel: prepare descriptors & launch on stream
Kernel->>GPU: device execution (TMA DMA -> SMEM -> MMA warps -> epilogue -> TMA store)
GPU-->>Kernel: kernel completes
Kernel-->>Runner: return output tensor
Runner-->>mmfp4: deliver result
mmfp4-->>User: return tensor
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for the "b12x" backend for block-scaled FP4 GEMM, targeting the SM120 (Blackwell) architecture. Key additions include a new kernel implementation using warp-level MMA, shared memory layout utilities for scale factors, and integration into the mm_fp4 API and test suite. Feedback suggests refactoring the shared memory layout functions and consolidating redundant tile selection logic to improve maintainability and adhere to DRY principles.
| def _select_default_mma_tiler_mn(m: int, n: int, sm_count: int) -> Tuple[int, int]: | ||
| coarse_tile = (128, 128) | ||
| coarse_tiles = ((m + coarse_tile[0] - 1) // coarse_tile[0]) * ( | ||
| (n + coarse_tile[1] - 1) // coarse_tile[1] | ||
| ) | ||
| if m <= 128 and coarse_tiles < max(1, sm_count // 2): | ||
| if n > 1536: | ||
| return (64, 128) | ||
| medium_tile = (128, 64) | ||
| medium_tiles = ((m + medium_tile[0] - 1) // medium_tile[0]) * ( | ||
| (n + medium_tile[1] - 1) // medium_tile[1] | ||
| ) | ||
| if medium_tiles < max(1, sm_count // 2): | ||
| return (64, 64) | ||
| return (128, 64) | ||
| return (128, 128) |
There was a problem hiding this comment.
The default for sm120 can change in the future so let's keep a separate helper function.
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tests/gemm/test_mm_fp4.py (1)
114-114: Add one targeted assertion for the newauto-selection behavior.This adds explicit
b12xcoverage, but the PR also changesbackend="auto"to preferb12xon SM120/NVFP4. A fallback regression tocutlass/cudnnwould still pass the current tests because they only check numerical output.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_mm_fp4.py` at line 114, Add a targeted assertion to the test in tests/gemm/test_mm_fp4.py to verify that when backend="auto" on SM120/NVFP4 the runtime chooses "b12x" (not a fallback like "cutlass" or "cudnn"); update the parametrized test (the pytest.mark.parametrize("backend", ... ) block) or add a separate case that runs the same codepath with backend="auto" and asserts the reported/returned selected backend equals "b12x" (use the existing test helper or function that returns the chosen backend from the mm invocation to perform the check).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/cute_dsl/utils.py`:
- Around line 444-469: The current K-layout construction can produce truncated
extents because it divides tile_shape_mnk[2] by sf_vec_size and blk_sf and also
uses blk_sf // mma_nsf; add stronger asserts: ensure blk_sf is divisible by
mma_nsf (assert blk_sf % mma_nsf == 0) and ensure the K dimension is divisible
by the SF vector and block factors (assert tile_shape_mnk[2] % (sf_vec_size *
blk_sf) == 0); apply these checks near the existing assert that references
tile_shape_mnk[2] and also make the same changes in the other helper (the block
that builds sSFA_shapeK / sSF_strideM around the other occurrence) so blk_sf //
mma_nsf and tile_shape_mnk[2] // sf_vec_size // blk_sf cannot silently floor.
In `@flashinfer/gemm/__init__.py`:
- Around line 51-58: The new Sm120BlockScaledDenseGemmKernel import from
dense_blockscaled_gemm_sm120 is currently inside a broad try/except that on
failure suppresses all CuTe-DSL exports; separate the SM120 import into its own
try/except so an ImportError for dense_blockscaled_gemm_sm120 only skips
Sm120BlockScaledDenseGemmKernel, and ensure _cute_dsl_kernels and any existing
SM100 imports (e.g., Sm100BlockScaledPersistentDenseGemmKernel and
grouped_gemm_nt_masked) remain defined and exported regardless of the SM120
import outcome.
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py`:
- Around line 79-87: The docstring in dense_blockscaled_gemm_sm120.py advertises
MXF4/sf_vec_size=32 but can_implement() only accepts NVFP4 with sf_vec_size=16;
update the supported-combinations and tile-shape notes in the class/module
docstring to reflect only NVFP4 (A/B: Float4E2M1FN, SF: Float8E4M3FN,
sf_vec_size: 16) and adjust any tile_k constraint text accordingly, and ensure
the docstring language aligns with the runtime check in can_implement().
- Around line 1223-1260: can_implement currently allows any 64-aligned
mma_tiler_mn but _compute_stages then forces ab_stage >= 1 even when a single
(A+B+SF+epilogue) stage cannot fit in SM120; fix by adding a hard rejection when
per-stage shared-memory demand exceeds available SM memory. Specifically, in
_compute_stages (and/or can_implement) compute the raw ab_stage before clamping
by evaluating available_smem_per_stage = (smem_capacity - occupancy * 1024) //
occupancy - mbar_helpers_bytes - epi_bytes and comparing that to
(ab_bytes_per_stage + sf_bytes_per_stage); if available_smem_per_stage <
(ab_bytes_per_stage + sf_bytes_per_stage) then return a failure/indicate
non-implementable (e.g., have can_implement return False or have _compute_stages
raise/return (0,0)), otherwise proceed to compute ab_stage = max(1,
min(calculated_ab_stage,4)); update callers to treat a zero/exception result as
“tiler cannot fit” so bad tilers like (512,512) are rejected early.
---
Nitpick comments:
In `@tests/gemm/test_mm_fp4.py`:
- Line 114: Add a targeted assertion to the test in tests/gemm/test_mm_fp4.py to
verify that when backend="auto" on SM120/NVFP4 the runtime chooses "b12x" (not a
fallback like "cutlass" or "cudnn"); update the parametrized test (the
pytest.mark.parametrize("backend", ... ) block) or add a separate case that runs
the same codepath with backend="auto" and asserts the reported/returned selected
backend equals "b12x" (use the existing test helper or function that returns the
chosen backend from the mm invocation to perform the check).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bcfae6fd-d49b-4c9c-b8c1-54d5ca68f6c4
📒 Files selected for processing (6)
benchmarks/routines/gemm.pyflashinfer/cute_dsl/utils.pyflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.pytests/gemm/test_mm_fp4.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py (1)
143-221: Missing assertion forab_stage > 0.There's an assertion for
epi_stage > 0at line 198, but no corresponding assertion forab_stage > 0. Whilecan_implement()now includes an SMEM budget check that should prevent configurations whereab_stagewould be 0, adding a defensive assertion here would provide a clearer error message if the invariant is violated.🛡️ Suggested defensive assertion
self.ab_stage, self.epi_stage = self._compute_stages( ... ) + assert self.ab_stage > 0, ( + "ab_stage <= 0, not enough shared memory. This configuration will be skipped." + ) assert self.epi_stage > 0, ( "epi_stage <= 0, not enough shared memory. This configuration will be skipped." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` around lines 143 - 221, The function _setup_attributes computes (self.ab_stage, self.epi_stage) via self._compute_stages but only asserts self.epi_stage > 0; add a defensive assertion that self.ab_stage > 0 immediately after _compute_stages to fail fast with a clear message if the invariant is violated (mirror the style of the existing epi_stage assertion and include a helpful message like "ab_stage <= 0, not enough shared memory. This configuration will be skipped."); this ensures configurations where ab_stage becomes zero are caught early and points to _compute_stages/_setup_attributes for debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py`:
- Around line 143-221: The function _setup_attributes computes (self.ab_stage,
self.epi_stage) via self._compute_stages but only asserts self.epi_stage > 0;
add a defensive assertion that self.ab_stage > 0 immediately after
_compute_stages to fail fast with a clear message if the invariant is violated
(mirror the style of the existing epi_stage assertion and include a helpful
message like "ab_stage <= 0, not enough shared memory. This configuration will
be skipped."); this ensures configurations where ab_stage becomes zero are
caught early and points to _compute_stages/_setup_attributes for debugging.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c68e6971-33a2-4f83-b94b-a2d111f0a9d4
📒 Files selected for processing (3)
flashinfer/cute_dsl/utils.pyflashinfer/gemm/__init__.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/gemm/init.py
|
/bot run |
|
/bot stop |
|
/bot run |
|
The GitLab CI pipeline #48432589 has been cancelled. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
5257-5261:⚠️ Potential issue | 🟡 MinorUpdate the
enable_pdldocs forb12x.The docstring still says this flag is only used by
cute-dsl, but Lines 5323-5325 now forward it into_b12x_gemm_fp4_runner(...)as well. The generated API docs will be misleading unless the parameter description is broadened.Also applies to: 5323-5325
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 5257 - 5261, The docstring for the parameter enable_pdl is now inaccurate because enable_pdl is forwarded into _b12x_gemm_fp4_runner (in addition to cute_dsl); update the enable_pdl parameter description to reflect that it controls Programmatic Dependent Launch behavior for both the cute_dsl backend and the b12x runner paths (or state that it is honored by any backend that supports PDL, including _b12x_gemm_fp4_runner), mentioning both symbols enable_pdl and _b12x_gemm_fp4_runner so readers know where the flag is used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4948-4959: The fallback path that sets tactic via
_select_default_sm120_mma_tiler can pick an invalid tactic because it skips the
same validation used in get_valid_tactics; update the tactic == -1 branch to
pick a tactic only after validating with
Sm120BlockScaledDenseGemmKernel.can_implement (or by selecting from the list
returned by get_valid_tactics) using the current real_k and c_cutlass_dtype (and
any runner state required), so replace the raw _select_default_sm120_mma_tiler
return with a validated choice or rerun can_implement against the candidate and
loop to the next valid option.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5257-5261: The docstring for the parameter enable_pdl is now
inaccurate because enable_pdl is forwarded into _b12x_gemm_fp4_runner (in
addition to cute_dsl); update the enable_pdl parameter description to reflect
that it controls Programmatic Dependent Launch behavior for both the cute_dsl
backend and the b12x runner paths (or state that it is honored by any backend
that supports PDL, including _b12x_gemm_fp4_runner), mentioning both symbols
enable_pdl and _b12x_gemm_fp4_runner so readers know where the flag is used.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 154c3c6b-d722-4007-9674-00c4db9dbd56
📒 Files selected for processing (2)
flashinfer/gemm/gemm_base.pytests/gemm/test_mm_fp4.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gemm/test_mm_fp4.py
| if tactic is None or tactic == -1: | ||
| _sm_count = torch.cuda.get_device_properties( | ||
| a.device | ||
| ).multi_processor_count | ||
| tactic = ( | ||
| _select_default_sm120_mma_tiler(m, n, _sm_count), | ||
| (1, 1), | ||
| False, | ||
| False, | ||
| "sm120", | ||
| None, | ||
| ) |
There was a problem hiding this comment.
Validate the fallback b12x tactic before using it.
get_valid_tactics() filters each tile through Sm120BlockScaledDenseGemmKernel.can_implement(...), but the tactic == -1 path skips that filter and picks a tile from _select_default_sm120_mma_tiler() using only m, n, and SM count. Since validity also depends on real_k and c_cutlass_dtype, this fallback can choose a tactic the runner itself would reject. Please derive the default from the validated set or rerun can_implement(...) here.
Suggested fix
if tactic is None or tactic == -1:
- _sm_count = torch.cuda.get_device_properties(
- a.device
- ).multi_processor_count
- tactic = (
- _select_default_sm120_mma_tiler(m, n, _sm_count),
- (1, 1),
- False,
- False,
- "sm120",
- None,
- )
+ sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count
+ preferred_tiles = (
+ _select_default_sm120_mma_tiler(m, n, sm_count),
+ (128, 128),
+ (128, 64),
+ (64, 128),
+ (64, 64),
+ )
+ for preferred_tile in dict.fromkeys(preferred_tiles):
+ if Sm120BlockScaledDenseGemmKernel.can_implement(
+ cutlass.Float4E2M1FN,
+ cutlass.Float8E4M3FN,
+ sf_vec_size,
+ c_cutlass_dtype,
+ preferred_tile,
+ (1, 1),
+ m,
+ n,
+ real_k,
+ batch_size,
+ "k",
+ "k",
+ "n",
+ ):
+ tactic = (preferred_tile, (1, 1), False, False, "sm120", None)
+ break
+ else:
+ raise RuntimeError(
+ f"No valid b12x tactic for m={m}, n={n}, k={real_k}, out_dtype={out_dtype}."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if tactic is None or tactic == -1: | |
| _sm_count = torch.cuda.get_device_properties( | |
| a.device | |
| ).multi_processor_count | |
| tactic = ( | |
| _select_default_sm120_mma_tiler(m, n, _sm_count), | |
| (1, 1), | |
| False, | |
| False, | |
| "sm120", | |
| None, | |
| ) | |
| if tactic is None or tactic == -1: | |
| sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count | |
| preferred_tiles = ( | |
| _select_default_sm120_mma_tiler(m, n, sm_count), | |
| (128, 128), | |
| (128, 64), | |
| (64, 128), | |
| (64, 64), | |
| ) | |
| for preferred_tile in dict.fromkeys(preferred_tiles): | |
| if Sm120BlockScaledDenseGemmKernel.can_implement( | |
| cutlass.Float4E2M1FN, | |
| cutlass.Float8E4M3FN, | |
| sf_vec_size, | |
| c_cutlass_dtype, | |
| preferred_tile, | |
| (1, 1), | |
| m, | |
| n, | |
| real_k, | |
| batch_size, | |
| "k", | |
| "k", | |
| "n", | |
| ): | |
| tactic = (preferred_tile, (1, 1), False, False, "sm120", None) | |
| break | |
| else: | |
| raise RuntimeError( | |
| f"No valid b12x tactic for m={m}, n={n}, k={real_k}, out_dtype={out_dtype}." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 4948 - 4959, The fallback path
that sets tactic via _select_default_sm120_mma_tiler can pick an invalid tactic
because it skips the same validation used in get_valid_tactics; update the
tactic == -1 branch to pick a tactic only after validating with
Sm120BlockScaledDenseGemmKernel.can_implement (or by selecting from the list
returned by get_valid_tactics) using the current real_k and c_cutlass_dtype (and
any runner state required), so replace the raw _select_default_sm120_mma_tiler
return with a validated choice or rerun can_implement against the candidate and
loop to the next valid option.
nv-yunzheq
left a comment
There was a problem hiding this comment.
LGTM. left a few comments. Major question on sm121 support
| ) | ||
| else: | ||
| raise ValueError(f"Unsupported backend: {backend}") | ||
| return flashinfer.gemm.mm_fp4( |
There was a problem hiding this comment.
Why the ValueError get removed?
There was a problem hiding this comment.
I found this value error in the benchmark is redundant because support checks already ensure that runnable backends are called here.
I figured that it would be a good opportunity to simplify the code by removing redundant checks.
|
|
||
| if is_cute_dsl_available(): | ||
| from .kernels.dense_blockscaled_gemm_sm120 import ( | ||
| Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel, |
There was a problem hiding this comment.
It seems the kernel name and file name is to mirror the cute dsl kernel imported from cutlass example. But the backend name is not cute-dsl through. It might be hard for people to understand the code in the future. Do we want to add something like b12x? Or we could even just use cute-dsl as the backend for this kernel. The drawback of later is that if there is cutlass block scaled gemm on sm120 in the example in the future and we try to port it, it might cause confusion.
There was a problem hiding this comment.
This is a good point. I can rename the class names in a followup PR
| return True | ||
|
|
||
|
|
||
| @supported_compute_capability([120]) |
There was a problem hiding this comment.
This is a good catch and is intentional. SM121 (Spark) is not yet supported and is expected to be supported with a future nvidia-cutlass-dsl release.
| batch_size = 1 | ||
|
|
||
| if tactic is None or tactic == -1: | ||
| _sm_count = torch.cuda.get_device_properties( |
There was a problem hiding this comment.
Is this a heavy query? If so, probably we want to put it in a separate function and add a cache to the result so we don't need to query it every time
There was a problem hiding this comment.
Another good point. I believe we have a helper function to get the sm count. I can update in a followup PR
nv-yunzheq
left a comment
There was a problem hiding this comment.
The suggestions are not blocking. Therefore, approving the PR.
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Follow-up to #3051 (`backend="b12x"` for `mm_fp4` on SM120) and #3080 (`b12x_fused_moe` / `B12xMoEWrapper` SM120 APIs) addressing four reviewer comments that landed after merge. No public API changes; no kernel behavior changes. - **Copyright**: bump `tests/moe/test_b12x_fused_moe.py` to 2026. - **Benchmark split**: new `b12x_fused_moe` routine (SM120/121, BF16 input, SwiGLU + ReLU²); `cute_dsl_fp4_block_scale_moe` is now SM100/103-only. Aligns with the `B12xMoEWrapper` / `CuteDslMoEWrapper` Python API split. - **Cache SM count**: replace a hot-path `torch.cuda.get_device_properties(...).multi_processor_count` in the `b12x` FP4 GEMM runner with the cached `get_device_sm_count()` helper. - **Rename for provenance**: `dense_blockscaled_gemm_sm120.py` → `dense_blockscaled_gemm_sm120_b12x.py` and `Sm120BlockScaledDenseGemmKernel` → `Sm120B12xBlockScaledDenseGemmKernel` (via `git mv`, 6 import sites updated). `backend="b12x"` string unchanged. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added new `b12x_fused_moe` benchmark routine for NVFP4 MoE inference with support for both SwiGLU and ReLU2 activation types. * Extended Blackwell architecture support with updated kernel implementations. * **Documentation** * Updated benchmark samples with new `b12x_fused_moe` test configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Summary
backend="b12x"option formm_fp4targeting SM120 GPUs. Supports nvfp4 only.b12xblock-scaled NVFP4 dense GEMM kernel using CuTe DSL. Ported from the b12x librarybackend="auto"now prefers "b12x" over "cutlass" and "cudnn" for NVFP4nvidia-cutlass-dsl==4.5wheel release.Changes
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.pywrapper()method for FlashInfer's TVM-FFI compile interfaceflashinfer/cute_dsl/utils.pysm120_make_smem_layout_sfa/sfbwith 64-aligned tile supportflashinfer/gemm/gemm_base.py_b12x_gemm_fp4_requirement,_b12x_gemm_fp4_runner(separate cache and runner class),_select_default_sm120_mma_tilerheuristic, SM120 auto-selection in heuristicflashinfer/gemm/__init__.pySm120BlockScaledDenseGemmKerneltests/gemm/test_mm_fp4.py"b12x"to backend parametrize with SM120-only skipbenchmarks/routines/gemm.py"b12x"to CLI choices and autotune-supported backends, remove redundant backend guard inrun_backendPerformance numbers on RTX 5090 (SM120)
Geomean speedup vs CUTLASS:
b12xperformance is particularly strong on small-M (decode) shapes where the underfill tiles (64x64, 64x128) kick in — many of those show 1.3-1.6x speedup. The larger shapes are roughly at parityClick to view performance comparisons on between backends
🔍 Related Issues
#3013
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Enhancements
Tests