chore: Address non-blocking review feedback for #3051 / #3080#3128
chore: Address non-blocking review feedback for #3051 / #3080#3128aleozlx merged 2 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThe PR introduces Changes
Sequence Diagram(s)sequenceDiagram
participant Benchmark as Benchmark Runner
participant DataGen as Data Generation
participant Wrapper as B12xMoEWrapper/<br/>Functional API
participant Kernel as B12x/CuTe-DSL Kernel
participant Metrics as Metrics Collector
alt b12x_fused_moe path
Benchmark->>DataGen: Call _create_nvfp4_moe_test_data<br/>(backend="b12x", is_gated)
DataGen->>DataGen: Create bf16 inputs<br/>Generate NVFP4 weights
DataGen-->>Benchmark: Return prepared tensors
else cute_dsl path
Benchmark->>DataGen: Call _create_nvfp4_moe_test_data<br/>(backend="cute-dsl", is_gated)
DataGen->>DataGen: Quantize inputs to fp4<br/>Generate interleaved FC1<br/>Enforce SwiGLU activation
DataGen-->>Benchmark: Return prepared tensors
end
Benchmark->>Wrapper: Execute MoE with activation validation<br/>(Relu2/SwiGLU)
Wrapper->>Kernel: Dispatch to appropriate kernel<br/>(Sm120B12x / CuTe-DSL)
Kernel->>Kernel: Execute fused MoE computation
Kernel-->>Wrapper: Return output tensors
Wrapper->>Metrics: Collect timing & throughput data
Metrics-->>Benchmark: Report performance results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the b12x_fused_moe routine for SM120/SM121 Blackwell GPUs, supporting both SwiGLU and ReLU2 activations with bf16 inputs and NVFP4 weights. It generalizes the test data creation utility to support multiple backends and updates internal GEMM kernel aliases for consistency. Feedback points out potential NameError issues due to missing or late imports of flashinfer and get_device_sm_count, and suggests updating a docstring to reflect the generalized functionality of the test data creation helper.
| """ | ||
| if args.verbose >= 1: | ||
| print("[INFO] Running testB12xFusedMoe") | ||
| print(f"[INFO] FlashInfer version: {flashinfer.__version__}") |
There was a problem hiding this comment.
The code uses flashinfer.__version__, but flashinfer is only imported via from flashinfer import ... at line 1620. This will result in a NameError unless import flashinfer is present at the top of the file. Please ensure flashinfer is imported as a module if you intend to access its attributes directly.
| tactic = ( | ||
| _select_default_sm120_mma_tiler(m, n, _sm_count), | ||
| _select_default_sm120_mma_tiler( | ||
| m, n, get_device_sm_count(a.device) |
| is_gated: bool = True, | ||
| ): | ||
| """Create NVFP4-quantized test data for CuteDSL MoE (Blackwell kernels). | ||
| """Create NVFP4-quantized test data for CuTe-DSL-family MoE kernels. |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
4989-4995:⚠️ Potential issue | 🟡 MinorReplace the lambda with a named function to satisfy Ruff E731.
Ruff flags this line: "Do not assign a
lambdaexpression, use adef." While the lambda works correctly at runtime, the linting rule should be followed. Convert to a named function definition.Proposed fix
- make_kernel = lambda: Sm120B12xBlockScaledDenseGemmKernel( - sf_vec_size, - mma_tiler_mn, - cluster_shape_mn, - use_prefetch, - enable_pdl, - ) + def make_kernel(): + return Sm120B12xBlockScaledDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_prefetch, + enable_pdl, + )Note: Similar lambda assignments exist at lines 4786 and 4794 and should be fixed the same way.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 4989 - 4995, Convert the assigned lambda make_kernel into a proper named function: replace "make_kernel = lambda: Sm120B12xBlockScaledDenseGemmKernel(...)" with a def make_kernel(): that returns Sm120B12xBlockScaledDenseGemmKernel(...) using the same parameters (sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); do the same fix for the other similar lambda assignments around the earlier occurrences (the ones currently at the sites creating Sm120B12xBlockScaledDenseGemmKernel-like instances) so Ruff E731 is satisfied while preserving behavior.
🧹 Nitpick comments (2)
benchmarks/routines/moe.py (2)
1596-1854:testB12xFusedMoelooks correct and mirrors the Cute-DSL test structure.Spot checks against the wrapper/functional signatures in
flashinfer/fused_moe/cute_dsl/b12x_moe.py:
B12xMoEWrapper(...)kwargs (num_experts,top_k,hidden_size,intermediate_size,use_cuda_graph,max_num_tokens,num_local_experts,activation) match the class signature;devicedefaults to"cuda", so omitting it is fine.- The functional
b12x_fused_moewarmup call supplies all keyword-only args (w1_alpha,w2_alpha,fc2_input_scale) plus the baked-innum_experts/top_k/num_local_experts/output/activationviapartial, which is signature-compatible.activation_strmapping (Swiglu→"silu",Relu2→"relu2") andis_gatedpropagation into TFLOPs/bandwidth are consistent with the kernel's documented SwiGLU vs ReLU² FC1 shapes.- Not forwarding
local_expert_offsetto the wrapper/functional path is intentional on SM120/121 (EP unsupported), so no concern there.One minor ergonomic note:
args.input_dtype/args.weight_dtypeare read intoinput_dtype/weight_dtypeand only surface in the bandwidth-accounting and result dict, while the actual kernel inputs are forced to bf16 + nvfp4. This matchestestCuteDslFp4BlockScaleMoe's existing behavior, but it means a user passing--input_dtype float16silently gets bf16 inputs and possibly misleading bandwidth numbers. A one-line warning when the CLI dtype disagrees with what the routine will actually run would be nice.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` around lines 1596 - 1854, The CLI dtypes (args.input_dtype / args.weight_dtype) can disagree with the actual formats used (bf16 inputs and nvfp4 weights) which may mislead bandwidth numbers; update testB12xFusedMoe to detect this mismatch and emit a one-line warning when they differ (e.g., if dtype_str_to_torch_dtype(args.input_dtype) != torch.bfloat16 or dtype_str_to_torch_dtype(args.weight_dtype) != NVFP4-equivalent) before creating tensors; place the check after input_dtype/weight_dtype are computed and print the warning when args.verbose >= 1 (or use processLogger), referencing input_dtype, weight_dtype, tensors["x_bf16"], and the b12x kernel expectation to guide the user.
1204-1236: Backend-aware test data helper is well-structured.The merge into
_create_nvfp4_moe_test_data(..., backend, is_gated)is clean: backend validation up front, explicit rejection of non-gated for cute-dsl, and interleaving gated only for cute-dsl all match the documented FC1 layout contracts inb12x_fused_moe/cute_dsl_fused_moe_nvfp4.Optional nit: for
backend="b12x"the fp4x_quantized/x_sfare computed but never consumed by the b12x path (the kernel fuses input quantization internally). Guarding those lines behindif backend == "cute-dsl":and returningNonefor the unused keys would shave a bit off benchmark setup time and make the data-prep intent explicit.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` around lines 1204 - 1236, In _create_nvfp4_moe_test_data, avoid computing fp4-quantized inputs for the b12x path: move the fp4_quantize (and subsequent convert_sf_to_mma_layout) calls so they only run when backend == "cute-dsl", and for backend == "b12x" set the corresponding return entries (e.g., x_quantized, x_sf, token_final_scales converted via convert_sf_to_mma_layout) to None or omit them to make intent explicit; update any returned dict keys accordingly so callers of _create_nvfp4_moe_test_data (and tests using token_selected_experts / token_final_scales) still find the expected keys but with None values for the b12x path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 64-68: The module replaced the old exported name causing a
compatibility break: add a backward-compatible alias
Sm120BlockScaledDenseGemmKernel that points to
Sm120B12xBlockScaledDenseGemmKernel and ensure the old name is included in the
public exports list (i.e., append "Sm120BlockScaledDenseGemmKernel" to
_cute_dsl_kernels or __all__ alongside "Sm120B12xBlockScaledDenseGemmKernel") so
existing imports like from flashinfer.gemm import
Sm120BlockScaledDenseGemmKernel continue to work.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4989-4995: Convert the assigned lambda make_kernel into a proper
named function: replace "make_kernel = lambda:
Sm120B12xBlockScaledDenseGemmKernel(...)" with a def make_kernel(): that returns
Sm120B12xBlockScaledDenseGemmKernel(...) using the same parameters (sf_vec_size,
mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); do the same fix for
the other similar lambda assignments around the earlier occurrences (the ones
currently at the sites creating Sm120B12xBlockScaledDenseGemmKernel-like
instances) so Ruff E731 is satisfied while preserving behavior.
---
Nitpick comments:
In `@benchmarks/routines/moe.py`:
- Around line 1596-1854: The CLI dtypes (args.input_dtype / args.weight_dtype)
can disagree with the actual formats used (bf16 inputs and nvfp4 weights) which
may mislead bandwidth numbers; update testB12xFusedMoe to detect this mismatch
and emit a one-line warning when they differ (e.g., if
dtype_str_to_torch_dtype(args.input_dtype) != torch.bfloat16 or
dtype_str_to_torch_dtype(args.weight_dtype) != NVFP4-equivalent) before creating
tensors; place the check after input_dtype/weight_dtype are computed and print
the warning when args.verbose >= 1 (or use processLogger), referencing
input_dtype, weight_dtype, tensors["x_bf16"], and the b12x kernel expectation to
guide the user.
- Around line 1204-1236: In _create_nvfp4_moe_test_data, avoid computing
fp4-quantized inputs for the b12x path: move the fp4_quantize (and subsequent
convert_sf_to_mma_layout) calls so they only run when backend == "cute-dsl", and
for backend == "b12x" set the corresponding return entries (e.g., x_quantized,
x_sf, token_final_scales converted via convert_sf_to_mma_layout) to None or omit
them to make intent explicit; update any returned dict keys accordingly so
callers of _create_nvfp4_moe_test_data (and tests using token_selected_experts /
token_final_scales) still find the expected keys but with None values for the
b12x path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 72ca1913-8b0d-4818-b6bf-28aa6d5beb4f
📒 Files selected for processing (10)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pybenchmarks/samples/sample_testlist.txtflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pyflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm120_b12x.pytests/moe/test_b12x_fused_moe.py
| from .kernels.dense_blockscaled_gemm_sm120_b12x import ( | ||
| Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel, | ||
| ) | ||
|
|
||
| _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel") | ||
| _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel") |
There was a problem hiding this comment.
Keep the old exported kernel alias for compatibility.
Because _cute_dsl_kernels is folded into __all__, replacing the old Sm120BlockScaledDenseGemmKernel export with only Sm120B12xBlockScaledDenseGemmKernel can break existing from flashinfer.gemm import Sm120BlockScaledDenseGemmKernel users despite the PR’s no-public-API-change goal.
🔁 Proposed compatibility alias
from .kernels.dense_blockscaled_gemm_sm120_b12x import (
Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel,
)
+ Sm120BlockScaledDenseGemmKernel = Sm120B12xBlockScaledDenseGemmKernel
- _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel")
+ _cute_dsl_kernels.extend(
+ [
+ "Sm120B12xBlockScaledDenseGemmKernel",
+ "Sm120BlockScaledDenseGemmKernel",
+ ]
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .kernels.dense_blockscaled_gemm_sm120_b12x import ( | |
| Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel, | |
| ) | |
| _cute_dsl_kernels.append("Sm120BlockScaledDenseGemmKernel") | |
| _cute_dsl_kernels.append("Sm120B12xBlockScaledDenseGemmKernel") | |
| from .kernels.dense_blockscaled_gemm_sm120_b12x import ( | |
| Sm120B12xBlockScaledDenseGemmKernel as Sm120B12xBlockScaledDenseGemmKernel, | |
| ) | |
| Sm120BlockScaledDenseGemmKernel = Sm120B12xBlockScaledDenseGemmKernel | |
| _cute_dsl_kernels.extend( | |
| [ | |
| "Sm120B12xBlockScaledDenseGemmKernel", | |
| "Sm120BlockScaledDenseGemmKernel", | |
| ] | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/__init__.py` around lines 64 - 68, The module replaced the
old exported name causing a compatibility break: add a backward-compatible alias
Sm120BlockScaledDenseGemmKernel that points to
Sm120B12xBlockScaledDenseGemmKernel and ensure the old name is included in the
public exports list (i.e., append "Sm120BlockScaledDenseGemmKernel" to
_cute_dsl_kernels or __all__ alongside "Sm120B12xBlockScaledDenseGemmKernel") so
existing imports like from flashinfer.gemm import
Sm120BlockScaledDenseGemmKernel continue to work.
|
/bot run |
📌 Description
Follow-up to #3051 (
backend="b12x"formm_fp4on SM120) and #3080 (b12x_fused_moe/B12xMoEWrapperSM120 APIs) addressing four reviewer comments that landed after merge. No public API changes; no kernel behavior changes.tests/moe/test_b12x_fused_moe.pyto 2026.b12x_fused_moeroutine (SM120/121, BF16 input, SwiGLU + ReLU²);cute_dsl_fp4_block_scale_moeis now SM100/103-only. Aligns with theB12xMoEWrapper/CuteDslMoEWrapperPython API split.torch.cuda.get_device_properties(...).multi_processor_countin theb12xFP4 GEMM runner with the cachedget_device_sm_count()helper.dense_blockscaled_gemm_sm120.py→dense_blockscaled_gemm_sm120_b12x.pyandSm120BlockScaledDenseGemmKernel→Sm120B12xBlockScaledDenseGemmKernel(viagit mv, 6 import sites updated).backend="b12x"string unchanged.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
b12x_fused_moebenchmark routine for NVFP4 MoE inference with support for both SwiGLU and ReLU2 activation types.Documentation
b12x_fused_moetest configurations.