Skip to content

chore: Address non-blocking review feedback for #3051 / #3080#3128

Merged
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
bkryu:b12x_followup
Apr 24, 2026
Merged

chore: Address non-blocking review feedback for #3051 / #3080#3128
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
bkryu:b12x_followup

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 21, 2026

📌 Description

Follow-up to #3051 (backend="b12x" for mm_fp4 on SM120) and #3080 (b12x_fused_moe / B12xMoEWrapper SM120 APIs) addressing four reviewer comments that landed after merge. No public API changes; no kernel behavior changes.

  • Copyright: bump tests/moe/test_b12x_fused_moe.py to 2026.
  • Benchmark split: new b12x_fused_moe routine (SM120/121, BF16 input, SwiGLU + ReLU²); cute_dsl_fp4_block_scale_moe is now SM100/103-only. Aligns with the B12xMoEWrapper / CuteDslMoEWrapper Python API split.
  • Cache SM count: replace a hot-path torch.cuda.get_device_properties(...).multi_processor_count in the b12x FP4 GEMM runner with the cached get_device_sm_count() helper.
  • Rename for provenance: dense_blockscaled_gemm_sm120.pydense_blockscaled_gemm_sm120_b12x.py and Sm120BlockScaledDenseGemmKernelSm120B12xBlockScaledDenseGemmKernel (via git mv, 6 import sites updated). backend="b12x" string unchanged.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added new b12x_fused_moe benchmark routine for NVFP4 MoE inference with support for both SwiGLU and ReLU2 activation types.
    • Extended Blackwell architecture support with updated kernel implementations.
  • Documentation

    • Updated benchmark samples with new b12x_fused_moe test configurations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

The PR introduces b12x_fused_moe, a new MoE benchmark routine targeting Blackwell SM12x architectures. It refactors benchmark test data generation into a backend-aware helper supporting both CuTe-DSL and B12x variants, updates GEMM kernel references across SM12x MoE modules, and adds corresponding test infrastructure.

Changes

Cohort / File(s) Summary
Benchmark Registration
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/samples/sample_testlist.txt
Added b12x_fused_moe to MoE benchmark APIs with SM120 backend support; removed CuTe-DSL support from cute_dsl_fp4_block_scale_moe for compute capabilities 12.0/12.1; added sample test entries for both SwiGLU and ReLU2 activation modes.
Benchmark Test Implementation
benchmarks/routines/moe.py
Refactored _create_cute_dsl_moe_test_data into generalized _create_nvfp4_moe_test_data with backend/activation-type parameters; added testB12xFusedMoe function supporting bf16-input + NVFP4-weight benchmarking with activation validation and performance metrics; updated testCuteDslFp4BlockScaleMoe to use new helper with SwiGLU enforcement.
MoE SM12x Kernel Modules
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py, moe_micro_kernel.py, moe_static_kernel.py
Updated GEMM kernel imports to use Sm120B12xBlockScaledDenseGemmKernel from dense_blockscaled_gemm_sm120_b12x instead of Sm120BlockScaledDenseGemmKernel.
GEMM Kernel Infrastructure
flashinfer/gemm/__init__.py, flashinfer/gemm/gemm_base.py, flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
Updated kernel symbol exports and imports to reference Sm120B12xBlockScaledDenseGemmKernel; refactored SM count retrieval in _b12x_gemm_fp4_runner to use helper function instead of direct CUDA property query.
Test Updates
tests/moe/test_b12x_fused_moe.py
Updated copyright header from 2025 to 2026.

Sequence Diagram(s)

sequenceDiagram
    participant Benchmark as Benchmark Runner
    participant DataGen as Data Generation
    participant Wrapper as B12xMoEWrapper/<br/>Functional API
    participant Kernel as B12x/CuTe-DSL Kernel
    participant Metrics as Metrics Collector

    alt b12x_fused_moe path
        Benchmark->>DataGen: Call _create_nvfp4_moe_test_data<br/>(backend="b12x", is_gated)
        DataGen->>DataGen: Create bf16 inputs<br/>Generate NVFP4 weights
        DataGen-->>Benchmark: Return prepared tensors
    else cute_dsl path
        Benchmark->>DataGen: Call _create_nvfp4_moe_test_data<br/>(backend="cute-dsl", is_gated)
        DataGen->>DataGen: Quantize inputs to fp4<br/>Generate interleaved FC1<br/>Enforce SwiGLU activation
        DataGen-->>Benchmark: Return prepared tensors
    end

    Benchmark->>Wrapper: Execute MoE with activation validation<br/>(Relu2/SwiGLU)
    Wrapper->>Kernel: Dispatch to appropriate kernel<br/>(Sm120B12x / CuTe-DSL)
    Kernel->>Kernel: Execute fused MoE computation
    Kernel-->>Wrapper: Return output tensors
    Wrapper->>Metrics: Collect timing & throughput data
    Metrics-->>Benchmark: Report performance results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

  • PR #2725: Adds SM120 compute-capability checks and nvcc arch handling to enable Blackwell NVFP4 MoE support infrastructure.
  • PR #3066: Implements b12x CuTe-DSL fused MoE kernel changes and B12x variant kernel aliasing overlapping with this PR's kernel symbol updates.
  • PR #3080: Introduces B12xMoEWrapper and SM120 kernel refactoring that directly aligns with this PR's b12x backend routing and kernel references.

Suggested labels

run-ci, op: moe, benchmark, ready

Suggested reviewers

  • yzh119
  • jiahanc
  • cyx-6
  • yongwww
  • aleozlx
  • jimmyzho
  • nv-yunzheq

Poem

🐰 A new rabbit hops through Blackwell's gate,
With b12x fused kernels running great!
Backend-aware data flows both ways,
CuTe and B12x in benchmark maze,
Performance metrics shine—hip-hip-hooray! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the PR as addressing non-blocking review feedback from two related issues (#3051 and #3080), which aligns with the changeset's four targeted follow-up fixes.
Description check ✅ Passed The PR description provides a clear, structured explanation of changes addressing follow-up feedback. It covers four distinct areas: copyright updates, benchmark split, SM count caching, and file/class renaming with specific details on each.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the b12x_fused_moe routine for SM120/SM121 Blackwell GPUs, supporting both SwiGLU and ReLU2 activations with bf16 inputs and NVFP4 weights. It generalizes the test data creation utility to support multiple backends and updates internal GEMM kernel aliases for consistency. Feedback points out potential NameError issues due to missing or late imports of flashinfer and get_device_sm_count, and suggests updating a docstring to reflect the generalized functionality of the test data creation helper.

"""
if args.verbose >= 1:
print("[INFO] Running testB12xFusedMoe")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code uses flashinfer.__version__, but flashinfer is only imported via from flashinfer import ... at line 1620. This will result in a NameError unless import flashinfer is present at the top of the file. Please ensure flashinfer is imported as a module if you intend to access its attributes directly.

tactic = (
_select_default_sm120_mma_tiler(m, n, _sm_count),
_select_default_sm120_mma_tiler(
m, n, get_device_sm_count(a.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function get_device_sm_count is used here but it doesn't appear to be imported in this file's diff. If it's not already imported at the top of flashinfer/gemm/gemm_base.py, this will cause a NameError.

is_gated: bool = True,
):
"""Create NVFP4-quantized test data for CuteDSL MoE (Blackwell kernels).
"""Create NVFP4-quantized test data for CuTe-DSL-family MoE kernels.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for _create_nvfp4_moe_test_data still mentions CuteDslMoEWrapper.run() in its summary, but the function is now also used for B12xMoEWrapper. It would be more accurate to say it returns tensors needed by the MoE wrappers.

    Returns a dict with all tensors needed by the MoE wrappers.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

4989-4995: ⚠️ Potential issue | 🟡 Minor

Replace the lambda with a named function to satisfy Ruff E731.

Ruff flags this line: "Do not assign a lambda expression, use a def." While the lambda works correctly at runtime, the linting rule should be followed. Convert to a named function definition.

Proposed fix
-            make_kernel = lambda: Sm120B12xBlockScaledDenseGemmKernel(
-                sf_vec_size,
-                mma_tiler_mn,
-                cluster_shape_mn,
-                use_prefetch,
-                enable_pdl,
-            )
+            def make_kernel():
+                return Sm120B12xBlockScaledDenseGemmKernel(
+                    sf_vec_size,
+                    mma_tiler_mn,
+                    cluster_shape_mn,
+                    use_prefetch,
+                    enable_pdl,
+                )

Note: Similar lambda assignments exist at lines 4786 and 4794 and should be fixed the same way.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4989 - 4995, Convert the assigned
lambda make_kernel into a proper named function: replace "make_kernel = lambda:
Sm120B12xBlockScaledDenseGemmKernel(...)" with a def make_kernel(): that returns
Sm120B12xBlockScaledDenseGemmKernel(...) using the same parameters (sf_vec_size,
mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); do the same fix for
the other similar lambda assignments around the earlier occurrences (the ones
currently at the sites creating Sm120B12xBlockScaledDenseGemmKernel-like
instances) so Ruff E731 is satisfied while preserving behavior.
🧹 Nitpick comments (2)
benchmarks/routines/moe.py (2)

1596-1854: testB12xFusedMoe looks correct and mirrors the Cute-DSL test structure.

Spot checks against the wrapper/functional signatures in flashinfer/fused_moe/cute_dsl/b12x_moe.py:

  • B12xMoEWrapper(...) kwargs (num_experts, top_k, hidden_size, intermediate_size, use_cuda_graph, max_num_tokens, num_local_experts, activation) match the class signature; device defaults to "cuda", so omitting it is fine.
  • The functional b12x_fused_moe warmup call supplies all keyword-only args (w1_alpha, w2_alpha, fc2_input_scale) plus the baked-in num_experts/top_k/num_local_experts/output/activation via partial, which is signature-compatible.
  • activation_str mapping (Swiglu→"silu", Relu2→"relu2") and is_gated propagation into TFLOPs/bandwidth are consistent with the kernel's documented SwiGLU vs ReLU² FC1 shapes.
  • Not forwarding local_expert_offset to the wrapper/functional path is intentional on SM120/121 (EP unsupported), so no concern there.

One minor ergonomic note: args.input_dtype / args.weight_dtype are read into input_dtype / weight_dtype and only surface in the bandwidth-accounting and result dict, while the actual kernel inputs are forced to bf16 + nvfp4. This matches testCuteDslFp4BlockScaleMoe's existing behavior, but it means a user passing --input_dtype float16 silently gets bf16 inputs and possibly misleading bandwidth numbers. A one-line warning when the CLI dtype disagrees with what the routine will actually run would be nice.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1596 - 1854, The CLI dtypes
(args.input_dtype / args.weight_dtype) can disagree with the actual formats used
(bf16 inputs and nvfp4 weights) which may mislead bandwidth numbers; update
testB12xFusedMoe to detect this mismatch and emit a one-line warning when they
differ (e.g., if dtype_str_to_torch_dtype(args.input_dtype) != torch.bfloat16 or
dtype_str_to_torch_dtype(args.weight_dtype) != NVFP4-equivalent) before creating
tensors; place the check after input_dtype/weight_dtype are computed and print
the warning when args.verbose >= 1 (or use processLogger), referencing
input_dtype, weight_dtype, tensors["x_bf16"], and the b12x kernel expectation to
guide the user.

1204-1236: Backend-aware test data helper is well-structured.

The merge into _create_nvfp4_moe_test_data(..., backend, is_gated) is clean: backend validation up front, explicit rejection of non-gated for cute-dsl, and interleaving gated only for cute-dsl all match the documented FC1 layout contracts in b12x_fused_moe / cute_dsl_fused_moe_nvfp4.

Optional nit: for backend="b12x" the fp4 x_quantized/x_sf are computed but never consumed by the b12x path (the kernel fuses input quantization internally). Guarding those lines behind if backend == "cute-dsl": and returning None for the unused keys would shave a bit off benchmark setup time and make the data-prep intent explicit.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1204 - 1236, In
_create_nvfp4_moe_test_data, avoid computing fp4-quantized inputs for the b12x
path: move the fp4_quantize (and subsequent convert_sf_to_mma_layout) calls so
they only run when backend == "cute-dsl", and for backend == "b12x" set the
corresponding return entries (e.g., x_quantized, x_sf, token_final_scales
converted via convert_sf_to_mma_layout) to None or omit them to make intent
explicit; update any returned dict keys accordingly so callers of
_create_nvfp4_moe_test_data (and tests using token_selected_experts /
token_final_scales) still find the expected keys but with None values for the
b12x path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 64-68: The module replaced the old exported name causing a
compatibility break: add a backward-compatible alias
Sm120BlockScaledDenseGemmKernel that points to
Sm120B12xBlockScaledDenseGemmKernel and ensure the old name is included in the
public exports list (i.e., append "Sm120BlockScaledDenseGemmKernel" to
_cute_dsl_kernels or __all__ alongside "Sm120B12xBlockScaledDenseGemmKernel") so
existing imports like from flashinfer.gemm import
Sm120BlockScaledDenseGemmKernel continue to work.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4989-4995: Convert the assigned lambda make_kernel into a proper
named function: replace "make_kernel = lambda:
Sm120B12xBlockScaledDenseGemmKernel(...)" with a def make_kernel(): that returns
Sm120B12xBlockScaledDenseGemmKernel(...) using the same parameters (sf_vec_size,
mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); do the same fix for
the other similar lambda assignments around the earlier occurrences (the ones
currently at the sites creating Sm120B12xBlockScaledDenseGemmKernel-like
instances) so Ruff E731 is satisfied while preserving behavior.

---

Nitpick comments:
In `@benchmarks/routines/moe.py`:
- Around line 1596-1854: The CLI dtypes (args.input_dtype / args.weight_dtype)
can disagree with the actual formats used (bf16 inputs and nvfp4 weights) which
may mislead bandwidth numbers; update testB12xFusedMoe to detect this mismatch
and emit a one-line warning when they differ (e.g., if
dtype_str_to_torch_dtype(args.input_dtype) != torch.bfloat16 or
dtype_str_to_torch_dtype(args.weight_dtype) != NVFP4-equivalent) before creating
tensors; place the check after input_dtype/weight_dtype are computed and print
the warning when args.verbose >= 1 (or use processLogger), referencing
input_dtype, weight_dtype, tensors["x_bf16"], and the b12x kernel expectation to
guide the user.
- Around line 1204-1236: In _create_nvfp4_moe_test_data, avoid computing
fp4-quantized inputs for the b12x path: move the fp4_quantize (and subsequent
convert_sf_to_mma_layout) calls so they only run when backend == "cute-dsl", and
for backend == "b12x" set the corresponding return entries (e.g., x_quantized,
x_sf, token_final_scales converted via convert_sf_to_mma_layout) to None or omit
them to make intent explicit; update any returned dict keys accordingly so
callers of _create_nvfp4_moe_test_data (and tests using token_selected_experts /
token_final_scales) still find the expected keys but with None values for the
b12x path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 72ca1913-8b0d-4818-b6bf-28aa6d5beb4f

📥 Commits

Reviewing files that changed from the base of the PR and between 8a9970b and e16e85b.

📒 Files selected for processing (10)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/moe.py
  • benchmarks/samples/sample_testlist.txt
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.py
  • tests/moe/test_b12x_fused_moe.py

Comment on lines +64 to +68
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)

_cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
_cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep the old exported kernel alias for compatibility.

Because _cute_dsl_kernels is folded into __all__, replacing the old Sm120BlockScaledDenseGemmKernel export with only Sm120B12xBlockScaledDenseGemmKernel can break existing from flashinfer.gemm import Sm120BlockScaledDenseGemmKernel users despite the PR’s no-public-API-change goal.

🔁 Proposed compatibility alias
         from .kernels.dense_blockscaled_gemm_sm120_b12x import (
             Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
         )
+        Sm120BlockScaledDenseGemmKernel = Sm120B12xBlockScaledDenseGemmKernel
 
-        _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
+        _cute_dsl_kernels.extend(
+            [
+                "Sm120B12xBlockScaledDenseGemmKernel",
+                "Sm120BlockScaledDenseGemmKernel",
+            ]
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)
_cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel")
_cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)
Sm120BlockScaledDenseGemmKernel = Sm120B12xBlockScaledDenseGemmKernel
_cute_dsl_kernels.extend(
[
"Sm120B12xBlockScaledDenseGemmKernel",
"Sm120BlockScaledDenseGemmKernel",
]
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/__init__.py` around lines 64 - 68, The module replaced the
old exported name causing a compatibility break: add a backward-compatible alias
Sm120BlockScaledDenseGemmKernel that points to
Sm120B12xBlockScaledDenseGemmKernel and ensure the old name is included in the
public exports list (i.e., append "Sm120BlockScaledDenseGemmKernel" to
_cute_dsl_kernels or __all__ alongside "Sm120B12xBlockScaledDenseGemmKernel") so
existing imports like from flashinfer.gemm import
Sm120BlockScaledDenseGemmKernel continue to work.

@bkryu bkryu added the run-ci label Apr 21, 2026
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 21, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !575 has been created, and the CI pipeline #49044922 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu self-assigned this Apr 21, 2026
@aleozlx aleozlx enabled auto-merge (squash) April 24, 2026 02:28
@aleozlx aleozlx merged commit 223f2a4 into flashinfer-ai:main Apr 24, 2026
56 of 101 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants