Skip to content

feat: support mxfp4 & mxfp8 entrypoint for blackwell cutedsl dense gemm#2660

Merged
flashinfer-bot merged 8 commits intoflashinfer-ai:mainfrom
bzhng-development:brayden/mxfp4-blockscale-dense-gemm
Mar 6, 2026
Merged

feat: support mxfp4 & mxfp8 entrypoint for blackwell cutedsl dense gemm#2660
flashinfer-bot merged 8 commits intoflashinfer-ai:mainfrom
bzhng-development:brayden/mxfp4-blockscale-dense-gemm

Conversation

@b8zhong
Copy link
Copy Markdown
Contributor

@b8zhong b8zhong commented Feb 28, 2026

📌 Description

This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0 scale.

For most shapes of mxfp4, it's performance is quite good. Geomean speedup around 1.20x. Passed refcheck.

image

For mxfp8, the performance is a bit lacking, as I recycled most of the cute-dsl heuristics for mxfp4. Mainly for enablement first.

Local testing result (B200):

python tests/gemm/test_mm_fp4.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 4512 items                                                                                                                                                                                                                                                 

tests/gemm/test_mm_fp4.py .............................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  5%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss................................ [ 10%]
.............................................................................................................................................................................................................................................................. [ 16%]
.............................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 21%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss.......................................................................... [ 27%]
....................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssss [ 44%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 61%]
ssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................... [ 78%]
......................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 95%]
............................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                   [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_fp4.py: 588 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_fp4.py: 294 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1440 passed, 3072 skipped, 882 warnings in 535.13s (0:08:55) ====================================================================================================
pytest tests/gemm/test_mm_mxfp8.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 2131 items                                                                                                                                                                                                                                                 

tests/gemm/test_mm_mxfp8.py .................................................................................................................................................................................................................................. [ 10%]
.............................................................................................................................................................................................................................................................. [ 22%]
................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 34%]
..................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................................................................................ [ 46%]
.............................................................................................................................................................................................................................................................. [ 58%]
......................................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%]
ssssssssssssssssssssssssssssssssssssssssss......................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssss............ [ 94%]
...............................................................................................................................                                                                                                                                [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_mxfp8.py: 314 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_mxfp8.py: 157 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1633 passed, 498 skipped, 471 warnings in 245.76s (0:04:05) =====================================================================================================

Summary by CodeRabbit

  • New Features

    • Added CuTe DSL ("cute-dsl") as a selectable backend for MXFP8 and FP4 matrix-multiply paths and support for both FP4 and MXFP8 input encodings.
  • Behavior

    • Backend-specific alpha handling and availability/layout checks for the CuTe DSL path; runtime skips when layout/scale semantics aren’t supported.
  • Tests

    • Tests and benchmarks expanded to include "cute-dsl" and "auto" in backend selections and validate encoding/layout behaviors.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 28, 2026

📝 Walkthrough

Walkthrough

Adds CuTe DSL backend support to FP4 and MXFP8 GEMM: backend wiring, availability checks, requirement validators, CuTe DSL runners and kernel caches, dual-input FP4/MXFP8 kernel paths, and updated benchmarks/tests to include cute-dsl and auto.

Changes

Cohort / File(s) Summary
Benchmark Backend Configuration
benchmarks/routines/flashinfer_benchmark_utils.py
Extended mm_mxfp8 dtype→backend mapping to include "cute-dsl" for CC 10.0 and 10.3.
Benchmark Tests / Routines
benchmarks/routines/gemm.py
Added conditional alpha handling for cute-dsl in FP4 paths; expanded autotune/run backend allowlists to include "cute-dsl" and "auto", and unified backend execution flows.
Core GEMM Implementation
flashinfer/gemm/gemm_base.py
Wired in CuTe DSL: availability checks, requirement validators, _cute_dsl_gemm_mxfp8_runner and _cute_dsl_gemm_fp4_runner, kernel caches, dtype mapping, heuristic/backends updates, and updated public signatures to accept "cute-dsl".
Kernel Implementation
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
Added dual-input encoding support: FP4 path accepts uint8 (2 FP4 per byte, k = k_raw*2) with recast pointers; MXFP8 path uses Float8 inputs directly (k = k_raw); runtime branching selects path.
GEMM Tests
tests/gemm/test_mm_fp4.py, tests/gemm/test_mm_mxfp8.py
Removed nvfp4-only skip for cute-dsl in FP4 test; added runtime skip for cute-dsl when non-swizzled MXFP8 layouts are used; expanded test parametrizations to include "cute-dsl" and "auto".

Sequence Diagram(s)

sequenceDiagram
    actor User as User Code
    participant Dispatcher as GEMM Dispatcher
    participant BackendSel as Backend Selector
    participant ReqCheck as Requirement Validator
    participant Runner as CuTe DSL Runner
    participant Kernel as Kernel Executor

    User->>Dispatcher: call mm_fp4 / mm_mxfp8 (backend="cute-dsl")
    Dispatcher->>BackendSel: resolve backend
    BackendSel->>ReqCheck: validate requirements (_cute_dsl_gemm_*_requirement)
    alt requirements valid
        ReqCheck-->>BackendSel: valid
        BackendSel->>Runner: instantiate / fetch runner
        Runner->>Runner: generate tactics & compile if needed
        Runner->>Kernel: launch kernel with chosen tactic
        Kernel->>Kernel: branch: FP4 recast (uint8→fp4, k=k_raw*2) or MXFP8 direct (float8, k=k_raw)
        Kernel-->>Runner: complete
        Runner-->>Dispatcher: return output tensor
    else requirements invalid
        ReqCheck-->>BackendSel: invalid
        BackendSel-->>Dispatcher: fallback / error
    end
    Dispatcher-->>User: return result or error
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • nv-yunzheq
  • jimmyzho
  • nvmbreughe
  • yongwww
  • jiahanc
  • aleozlx
  • bkryu

Poem

🐰 I hopped through code with nimble feet,
CuTe DSL now joins the feat.
FP4 recast, MXFP8 aligned,
Kernels cached and tactics mined.
Hop—merge—celebrate, a rabbit’s treat!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.03% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding support for mxfp4 and mxfp8 in the blackwell cutedsl dense gemm kernel, which is the primary focus of this PR.
Description check ✅ Passed The description covers the main functionality (mxfp4/mxfp8 support, block size, performance metrics) but lacks explicit sections for related issues and pre-commit checks, which are part of the repository template.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@b8zhong b8zhong changed the title tiny support mxfp4 & mxfp8 entrypoint for blackwell cutedsl dense gemm feat: support mxfp4 & mxfp8 entrypoint for blackwell cutedsl dense gemm Feb 28, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the mixed-precision GEMM capabilities by integrating the cute-dsl backend for both MXFP4 and MXFP8 data types, primarily targeting Blackwell GPUs. This enablement allows for experimentation and potential performance gains with these new precision formats, with MXFP4 already showing promising speedups. The changes involve core GEMM logic, backend registration, and comprehensive test updates to ensure functionality and correctness.

Highlights

  • CuTe DSL Backend Integration: Introduced cute-dsl backend support for mixed-precision FP4 (MXFP4) and mixed-precision FP8 (MXFP8) General Matrix Multiply (GEMM) operations, specifically targeting Blackwell architecture (compute capabilities 10.0 and 10.3).
  • Unified Kernel for FP4 and MXFP8: Implemented a unified kernel in dense_blockscaled_gemm_sm100.py that dynamically handles both FP4 (packed Uint8) and MXFP8 (Float8) input tensor types.
  • Enhanced GEMM Functions: Updated mm_fp4 and mm_mxfp8 functions to integrate the new cute-dsl runner, including specific requirement checks for scale factor layouts and dynamic parameter handling based on FP4 type.
  • Benchmark and Test Coverage: Adjusted benchmark utilities and test configurations to include and validate the cute-dsl backend for MXFP4 and MXFP8, ensuring proper functionality and performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added "cute-dsl" as a supported backend for compute capabilities 10.0 and 10.3.
  • benchmarks/routines/gemm.py
    • Modified the alpha parameter handling in testMmFp4 and run_backend to set alpha to 1.0 for cute-dsl when not using nvfp4.
    • Included "cute-dsl" in the autotune_supported_backends list and run_backend logic for MXFP8 tests.
  • flashinfer/gemm/gemm_base.py
    • Extended backend literal types in _check_mm_mxfp8_problem_size, _cutlass_gemm_mxfp8_requirement, _heuristic_func_mm_mxfp8, and mm_mxfp8 to include "cute-dsl".
    • Introduced _cute_dsl_gemm_mxfp8_requirement to define conditions for cute-dsl MXFP8 GEMM, requiring swizzled 1D block scales.
    • Added _cute_dsl_gemm_mxfp8_runner to implement the cute-dsl MXFP8 GEMM logic, including tactic selection and kernel compilation.
    • Updated _cute_dsl_gemm_fp4_requirement to remove the use_nvfp4 restriction, allowing MXFP4.
    • Modified _cute_dsl_gemm_fp4_runner to accept use_nvfp4 as a parameter, enabling dynamic sf_vec_size and sf_dtype based on FP4 type.
    • Integrated _cute_dsl_gemm_mxfp8_requirement and _cute_dsl_gemm_mxfp8_runner into the backend_requirement and backend_to_runner_factory maps for mm_mxfp8.
    • Updated the mm_fp4 docstring to clarify cute-dsl quantization requirements for NVFP4 and MXFP4.
    • Passed use_nvfp4 to the _cute_dsl_gemm_fp4_runner factory.
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
    • Refactored the wrapper function to dynamically handle input tensors (mA, mB) as either Uint8 (for FP4 packed data) or Float8 (for MXFP8 data), adjusting k and ptr accordingly.
    • Updated docstrings for mA and mB to reflect support for both FP4 and MXFP8 paths.
  • tests/gemm/test_mm_fp4.py
    • Removed the pytest.skip condition that restricted cute-dsl backend to nvfp4 only.
    • Added "cute-dsl" to the list of backends supported for MXFP4 tests.
  • tests/gemm/test_mm_mxfp8.py
    • Generalized the _skip_if_unsupported message from "cutlass" to "backend".
    • Added a pytest.skip condition for cute-dsl if is_sf_swizzled_layout is False, as it currently requires swizzled 1D scale layout.
    • Included "cute-dsl" in the backend parameterizations for test_mm_mxfp8 and test_mm_mxfp8_large_dimensions.
Activity
  • The author, b8zhong, provided detailed local testing results on a B200 GPU for both test_mm_fp4.py and test_mm_mxfp8.py, showing a mix of passed and skipped tests, along with deprecation warnings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for mxfp4 and mxfp8 data types to the cute-dsl backend for dense GEMM operations on Blackwell GPUs. The changes include updating backend lists, adding new kernel runners and requirement checks, and modifying existing FP4 runners to handle the new data types. The test suite is also updated to cover these new features. My review focuses on improving code maintainability by reducing duplication and simplifying logic. Overall, the changes are well-structured and align with the PR's objectives.

Comment thread benchmarks/routines/gemm.py Outdated
Comment thread flashinfer/gemm/gemm_base.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

2787-2792: Cache SM count once during tactic search.

torch.cuda.get_device_properties(...).multi_processor_count is queried repeatedly inside nested loops. Hoisting it once per call reduces Python/CUDA overhead during autotune search.

♻️ Suggested change
         def get_valid_tactics(
             self,
             inputs: List[torch.Tensor],
             profile: OptimizationProfile,
         ) -> list:
             (a, b, a_descale, b_descale, _, out, _) = inputs
             m = a.shape[0]
             real_k = a.shape[1]
             n = b.shape[1]
+            sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count
@@
                             if use_prefetch:
                                 cta_nums = self._get_approximate_cta_nums(
                                     kernel_m, kernel_n, mma_tiler_mn, cluster_shape_mn
                                 )
-                                sm_count = torch.cuda.get_device_properties(
-                                    a.device
-                                ).multi_processor_count
                                 cta_wave_ratio = cta_nums / sm_count
                                 if not (0.5 < cta_wave_ratio < 1.0 or real_k >= 8192):
                                     continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2787 - 2792, Hoist the CUDA SM
query out of the inner nested loops by calling
torch.cuda.get_device_properties(a.device).multi_processor_count once at the
start of the tactic search (store in a local sm_count variable) and reuse that
sm_count when computing cta_wave_ratio (used with cta_nums and real_k) instead
of calling get_device_properties repeatedly; update the scope where sm_count,
cta_wave_ratio, cta_nums and real_k are used so the single cached sm_count
replaces the repeated calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/gemm.py`:
- Line 1300: The autotune filtering currently excludes "auto" causing
user-requested backends to be removed; update the autotune_supported_backends
list (and the equivalent logic around lines where it's replicated) to include
"auto" or add explicit handling so that run_backend("auto") is preserved when
--autotune is set; locate the autotune_supported_backends symbol in
benchmarks/routines/gemm.py (and the duplicate logic around the 1353-1361
region) and either add "auto" to the list or alter the filter to allow "auto"
through before applying autotune-specific backend restrictions.
- Around line 1094-1098: The code is allocating torch.tensor([1.0],
dtype=torch.float32, device=device) inside the timed path when (not use_nvfp4
and backend == "cute-dsl"), which creates a new CUDA tensor each invocation and
skews timings; fix by creating a single reusable scalar tensor and reusing it
instead of allocating per call (e.g., compute a module-/scope-level or
outer-loop cached torch.tensor(1.0, dtype=torch.float32, device=device) or use a
Python float when the backend accepts it) and replace the inline allocation used
for the alpha argument (the conditional that checks use_nvfp4 and backend ==
"cute-dsl") in both occurrences (the alpha passed at the shown diff and the
similar block at 1136-1140).

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 2139-2151: The code currently selects the FP4 packed path based
only on mA.element_type, which can corrupt results if mB has a different dtype;
update the conditional to require both mA.element_type and mB.element_type be
cutlass.Uint8 (e.g., if cutlass.const_expr(mA.element_type == cutlass.Uint8 and
mB.element_type == cutlass.Uint8)): only in that case set k = k_raw * 2 and
recast both iterators with cute.recast_ptr(..., dtype=cutlass.Float4E2M1FN);
otherwise treat inputs as MXFP8 (k = k_raw with a_ptr = mA.iterator and b_ptr =
mB.iterator) and add a clear error/exception if dtypes are mismatched but not
supported so silent corruption cannot occur.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2787-2792: Hoist the CUDA SM query out of the inner nested loops
by calling torch.cuda.get_device_properties(a.device).multi_processor_count once
at the start of the tactic search (store in a local sm_count variable) and reuse
that sm_count when computing cta_wave_ratio (used with cta_nums and real_k)
instead of calling get_device_properties repeatedly; update the scope where
sm_count, cta_wave_ratio, cta_nums and real_k are used so the single cached
sm_count replaces the repeated calls.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f521fe1 and 949c488.

📒 Files selected for processing (6)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/gemm.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
  • tests/gemm/test_mm_fp4.py
  • tests/gemm/test_mm_mxfp8.py

Comment thread benchmarks/routines/gemm.py Outdated
Comment thread benchmarks/routines/gemm.py Outdated
Comment thread flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
flashinfer/gemm/gemm_base.py (3)

2700-2702: Remove dead code assignment.

The sm_major and sm_minor parameters are passed to the function but only used in a no-op assignment. If they're reserved for future use, consider adding a comment. Otherwise, the parameters could be removed from the signature or prefixed with underscore.

🧹 Suggested cleanup
-    _ = sm_major, sm_minor
+    # Reserved for future SM-specific kernel selection
+    _ = sm_major, sm_minor

Or if truly unused:

 def _cute_dsl_gemm_mxfp8_runner(
-    sm_major: int,
-    sm_minor: int,
+    sm_major: int,  # noqa: ARG001 - reserved for future SM-specific paths
+    sm_minor: int,  # noqa: ARG001 - reserved for future SM-specific paths
     enable_pdl: bool,
     out_dtype: torch.dtype,
 ):
-    _ = sm_major, sm_minor
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2700 - 2702, The assignment "_ =
sm_major, sm_minor" is dead code; either remove that line entirely or make the
unused parameters explicit by renaming them (e.g., _sm_major, _sm_minor) or
adding a clarifying comment that they're reserved for future use. Update the
function signature accordingly (if renaming) and adjust any callers if you
change parameter names; otherwise just delete the no-op assignment in
gemm_base.py near where c_cutlass_dtype is set so sm_major and sm_minor are not
silently ignored.

2719-2721: Consider prefixing unused variables with underscore.

The variables a_descale, b_descale, and out are unpacked but never used. Prefixing them with _ would suppress linter warnings and clarify intent.

🧹 Suggested cleanup
-            (a, b, a_descale, b_descale, _, out, _) = inputs
+            (a, b, _a_descale, _b_descale, _, _out, _) = inputs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2719 - 2721, The unpacking in the
method that processes "inputs" currently binds a_descale, b_descale, and out but
never uses them; rename those to start with an underscore (e.g., _a_descale,
_b_descale, _out or simply _ ) in the tuple assignment (the line "(a, b,
a_descale, b_descale, _, out, _) = inputs") so linters know they are
intentionally unused; update only the variable names in that unpacking and leave
all other logic (variables a, b, m, etc.) unchanged.

2961-2983: Clarify: cute-dsl is not auto-selected for mm_mxfp8.

The heuristic function only returns ["cutlass"] and doesn't include "cute-dsl". This means users must explicitly specify backend="cute-dsl" to use the CuTe DSL path - it won't be auto-selected.

Based on the PR description mentioning "weaker performance" for mxfp8, this appears intentional. Consider adding a brief comment in _heuristic_func_mm_mxfp8 to document this design choice:

📝 Suggested documentation
 def _heuristic_func_mm_mxfp8(
     ...
 ) -> List[str]:
+    # Note: cute-dsl is available but not auto-selected due to performance characteristics.
+    # Users can explicitly specify backend="cute-dsl" if desired.
     if "cutlass" in suitable_backends:
         return ["cutlass"]
     return []
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2961 - 2983, The heuristic in
_heuristic_func_mm_mxfp8 currently returns only ["cutlass"], intentionally
excluding "cute-dsl" since CuTe DSL is not auto-selected for mxFP8 due to weaker
performance; update the function by adding a concise comment above or inside
_heuristic_func_mm_mxfp8 (and mention mm_mxfp8.suitable_auto_backends in the
comment) stating that "cute-dsl" is intentionally omitted from auto-selection
and must be explicitly requested via backend="cute-dsl", and briefly note the
rationale (weaker performance) so future readers understand this design choice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 2142-2159: The else branch currently assumes any non-Uint8
matching dtype is MXFP8 and proceeds, but you must explicitly guard against
unsupported dtypes: call the same allowlist logic used by
is_valid_dtypes_and_scale_factor_vec_size to verify mA.element_type (and
mB.element_type) is one of the supported FP8/FP4 types (cutlass.Float4E2M1FN,
cutlass.Float8E5M2, cutlass.Float8E4M3FN) before taking the MXFP8 path that sets
k = k_raw and uses mA.iterator/mB.iterator; if the element types are not in that
set, raise a TypeError with a clear message rather than letting downstream MXFP8
logic fail.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2700-2702: The assignment "_ = sm_major, sm_minor" is dead code;
either remove that line entirely or make the unused parameters explicit by
renaming them (e.g., _sm_major, _sm_minor) or adding a clarifying comment that
they're reserved for future use. Update the function signature accordingly (if
renaming) and adjust any callers if you change parameter names; otherwise just
delete the no-op assignment in gemm_base.py near where c_cutlass_dtype is set so
sm_major and sm_minor are not silently ignored.
- Around line 2719-2721: The unpacking in the method that processes "inputs"
currently binds a_descale, b_descale, and out but never uses them; rename those
to start with an underscore (e.g., _a_descale, _b_descale, _out or simply _ ) in
the tuple assignment (the line "(a, b, a_descale, b_descale, _, out, _) =
inputs") so linters know they are intentionally unused; update only the variable
names in that unpacking and leave all other logic (variables a, b, m, etc.)
unchanged.
- Around line 2961-2983: The heuristic in _heuristic_func_mm_mxfp8 currently
returns only ["cutlass"], intentionally excluding "cute-dsl" since CuTe DSL is
not auto-selected for mxFP8 due to weaker performance; update the function by
adding a concise comment above or inside _heuristic_func_mm_mxfp8 (and mention
mm_mxfp8.suitable_auto_backends in the comment) stating that "cute-dsl" is
intentionally omitted from auto-selection and must be explicitly requested via
backend="cute-dsl", and briefly note the rationale (weaker performance) so
future readers understand this design choice.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 949c488 and 6608b2f.

📒 Files selected for processing (3)
  • benchmarks/routines/gemm.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/routines/gemm.py

Comment on lines +2142 to +2159
if cutlass.const_expr(
mA.element_type == cutlass.Uint8 and mB.element_type == cutlass.Uint8
):
# FP4 packed path: 2 FP4 values per uint8 byte
k = k_raw * 2
a_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN)
b_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN)
elif cutlass.const_expr(mA.element_type != mB.element_type):
raise TypeError(
"Unsupported mixed input dtypes for block-scaled GEMM: "
"mA and mB must have matching element_type "
"(both Uint8 for FP4 path, or both FP8 for MXFP8 path)."
)
else:
# MXFP8 path: input tensors are already FP8.
k = k_raw
a_ptr = mA.iterator
b_ptr = mB.iterator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# First, check if the file exists and get the target lines
echo "=== File exists check ==="
if [ -f "flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py" ]; then
    echo "File found"
else
    find . -name "dense_blockscaled_gemm_sm100.py" 2>/dev/null || echo "File not found"
fi

echo
echo "=== Lines 2142-2159 ==="
sed -n '2140,2165p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 1182


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Search for dtype validation and allowlist checks
echo "=== Search for is_valid_dtypes_and_scale_factor_vec_size or similar ==="
rg -n 'is_valid_dtypes|Float8E4M3FN|Float8E5M2|Float4E2M1FN' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -30

echo
echo "=== Search for wrapper/caller functions that might validate dtypes ==="
rg -n 'def.*blockscaled.*gemm|def make_blockscaled' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 1523


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check for any dtype validation at function entry
echo "=== Check function signature and early validation ==="
rg -n -B10 'if cutlass.const_expr' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -A10 'element_type == cutlass.Uint8' | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 632


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get more context around the is_valid_dtypes_and_scale_factor_vec_size function
echo "=== Function definition at line 1865 ==="
sed -n '1865,1920p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check caller at line 2070 ==="
sed -n '2065,2085p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 2871


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find which function contains lines 2142-2159
echo "=== Find function containing the dtype routing logic ==="
sed -n '2100,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -E 'def |class '

echo
echo "=== Full context: lines 2080-2175 ==="
sed -n '2080,2175p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 124


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if dtype validation happens at the Python wrapper level (before kernel instantiation)
echo "=== Check for wrapper functions that call this kernel ==="
rg -n 'blockscaled_gemm' flashinfer/gemm/kernels/*.py | grep -E 'def |class ' | head -20

echo
echo "=== Check higher-level APIs that might validate dtypes ==="
find flashinfer -name "*.py" -type f -exec grep -l "blockscaled_gemm\|block_scaled" {} \; 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 125


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the function that contains the dtype routing logic
echo "=== Search for function definition before line 2142 ==="
sed -n '2040,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | tail -50 | head -40

echo
echo "=== Full context: lines 2120-2170 ==="
sed -n '2120,2170p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 4801


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the function signature
echo "=== Grep for 'def' before line 2142 ==="
rg -n '^    def ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | awk -F: '{if($1<2142) print}' | tail -1

echo
echo "=== Get function name and signature ==="
sed -n '2035,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 5247


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check where dtype validation happens - at caller site
echo "=== Look for calls to the function containing the routing logic ==="
rg -n 'def make_blockscaled_gemm' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check if dtype validation is done at instantiation time ==="
rg -n 'ab_dtype.*element_type|element_type.*ab_dtype' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 135


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Let's trace through how this kernel is used
echo "=== Check if there are Python wrappers that validate dtypes before kernel instantiation ==="
find flashinfer -name "*.py" -type f | xargs grep -l "make_blockscaled\|DenseBlockScaledGemm" | head -5

echo
echo "=== Check what calls DenseBlockScaledGemm.make ==="
rg -n 'DenseBlockScaledGemm' flashinfer/ -A2 -B2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 634


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the class that contains this wrapper
echo "=== Find the class definition ==="
sed -n '1800,2100p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -E '^class ' | tail -1

echo
echo "=== Get class name from class definition near line 2089 ==="
rg -n '^class ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check how wrapper is invoked - look at can_implement vs wrapper relationship ==="
sed -n '2060,2095p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 101


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if there's a public API that validates before calling wrapper
echo "=== Search for public methods that call wrapper ==="
rg -n '\.wrapper\(' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check higher-level wrappers ==="
grep -r 'def.*blockscaled' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at one of the higher-level wrapper files
echo "=== Check one of the fusion wrappers ==="
head -100 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

echo
echo "=== Look for where the kernel wrapper is actually instantiated/called ==="
rg -n 'DenseBlockScaledGemm\|\.wrapper' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 4733


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the class definition more carefully
echo "=== Find DenseBlockScaledGemm class ==="
rg -n 'class DenseBlockScaledGemm' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -A5

echo
echo "=== Check entire file structure for class ==="
rg -n '^class ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Look at lines 1000-1100 for class definition ==="
sed -n '1000,1100p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 107


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check where this kernel is instantiated/used at runtime
echo "=== Search all Python files for references to dense_blockscaled_gemm ==="
find flashinfer -name "*.py" -type f | xargs grep -l 'dense_blockscaled_gemm_sm100' 2>/dev/null

echo
echo "=== Check what calls or validates before kernel execution ==="
rg -n 'import.*DenseBlockScaledGemm\|from.*dense_blockscaled_gemm' flashinfer/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 232


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Let's check if there's validation done at Python binding/instantiation level
echo "=== Check for any dtype checking in blockscaled_contiguous_grouped_gemm.py ==="
grep -n 'element_type\|Float8\|Int8\|dtype' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -30

echo
echo "=== Search for where wrapper() is called ==="
rg -n 'wrapper\(' flashinfer/fused_moe/cute_dsl/blackwell/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1589


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check gemm_base.py for how dense_blockscaled_gemm is used
echo "=== Check gemm_base.py for kernel instantiation ==="
cat flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at the blockscaled_contiguous_grouped_gemm initialization and how it validates
echo "=== Check Sm100BlockScaledContiguousGroupedGemmKernel initialization (around line 410-440) ==="
sed -n '400,500p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 4167


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if there's earlier validation before wrapper is called
echo "=== Check where wrapper is actually invoked and what dtypes reach it ==="
sed -n '500,600p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 4058


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find actual call site of wrapper() - check if there's Python-level validation
echo "=== Search for methods that call wrapper() or instantiate blockscaled kernels ==="
rg -n '\.wrapper\(|DenseBlockScaledGemm\|Sm100BlockScaledContiguousGroupedGemmKernel' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -30

echo
echo "=== Check test files for blockscaled ==="
find flashinfer -name "*test*" -type f -name "*.py" | xargs grep -l "blockscaled\|DenseBlockScaled" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if wrapper is called at the public API level
echo "=== Check execute/run methods that call wrapper ==="
sed -n '600,800p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | grep -A30 'def execute\|def run\|wrapper('

echo
echo "=== Look for where wrapper is actually called in blockscaled_contiguous_grouped_gemm.py ==="
rg -n 'wrapper(' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py -A5 -B5

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check what dtypes are allowed into the kernel at Python level
echo "=== Check if there's validation before wrapper invocation ==="
grep -n 'Float8E4M3FN\|Float8E5M2\|Float4E2M1FN\|element_type' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1276


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at the can_implement method more carefully to see if it's called at runtime
echo "=== Get full can_implement method ==="
sed -n '2060,2088p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check if can_implement is called before wrapper invocation ==="
rg -n 'can_implement' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check if this kernel is directly exposed to users or only used internally ==="
grep -n 'def.*blockscaled\|class.*blockscaled' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 1744


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check what happens at the public API level
echo "=== Check how blockscaled kernels are exposed ==="
find flashinfer -name "*.py" -type f | xargs grep -l 'blockscaled_contiguous_grouped_gemm\|dense_blockscaled' | grep -v test | grep -v cute_dsl

echo
echo "=== Check MOE or high-level wrappers ==="
find flashinfer/fused_moe -name "*.py" -type f | xargs grep -l 'blockscaled' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 776


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Let's look at how mA and mB get their element_type - what dtypes can they actually have?
echo "=== Check documentation about what dtypes reach wrapper ==="
sed -n '2095,2140p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 2506


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check the Sm100BlockScaledContiguousGroupedGemmKernel more carefully
echo "=== Full __call__ method and setup ==="
sed -n '300,450p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | grep -A50 'def __call__\|def forward'

echo
echo "=== Check what dtypes actually reach the kernel at runtime ==="
sed -n '400,450p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 4794


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check if there's any validation at module load time or decorator level
echo "=== Look for validation decorators or checks before kernel invocation ==="
sed -n '100,300p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

echo
echo "=== Check if can_implement filters out unsupported configs ==="
grep -B20 -A20 'can_implement' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 8920


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Determine if there's any runtime validation before wrapper() is called
echo "=== Check kernel class hierarchy and method resolution ==="
sed -n '60,150p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 3743


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Final check: does wrapper() get called at runtime without dtype validation?
echo "=== Search for where wrapper is directly called at runtime ==="
rg -n '@cute.jit' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -A3

echo
echo "=== Confirm if can_implement is called at wrapper invocation time ==="
sed -n '2080,2095p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

echo
echo "=== Check what happens in _setup_attributes when unsupported dtype is passed ==="
sed -n '1920,1960p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py

Repository: flashinfer-ai/flashinfer

Length of output: 2433


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at the sm100_utils function that would fail with unsupported dtype
echo "=== Check make_blockscaled_trivial_tiled_mma behavior with unsupported dtype ==="
find flashinfer -name "*blackwell*helper*" -o -name "*sm100*util*" | head -5

echo
echo "=== Check if there's dtype validation in sm100_utils ==="
rg -n 'make_blockscaled_trivial_tiled_mma' flashinfer/ --type py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 1497


Add explicit FP8 allowlist guard in the else path to reject unsupported non-Uint8 input types before MXFP8 processing.

After the mixed-dtype check, the else path accepts any matching non-Uint8 dtype without validating it's one of the supported FP8 types. This allows unsupported pairs like Int8/Int8 or Float32/Float32 to route into the MXFP8 flow, where they fail later with less actionable errors. The is_valid_dtypes_and_scale_factor_vec_size static method (line 1865) defines the allowlist: only Float4E2M1FN, Float8E5M2, and Float8E4M3FN are valid. Add an explicit guard to enforce this before the MXFP8 path:

🔧 Proposed fix
         if cutlass.const_expr(
             mA.element_type == cutlass.Uint8 and mB.element_type == cutlass.Uint8
         ):
             # FP4 packed path: 2 FP4 values per uint8 byte
             k = k_raw * 2
             a_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN)
             b_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN)
         elif cutlass.const_expr(mA.element_type != mB.element_type):
             raise TypeError(
                 "Unsupported mixed input dtypes for block-scaled GEMM: "
                 "mA and mB must have matching element_type "
                 "(both Uint8 for FP4 path, or both FP8 for MXFP8 path)."
             )
+        elif cutlass.const_expr(
+            mA.element_type != cutlass.Float8E4M3FN
+            and mA.element_type != cutlass.Float8E5M2
+        ):
+            raise TypeError(
+                "Unsupported input dtype for MXFP8 path: expected Float8E4M3FN or Float8E5M2 when inputs are not Uint8."
+            )
         else:
             # MXFP8 path: input tensors are already FP8.
             k = k_raw
             a_ptr = mA.iterator
             b_ptr = mB.iterator
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 2150-2154: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py` around lines 2142 -
2159, The else branch currently assumes any non-Uint8 matching dtype is MXFP8
and proceeds, but you must explicitly guard against unsupported dtypes: call the
same allowlist logic used by is_valid_dtypes_and_scale_factor_vec_size to verify
mA.element_type (and mB.element_type) is one of the supported FP8/FP4 types
(cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN) before taking
the MXFP8 path that sets k = k_raw and uses mA.iterator/mB.iterator; if the
element types are not in that set, raise a TypeError with a clear message rather
than letting downstream MXFP8 logic fail.

@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Mar 4, 2026

@bkryu this is a contribution by SGLang team. Could you assign this to someone from FlashInfer team for review? thanks!

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 4, 2026

@nv-yunzheq, this PR expands the cute dsl FP4 GEMM you integrated. Can you help review?

print(f"[VVERBOSE] {mat2_fp4.dtype = }")

alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
alpha_for_cute_dsl_mxfp4 = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why mxfp4 cutedsl backend have to be using a device tensor with value 1.0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I removed it, I encountered some compilation relating to make_fake_compact_tensor(cutlass.Float32, (1,)), as I believe they still share the exact same code path. To avoid touching too much cute-dsl which I am not super familiar with, I left it as passing it 1.0 for now. But I also think it can be removed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment something like TODO so that in the future people could understand why you add this and we should get it removed in the future?

Comment thread flashinfer/gemm/gemm_base.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
flashinfer/gemm/gemm_base.py (2)

3774-3789: Consider refactoring lambda assignments to function definitions.

Per static analysis (E731), prefer def statements over lambda assignments for clarity:

♻️ Suggested refactor
-            if kernel_type == "sm103" and Sm103Kernel is not None:
-                make_kernel = lambda: Sm103Kernel(
-                    sf_vec_size,
-                    mma_tiler_mn,
-                    cluster_shape_mn,
-                    use_tma_store,
-                    enable_pdl,
-                )
-            else:
-                make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel(
-                    sf_vec_size,
-                    mma_tiler_mn,
-                    cluster_shape_mn,
-                    use_prefetch,
-                    enable_pdl,
-                )
+            def make_kernel():
+                if kernel_type == "sm103" and Sm103Kernel is not None:
+                    return Sm103Kernel(
+                        sf_vec_size,
+                        mma_tiler_mn,
+                        cluster_shape_mn,
+                        use_tma_store,
+                        enable_pdl,
+                    )
+                else:
+                    return Sm100BlockScaledPersistentDenseGemmKernel(
+                        sf_vec_size,
+                        mma_tiler_mn,
+                        cluster_shape_mn,
+                        use_prefetch,
+                        enable_pdl,
+                    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3774 - 3789, Refactor the two
inline lambda assignments to named functions to satisfy E731: replace the lambda
assigned to make_kernel in the kernel_type == "sm103" branch with a def
make_kernel(): that returns Sm103Kernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_tma_store, enable_pdl), and similarly replace the
else-branch lambda with a def make_kernel(): that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); keep the same function name
make_kernel so callers are unchanged and ensure both definitions are placed in
the same scope where make_kernel was originally assigned.

2931-2932: Consider removing or documenting the unused SM version assignment.

The line _ = sm_major, sm_minor silences the unused variable warning but these parameters might be needed for SM-specific kernel selection (e.g., SM103-specific tactics). If they're reserved for future use, add a comment; otherwise, consider removing them from the function signature.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2931 - 2932, The assignment `_ =
sm_major, sm_minor` in gemm_base.py silently discards SM version parameters
(sm_major, sm_minor); either remove these parameters from the function signature
if they are not needed, or document their intentional reservation for future
SM-specific kernel selection by replacing that discard line with a brief comment
referencing SM-specific tactics (e.g., "reserved for SM-specific kernel
selection / SM103 tactics") and keep the parameters; locate the usage near
c_cutlass_dtype / cutlass_dtype_attr in the same function and update the
function signature and docstring accordingly (remove sm_major/sm_minor if
unused, or add the explanatory comment and leave them).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 2894: The kernel cache dictionaries _CUTE_DSL_MM_MXFP8_KERNEL_CACHE and
_CUTE_DSL_MM_FP4_KERNEL_CACHE are missing SM version in their keys which causes
kernel collisions across devices with different (sm_major, sm_minor); update the
cache key construction and any lookup/insert sites that use these caches to
include the device SM tuple (sm_major, sm_minor) as part of the key (i.e.,
change key tuples to prepend or append (sm_major, sm_minor) so compiled kernels
for Sm100 and Sm103 are distinct), and ensure all places that read from or write
to these caches (the code that constructs the current cache key and calls into
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE) are modified
consistently to use the new tuple form.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3774-3789: Refactor the two inline lambda assignments to named
functions to satisfy E731: replace the lambda assigned to make_kernel in the
kernel_type == "sm103" branch with a def make_kernel(): that returns
Sm103Kernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store,
enable_pdl), and similarly replace the else-branch lambda with a def
make_kernel(): that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); keep the same function name
make_kernel so callers are unchanged and ensure both definitions are placed in
the same scope where make_kernel was originally assigned.
- Around line 2931-2932: The assignment `_ = sm_major, sm_minor` in gemm_base.py
silently discards SM version parameters (sm_major, sm_minor); either remove
these parameters from the function signature if they are not needed, or document
their intentional reservation for future SM-specific kernel selection by
replacing that discard line with a brief comment referencing SM-specific tactics
(e.g., "reserved for SM-specific kernel selection / SM103 tactics") and keep the
parameters; locate the usage near c_cutlass_dtype / cutlass_dtype_attr in the
same function and update the function signature and docstring accordingly
(remove sm_major/sm_minor if unused, or add the explanatory comment and leave
them).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1c366cea-a4ec-422a-b78b-5244a8322d2d

📥 Commits

Reviewing files that changed from the base of the PR and between 6608b2f and 0703627.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

return alpha_tensor.reshape(1)


_CUTE_DSL_MM_MXFP8_KERNEL_CACHE: dict[tuple, tuple] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the actual kernel cache usage at the mentioned lines
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2880,2910p'

Repository: flashinfer-ai/flashinfer

Length of output: 1247


🏁 Script executed:

# Also check the second location mentioned
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3560,3580p'

Repository: flashinfer-ai/flashinfer

Length of output: 907


🏁 Script executed:

# Search for how kernel caches are used - key construction and retrieval
rg -n "KERNEL_CACHE\[|KERNEL_CACHE =" flashinfer/gemm/ --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any device-related parameters in the cache key tuples
rg -n "get_compute_capability|device|torch.cuda" flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1355


🏁 Script executed:

# See how kernels are actually compiled to understand if they're device-specific
rg -n "def.*compile|cupy_module|cute_dsl" flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1326


🏁 Script executed:

# Find where MXFP8 and FP4 kernel caches are actually used (accessed for get/set)
rg -n "_CUTE_DSL_MM_MXFP8_KERNEL_CACHE\[|_CUTE_DSL_MM_FP4_KERNEL_CACHE\[" flashinfer/gemm/gemm_base.py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the kernel compilation logic - find where kernels are actually compiled
rg -n "def _cute_dsl_gemm_mxfp8_runner|def _cute_dsl_gemm_fp4_runner" flashinfer/gemm/gemm_base.py -A 50 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3884


🏁 Script executed:

# Check if SM version is the only device-relevant information passed to kernel compilation
rg -n "compile.*kernel|kernel.*compile|cute_dsl.*kernel" flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 180


🏁 Script executed:

# Look at the ALPHA cache pattern - it IS device-keyed, let's understand why
rg -n "_CUTE_DSL_ALPHA_ONE_CACHE" flashinfer/gemm/gemm_base.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 954


🏁 Script executed:

# Search for where kernels are stored in and retrieved from caches
rg -n "MM_MXFP8_KERNEL_CACHE\|MM_FP4_KERNEL_CACHE" flashinfer/gemm/gemm_base.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Get more context around the runner functions to see cache access patterns
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2950,3050p'

Repository: flashinfer-ai/flashinfer

Length of output: 3815


🏁 Script executed:

# Also check FP4 runner's cache usage
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3620,3720p'

Repository: flashinfer-ai/flashinfer

Length of output: 4674


🏁 Script executed:

# Search for actual kernel invocation - where kernels are called
rg -n "def forward\|kernel\(" flashinfer/gemm/gemm_base.py | grep -A 5 -B 5 "2950\|2960\|2970\|2980\|2990\|3000\|3010\|3020"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at how the tactics/kernels are actually compiled and cached
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2990,3050p'

Repository: flashinfer-ai/flashinfer

Length of output: 2302


🏁 Script executed:

# Check FP4 forward method cache key construction
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3717,3820p'

Repository: flashinfer-ai/flashinfer

Length of output: 4248


🏁 Script executed:

# Look at _compile_block_scaled_gemm to understand what device info it uses
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2767,2850p'

Repository: flashinfer-ai/flashinfer

Length of output: 3201


🏁 Script executed:

# Search for any device-specific compilation logic in the CuTe DSL kernels
rg -n "device\|sm_major\|sm_minor" flashinfer/gemm/kernels/ --type py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check how runners are instantiated - are they created per-device or per-SM?
rg -n "_cute_dsl_gemm_mxfp8_runner\|_cute_dsl_gemm_fp4_runner" flashinfer/gemm/gemm_base.py | grep -v "def \|runner(sm"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at where the runner is actually created and used
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3170,3200p'

Repository: flashinfer-ai/flashinfer

Length of output: 1032


🏁 Script executed:

# Check if there's any caching of runner instances themselves
rg -n "def.*_cute_dsl_gemm.*_runner\|return.*GemmRunner" flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at how get_compute_capability is used to route to runners
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3140,3190p'

Repository: flashinfer-ai/flashinfer

Length of output: 2157


🏁 Script executed:

# Search for how SM version is used at runtime
rg -n "get_compute_capability" flashinfer/gemm/gemm_base.py -A 5 -B 5

Repository: flashinfer-ai/flashinfer

Length of output: 2084


🏁 Script executed:

# Let's check if there's any SM-specific kernel selection that would be cached
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3590,3610p'

Repository: flashinfer-ai/flashinfer

Length of output: 952


🏁 Script executed:

# Verify if both SM100 and SM103 kernels could be selected for same cache key
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2930,2935p'

Repository: flashinfer-ai/flashinfer

Length of output: 286


🏁 Script executed:

# Check if there's any multi-GPU test or scenario that shows potential issue
rg -n "multi.*gpu\|heterogeneous\|different.*device" flashinfer/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if kernel compilation depends on SM version or produces device-agnostic code
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2800,2870p'

Repository: flashinfer-ai/flashinfer

Length of output: 2449


🏁 Script executed:

# Look for any SM-specific compilation flags or kernel selection
rg -n "sm_major\|sm_minor\|sm_version" flashinfer/gemm/gemm_base.py | grep -i "compile\|kernel" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what gets passed to kernel compilation - does it include SM info?
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2850,2875p'

Repository: flashinfer-ai/flashinfer

Length of output: 734


🏁 Script executed:

# Check if runners are reused across calls or if they're created fresh each time
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3130,3220p'

Repository: flashinfer-ai/flashinfer

Length of output: 3606


🏁 Script executed:

# Verify that kernel class names are SM-specific
rg -n "Sm100BlockScaledPersistentDenseGemmKernel\|Sm103BlockScaledPersistentDenseGemmKernel" flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Final verification: confirm SM103 is disabled and would cause cache collision if re-enabled
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3595,3610p'

Repository: flashinfer-ai/flashinfer

Length of output: 816


🏁 Script executed:

# Check if there's any documentation or design rationale for not including SM in cache
rg -n "cache\|Cache" flashinfer/gemm/gemm_base.py | grep -i "key\|comment\|note" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 50


Add SM version to kernel cache keys to prevent kernel type mismatch on multi-SM systems.

The kernel caches don't include SM version in their keys. While compiled CuTe kernels are device-agnostic within the same SM version, when SM103 support is re-enabled, Sm100 and Sm103 kernels would share identical cache keys, causing incorrect kernel reuse. For example, a kernel compiled for SM103 would be replaced by or collide with an SM100 kernel in the module-level cache.

Add (sm_major, sm_minor) to the cache key tuple to prevent this collision:

  • Line 2894 (_CUTE_DSL_MM_MXFP8_KERNEL_CACHE)
  • Line 3570 (_CUTE_DSL_MM_FP4_KERNEL_CACHE)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` at line 2894, The kernel cache dictionaries
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE are missing SM
version in their keys which causes kernel collisions across devices with
different (sm_major, sm_minor); update the cache key construction and any
lookup/insert sites that use these caches to include the device SM tuple
(sm_major, sm_minor) as part of the key (i.e., change key tuples to prepend or
append (sm_major, sm_minor) so compiled kernels for Sm100 and Sm103 are
distinct), and ensure all places that read from or write to these caches (the
code that constructs the current cache key and calls into
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE) are modified
consistently to use the new tuple form.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment thread flashinfer/gemm/gemm_base.py Outdated
batch_size = 1

if tactic is None or tactic == -1:
tactic = ((128, 128), (1, 1), False, False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using something like _SM100_MMA_TILER_MN_CANDIDATES[0] instead of hardcode here

print(f"[VVERBOSE] {mat2_fp4.dtype = }")

alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
alpha_for_cute_dsl_mxfp4 = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment something like TODO so that in the future people could understand why you add this and we should get it removed in the future?

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !385 has been created, and the CI pipeline #45467505 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
benchmarks/routines/gemm.py (1)

1074-1080: The CuTe DSL alpha workaround looks obsolete now.

flashinfer/gemm/gemm_base.py:2880-2894 already normalizes alpha=None to a cached device tensor([1.0]), so this benchmark-only alpha_for_cute_dsl_mxfp4 path adds extra state without changing the launch contract. Passing alpha through unchanged would simplify the benchmark harness and remove one more backend-specific exception.

♻️ Proposed simplification
-    # TODO: for MXFP4, we don't need a global scale, we should change the compile interface to make
-    # alpha optional.
-    alpha_for_cute_dsl_mxfp4 = (
-        torch.tensor([1.0], dtype=torch.float32, device=device)
-        if not use_nvfp4
-        else None
-    )
@@
-                alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha),
+                alpha=alpha,
@@
-                alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha),
+                alpha=alpha,

Also applies to: 1101-1101, 1139-1139

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/gemm.py` around lines 1074 - 1080, Remove the
benchmark-only CuTe DSL alpha workaround: delete the special-case variable
alpha_for_cute_dsl_mxfp4 and its conditional construction (the use_nvfp4/device
torch.tensor branch) and instead pass the original alpha argument through
unchanged to the CuTe DSL calls; rely on gemm_base.py's normalization which
converts alpha=None to a cached device tensor, so update all call sites that
used alpha_for_cute_dsl_mxfp4 (the occurrences around the current blocks
including the ones referenced) to use the existing alpha variable directly and
remove any extra state or backend-specific branching (symbols to update:
alpha_for_cute_dsl_mxfp4, use_nvfp4, and the call sites that previously
substituted that variable).
flashinfer/gemm/gemm_base.py (2)

2943-2943: Use underscore prefix for unused unpacked variables.

Static analysis correctly flagged that a_descale, b_descale, and out are unpacked but never used in get_valid_tactics. Use underscore prefix to indicate intentionally unused variables.

♻️ Suggested fix
-            (a, b, a_descale, b_descale, _, out, _) = inputs
+            (a, b, _a_descale, _b_descale, _, _out, _) = inputs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` at line 2943, In get_valid_tactics update the
tuple unpack at "(a, b, a_descale, b_descale, _, out, _) = inputs" to mark
unused variables with a leading underscore so static analysis knows they're
intentionally unused; rename a_descale -> _a_descale, b_descale -> _b_descale
and out -> _out (leave the existing anonymous "_" entries as-is) so the function
signature and subsequent references remain correct.

3790-3804: Consider converting lambdas to local functions for clarity.

Static analysis flagged these lambda assignments. While they work correctly (called immediately, no late binding issues), converting to local def statements would be cleaner and follow Python best practices.

♻️ Suggested refactor
             if kernel_type == "sm103" and Sm103Kernel is not None:
-                make_kernel = lambda: Sm103Kernel(
-                    sf_vec_size,
-                    mma_tiler_mn,
-                    cluster_shape_mn,
-                    use_tma_store,
-                    enable_pdl,
-                )
+                def make_kernel():
+                    return Sm103Kernel(
+                        sf_vec_size,
+                        mma_tiler_mn,
+                        cluster_shape_mn,
+                        use_tma_store,
+                        enable_pdl,
+                    )
             else:
-                make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel(
-                    sf_vec_size,
-                    mma_tiler_mn,
-                    cluster_shape_mn,
-                    use_prefetch,
-                    enable_pdl,
-                )
+                def make_kernel():
+                    return Sm100BlockScaledPersistentDenseGemmKernel(
+                        sf_vec_size,
+                        mma_tiler_mn,
+                        cluster_shape_mn,
+                        use_prefetch,
+                        enable_pdl,
+                    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3790 - 3804, The two lambda
assignments to make_kernel should be replaced with local def functions for
clarity: in the branch that currently assigns make_kernel = lambda:
Sm103Kernel(...), define a local function def make_kernel() that returns
Sm103Kernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store,
enable_pdl), and in the else branch define def make_kernel() that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); ensure the local function names
match the original symbol make_kernel and capture the same variables
(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store/use_prefetch,
enable_pdl) so callers of make_kernel remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@benchmarks/routines/gemm.py`:
- Around line 1074-1080: Remove the benchmark-only CuTe DSL alpha workaround:
delete the special-case variable alpha_for_cute_dsl_mxfp4 and its conditional
construction (the use_nvfp4/device torch.tensor branch) and instead pass the
original alpha argument through unchanged to the CuTe DSL calls; rely on
gemm_base.py's normalization which converts alpha=None to a cached device
tensor, so update all call sites that used alpha_for_cute_dsl_mxfp4 (the
occurrences around the current blocks including the ones referenced) to use the
existing alpha variable directly and remove any extra state or backend-specific
branching (symbols to update: alpha_for_cute_dsl_mxfp4, use_nvfp4, and the call
sites that previously substituted that variable).

In `@flashinfer/gemm/gemm_base.py`:
- Line 2943: In get_valid_tactics update the tuple unpack at "(a, b, a_descale,
b_descale, _, out, _) = inputs" to mark unused variables with a leading
underscore so static analysis knows they're intentionally unused; rename
a_descale -> _a_descale, b_descale -> _b_descale and out -> _out (leave the
existing anonymous "_" entries as-is) so the function signature and subsequent
references remain correct.
- Around line 3790-3804: The two lambda assignments to make_kernel should be
replaced with local def functions for clarity: in the branch that currently
assigns make_kernel = lambda: Sm103Kernel(...), define a local function def
make_kernel() that returns Sm103Kernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_tma_store, enable_pdl), and in the else branch define def
make_kernel() that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); ensure the local function names
match the original symbol make_kernel and capture the same variables
(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store/use_prefetch,
enable_pdl) so callers of make_kernel remain unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: eb68dc46-c987-4a22-8bd4-7e563375cbba

📥 Commits

Reviewing files that changed from the base of the PR and between 0703627 and 59de382.

📒 Files selected for processing (2)
  • benchmarks/routines/gemm.py
  • flashinfer/gemm/gemm_base.py

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45467505: 9/20 passed

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ci result looks good. Approve.
Thanks for contribution to the project!

@flashinfer-bot flashinfer-bot merged commit 825c7e0 into flashinfer-ai:main Mar 6, 2026
30 of 31 checks passed
@b8zhong b8zhong deleted the brayden/mxfp4-blockscale-dense-gemm branch March 7, 2026 13:41
@YangXu1990uiuc
Copy link
Copy Markdown
Collaborator

what's the cudnn/cublaslt version tested here?

@b8zhong
Copy link
Copy Markdown
Contributor Author

b8zhong commented Mar 9, 2026

@YangXu1990uiuc Hi, I don't exactly remember (unfortunately, I also deleted the docker container...). I believe it was 9.13 though. But, I also noticed perf increase in cuDNN 9.19 in seperate testing, so I also agree that the gap could be closer on newer versions.

frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…mm (flashinfer-ai#2660)

<!-- .github/pull_request_template.md -->

## 📌 Description

This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0
scale.

For most shapes of mxfp4, it's performance is quite good. Geomean
speedup around 1.20x. Passed refcheck.

<img width="2384" height="4830" alt="image"
src="https://github.com/user-attachments/assets/4e1bf17f-adaa-464c-92e4-4d28e7776dc7"
/>

For mxfp8, the performance is a bit lacking, as I recycled most of the
cute-dsl heuristics for mxfp4. Mainly for enablement first.

Local testing result (B200):

```
python tests/gemm/test_mm_fp4.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 4512 items                                                                                                                                                                                                                                                 

tests/gemm/test_mm_fp4.py .............................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  5%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss................................ [ 10%]
.............................................................................................................................................................................................................................................................. [ 16%]
.............................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 21%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss.......................................................................... [ 27%]
....................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssss [ 44%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 61%]
ssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................... [ 78%]
......................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 95%]
............................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                   [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_fp4.py: 588 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_fp4.py: 294 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1440 passed, 3072 skipped, 882 warnings in 535.13s (0:08:55) ====================================================================================================
```

```
pytest tests/gemm/test_mm_mxfp8.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 2131 items                                                                                                                                                                                                                                                 

tests/gemm/test_mm_mxfp8.py .................................................................................................................................................................................................................................. [ 10%]
.............................................................................................................................................................................................................................................................. [ 22%]
................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 34%]
..................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................................................................................ [ 46%]
.............................................................................................................................................................................................................................................................. [ 58%]
......................................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%]
ssssssssssssssssssssssssssssssssssssssssss......................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssss............ [ 94%]
...............................................................................................................................                                                                                                                                [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_mxfp8.py: 314 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_mxfp8.py: 157 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1633 passed, 498 skipped, 471 warnings in 245.76s (0:04:05) =====================================================================================================
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added CuTe DSL ("cute-dsl") as a selectable backend for MXFP8 and FP4
matrix-multiply paths and support for both FP4 and MXFP8 input
encodings.

* **Behavior**
* Backend-specific alpha handling and availability/layout checks for the
CuTe DSL path; runtime skips when layout/scale semantics aren’t
supported.

* **Tests**
* Tests and benchmarks expanded to include "cute-dsl" and "auto" in
backend selections and validate encoding/layout behaviors.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…mm (flashinfer-ai#2660)

<!-- .github/pull_request_template.md -->

## 📌 Description

This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0
scale.

For most shapes of mxfp4, it's performance is quite good. Geomean
speedup around 1.20x. Passed refcheck.

<img width="2384" height="4830" alt="image"
src="https://github.com/user-attachments/assets/4e1bf17f-adaa-464c-92e4-4d28e7776dc7"
/>

For mxfp8, the performance is a bit lacking, as I recycled most of the
cute-dsl heuristics for mxfp4. Mainly for enablement first.

Local testing result (B200):

```
python tests/gemm/test_mm_fp4.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 4512 items

tests/gemm/test_mm_fp4.py .............................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  5%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss................................ [ 10%]
.............................................................................................................................................................................................................................................................. [ 16%]
.............................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 21%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss.......................................................................... [ 27%]
....................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssss [ 44%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 61%]
ssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................... [ 78%]
......................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 95%]
............................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                   [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_fp4.py: 588 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_fp4.py: 294 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1440 passed, 3072 skipped, 882 warnings in 535.13s (0:08:55) ====================================================================================================
```

```
pytest tests/gemm/test_mm_mxfp8.py
======================================================================================================================== test session starts =========================================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
plugins: anyio-4.12.1, typeguard-4.5.1
collected 2131 items

tests/gemm/test_mm_mxfp8.py .................................................................................................................................................................................................................................. [ 10%]
.............................................................................................................................................................................................................................................................. [ 22%]
................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 34%]
..................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................................................................................ [ 46%]
.............................................................................................................................................................................................................................................................. [ 58%]
......................................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%]
ssssssssssssssssssssssssssssssssssssssssss......................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssss............ [ 94%]
...............................................................................................................................                                                                                                                                [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
tests/gemm/test_mm_mxfp8.py: 314 warnings
  /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/gemm/test_mm_mxfp8.py: 157 warnings
  /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code.
    c_producer_group = pipeline.CooperativeGroup(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================== 1633 passed, 498 skipped, 471 warnings in 245.76s (0:04:05) =====================================================================================================
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added CuTe DSL ("cute-dsl") as a selectable backend for MXFP8 and FP4
matrix-multiply paths and support for both FP4 and MXFP8 input
encodings.

* **Behavior**
* Backend-specific alpha handling and availability/layout checks for the
CuTe DSL path; runtime skips when layout/scale semantics aren’t
supported.

* **Tests**
* Tests and benchmarks expanded to include "cute-dsl" and "auto" in
backend selections and validate encoding/layout behaviors.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants