Skip to content

feat: Add backend="b12x" for mm_fp4 on SM120#3051

Merged
bkryu merged 10 commits intoflashinfer-ai:mainfrom
bkryu:b12x_mm_fp4
Apr 14, 2026
Merged

feat: Add backend="b12x" for mm_fp4 on SM120#3051
bkryu merged 10 commits intoflashinfer-ai:mainfrom
bkryu:b12x_mm_fp4

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 13, 2026

📌 Description

Summary

  • Add a new backend="b12x" option for mm_fp4 targeting SM120 GPUs. Supports nvfp4 only.
  • Port the b12x block-scaled NVFP4 dense GEMM kernel using CuTe DSL. Ported from the b12x library
  • On SM120, backend="auto" now prefers "b12x" over "cutlass" and "cudnn" for NVFP4
  • SM121 (Spark) is not yet supported pending a nvidia-cutlass-dsl==4.5 wheel release.

Changes

File Change
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py New SM120 kernel with wrapper() method for FlashInfer's TVM-FFI compile interface
flashinfer/cute_dsl/utils.py Add sm120_make_smem_layout_sfa/sfb with 64-aligned tile support
flashinfer/gemm/gemm_base.py New _b12x_gemm_fp4_requirement, _b12x_gemm_fp4_runner (separate cache and runner class), _select_default_sm120_mma_tiler heuristic, SM120 auto-selection in heuristic
flashinfer/gemm/__init__.py Export Sm120BlockScaledDenseGemmKernel
tests/gemm/test_mm_fp4.py Add "b12x" to backend parametrize with SM120-only skip
benchmarks/routines/gemm.py Add "b12x" to CLI choices and autotune-supported backends, remove redundant backend guard in run_backend

Performance numbers on RTX 5090 (SM120)

Geomean speedup vs CUTLASS:

  • cuDNN: 1.02x
  • b12x: 1.20x

b12x performance is particularly strong on small-M (decode) shapes where the underfill tiles (64x64, 64x128) kick in — many of those show 1.3-1.6x speedup. The larger shapes are roughly at parity

Click to view performance comparisons on between backends
M N K cuDNN (us) CUTLASS (us) b12x (us) best
1 512 7168 18.448 24.240 13.680 b12x
4 512 7168 18.976 24.400 13.983 b12x
16 512 7168 19.152 24.304 14.303 b12x
64 512 7168 19.088 24.352 15.360 b12x
256 512 7168 21.728 24.575 25.392 cuDNN
1024 512 7168 23.519 25.232 26.592 cuDNN
1 896 1024 7.056 6.512 4.720 b12x
4 896 1024 7.232 6.608 4.976 b12x
16 896 1024 7.232 6.544 4.944 b12x
64 896 1024 7.152 6.432 4.993 b12x
256 896 1024 7.360 6.640 6.656 CUTLASS
1024 896 1024 7.664 6.736 6.928 CUTLASS
1 896 5120 20.192 18.144 11.200 b12x
8 896 5120 18.624 18.208 11.424 b12x
64 896 5120 19.008 18.224 12.560 b12x
512 896 5120 20.192 18.864 20.032 CUTLASS
1 1024 7168 18.895 24.256 14.880 b12x
4 1024 7168 19.231 24.592 15.744 b12x
16 1024 7168 19.488 24.671 16.432 b12x
256 1024 7168 23.103 24.720 25.824 cuDNN
1024 1024 7168 28.928 26.160 28.816 CUTLASS
1 1280 8192 18.287 27.872 18.191 b12x
8 1280 8192 18.688 27.856 19.296 cuDNN
64 1280 8192 18.528 27.920 19.936 cuDNN
512 1280 8192 21.328 29.040 31.599 cuDNN
1 1792 5120 17.600 18.864 14.592 b12x
8 1792 5120 17.488 18.912 14.272 b12x
64 1792 5120 20.063 18.832 14.976 b12x
512 1792 5120 23.536 19.904 22.080 CUTLASS
1 2560 8192 20.304 29.328 25.520 cuDNN
8 2560 8192 22.288 29.151 26.064 cuDNN
64 2560 8192 19.472 29.184 26.272 cuDNN
512 2560 8192 34.128 30.752 33.871 CUTLASS
1 3584 5120 21.456 22.128 19.504 b12x
8 3584 5120 21.071 22.111 19.520 b12x
64 3584 5120 20.800 22.463 20.128 b12x
512 3584 5120 26.079 24.608 24.736 CUTLASS
1 4608 7168 29.040 32.512 28.016 b12x
4 4608 7168 28.816 32.560 28.192 b12x
16 4608 7168 32.399 32.992 28.559 b12x
64 4608 7168 24.688 32.479 28.496 cuDNN
256 4608 7168 35.776 34.255 35.888 CUTLASS
1024 4608 7168 68.367 71.200 67.919 b12x
1 5120 640 7.632 6.768 5.456 b12x
8 5120 640 7.648 6.624 5.200 b12x
64 5120 640 7.808 6.752 5.296 b12x
512 5120 640 8.928 8.288 7.504 b12x
1 5120 1024 8.240 7.728 6.272 b12x
8 5120 1024 8.079 7.680 6.176 b12x
64 5120 1024 8.160 7.536 5.952 b12x
512 5120 1024 9.136 10.368 8.784 b12x
1 5120 1280 9.968 9.088 7.456 b12x
8 5120 1280 9.680 9.328 7.504 b12x
64 5120 1280 9.775 9.008 7.280 b12x
512 5120 1280 10.848 11.776 10.128 b12x
1 5120 2048 12.816 12.512 10.032 b12x
8 5120 2048 12.640 12.160 9.888 b12x
64 5120 2048 12.816 11.536 9.872 b12x
512 5120 2048 13.872 14.624 13.232 b12x
1 5120 2560 15.024 14.384 11.760 b12x
8 5120 2560 14.816 14.336 11.968 b12x
64 5120 2560 14.944 13.872 11.584 b12x
512 5120 2560 16.400 16.592 15.887 b12x
1 5120 4096 20.608 19.040 15.584 b12x
8 5120 4096 20.527 18.800 15.984 b12x
64 5120 4096 20.784 19.280 17.168 b12x
512 5120 4096 23.696 23.967 23.376 b12x
1 5120 5120 24.303 23.856 19.823 b12x
8 5120 5120 22.415 23.663 19.840 b12x
64 5120 5120 25.568 23.552 20.560 b12x
512 5120 5120 29.328 30.016 29.152 b12x
1 5120 8192 35.887 33.903 27.616 b12x
8 5120 8192 31.424 34.336 29.200 b12x
64 5120 8192 36.224 34.591 31.696 b12x
512 5120 8192 42.111 43.791 41.440 b12x
1 5120 16384 59.119 58.911 49.968 b12x
8 5120 16384 59.103 58.959 50.175 b12x
64 5120 16384 51.456 58.864 51.407 b12x
512 5120 16384 81.375 84.191 81.871 cuDNN
1 7168 256 5.424 4.976 3.968 b12x
4 7168 256 5.728 5.168 3.935 b12x
16 7168 256 5.456 4.992 4.096 b12x
64 7168 256 5.855 5.184 4.240 b12x
256 7168 256 5.856 5.488 5.040 b12x
1024 7168 256 9.696 12.912 10.048 cuDNN
1 7168 512 7.072 7.008 5.184 b12x
4 7168 512 7.120 6.864 5.408 b12x
16 7168 512 6.976 6.720 5.120 b12x
64 7168 512 7.071 6.559 5.456 b12x
256 7168 512 7.296 7.104 6.576 b12x
1024 7168 512 15.280 18.000 14.752 b12x
4 7168 2304 15.168 14.800 11.551 b12x
16 7168 2304 14.848 14.479 11.232 b12x
64 7168 2304 14.528 13.712 11.423 b12x
256 7168 2304 15.136 15.840 14.320 b12x
1 7168 4608 25.552 25.792 19.264 b12x
4 7168 4608 25.200 25.184 18.560 b12x
16 7168 4608 27.648 24.320 19.120 b12x
64 7168 4608 25.952 24.383 22.544 b12x
256 7168 4608 27.984 27.328 26.480 b12x
1024 7168 4608 70.480 72.335 69.743 b12x
1 7168 5120 28.063 27.552 19.552 b12x
8 7168 5120 26.048 26.287 20.192 b12x
64 7168 5120 27.824 25.215 23.312 b12x
512 7168 5120 47.567 48.175 47.648 cuDNN
1 8192 1024 10.352 9.775 7.376 b12x
8 8192 1024 10.367 9.456 7.696 b12x
64 8192 1024 9.568 9.233 7.712 b12x
512 8192 1024 14.016 15.264 13.776 b12x
1 8192 2048 14.832 14.688 10.640 b12x
8 8192 2048 14.591 14.080 10.560 b12x
64 8192 2048 14.128 13.104 11.936 b12x
512 8192 2048 22.272 23.280 22.080 b12x
1 8192 3584 21.087 21.872 15.727 b12x
8 8192 3584 22.272 21.264 15.808 b12x
64 8192 3584 20.448 19.200 18.336 b12x
512 8192 3584 35.423 35.023 35.344 CUTLASS
1 8192 4096 22.880 23.872 16.416 b12x
8 8192 4096 25.312 23.663 18.352 b12x
64 8192 4096 22.176 21.184 20.624 b12x
512 8192 4096 39.392 38.928 39.552 CUTLASS
1 8192 7168 37.663 38.783 26.863 b12x
8 8192 7168 37.888 37.296 27.663 b12x
64 8192 7168 37.584 34.224 32.416 b12x
512 8192 7168 67.824 70.751 67.759 b12x
1 8192 8192 40.288 42.784 29.472 b12x
8 8192 8192 39.487 42.191 31.904 b12x
64 8192 8192 38.159 35.663 39.263 CUTLASS
512 8192 8192 76.895 80.655 77.391 cuDNN
1 8192 14336 69.615 70.624 52.720 b12x
8 8192 14336 69.647 70.175 53.615 b12x
64 8192 14336 61.919 67.904 58.223 b12x
512 8192 14336 124.814 128.014 124.735 b12x
1 8192 28672 134.031 135.231 100.127 b12x
8 8192 28672 133.935 131.982 102.047 b12x
64 8192 28672 108.863 127.887 122.463 cuDNN
512 8192 28672 233.486 243.533 232.621 b12x
1 9216 7168 41.200 45.840 29.296 b12x
4 9216 7168 41.232 46.031 29.584 b12x
16 9216 7168 39.008 42.911 32.912 b12x
64 9216 7168 37.999 35.135 33.151 b12x
256 9216 7168 42.655 39.839 41.536 CUTLASS
1024 9216 7168 125.534 128.511 125.279 b12x
1 10240 8192 47.807 50.847 35.328 b12x
8 10240 8192 47.535 51.615 35.887 b12x
64 10240 8192 38.944 38.336 44.496 CUTLASS
512 10240 8192 84.383 86.399 83.343 b12x

🔍 Related Issues

#3013

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added "b12x" FP4 backend with SM120-optimized execution and a new SM120 block-scaled dense GEMM kernel.
    • Registered SM120 kernel for CuTe-DSL-enabled builds and added SM120-specific shared-memory layouts.
  • Enhancements

    • Improved SM120 tiling/auto-selection heuristics for small-M and low-occupancy scenarios.
  • Tests

    • Extended FP4 tests to include "b12x" with SM120-specific preconditions and skip logic.

bkryu and others added 6 commits April 13, 2026 10:41
Port the b12x block-scaled NVFP4 dense GEMM kernel to FlashInfer as a
backend='cute-dsl' option for mm_fp4 on SM120 and SM121 GPUs.

Key features:
- Warp-level MMA (MmaMXF4NVF4Op m16n8k64) with 8 MMA warps + 1 DMA warp
- Underfill tile selection (64x64, 64x128, 128x64) for small-M shapes
- SF tile decoupling: scale factor SMEM rounded up to 128-element blocks
- Epilogue sync elimination for single-store tiles
- Pipeline stage cap at 4 for bounded register pressure
- PersistentTileSchedulerParams workaround for CUTLASS DSL 4.4.1

Changes:
- New: flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
- Modified: flashinfer/cute_dsl/utils.py (sm120_make_smem_layout_sfa/sfb)
- Modified: flashinfer/gemm/gemm_base.py (SM120 dispatch, tactics, heuristic)
- Modified: flashinfer/gemm/__init__.py (export Sm120BlockScaledDenseGemmKernel)
- Modified: tests/gemm/test_mm_fp4.py (enable SM120/121 for cute-dsl)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…o SM120 only

- Add 'b12x' as a new backend name for mm_fp4 (SM120 only)
- Revert 'cute-dsl' requirement back to SM100/SM103 only
- Remove SM121 support (pending CuTe-DSL 4.5 wheel fix)
- Add 'b12x' to benchmark CLI choices
- Update tests to parametrize 'b12x' backend separately

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

Adds SM120 (sm_120) block‑scaled GEMM support: new "b12x" backend and runner, SM120 CuTe‑DSL shared‑memory layout ops, an SM120 block‑scaled DenseGemm kernel and host launch path, plus tests/benchmarks and mm_fp4 backend wiring and heuristics.

Changes

Cohort / File(s) Summary
CuTe‑DSL Shared Memory Layout Utilities
flashinfer/cute_dsl/utils.py
Added sm120_make_smem_layout_sfa and sm120_make_smem_layout_sfb user‑ops to construct staged SMEM layouts for SFA/SFB with divisibility checks, quantized tile rounding, and shape/stride assembly.
Public API exports
flashinfer/gemm/__init__.py
Conditionally register Sm120BlockScaledDenseGemmKernel in _cute_dsl_kernels / __all__ when CuTe‑DSL is available.
Core GEMM backend & heuristics
flashinfer/gemm/gemm_base.py
Add "b12x" backend literal to mm_fp4 API, implement _b12x_gemm_fp4_requirement and _b12x_gemm_fp4_runner, add _select_default_sm120_mma_tiler, and prefer "b12x" in the SM120+NVFP4 heuristic.
SM120 kernel implementation
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
New SM120 block‑scaled DenseGemm kernel (device warp MMA + DMA/TMA staging + epilogue), compile/cache helpers, host _DenseGemmLaunch, dense_gemm API, and Sm120BlockScaledDenseGemmKernel alias.
Tests & Benchmarks
tests/gemm/test_mm_fp4.py, benchmarks/routines/gemm.py
Add "b12x" to backend parametrizations and benchmark CLI choices; tests add SM120/NVFP4/layout preconditions and benchmark helper removed backend allowlist so runner is invoked for the backend.

Sequence Diagram

sequenceDiagram
    actor User
    participant mmfp4 as mm_fp4()
    participant Selector as Heuristic Selector
    participant Runner as b12x Runner
    participant Compiler as cute.compile/cache
    participant Kernel as Sm120BlockScaledDenseGemmKernel
    participant GPU as GPU Device

    User->>mmfp4: call mm_fp4(..., backend="auto" or "b12x")
    mmfp4->>Selector: resolve backend (heuristic: prefer b12x on SM120+NVFP4)
    Selector-->>mmfp4: backend="b12x"
    mmfp4->>Runner: invoke _b12x_gemm_fp4_runner(...)
    Runner->>Compiler: compile/select tactic & layouts (cached)
    Compiler-->>Runner: compiled kernel callable
    Runner->>Kernel: prepare descriptors & launch on stream
    Kernel->>GPU: device execution (TMA DMA -> SMEM -> MMA warps -> epilogue -> TMA store)
    GPU-->>Kernel: kernel completes
    Kernel-->>Runner: return output tensor
    Runner-->>mmfp4: deliver result
    mmfp4-->>User: return tensor
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related issues

Possibly related PRs

Suggested labels

cute-dsl

Suggested reviewers

  • samuellees
  • yongwww
  • aleozlx
  • dhiraj113
  • nvmbreughe
  • jimmyzho
  • yzh119
  • cyx-6

Poem

🐰 In shared memory I tuck and weave,
For SM120 a kernel I conceive,
CuTe layouts and TMA in flight,
b12x hums through day and night,
hop, compute, deliver—what a sight!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description includes a comprehensive summary section with changes overview, related issues, and performance benchmarks; however, the test checklist items remain unchecked (Tests added/updated and All tests passing), which are required by the template. Verify test status: confirm whether tests have been added/updated for the b12x backend and that all tests pass; update the checklist items accordingly.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main change: adding a new backend option 'b12x' for the mm_fp4 function targeting SM120 GPUs.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the "b12x" backend for block-scaled FP4 GEMM, targeting the SM120 (Blackwell) architecture. Key additions include a new kernel implementation using warp-level MMA, shared memory layout utilities for scale factors, and integration into the mm_fp4 API and test suite. Feedback suggests refactoring the shared memory layout functions and consolidating redundant tile selection logic to improve maintainability and adhere to DRY principles.

Comment thread flashinfer/cute_dsl/utils.py
Comment on lines +1796 to +1811
def _select_default_mma_tiler_mn(m: int, n: int, sm_count: int) -> Tuple[int, int]:
coarse_tile = (128, 128)
coarse_tiles = ((m + coarse_tile[0] - 1) // coarse_tile[0]) * (
(n + coarse_tile[1] - 1) // coarse_tile[1]
)
if m <= 128 and coarse_tiles < max(1, sm_count // 2):
if n > 1536:
return (64, 128)
medium_tile = (128, 64)
medium_tiles = ((m + medium_tile[0] - 1) // medium_tile[0]) * (
(n + medium_tile[1] - 1) // medium_tile[1]
)
if medium_tiles < max(1, sm_count // 2):
return (64, 64)
return (128, 64)
return (128, 128)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _select_default_mma_tiler_mn is a duplicate of _select_default_sm120_mma_tiler defined in flashinfer/gemm/gemm_base.py.

To avoid code duplication, please remove this local implementation and import the function from gemm_base.py.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for sm120 can change in the future so let's keep a separate helper function.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 13, 2026

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@bkryu bkryu self-assigned this Apr 13, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tests/gemm/test_mm_fp4.py (1)

114-114: Add one targeted assertion for the new auto-selection behavior.

This adds explicit b12x coverage, but the PR also changes backend="auto" to prefer b12x on SM120/NVFP4. A fallback regression to cutlass/cudnn would still pass the current tests because they only check numerical output.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_mm_fp4.py` at line 114, Add a targeted assertion to the test
in tests/gemm/test_mm_fp4.py to verify that when backend="auto" on SM120/NVFP4
the runtime chooses "b12x" (not a fallback like "cutlass" or "cudnn"); update
the parametrized test (the pytest.mark.parametrize("backend", ... ) block) or
add a separate case that runs the same codepath with backend="auto" and asserts
the reported/returned selected backend equals "b12x" (use the existing test
helper or function that returns the chosen backend from the mm invocation to
perform the check).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/cute_dsl/utils.py`:
- Around line 444-469: The current K-layout construction can produce truncated
extents because it divides tile_shape_mnk[2] by sf_vec_size and blk_sf and also
uses blk_sf // mma_nsf; add stronger asserts: ensure blk_sf is divisible by
mma_nsf (assert blk_sf % mma_nsf == 0) and ensure the K dimension is divisible
by the SF vector and block factors (assert tile_shape_mnk[2] % (sf_vec_size *
blk_sf) == 0); apply these checks near the existing assert that references
tile_shape_mnk[2] and also make the same changes in the other helper (the block
that builds sSFA_shapeK / sSF_strideM around the other occurrence) so blk_sf //
mma_nsf and tile_shape_mnk[2] // sf_vec_size // blk_sf cannot silently floor.

In `@flashinfer/gemm/__init__.py`:
- Around line 51-58: The new Sm120BlockScaledDenseGemmKernel import from
dense_blockscaled_gemm_sm120 is currently inside a broad try/except that on
failure suppresses all CuTe-DSL exports; separate the SM120 import into its own
try/except so an ImportError for dense_blockscaled_gemm_sm120 only skips
Sm120BlockScaledDenseGemmKernel, and ensure _cute_dsl_kernels and any existing
SM100 imports (e.g., Sm100BlockScaledPersistentDenseGemmKernel and
grouped_gemm_nt_masked) remain defined and exported regardless of the SM120
import outcome.

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py`:
- Around line 79-87: The docstring in dense_blockscaled_gemm_sm120.py advertises
MXF4/sf_vec_size=32 but can_implement() only accepts NVFP4 with sf_vec_size=16;
update the supported-combinations and tile-shape notes in the class/module
docstring to reflect only NVFP4 (A/B: Float4E2M1FN, SF: Float8E4M3FN,
sf_vec_size: 16) and adjust any tile_k constraint text accordingly, and ensure
the docstring language aligns with the runtime check in can_implement().
- Around line 1223-1260: can_implement currently allows any 64-aligned
mma_tiler_mn but _compute_stages then forces ab_stage >= 1 even when a single
(A+B+SF+epilogue) stage cannot fit in SM120; fix by adding a hard rejection when
per-stage shared-memory demand exceeds available SM memory. Specifically, in
_compute_stages (and/or can_implement) compute the raw ab_stage before clamping
by evaluating available_smem_per_stage = (smem_capacity - occupancy * 1024) //
occupancy - mbar_helpers_bytes - epi_bytes and comparing that to
(ab_bytes_per_stage + sf_bytes_per_stage); if available_smem_per_stage <
(ab_bytes_per_stage + sf_bytes_per_stage) then return a failure/indicate
non-implementable (e.g., have can_implement return False or have _compute_stages
raise/return (0,0)), otherwise proceed to compute ab_stage = max(1,
min(calculated_ab_stage,4)); update callers to treat a zero/exception result as
“tiler cannot fit” so bad tilers like (512,512) are rejected early.

---

Nitpick comments:
In `@tests/gemm/test_mm_fp4.py`:
- Line 114: Add a targeted assertion to the test in tests/gemm/test_mm_fp4.py to
verify that when backend="auto" on SM120/NVFP4 the runtime chooses "b12x" (not a
fallback like "cutlass" or "cudnn"); update the parametrized test (the
pytest.mark.parametrize("backend", ... ) block) or add a separate case that runs
the same codepath with backend="auto" and asserts the reported/returned selected
backend equals "b12x" (use the existing test helper or function that returns the
chosen backend from the mm invocation to perform the check).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bcfae6fd-d49b-4c9c-b8c1-54d5ca68f6c4

📥 Commits

Reviewing files that changed from the base of the PR and between e64ae8b and 1b3715b.

📒 Files selected for processing (6)
  • benchmarks/routines/gemm.py
  • flashinfer/cute_dsl/utils.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
  • tests/gemm/test_mm_fp4.py

Comment thread flashinfer/cute_dsl/utils.py
Comment thread flashinfer/gemm/__init__.py Outdated
Comment thread flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
Comment thread flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py (1)

143-221: Missing assertion for ab_stage > 0.

There's an assertion for epi_stage > 0 at line 198, but no corresponding assertion for ab_stage > 0. While can_implement() now includes an SMEM budget check that should prevent configurations where ab_stage would be 0, adding a defensive assertion here would provide a clearer error message if the invariant is violated.

🛡️ Suggested defensive assertion
         self.ab_stage, self.epi_stage = self._compute_stages(
             ...
         )

+        assert self.ab_stage > 0, (
+            "ab_stage <= 0, not enough shared memory. This configuration will be skipped."
+        )
         assert self.epi_stage > 0, (
             "epi_stage <= 0, not enough shared memory. This configuration will be skipped."
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py` around lines 143 -
221, The function _setup_attributes computes (self.ab_stage, self.epi_stage) via
self._compute_stages but only asserts self.epi_stage > 0; add a defensive
assertion that self.ab_stage > 0 immediately after _compute_stages to fail fast
with a clear message if the invariant is violated (mirror the style of the
existing epi_stage assertion and include a helpful message like "ab_stage <= 0,
not enough shared memory. This configuration will be skipped."); this ensures
configurations where ab_stage becomes zero are caught early and points to
_compute_stages/_setup_attributes for debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py`:
- Around line 143-221: The function _setup_attributes computes (self.ab_stage,
self.epi_stage) via self._compute_stages but only asserts self.epi_stage > 0;
add a defensive assertion that self.ab_stage > 0 immediately after
_compute_stages to fail fast with a clear message if the invariant is violated
(mirror the style of the existing epi_stage assertion and include a helpful
message like "ab_stage <= 0, not enough shared memory. This configuration will
be skipped."); this ensures configurations where ab_stage becomes zero are
caught early and points to _compute_stages/_setup_attributes for debugging.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c68e6971-33a2-4f83-b94b-a2d111f0a9d4

📥 Commits

Reviewing files that changed from the base of the PR and between 1b3715b and 652b178.

📒 Files selected for processing (3)
  • flashinfer/cute_dsl/utils.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gemm/init.py

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !541 has been created, and the CI pipeline #48432589 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot stop

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48432589 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !541 has been updated with latest changes, and the CI pipeline #48452126 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

5257-5261: ⚠️ Potential issue | 🟡 Minor

Update the enable_pdl docs for b12x.

The docstring still says this flag is only used by cute-dsl, but Lines 5323-5325 now forward it into _b12x_gemm_fp4_runner(...) as well. The generated API docs will be misleading unless the parameter description is broadened.

Also applies to: 5323-5325

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 5257 - 5261, The docstring for the
parameter enable_pdl is now inaccurate because enable_pdl is forwarded into
_b12x_gemm_fp4_runner (in addition to cute_dsl); update the enable_pdl parameter
description to reflect that it controls Programmatic Dependent Launch behavior
for both the cute_dsl backend and the b12x runner paths (or state that it is
honored by any backend that supports PDL, including _b12x_gemm_fp4_runner),
mentioning both symbols enable_pdl and _b12x_gemm_fp4_runner so readers know
where the flag is used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4948-4959: The fallback path that sets tactic via
_select_default_sm120_mma_tiler can pick an invalid tactic because it skips the
same validation used in get_valid_tactics; update the tactic == -1 branch to
pick a tactic only after validating with
Sm120BlockScaledDenseGemmKernel.can_implement (or by selecting from the list
returned by get_valid_tactics) using the current real_k and c_cutlass_dtype (and
any runner state required), so replace the raw _select_default_sm120_mma_tiler
return with a validated choice or rerun can_implement against the candidate and
loop to the next valid option.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5257-5261: The docstring for the parameter enable_pdl is now
inaccurate because enable_pdl is forwarded into _b12x_gemm_fp4_runner (in
addition to cute_dsl); update the enable_pdl parameter description to reflect
that it controls Programmatic Dependent Launch behavior for both the cute_dsl
backend and the b12x runner paths (or state that it is honored by any backend
that supports PDL, including _b12x_gemm_fp4_runner), mentioning both symbols
enable_pdl and _b12x_gemm_fp4_runner so readers know where the flag is used.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 154c3c6b-d722-4007-9674-00c4db9dbd56

📥 Commits

Reviewing files that changed from the base of the PR and between 652b178 and 315e23e.

📒 Files selected for processing (2)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_mm_fp4.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gemm/test_mm_fp4.py

Comment on lines +4948 to +4959
if tactic is None or tactic == -1:
_sm_count = torch.cuda.get_device_properties(
a.device
).multi_processor_count
tactic = (
_select_default_sm120_mma_tiler(m, n, _sm_count),
(1, 1),
False,
False,
"sm120",
None,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the fallback b12x tactic before using it.

get_valid_tactics() filters each tile through Sm120BlockScaledDenseGemmKernel.can_implement(...), but the tactic == -1 path skips that filter and picks a tile from _select_default_sm120_mma_tiler() using only m, n, and SM count. Since validity also depends on real_k and c_cutlass_dtype, this fallback can choose a tactic the runner itself would reject. Please derive the default from the validated set or rerun can_implement(...) here.

Suggested fix
             if tactic is None or tactic == -1:
-                _sm_count = torch.cuda.get_device_properties(
-                    a.device
-                ).multi_processor_count
-                tactic = (
-                    _select_default_sm120_mma_tiler(m, n, _sm_count),
-                    (1, 1),
-                    False,
-                    False,
-                    "sm120",
-                    None,
-                )
+                sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count
+                preferred_tiles = (
+                    _select_default_sm120_mma_tiler(m, n, sm_count),
+                    (128, 128),
+                    (128, 64),
+                    (64, 128),
+                    (64, 64),
+                )
+                for preferred_tile in dict.fromkeys(preferred_tiles):
+                    if Sm120BlockScaledDenseGemmKernel.can_implement(
+                        cutlass.Float4E2M1FN,
+                        cutlass.Float8E4M3FN,
+                        sf_vec_size,
+                        c_cutlass_dtype,
+                        preferred_tile,
+                        (1, 1),
+                        m,
+                        n,
+                        real_k,
+                        batch_size,
+                        "k",
+                        "k",
+                        "n",
+                    ):
+                        tactic = (preferred_tile, (1, 1), False, False, "sm120", None)
+                        break
+                else:
+                    raise RuntimeError(
+                        f"No valid b12x tactic for m={m}, n={n}, k={real_k}, out_dtype={out_dtype}."
+                    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if tactic is None or tactic == -1:
_sm_count = torch.cuda.get_device_properties(
a.device
).multi_processor_count
tactic = (
_select_default_sm120_mma_tiler(m, n, _sm_count),
(1, 1),
False,
False,
"sm120",
None,
)
if tactic is None or tactic == -1:
sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count
preferred_tiles = (
_select_default_sm120_mma_tiler(m, n, sm_count),
(128, 128),
(128, 64),
(64, 128),
(64, 64),
)
for preferred_tile in dict.fromkeys(preferred_tiles):
if Sm120BlockScaledDenseGemmKernel.can_implement(
cutlass.Float4E2M1FN,
cutlass.Float8E4M3FN,
sf_vec_size,
c_cutlass_dtype,
preferred_tile,
(1, 1),
m,
n,
real_k,
batch_size,
"k",
"k",
"n",
):
tactic = (preferred_tile, (1, 1), False, False, "sm120", None)
break
else:
raise RuntimeError(
f"No valid b12x tactic for m={m}, n={n}, k={real_k}, out_dtype={out_dtype}."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4948 - 4959, The fallback path
that sets tactic via _select_default_sm120_mma_tiler can pick an invalid tactic
because it skips the same validation used in get_valid_tactics; update the
tactic == -1 branch to pick a tactic only after validating with
Sm120BlockScaledDenseGemmKernel.can_implement (or by selecting from the list
returned by get_valid_tactics) using the current real_k and c_cutlass_dtype (and
any runner state required), so replace the raw _select_default_sm120_mma_tiler
return with a validated choice or rerun can_implement against the candidate and
loop to the next valid option.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. left a few comments. Major question on sm121 support

)
else:
raise ValueError(f"Unsupported backend: {backend}")
return flashinfer.gemm.mm_fp4(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the ValueError get removed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this value error in the benchmark is redundant because support checks already ensure that runnable backends are called here.

I figured that it would be a good opportunity to simplify the code by removing redundant checks.


if is_cute_dsl_available():
from .kernels.dense_blockscaled_gemm_sm120 import (
Sm120BlockScaledDenseGemmKernel as Sm120BlockScaledDenseGemmKernel,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the kernel name and file name is to mirror the cute dsl kernel imported from cutlass example. But the backend name is not cute-dsl through. It might be hard for people to understand the code in the future. Do we want to add something like b12x? Or we could even just use cute-dsl as the backend for this kernel. The drawback of later is that if there is cutlass block scaled gemm on sm120 in the example in the future and we try to port it, it might cause confusion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I can rename the class names in a followup PR

return True


@supported_compute_capability([120])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about 121?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch and is intentional. SM121 (Spark) is not yet supported and is expected to be supported with a future nvidia-cutlass-dsl release.

batch_size = 1

if tactic is None or tactic == -1:
_sm_count = torch.cuda.get_device_properties(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a heavy query? If so, probably we want to put it in a separate function and add a cache to the result so we don't need to query it every time

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good point. I believe we have a helper function to get the sm count. I can update in a followup PR

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestions are not blocking. Therefore, approving the PR.

@bkryu bkryu merged commit 8c93f92 into flashinfer-ai:main Apr 14, 2026
65 of 71 checks passed
Comment thread benchmarks/routines/gemm.py
@coderabbitai coderabbitai Bot mentioned this pull request Apr 15, 2026
5 tasks
@bkryu bkryu deleted the b12x_mm_fp4 branch April 15, 2026 16:18
aleozlx pushed a commit that referenced this pull request Apr 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Follow-up to #3051 (`backend="b12x"` for `mm_fp4` on SM120) and #3080
(`b12x_fused_moe` / `B12xMoEWrapper` SM120 APIs) addressing four
reviewer comments that landed after merge. No public API changes; no
kernel behavior changes.

- **Copyright**: bump `tests/moe/test_b12x_fused_moe.py` to 2026.
- **Benchmark split**: new `b12x_fused_moe` routine (SM120/121, BF16
input, SwiGLU + ReLU²); `cute_dsl_fp4_block_scale_moe` is now
SM100/103-only. Aligns with the `B12xMoEWrapper` / `CuteDslMoEWrapper`
Python API split.
- **Cache SM count**: replace a hot-path
`torch.cuda.get_device_properties(...).multi_processor_count` in the
`b12x` FP4 GEMM runner with the cached `get_device_sm_count()` helper.
- **Rename for provenance**: `dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py` and
`Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel` (via `git mv`, 6 import sites
updated). `backend="b12x"` string unchanged.


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added new `b12x_fused_moe` benchmark routine for NVFP4 MoE inference
with support for both SwiGLU and ReLU2 activation types.
* Extended Blackwell architecture support with updated kernel
implementations.

* **Documentation**
* Updated benchmark samples with new `b12x_fused_moe` test
configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants