feat: support mxfp4 & mxfp8 entrypoint for blackwell cutedsl dense gemm#2660
Conversation
📝 WalkthroughWalkthroughAdds CuTe DSL backend support to FP4 and MXFP8 GEMM: backend wiring, availability checks, requirement validators, CuTe DSL runners and kernel caches, dual-input FP4/MXFP8 kernel paths, and updated benchmarks/tests to include Changes
Sequence Diagram(s)sequenceDiagram
actor User as User Code
participant Dispatcher as GEMM Dispatcher
participant BackendSel as Backend Selector
participant ReqCheck as Requirement Validator
participant Runner as CuTe DSL Runner
participant Kernel as Kernel Executor
User->>Dispatcher: call mm_fp4 / mm_mxfp8 (backend="cute-dsl")
Dispatcher->>BackendSel: resolve backend
BackendSel->>ReqCheck: validate requirements (_cute_dsl_gemm_*_requirement)
alt requirements valid
ReqCheck-->>BackendSel: valid
BackendSel->>Runner: instantiate / fetch runner
Runner->>Runner: generate tactics & compile if needed
Runner->>Kernel: launch kernel with chosen tactic
Kernel->>Kernel: branch: FP4 recast (uint8→fp4, k=k_raw*2) or MXFP8 direct (float8, k=k_raw)
Kernel-->>Runner: complete
Runner-->>Dispatcher: return output tensor
else requirements invalid
ReqCheck-->>BackendSel: invalid
BackendSel-->>Dispatcher: fallback / error
end
Dispatcher-->>User: return result or error
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the mixed-precision GEMM capabilities by integrating the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for mxfp4 and mxfp8 data types to the cute-dsl backend for dense GEMM operations on Blackwell GPUs. The changes include updating backend lists, adding new kernel runners and requirement checks, and modifying existing FP4 runners to handle the new data types. The test suite is also updated to cover these new features. My review focuses on improving code maintainability by reducing duplication and simplifying logic. Overall, the changes are well-structured and align with the PR's objectives.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
2787-2792: Cache SM count once during tactic search.
torch.cuda.get_device_properties(...).multi_processor_countis queried repeatedly inside nested loops. Hoisting it once per call reduces Python/CUDA overhead during autotune search.♻️ Suggested change
def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> list: (a, b, a_descale, b_descale, _, out, _) = inputs m = a.shape[0] real_k = a.shape[1] n = b.shape[1] + sm_count = torch.cuda.get_device_properties(a.device).multi_processor_count @@ if use_prefetch: cta_nums = self._get_approximate_cta_nums( kernel_m, kernel_n, mma_tiler_mn, cluster_shape_mn ) - sm_count = torch.cuda.get_device_properties( - a.device - ).multi_processor_count cta_wave_ratio = cta_nums / sm_count if not (0.5 < cta_wave_ratio < 1.0 or real_k >= 8192): continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2787 - 2792, Hoist the CUDA SM query out of the inner nested loops by calling torch.cuda.get_device_properties(a.device).multi_processor_count once at the start of the tactic search (store in a local sm_count variable) and reuse that sm_count when computing cta_wave_ratio (used with cta_nums and real_k) instead of calling get_device_properties repeatedly; update the scope where sm_count, cta_wave_ratio, cta_nums and real_k are used so the single cached sm_count replaces the repeated calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/gemm.py`:
- Line 1300: The autotune filtering currently excludes "auto" causing
user-requested backends to be removed; update the autotune_supported_backends
list (and the equivalent logic around lines where it's replicated) to include
"auto" or add explicit handling so that run_backend("auto") is preserved when
--autotune is set; locate the autotune_supported_backends symbol in
benchmarks/routines/gemm.py (and the duplicate logic around the 1353-1361
region) and either add "auto" to the list or alter the filter to allow "auto"
through before applying autotune-specific backend restrictions.
- Around line 1094-1098: The code is allocating torch.tensor([1.0],
dtype=torch.float32, device=device) inside the timed path when (not use_nvfp4
and backend == "cute-dsl"), which creates a new CUDA tensor each invocation and
skews timings; fix by creating a single reusable scalar tensor and reusing it
instead of allocating per call (e.g., compute a module-/scope-level or
outer-loop cached torch.tensor(1.0, dtype=torch.float32, device=device) or use a
Python float when the backend accepts it) and replace the inline allocation used
for the alpha argument (the conditional that checks use_nvfp4 and backend ==
"cute-dsl") in both occurrences (the alpha passed at the shown diff and the
similar block at 1136-1140).
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 2139-2151: The code currently selects the FP4 packed path based
only on mA.element_type, which can corrupt results if mB has a different dtype;
update the conditional to require both mA.element_type and mB.element_type be
cutlass.Uint8 (e.g., if cutlass.const_expr(mA.element_type == cutlass.Uint8 and
mB.element_type == cutlass.Uint8)): only in that case set k = k_raw * 2 and
recast both iterators with cute.recast_ptr(..., dtype=cutlass.Float4E2M1FN);
otherwise treat inputs as MXFP8 (k = k_raw with a_ptr = mA.iterator and b_ptr =
mB.iterator) and add a clear error/exception if dtypes are mismatched but not
supported so silent corruption cannot occur.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2787-2792: Hoist the CUDA SM query out of the inner nested loops
by calling torch.cuda.get_device_properties(a.device).multi_processor_count once
at the start of the tactic search (store in a local sm_count variable) and reuse
that sm_count when computing cta_wave_ratio (used with cta_nums and real_k)
instead of calling get_device_properties repeatedly; update the scope where
sm_count, cta_wave_ratio, cta_nums and real_k are used so the single cached
sm_count replaces the repeated calls.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pytests/gemm/test_mm_fp4.pytests/gemm/test_mm_mxfp8.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
flashinfer/gemm/gemm_base.py (3)
2700-2702: Remove dead code assignment.The
sm_majorandsm_minorparameters are passed to the function but only used in a no-op assignment. If they're reserved for future use, consider adding a comment. Otherwise, the parameters could be removed from the signature or prefixed with underscore.🧹 Suggested cleanup
- _ = sm_major, sm_minor + # Reserved for future SM-specific kernel selection + _ = sm_major, sm_minorOr if truly unused:
def _cute_dsl_gemm_mxfp8_runner( - sm_major: int, - sm_minor: int, + sm_major: int, # noqa: ARG001 - reserved for future SM-specific paths + sm_minor: int, # noqa: ARG001 - reserved for future SM-specific paths enable_pdl: bool, out_dtype: torch.dtype, ): - _ = sm_major, sm_minor🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2700 - 2702, The assignment "_ = sm_major, sm_minor" is dead code; either remove that line entirely or make the unused parameters explicit by renaming them (e.g., _sm_major, _sm_minor) or adding a clarifying comment that they're reserved for future use. Update the function signature accordingly (if renaming) and adjust any callers if you change parameter names; otherwise just delete the no-op assignment in gemm_base.py near where c_cutlass_dtype is set so sm_major and sm_minor are not silently ignored.
2719-2721: Consider prefixing unused variables with underscore.The variables
a_descale,b_descale, andoutare unpacked but never used. Prefixing them with_would suppress linter warnings and clarify intent.🧹 Suggested cleanup
- (a, b, a_descale, b_descale, _, out, _) = inputs + (a, b, _a_descale, _b_descale, _, _out, _) = inputs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2719 - 2721, The unpacking in the method that processes "inputs" currently binds a_descale, b_descale, and out but never uses them; rename those to start with an underscore (e.g., _a_descale, _b_descale, _out or simply _ ) in the tuple assignment (the line "(a, b, a_descale, b_descale, _, out, _) = inputs") so linters know they are intentionally unused; update only the variable names in that unpacking and leave all other logic (variables a, b, m, etc.) unchanged.
2961-2983: Clarify:cute-dslis not auto-selected for mm_mxfp8.The heuristic function only returns
["cutlass"]and doesn't include"cute-dsl". This means users must explicitly specifybackend="cute-dsl"to use the CuTe DSL path - it won't be auto-selected.Based on the PR description mentioning "weaker performance" for mxfp8, this appears intentional. Consider adding a brief comment in
_heuristic_func_mm_mxfp8to document this design choice:📝 Suggested documentation
def _heuristic_func_mm_mxfp8( ... ) -> List[str]: + # Note: cute-dsl is available but not auto-selected due to performance characteristics. + # Users can explicitly specify backend="cute-dsl" if desired. if "cutlass" in suitable_backends: return ["cutlass"] return []🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2961 - 2983, The heuristic in _heuristic_func_mm_mxfp8 currently returns only ["cutlass"], intentionally excluding "cute-dsl" since CuTe DSL is not auto-selected for mxFP8 due to weaker performance; update the function by adding a concise comment above or inside _heuristic_func_mm_mxfp8 (and mention mm_mxfp8.suitable_auto_backends in the comment) stating that "cute-dsl" is intentionally omitted from auto-selection and must be explicitly requested via backend="cute-dsl", and briefly note the rationale (weaker performance) so future readers understand this design choice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 2142-2159: The else branch currently assumes any non-Uint8
matching dtype is MXFP8 and proceeds, but you must explicitly guard against
unsupported dtypes: call the same allowlist logic used by
is_valid_dtypes_and_scale_factor_vec_size to verify mA.element_type (and
mB.element_type) is one of the supported FP8/FP4 types (cutlass.Float4E2M1FN,
cutlass.Float8E5M2, cutlass.Float8E4M3FN) before taking the MXFP8 path that sets
k = k_raw and uses mA.iterator/mB.iterator; if the element types are not in that
set, raise a TypeError with a clear message rather than letting downstream MXFP8
logic fail.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2700-2702: The assignment "_ = sm_major, sm_minor" is dead code;
either remove that line entirely or make the unused parameters explicit by
renaming them (e.g., _sm_major, _sm_minor) or adding a clarifying comment that
they're reserved for future use. Update the function signature accordingly (if
renaming) and adjust any callers if you change parameter names; otherwise just
delete the no-op assignment in gemm_base.py near where c_cutlass_dtype is set so
sm_major and sm_minor are not silently ignored.
- Around line 2719-2721: The unpacking in the method that processes "inputs"
currently binds a_descale, b_descale, and out but never uses them; rename those
to start with an underscore (e.g., _a_descale, _b_descale, _out or simply _ ) in
the tuple assignment (the line "(a, b, a_descale, b_descale, _, out, _) =
inputs") so linters know they are intentionally unused; update only the variable
names in that unpacking and leave all other logic (variables a, b, m, etc.)
unchanged.
- Around line 2961-2983: The heuristic in _heuristic_func_mm_mxfp8 currently
returns only ["cutlass"], intentionally excluding "cute-dsl" since CuTe DSL is
not auto-selected for mxFP8 due to weaker performance; update the function by
adding a concise comment above or inside _heuristic_func_mm_mxfp8 (and mention
mm_mxfp8.suitable_auto_backends in the comment) stating that "cute-dsl" is
intentionally omitted from auto-selection and must be explicitly requested via
backend="cute-dsl", and briefly note the rationale (weaker performance) so
future readers understand this design choice.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/routines/gemm.py
| if cutlass.const_expr( | ||
| mA.element_type == cutlass.Uint8 and mB.element_type == cutlass.Uint8 | ||
| ): | ||
| # FP4 packed path: 2 FP4 values per uint8 byte | ||
| k = k_raw * 2 | ||
| a_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN) | ||
| b_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN) | ||
| elif cutlass.const_expr(mA.element_type != mB.element_type): | ||
| raise TypeError( | ||
| "Unsupported mixed input dtypes for block-scaled GEMM: " | ||
| "mA and mB must have matching element_type " | ||
| "(both Uint8 for FP4 path, or both FP8 for MXFP8 path)." | ||
| ) | ||
| else: | ||
| # MXFP8 path: input tensors are already FP8. | ||
| k = k_raw | ||
| a_ptr = mA.iterator | ||
| b_ptr = mB.iterator |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# First, check if the file exists and get the target lines
echo "=== File exists check ==="
if [ -f "flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py" ]; then
echo "File found"
else
find . -name "dense_blockscaled_gemm_sm100.py" 2>/dev/null || echo "File not found"
fi
echo
echo "=== Lines 2142-2159 ==="
sed -n '2140,2165p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 1182
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Search for dtype validation and allowlist checks
echo "=== Search for is_valid_dtypes_and_scale_factor_vec_size or similar ==="
rg -n 'is_valid_dtypes|Float8E4M3FN|Float8E5M2|Float4E2M1FN' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -30
echo
echo "=== Search for wrapper/caller functions that might validate dtypes ==="
rg -n 'def.*blockscaled.*gemm|def make_blockscaled' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 1523
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check for any dtype validation at function entry
echo "=== Check function signature and early validation ==="
rg -n -B10 'if cutlass.const_expr' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -A10 'element_type == cutlass.Uint8' | head -40Repository: flashinfer-ai/flashinfer
Length of output: 632
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Get more context around the is_valid_dtypes_and_scale_factor_vec_size function
echo "=== Function definition at line 1865 ==="
sed -n '1865,1920p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check caller at line 2070 ==="
sed -n '2065,2085p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 2871
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find which function contains lines 2142-2159
echo "=== Find function containing the dtype routing logic ==="
sed -n '2100,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -E 'def |class '
echo
echo "=== Full context: lines 2080-2175 ==="
sed -n '2080,2175p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 124
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if dtype validation happens at the Python wrapper level (before kernel instantiation)
echo "=== Check for wrapper functions that call this kernel ==="
rg -n 'blockscaled_gemm' flashinfer/gemm/kernels/*.py | grep -E 'def |class ' | head -20
echo
echo "=== Check higher-level APIs that might validate dtypes ==="
find flashinfer -name "*.py" -type f -exec grep -l "blockscaled_gemm\|block_scaled" {} \; 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 125
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the function that contains the dtype routing logic
echo "=== Search for function definition before line 2142 ==="
sed -n '2040,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | tail -50 | head -40
echo
echo "=== Full context: lines 2120-2170 ==="
sed -n '2120,2170p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 4801
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the function signature
echo "=== Grep for 'def' before line 2142 ==="
rg -n '^ def ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | awk -F: '{if($1<2142) print}' | tail -1
echo
echo "=== Get function name and signature ==="
sed -n '2035,2145p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 5247
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check where dtype validation happens - at caller site
echo "=== Look for calls to the function containing the routing logic ==="
rg -n 'def make_blockscaled_gemm' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check if dtype validation is done at instantiation time ==="
rg -n 'ab_dtype.*element_type|element_type.*ab_dtype' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 135
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Let's trace through how this kernel is used
echo "=== Check if there are Python wrappers that validate dtypes before kernel instantiation ==="
find flashinfer -name "*.py" -type f | xargs grep -l "make_blockscaled\|DenseBlockScaledGemm" | head -5
echo
echo "=== Check what calls DenseBlockScaledGemm.make ==="
rg -n 'DenseBlockScaledGemm' flashinfer/ -A2 -B2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 634
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the class that contains this wrapper
echo "=== Find the class definition ==="
sed -n '1800,2100p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -E '^class ' | tail -1
echo
echo "=== Get class name from class definition near line 2089 ==="
rg -n '^class ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check how wrapper is invoked - look at can_implement vs wrapper relationship ==="
sed -n '2060,2095p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 101
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if there's a public API that validates before calling wrapper
echo "=== Search for public methods that call wrapper ==="
rg -n '\.wrapper\(' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check higher-level wrappers ==="
grep -r 'def.*blockscaled' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look at one of the higher-level wrapper files
echo "=== Check one of the fusion wrappers ==="
head -100 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py
echo
echo "=== Look for where the kernel wrapper is actually instantiated/called ==="
rg -n 'DenseBlockScaledGemm\|\.wrapper' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 4733
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the class definition more carefully
echo "=== Find DenseBlockScaledGemm class ==="
rg -n 'class DenseBlockScaledGemm' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -A5
echo
echo "=== Check entire file structure for class ==="
rg -n '^class ' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Look at lines 1000-1100 for class definition ==="
sed -n '1000,1100p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 107
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check where this kernel is instantiated/used at runtime
echo "=== Search all Python files for references to dense_blockscaled_gemm ==="
find flashinfer -name "*.py" -type f | xargs grep -l 'dense_blockscaled_gemm_sm100' 2>/dev/null
echo
echo "=== Check what calls or validates before kernel execution ==="
rg -n 'import.*DenseBlockScaledGemm\|from.*dense_blockscaled_gemm' flashinfer/ --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 232
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Let's check if there's validation done at Python binding/instantiation level
echo "=== Check for any dtype checking in blockscaled_contiguous_grouped_gemm.py ==="
grep -n 'element_type\|Float8\|Int8\|dtype' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -30
echo
echo "=== Search for where wrapper() is called ==="
rg -n 'wrapper\(' flashinfer/fused_moe/cute_dsl/blackwell/ --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1589
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check gemm_base.py for how dense_blockscaled_gemm is used
echo "=== Check gemm_base.py for kernel instantiation ==="
cat flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look at the blockscaled_contiguous_grouped_gemm initialization and how it validates
echo "=== Check Sm100BlockScaledContiguousGroupedGemmKernel initialization (around line 410-440) ==="
sed -n '400,500p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 4167
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if there's earlier validation before wrapper is called
echo "=== Check where wrapper is actually invoked and what dtypes reach it ==="
sed -n '500,600p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 4058
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find actual call site of wrapper() - check if there's Python-level validation
echo "=== Search for methods that call wrapper() or instantiate blockscaled kernels ==="
rg -n '\.wrapper\(|DenseBlockScaledGemm\|Sm100BlockScaledContiguousGroupedGemmKernel' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -30
echo
echo "=== Check test files for blockscaled ==="
find flashinfer -name "*test*" -type f -name "*.py" | xargs grep -l "blockscaled\|DenseBlockScaled" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if wrapper is called at the public API level
echo "=== Check execute/run methods that call wrapper ==="
sed -n '600,800p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | grep -A30 'def execute\|def run\|wrapper('
echo
echo "=== Look for where wrapper is actually called in blockscaled_contiguous_grouped_gemm.py ==="
rg -n 'wrapper(' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py -A5 -B5Repository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check what dtypes are allowed into the kernel at Python level
echo "=== Check if there's validation before wrapper invocation ==="
grep -n 'Float8E4M3FN\|Float8E5M2\|Float4E2M1FN\|element_type' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1276
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look at the can_implement method more carefully to see if it's called at runtime
echo "=== Get full can_implement method ==="
sed -n '2060,2088p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check if can_implement is called before wrapper invocation ==="
rg -n 'can_implement' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check if this kernel is directly exposed to users or only used internally ==="
grep -n 'def.*blockscaled\|class.*blockscaled' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -5Repository: flashinfer-ai/flashinfer
Length of output: 1744
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check what happens at the public API level
echo "=== Check how blockscaled kernels are exposed ==="
find flashinfer -name "*.py" -type f | xargs grep -l 'blockscaled_contiguous_grouped_gemm\|dense_blockscaled' | grep -v test | grep -v cute_dsl
echo
echo "=== Check MOE or high-level wrappers ==="
find flashinfer/fused_moe -name "*.py" -type f | xargs grep -l 'blockscaled' | head -5Repository: flashinfer-ai/flashinfer
Length of output: 776
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Let's look at how mA and mB get their element_type - what dtypes can they actually have?
echo "=== Check documentation about what dtypes reach wrapper ==="
sed -n '2095,2140p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 2506
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check the Sm100BlockScaledContiguousGroupedGemmKernel more carefully
echo "=== Full __call__ method and setup ==="
sed -n '300,450p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | grep -A50 'def __call__\|def forward'
echo
echo "=== Check what dtypes actually reach the kernel at runtime ==="
sed -n '400,450p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 4794
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Check if there's any validation at module load time or decorator level
echo "=== Look for validation decorators or checks before kernel invocation ==="
sed -n '100,300p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py
echo
echo "=== Check if can_implement filters out unsupported configs ==="
grep -B20 -A20 'can_implement' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 8920
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Determine if there's any runtime validation before wrapper() is called
echo "=== Check kernel class hierarchy and method resolution ==="
sed -n '60,150p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 3743
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Final check: does wrapper() get called at runtime without dtype validation?
echo "=== Search for where wrapper is directly called at runtime ==="
rg -n '@cute.jit' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -A3
echo
echo "=== Confirm if can_implement is called at wrapper invocation time ==="
sed -n '2080,2095p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py
echo
echo "=== Check what happens in _setup_attributes when unsupported dtype is passed ==="
sed -n '1920,1960p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 2433
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look at the sm100_utils function that would fail with unsupported dtype
echo "=== Check make_blockscaled_trivial_tiled_mma behavior with unsupported dtype ==="
find flashinfer -name "*blackwell*helper*" -o -name "*sm100*util*" | head -5
echo
echo "=== Check if there's dtype validation in sm100_utils ==="
rg -n 'make_blockscaled_trivial_tiled_mma' flashinfer/ --type py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 1497
Add explicit FP8 allowlist guard in the else path to reject unsupported non-Uint8 input types before MXFP8 processing.
After the mixed-dtype check, the else path accepts any matching non-Uint8 dtype without validating it's one of the supported FP8 types. This allows unsupported pairs like Int8/Int8 or Float32/Float32 to route into the MXFP8 flow, where they fail later with less actionable errors. The is_valid_dtypes_and_scale_factor_vec_size static method (line 1865) defines the allowlist: only Float4E2M1FN, Float8E5M2, and Float8E4M3FN are valid. Add an explicit guard to enforce this before the MXFP8 path:
🔧 Proposed fix
if cutlass.const_expr(
mA.element_type == cutlass.Uint8 and mB.element_type == cutlass.Uint8
):
# FP4 packed path: 2 FP4 values per uint8 byte
k = k_raw * 2
a_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN)
b_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN)
elif cutlass.const_expr(mA.element_type != mB.element_type):
raise TypeError(
"Unsupported mixed input dtypes for block-scaled GEMM: "
"mA and mB must have matching element_type "
"(both Uint8 for FP4 path, or both FP8 for MXFP8 path)."
)
+ elif cutlass.const_expr(
+ mA.element_type != cutlass.Float8E4M3FN
+ and mA.element_type != cutlass.Float8E5M2
+ ):
+ raise TypeError(
+ "Unsupported input dtype for MXFP8 path: expected Float8E4M3FN or Float8E5M2 when inputs are not Uint8."
+ )
else:
# MXFP8 path: input tensors are already FP8.
k = k_raw
a_ptr = mA.iterator
b_ptr = mB.iterator🧰 Tools
🪛 Ruff (0.15.2)
[warning] 2150-2154: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py` around lines 2142 -
2159, The else branch currently assumes any non-Uint8 matching dtype is MXFP8
and proceeds, but you must explicitly guard against unsupported dtypes: call the
same allowlist logic used by is_valid_dtypes_and_scale_factor_vec_size to verify
mA.element_type (and mB.element_type) is one of the supported FP8/FP4 types
(cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN) before taking
the MXFP8 path that sets k = k_raw and uses mA.iterator/mB.iterator; if the
element types are not in that set, raise a TypeError with a clear message rather
than letting downstream MXFP8 logic fail.
|
@bkryu this is a contribution by SGLang team. Could you assign this to someone from FlashInfer team for review? thanks! |
|
@nv-yunzheq, this PR expands the cute dsl FP4 GEMM you integrated. Can you help review? |
| print(f"[VVERBOSE] {mat2_fp4.dtype = }") | ||
|
|
||
| alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None | ||
| alpha_for_cute_dsl_mxfp4 = ( |
There was a problem hiding this comment.
Is there a reason why mxfp4 cutedsl backend have to be using a device tensor with value 1.0?
There was a problem hiding this comment.
When I removed it, I encountered some compilation relating to make_fake_compact_tensor(cutlass.Float32, (1,)), as I believe they still share the exact same code path. To avoid touching too much cute-dsl which I am not super familiar with, I left it as passing it 1.0 for now. But I also think it can be removed.
There was a problem hiding this comment.
Could you add a comment something like TODO so that in the future people could understand why you add this and we should get it removed in the future?
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
flashinfer/gemm/gemm_base.py (2)
3774-3789: Consider refactoring lambda assignments to function definitions.Per static analysis (E731), prefer
defstatements over lambda assignments for clarity:♻️ Suggested refactor
- if kernel_type == "sm103" and Sm103Kernel is not None: - make_kernel = lambda: Sm103Kernel( - sf_vec_size, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - enable_pdl, - ) - else: - make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel( - sf_vec_size, - mma_tiler_mn, - cluster_shape_mn, - use_prefetch, - enable_pdl, - ) + def make_kernel(): + if kernel_type == "sm103" and Sm103Kernel is not None: + return Sm103Kernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + enable_pdl, + ) + else: + return Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_prefetch, + enable_pdl, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 3774 - 3789, Refactor the two inline lambda assignments to named functions to satisfy E731: replace the lambda assigned to make_kernel in the kernel_type == "sm103" branch with a def make_kernel(): that returns Sm103Kernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store, enable_pdl), and similarly replace the else-branch lambda with a def make_kernel(): that returns Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); keep the same function name make_kernel so callers are unchanged and ensure both definitions are placed in the same scope where make_kernel was originally assigned.
2931-2932: Consider removing or documenting the unused SM version assignment.The line
_ = sm_major, sm_minorsilences the unused variable warning but these parameters might be needed for SM-specific kernel selection (e.g., SM103-specific tactics). If they're reserved for future use, add a comment; otherwise, consider removing them from the function signature.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2931 - 2932, The assignment `_ = sm_major, sm_minor` in gemm_base.py silently discards SM version parameters (sm_major, sm_minor); either remove these parameters from the function signature if they are not needed, or document their intentional reservation for future SM-specific kernel selection by replacing that discard line with a brief comment referencing SM-specific tactics (e.g., "reserved for SM-specific kernel selection / SM103 tactics") and keep the parameters; locate the usage near c_cutlass_dtype / cutlass_dtype_attr in the same function and update the function signature and docstring accordingly (remove sm_major/sm_minor if unused, or add the explanatory comment and leave them).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 2894: The kernel cache dictionaries _CUTE_DSL_MM_MXFP8_KERNEL_CACHE and
_CUTE_DSL_MM_FP4_KERNEL_CACHE are missing SM version in their keys which causes
kernel collisions across devices with different (sm_major, sm_minor); update the
cache key construction and any lookup/insert sites that use these caches to
include the device SM tuple (sm_major, sm_minor) as part of the key (i.e.,
change key tuples to prepend or append (sm_major, sm_minor) so compiled kernels
for Sm100 and Sm103 are distinct), and ensure all places that read from or write
to these caches (the code that constructs the current cache key and calls into
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE) are modified
consistently to use the new tuple form.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3774-3789: Refactor the two inline lambda assignments to named
functions to satisfy E731: replace the lambda assigned to make_kernel in the
kernel_type == "sm103" branch with a def make_kernel(): that returns
Sm103Kernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store,
enable_pdl), and similarly replace the else-branch lambda with a def
make_kernel(): that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); keep the same function name
make_kernel so callers are unchanged and ensure both definitions are placed in
the same scope where make_kernel was originally assigned.
- Around line 2931-2932: The assignment `_ = sm_major, sm_minor` in gemm_base.py
silently discards SM version parameters (sm_major, sm_minor); either remove
these parameters from the function signature if they are not needed, or document
their intentional reservation for future SM-specific kernel selection by
replacing that discard line with a brief comment referencing SM-specific tactics
(e.g., "reserved for SM-specific kernel selection / SM103 tactics") and keep the
parameters; locate the usage near c_cutlass_dtype / cutlass_dtype_attr in the
same function and update the function signature and docstring accordingly
(remove sm_major/sm_minor if unused, or add the explanatory comment and leave
them).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1c366cea-a4ec-422a-b78b-5244a8322d2d
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
| return alpha_tensor.reshape(1) | ||
|
|
||
|
|
||
| _CUTE_DSL_MM_MXFP8_KERNEL_CACHE: dict[tuple, tuple] = {} |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the actual kernel cache usage at the mentioned lines
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2880,2910p'Repository: flashinfer-ai/flashinfer
Length of output: 1247
🏁 Script executed:
# Also check the second location mentioned
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3560,3580p'Repository: flashinfer-ai/flashinfer
Length of output: 907
🏁 Script executed:
# Search for how kernel caches are used - key construction and retrieval
rg -n "KERNEL_CACHE\[|KERNEL_CACHE =" flashinfer/gemm/ --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any device-related parameters in the cache key tuples
rg -n "get_compute_capability|device|torch.cuda" flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1355
🏁 Script executed:
# See how kernels are actually compiled to understand if they're device-specific
rg -n "def.*compile|cupy_module|cute_dsl" flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1326
🏁 Script executed:
# Find where MXFP8 and FP4 kernel caches are actually used (accessed for get/set)
rg -n "_CUTE_DSL_MM_MXFP8_KERNEL_CACHE\[|_CUTE_DSL_MM_FP4_KERNEL_CACHE\[" flashinfer/gemm/gemm_base.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the kernel compilation logic - find where kernels are actually compiled
rg -n "def _cute_dsl_gemm_mxfp8_runner|def _cute_dsl_gemm_fp4_runner" flashinfer/gemm/gemm_base.py -A 50 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3884
🏁 Script executed:
# Check if SM version is the only device-relevant information passed to kernel compilation
rg -n "compile.*kernel|kernel.*compile|cute_dsl.*kernel" flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 180
🏁 Script executed:
# Look at the ALPHA cache pattern - it IS device-keyed, let's understand why
rg -n "_CUTE_DSL_ALPHA_ONE_CACHE" flashinfer/gemm/gemm_base.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 954
🏁 Script executed:
# Search for where kernels are stored in and retrieved from caches
rg -n "MM_MXFP8_KERNEL_CACHE\|MM_FP4_KERNEL_CACHE" flashinfer/gemm/gemm_base.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get more context around the runner functions to see cache access patterns
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2950,3050p'Repository: flashinfer-ai/flashinfer
Length of output: 3815
🏁 Script executed:
# Also check FP4 runner's cache usage
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3620,3720p'Repository: flashinfer-ai/flashinfer
Length of output: 4674
🏁 Script executed:
# Search for actual kernel invocation - where kernels are called
rg -n "def forward\|kernel\(" flashinfer/gemm/gemm_base.py | grep -A 5 -B 5 "2950\|2960\|2970\|2980\|2990\|3000\|3010\|3020"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how the tactics/kernels are actually compiled and cached
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2990,3050p'Repository: flashinfer-ai/flashinfer
Length of output: 2302
🏁 Script executed:
# Check FP4 forward method cache key construction
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3717,3820p'Repository: flashinfer-ai/flashinfer
Length of output: 4248
🏁 Script executed:
# Look at _compile_block_scaled_gemm to understand what device info it uses
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2767,2850p'Repository: flashinfer-ai/flashinfer
Length of output: 3201
🏁 Script executed:
# Search for any device-specific compilation logic in the CuTe DSL kernels
rg -n "device\|sm_major\|sm_minor" flashinfer/gemm/kernels/ --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how runners are instantiated - are they created per-device or per-SM?
rg -n "_cute_dsl_gemm_mxfp8_runner\|_cute_dsl_gemm_fp4_runner" flashinfer/gemm/gemm_base.py | grep -v "def \|runner(sm"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at where the runner is actually created and used
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3170,3200p'Repository: flashinfer-ai/flashinfer
Length of output: 1032
🏁 Script executed:
# Check if there's any caching of runner instances themselves
rg -n "def.*_cute_dsl_gemm.*_runner\|return.*GemmRunner" flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how get_compute_capability is used to route to runners
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3140,3190p'Repository: flashinfer-ai/flashinfer
Length of output: 2157
🏁 Script executed:
# Search for how SM version is used at runtime
rg -n "get_compute_capability" flashinfer/gemm/gemm_base.py -A 5 -B 5Repository: flashinfer-ai/flashinfer
Length of output: 2084
🏁 Script executed:
# Let's check if there's any SM-specific kernel selection that would be cached
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3590,3610p'Repository: flashinfer-ai/flashinfer
Length of output: 952
🏁 Script executed:
# Verify if both SM100 and SM103 kernels could be selected for same cache key
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2930,2935p'Repository: flashinfer-ai/flashinfer
Length of output: 286
🏁 Script executed:
# Check if there's any multi-GPU test or scenario that shows potential issue
rg -n "multi.*gpu\|heterogeneous\|different.*device" flashinfer/ --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if kernel compilation depends on SM version or produces device-agnostic code
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2800,2870p'Repository: flashinfer-ai/flashinfer
Length of output: 2449
🏁 Script executed:
# Look for any SM-specific compilation flags or kernel selection
rg -n "sm_major\|sm_minor\|sm_version" flashinfer/gemm/gemm_base.py | grep -i "compile\|kernel" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check what gets passed to kernel compilation - does it include SM info?
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '2850,2875p'Repository: flashinfer-ai/flashinfer
Length of output: 734
🏁 Script executed:
# Check if runners are reused across calls or if they're created fresh each time
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3130,3220p'Repository: flashinfer-ai/flashinfer
Length of output: 3606
🏁 Script executed:
# Verify that kernel class names are SM-specific
rg -n "Sm100BlockScaledPersistentDenseGemmKernel\|Sm103BlockScaledPersistentDenseGemmKernel" flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Final verification: confirm SM103 is disabled and would cause cache collision if re-enabled
cd flashinfer && cat -n gemm/gemm_base.py | sed -n '3595,3610p'Repository: flashinfer-ai/flashinfer
Length of output: 816
🏁 Script executed:
# Check if there's any documentation or design rationale for not including SM in cache
rg -n "cache\|Cache" flashinfer/gemm/gemm_base.py | grep -i "key\|comment\|note" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
Add SM version to kernel cache keys to prevent kernel type mismatch on multi-SM systems.
The kernel caches don't include SM version in their keys. While compiled CuTe kernels are device-agnostic within the same SM version, when SM103 support is re-enabled, Sm100 and Sm103 kernels would share identical cache keys, causing incorrect kernel reuse. For example, a kernel compiled for SM103 would be replaced by or collide with an SM100 kernel in the module-level cache.
Add (sm_major, sm_minor) to the cache key tuple to prevent this collision:
- Line 2894 (
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE) - Line 3570 (
_CUTE_DSL_MM_FP4_KERNEL_CACHE)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` at line 2894, The kernel cache dictionaries
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE are missing SM
version in their keys which causes kernel collisions across devices with
different (sm_major, sm_minor); update the cache key construction and any
lookup/insert sites that use these caches to include the device SM tuple
(sm_major, sm_minor) as part of the key (i.e., change key tuples to prepend or
append (sm_major, sm_minor) so compiled kernels for Sm100 and Sm103 are
distinct), and ensure all places that read from or write to these caches (the
code that constructs the current cache key and calls into
_CUTE_DSL_MM_MXFP8_KERNEL_CACHE and _CUTE_DSL_MM_FP4_KERNEL_CACHE) are modified
consistently to use the new tuple form.
| batch_size = 1 | ||
|
|
||
| if tactic is None or tactic == -1: | ||
| tactic = ((128, 128), (1, 1), False, False) |
There was a problem hiding this comment.
How about using something like _SM100_MMA_TILER_MN_CANDIDATES[0] instead of hardcode here
| print(f"[VVERBOSE] {mat2_fp4.dtype = }") | ||
|
|
||
| alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None | ||
| alpha_for_cute_dsl_mxfp4 = ( |
There was a problem hiding this comment.
Could you add a comment something like TODO so that in the future people could understand why you add this and we should get it removed in the future?
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (3)
benchmarks/routines/gemm.py (1)
1074-1080: The CuTe DSLalphaworkaround looks obsolete now.
flashinfer/gemm/gemm_base.py:2880-2894already normalizesalpha=Noneto a cached devicetensor([1.0]), so this benchmark-onlyalpha_for_cute_dsl_mxfp4path adds extra state without changing the launch contract. Passingalphathrough unchanged would simplify the benchmark harness and remove one more backend-specific exception.♻️ Proposed simplification
- # TODO: for MXFP4, we don't need a global scale, we should change the compile interface to make - # alpha optional. - alpha_for_cute_dsl_mxfp4 = ( - torch.tensor([1.0], dtype=torch.float32, device=device) - if not use_nvfp4 - else None - ) @@ - alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha), + alpha=alpha, @@ - alpha=(alpha_for_cute_dsl_mxfp4 if (backend == "cute-dsl") else alpha), + alpha=alpha,Also applies to: 1101-1101, 1139-1139
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/gemm.py` around lines 1074 - 1080, Remove the benchmark-only CuTe DSL alpha workaround: delete the special-case variable alpha_for_cute_dsl_mxfp4 and its conditional construction (the use_nvfp4/device torch.tensor branch) and instead pass the original alpha argument through unchanged to the CuTe DSL calls; rely on gemm_base.py's normalization which converts alpha=None to a cached device tensor, so update all call sites that used alpha_for_cute_dsl_mxfp4 (the occurrences around the current blocks including the ones referenced) to use the existing alpha variable directly and remove any extra state or backend-specific branching (symbols to update: alpha_for_cute_dsl_mxfp4, use_nvfp4, and the call sites that previously substituted that variable).flashinfer/gemm/gemm_base.py (2)
2943-2943: Use underscore prefix for unused unpacked variables.Static analysis correctly flagged that
a_descale,b_descale, andoutare unpacked but never used inget_valid_tactics. Use underscore prefix to indicate intentionally unused variables.♻️ Suggested fix
- (a, b, a_descale, b_descale, _, out, _) = inputs + (a, b, _a_descale, _b_descale, _, _out, _) = inputs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` at line 2943, In get_valid_tactics update the tuple unpack at "(a, b, a_descale, b_descale, _, out, _) = inputs" to mark unused variables with a leading underscore so static analysis knows they're intentionally unused; rename a_descale -> _a_descale, b_descale -> _b_descale and out -> _out (leave the existing anonymous "_" entries as-is) so the function signature and subsequent references remain correct.
3790-3804: Consider converting lambdas to local functions for clarity.Static analysis flagged these lambda assignments. While they work correctly (called immediately, no late binding issues), converting to local
defstatements would be cleaner and follow Python best practices.♻️ Suggested refactor
if kernel_type == "sm103" and Sm103Kernel is not None: - make_kernel = lambda: Sm103Kernel( - sf_vec_size, - mma_tiler_mn, - cluster_shape_mn, - use_tma_store, - enable_pdl, - ) + def make_kernel(): + return Sm103Kernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + enable_pdl, + ) else: - make_kernel = lambda: Sm100BlockScaledPersistentDenseGemmKernel( - sf_vec_size, - mma_tiler_mn, - cluster_shape_mn, - use_prefetch, - enable_pdl, - ) + def make_kernel(): + return Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_prefetch, + enable_pdl, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 3790 - 3804, The two lambda assignments to make_kernel should be replaced with local def functions for clarity: in the branch that currently assigns make_kernel = lambda: Sm103Kernel(...), define a local function def make_kernel() that returns Sm103Kernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store, enable_pdl), and in the else branch define def make_kernel() that returns Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_prefetch, enable_pdl); ensure the local function names match the original symbol make_kernel and capture the same variables (sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store/use_prefetch, enable_pdl) so callers of make_kernel remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@benchmarks/routines/gemm.py`:
- Around line 1074-1080: Remove the benchmark-only CuTe DSL alpha workaround:
delete the special-case variable alpha_for_cute_dsl_mxfp4 and its conditional
construction (the use_nvfp4/device torch.tensor branch) and instead pass the
original alpha argument through unchanged to the CuTe DSL calls; rely on
gemm_base.py's normalization which converts alpha=None to a cached device
tensor, so update all call sites that used alpha_for_cute_dsl_mxfp4 (the
occurrences around the current blocks including the ones referenced) to use the
existing alpha variable directly and remove any extra state or backend-specific
branching (symbols to update: alpha_for_cute_dsl_mxfp4, use_nvfp4, and the call
sites that previously substituted that variable).
In `@flashinfer/gemm/gemm_base.py`:
- Line 2943: In get_valid_tactics update the tuple unpack at "(a, b, a_descale,
b_descale, _, out, _) = inputs" to mark unused variables with a leading
underscore so static analysis knows they're intentionally unused; rename
a_descale -> _a_descale, b_descale -> _b_descale and out -> _out (leave the
existing anonymous "_" entries as-is) so the function signature and subsequent
references remain correct.
- Around line 3790-3804: The two lambda assignments to make_kernel should be
replaced with local def functions for clarity: in the branch that currently
assigns make_kernel = lambda: Sm103Kernel(...), define a local function def
make_kernel() that returns Sm103Kernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_tma_store, enable_pdl), and in the else branch define def
make_kernel() that returns
Sm100BlockScaledPersistentDenseGemmKernel(sf_vec_size, mma_tiler_mn,
cluster_shape_mn, use_prefetch, enable_pdl); ensure the local function names
match the original symbol make_kernel and capture the same variables
(sf_vec_size, mma_tiler_mn, cluster_shape_mn, use_tma_store/use_prefetch,
enable_pdl) so callers of make_kernel remain unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: eb68dc46-c987-4a22-8bd4-7e563375cbba
📒 Files selected for processing (2)
benchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.py
|
[FAILED] Pipeline #45467505: 9/20 passed |
nv-yunzheq
left a comment
There was a problem hiding this comment.
The ci result looks good. Approve.
Thanks for contribution to the project!
|
what's the cudnn/cublaslt version tested here? |
|
@YangXu1990uiuc Hi, I don't exactly remember (unfortunately, I also deleted the docker container...). I believe it was 9.13 though. But, I also noticed perf increase in cuDNN 9.19 in seperate testing, so I also agree that the gap could be closer on newer versions. |
…mm (flashinfer-ai#2660) <!-- .github/pull_request_template.md --> ## 📌 Description This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0 scale. For most shapes of mxfp4, it's performance is quite good. Geomean speedup around 1.20x. Passed refcheck. <img width="2384" height="4830" alt="image" src="https://github.com/user-attachments/assets/4e1bf17f-adaa-464c-92e4-4d28e7776dc7" /> For mxfp8, the performance is a bit lacking, as I recycled most of the cute-dsl heuristics for mxfp4. Mainly for enablement first. Local testing result (B200): ``` python tests/gemm/test_mm_fp4.py ======================================================================================================================== test session starts ========================================================================================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /sgl-workspace/sglang/flashinfer configfile: pytest.ini plugins: anyio-4.12.1, typeguard-4.5.1 collected 4512 items tests/gemm/test_mm_fp4.py .............................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 5%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss................................ [ 10%] .............................................................................................................................................................................................................................................................. [ 16%] .............................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 21%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss.......................................................................... [ 27%] ....................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssss [ 44%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 61%] ssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%] ssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................... [ 78%] ......................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%] ssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 95%] ............................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [100%] ========================================================================================================================== warnings summary ========================================================================================================================== tests/gemm/test_mm_fp4.py: 588 warnings /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead res_or_list = opFunc(*args, **kwargs, loc=loc) tests/gemm/test_mm_fp4.py: 294 warnings /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code. c_producer_group = pipeline.CooperativeGroup( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================== 1440 passed, 3072 skipped, 882 warnings in 535.13s (0:08:55) ==================================================================================================== ``` ``` pytest tests/gemm/test_mm_mxfp8.py ======================================================================================================================== test session starts ========================================================================================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /sgl-workspace/sglang/flashinfer configfile: pytest.ini plugins: anyio-4.12.1, typeguard-4.5.1 collected 2131 items tests/gemm/test_mm_mxfp8.py .................................................................................................................................................................................................................................. [ 10%] .............................................................................................................................................................................................................................................................. [ 22%] ................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 34%] ..................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................................................................................ [ 46%] .............................................................................................................................................................................................................................................................. [ 58%] ......................................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssss [ 70%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%] ssssssssssssssssssssssssssssssssssssssssss......................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssss............ [ 94%] ............................................................................................................................... [100%] ========================================================================================================================== warnings summary ========================================================================================================================== tests/gemm/test_mm_mxfp8.py: 314 warnings /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead res_or_list = opFunc(*args, **kwargs, loc=loc) tests/gemm/test_mm_mxfp8.py: 157 warnings /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code. c_producer_group = pipeline.CooperativeGroup( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================== 1633 passed, 498 skipped, 471 warnings in 245.76s (0:04:05) ===================================================================================================== ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added CuTe DSL ("cute-dsl") as a selectable backend for MXFP8 and FP4 matrix-multiply paths and support for both FP4 and MXFP8 input encodings. * **Behavior** * Backend-specific alpha handling and availability/layout checks for the CuTe DSL path; runtime skips when layout/scale semantics aren’t supported. * **Tests** * Tests and benchmarks expanded to include "cute-dsl" and "auto" in backend selections and validate encoding/layout behaviors. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…mm (flashinfer-ai#2660) <!-- .github/pull_request_template.md --> ## 📌 Description This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0 scale. For most shapes of mxfp4, it's performance is quite good. Geomean speedup around 1.20x. Passed refcheck. <img width="2384" height="4830" alt="image" src="https://github.com/user-attachments/assets/4e1bf17f-adaa-464c-92e4-4d28e7776dc7" /> For mxfp8, the performance is a bit lacking, as I recycled most of the cute-dsl heuristics for mxfp4. Mainly for enablement first. Local testing result (B200): ``` python tests/gemm/test_mm_fp4.py ======================================================================================================================== test session starts ========================================================================================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /sgl-workspace/sglang/flashinfer configfile: pytest.ini plugins: anyio-4.12.1, typeguard-4.5.1 collected 4512 items tests/gemm/test_mm_fp4.py .............................................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 5%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss................................ [ 10%] .............................................................................................................................................................................................................................................................. [ 16%] .............................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 21%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................sssssssssssssssssssssssssssssssssssssssssssss.......................................................................... [ 27%] ....................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 38%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssss [ 44%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 61%] ssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%] ssssssssssssssssssssssssssssssssssssssssssssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.................... [ 78%] ......................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%] ssssssssssss..........................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 95%] ............................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [100%] ========================================================================================================================== warnings summary ========================================================================================================================== tests/gemm/test_mm_fp4.py: 588 warnings /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead res_or_list = opFunc(*args, **kwargs, loc=loc) tests/gemm/test_mm_fp4.py: 294 warnings /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code. c_producer_group = pipeline.CooperativeGroup( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================== 1440 passed, 3072 skipped, 882 warnings in 535.13s (0:08:55) ==================================================================================================== ``` ``` pytest tests/gemm/test_mm_mxfp8.py ======================================================================================================================== test session starts ========================================================================================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /sgl-workspace/sglang/flashinfer configfile: pytest.ini plugins: anyio-4.12.1, typeguard-4.5.1 collected 2131 items tests/gemm/test_mm_mxfp8.py .................................................................................................................................................................................................................................. [ 10%] .............................................................................................................................................................................................................................................................. [ 22%] ................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............................................................. [ 34%] ..................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss............................................................................................ [ 46%] .............................................................................................................................................................................................................................................................. [ 58%] ......................................................................................................................................................................................................................ssssssssssssssssssssssssssssssssssssssss [ 70%] ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%] ssssssssssssssssssssssssssssssssssssssssss......................................................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssss............ [ 94%] ............................................................................................................................... [100%] ========================================================================================================================== warnings summary ========================================================================================================================== tests/gemm/test_mm_mxfp8.py: 314 warnings /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead res_or_list = opFunc(*args, **kwargs, loc=loc) tests/gemm/test_mm_mxfp8.py: 157 warnings /sgl-workspace/sglang/flashinfer/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py:1365: DeprecationWarning: The 'alignment' parameter of CooperativeGroup's constructor is deprecated and will be removed in a subsequent release, please remove it from your code. c_producer_group = pipeline.CooperativeGroup( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================== 1633 passed, 498 skipped, 471 warnings in 245.76s (0:04:05) ===================================================================================================== ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added CuTe DSL ("cute-dsl") as a selectable backend for MXFP8 and FP4 matrix-multiply paths and support for both FP4 and MXFP8 input encodings. * **Behavior** * Backend-specific alpha handling and availability/layout checks for the CuTe DSL path; runtime skips when layout/scale semantics aren’t supported. * **Tests** * Tests and benchmarks expanded to include "cute-dsl" and "auto" in backend selections and validate encoding/layout behaviors. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
This kernel supports mxfp4 and mxfp8. Pass in block size = 32 and E8M0 scale.
For most shapes of mxfp4, it's performance is quite good. Geomean speedup around 1.20x. Passed refcheck.
For mxfp8, the performance is a bit lacking, as I recycled most of the cute-dsl heuristics for mxfp4. Mainly for enablement first.
Local testing result (B200):
Summary by CodeRabbit
New Features
Behavior
Tests