refactor: Port upstream CUTLASS fixes and refactor grouped_gemm_nt_masked GEMM module location#2503
Conversation
📝 WalkthroughWalkthroughIntegrates CuTe‑DSL kernels into GEMM exports, parameterizes Blackwell masked GEMM threading, and updates a CuTe‑DSL blockscaled GEMM benchmark to use Changes
Sequence Diagram(s)(Skipped — changes do not introduce a new multi-component sequential flow that requires visualization.) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request ports upstream CUTLASS fixes and refactors the location of the grouped_gemm_nt_masked module. The changes are solid, including updating benchmarks to use more robust timing functions and applying several important fixes to the kernel implementation, such as correcting a thread count calculation and removing magic numbers. My only suggestion is to refactor the conditional imports in flashinfer/gemm/__init__.py to reduce code duplication.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/bench_cute_dsl_blockscaled_gemm.py`:
- Around line 49-57: bench_gpu_time returns milliseconds but the code treats t
as seconds; fix by converting t to seconds immediately (e.g., t_s = t / 1e3) and
use t_s for all downstream computations (replace uses of t in TFLOPS and GB/s
formulas and the microsecond print). Specifically, update places referencing
bench_gpu_time result (variable t) and change the microsecond display from t *
1e6 to t * 1e3 (or use t_s * 1e6), and divide the TFLOPS and GB/s calculations
by 1e3 (i.e., use t_s instead of t) so TFLOPS and GB/s are computed with
seconds. Ensure all references to t (printing and performance calculations) use
the consistent t_s value.
In `@flashinfer/gemm/__init__.py`:
- Around line 63-68: The __all__ extension list is unsorted and triggers Ruff
RUF022; reorder the added symbols alphabetically so the list is sorted. In the
block that appends to __all__ (the one with grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor),
rearrange the strings into alphabetical order (create_scale_factor_tensor,
grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel) so the
module-level __all__ remains lexicographically sorted.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/__init__.py`:
- Around line 26-44: The current broad try/except around importing
is_cute_dsl_available and the CuTe-DSL kernels swallows any ImportError from the
kernel module; change the logic so you only suppress the absence of the CuTe-DSL
utils but let kernel import errors surface: import is_cute_dsl_available inside
a narrow try/except (or set a fallback that returns False) and then, if
is_cute_dsl_available() is True, import grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel, and create_scale_factor_tensor
normally (no broad try/except) and set _cute_dsl_kernels accordingly so genuine
import failures in those symbols are raised instead of being silenced.
In `@flashinfer/gemm/kernels/__init__.py`:
- Around line 33-38: The __all__ list appended inside the
is_cute_dsl_available() block is not lexicographically sorted; update the list
assigned to __all__ (the entries "grouped_gemm_nt_masked",
"Sm100BlockScaledPersistentDenseGemmKernel", "create_scale_factor_tensor") so
they are in sorted order (alphabetical) to satisfy RUF022, keeping the append
inside the is_cute_dsl_available() conditional and preserving the exact symbol
names.
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, seems the performance before and after this PR looks similar?
Yes you read it correctly. The performance seems identical 👍 |
|
/bot stop |
|
The GitLab CI pipeline #43379685 has been cancelled. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/gemm/test_cute_dsl_blockscaled_gemm.py (1)
262-281:⚠️ Potential issue | 🟡 MinorPre-existing bug:
sm_count=132is not a parameter oftest_blockscaled_gemm_python_interface.The
__main__block passessm_count=132, but the function signature only hasenable_dst_signals—sm_countis computed internally (line 95). Running this file directly will raiseTypeError. This is pre-existing, but worth fixing while you're here.Suggested fix
test_blockscaled_gemm_python_interface( lm=(1, 1024), kn=(7168, 4096), ab_dtype="float4_e2m1fn", sf_dtype="float8_e8m0fnu", sf_vec_size=16, c_dtype="float16", a_major="k", b_major="k", c_major="n", fuse_alpha=False, alpha_dtype="float32", mma_tiler_mn=(128, 128), cluster_shape_mn=(2, 1), tolerance=1e-01, iterations=3, - sm_count=132, enable_dst_signals=True, )
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/__init__.py`:
- Around line 39-47: The module-level deprecation warning in
flashinfer.cute_dsl/__init__.py fires on any import; change it so the warning
only appears when deprecated GEMM symbols are actually accessed: either move the
warnings.warn call into the shim module (blockscaled_gemm.py) where
grouped_gemm_nt_masked, Sm100BlockScaledPersistentDenseGemmKernel, and
create_scale_factor_tensor are defined, or implement module-level __getattr__ in
flashinfer.cute_dsl.__init__ that checks the attribute name (e.g.,
"grouped_gemm_nt_masked", "Sm100BlockScaledPersistentDenseGemmKernel",
"create_scale_factor_tensor"), emits the DeprecationWarning with stacklevel=2,
then lazily imports and returns the requested symbol; leave other imports (like
is_cute_dsl_available or rmsnorm_fp4quant) untouched so they won't trigger the
GEMM deprecation.
🧹 Nitpick comments (1)
tests/gemm/test_cute_dsl_blockscaled_gemm.py (1)
83-88: Consider usingflashinfer.utils.get_compute_capabilityfor the architecture skip.The device capability check uses
torch.cuda.get_device_capability()directly with hardcoded tuples. The coding guidelines require test files to useflashinfer.utilsfunctions (e.g.,get_compute_capability) to skip tests on unsupported GPU architectures. This is pre-existing code, so it can be addressed separately. As per coding guidelines,tests/**/*.py: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures."
|
/bot run |
|
[FAILED] Pipeline #43391828: 12/20 passed |
📌 Description
CUTLASS Upstream Updates
Ported the following commits from cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py:
1cfbb53a: Fix SM100 block-scale gemm overlapping accumulator and threads_per_warpself.threads_per_warp = 32constant to avoid hardcoded magic numbersnum_acc_consumer_threadscalculation (was missingthreads_per_warp * multiplier)elect_onecontext aroundacc_pipeline.consumer_release()callsself.threads_per_warpconsistentlyacb45938: Update nvvm API call from nvvm enum to strcute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta)tocute.arch.fence_proxy("async.shared", space="cta")Code Reorganization
flashinfer/cute_dsl/blockscaled_gemm.py→flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.pyflashinfer/gemm/kernels/__init__.pyfor the new module exportsflashinfer/cute_dsl/blockscaled_gemm.pythat re-exports from the new locationflashinfer/gemm/__init__.pyto export CuTe-DSL kernels when availableAll existing import paths continue to work:
Benchmarking results via
bench_cute_dsl_blockscaled_gemm.pyshow no perf difference:Before this PR
After this PR
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Deprecation
Chores
Improvements