fix: use correct SMEM capacity for SM120 consumer Blackwell GPUs#2835
fix: use correct SMEM capacity for SM120 consumer Blackwell GPUs#2835brandonmmusic-max wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
SM120 consumer Blackwell GPUs (RTX PRO 6000, RTX 5090) have 99KB shared memory, but the CuTe DSL MoE kernels hardcode sm_100 (227KB) for the SMEM capacity lookup. This causes _compute_stages to over-allocate pipeline stages on SM120, leading to suboptimal performance. Add get_blackwell_smem_arch() helper that auto-detects SM120 vs SM100 and returns the correct architecture string. All 4 Blackwell grouped GEMM kernels now use dynamic detection instead of hardcoded sm_100. Affected hardware: RTX PRO 6000 (SM120), RTX 5090 (SM120), RTX 5080 (SM120) Not affected: B200, B300, GB200 (SM100) — these already have 227KB SMEM Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis PR introduces a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Important Merge conflicts detected (Beta)
✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can disable poems in the walkthrough.Disable the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical performance issue on SM120 consumer Blackwell GPUs by correcting the shared memory capacity detection within CuTe DSL MoE kernels. Previously, these kernels assumed a larger shared memory capacity, leading to inefficient resource allocation. The changes introduce dynamic GPU architecture detection, ensuring that kernels utilize the accurate shared memory size, thereby optimizing performance for relevant hardware without impacting data center GPUs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a bug where the SMEM capacity for SM120 consumer Blackwell GPUs was hardcoded incorrectly. The introduction of the get_blackwell_smem_arch helper function to dynamically detect the SM architecture is a good solution. The changes are applied consistently across all affected kernel files. I've added one suggestion to cache the result of the new helper function to avoid redundant calls, which can improve performance.
| def get_blackwell_smem_arch() -> str: | ||
| """Return the correct SM architecture string for SMEM capacity lookup. | ||
|
|
||
| SM100 (B200/B300 data center) has 227KB shared memory. | ||
| SM120/SM121 (RTX PRO 6000/RTX 5090 consumer) has 99KB shared memory. | ||
|
|
||
| Using the wrong capacity causes _compute_stages to over-allocate pipeline | ||
| stages that don't fit in physical SMEM, degrading performance on SM120. | ||
| """ | ||
| import torch | ||
|
|
||
| if not torch.cuda.is_available(): | ||
| return "sm_100" # fallback | ||
| major, minor = torch.cuda.get_device_capability() | ||
| if major == 12: | ||
| return "sm_120" | ||
| return "sm_100" |
There was a problem hiding this comment.
This function may be called multiple times within the same process. To improve performance by avoiding redundant calls to torch.cuda.get_device_capability(), it's a good practice to cache its result. The device capability will not change during the execution of the program. You can use a decorator for this.
| def get_blackwell_smem_arch() -> str: | |
| """Return the correct SM architecture string for SMEM capacity lookup. | |
| SM100 (B200/B300 data center) has 227KB shared memory. | |
| SM120/SM121 (RTX PRO 6000/RTX 5090 consumer) has 99KB shared memory. | |
| Using the wrong capacity causes _compute_stages to over-allocate pipeline | |
| stages that don't fit in physical SMEM, degrading performance on SM120. | |
| """ | |
| import torch | |
| if not torch.cuda.is_available(): | |
| return "sm_100" # fallback | |
| major, minor = torch.cuda.get_device_capability() | |
| if major == 12: | |
| return "sm_120" | |
| return "sm_100" | |
| @__import__("functools").lru_cache(maxsize=None) | |
| def get_blackwell_smem_arch() -> str: | |
| """Return the correct SM architecture string for SMEM capacity lookup. | |
| SM100 (B200/B300 data center) has 227KB shared memory. | |
| SM120/SM121 (RTX PRO 6000/RTX 5090 consumer) has 99KB shared memory. | |
| Using the wrong capacity causes _compute_stages to over-allocate pipeline | |
| stages that don't fit in physical SMEM, degrading performance on SM120. | |
| """ | |
| import torch | |
| if not torch.cuda.is_available(): | |
| return "sm_100" # fallback | |
| major, minor = torch.cuda.get_device_capability() | |
| if major == 12: | |
| return "sm_120" | |
| return "sm_100" |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py`:
- Line 74: The assignment unpacks torch.cuda.get_device_capability() into major
and minor but minor is unused; update the unpacking in the code that calls
torch.cuda.get_device_capability() (the line assigning major, minor) to use an
underscore for the unused value (e.g., major, _minor or major, _) so linters
stop complaining while keeping the major variable intact.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4a4eccaf-212a-4a64-afda-a3d9f0f405e1
📒 Files selected for processing (5)
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/utils.py
|
|
||
| if not torch.cuda.is_available(): | ||
| return "sm_100" # fallback | ||
| major, minor = torch.cuda.get_device_capability() |
There was a problem hiding this comment.
Avoid the unused variable warning at Line 74.
minor is unpacked but never used; rename it to _minor (or _) to keep lint clean.
🔧 Minimal fix
- major, minor = torch.cuda.get_device_capability()
+ major, _minor = torch.cuda.get_device_capability()🧰 Tools
🪛 Ruff (0.15.6)
[warning] 74-74: Unpacked variable minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py` at line 74, The assignment
unpacks torch.cuda.get_device_capability() into major and minor but minor is
unused; update the unpacking in the code that calls
torch.cuda.get_device_capability() (the line assigning major, minor) to use an
underscore for the unused value (e.g., major, _minor or major, _) so linters
stop complaining while keeping the major variable intact.
|
Hi @brandonmmusic-max, I appreciate the PR, but the cute-dsl MoE kernels are intended for SM100 and SM103 only and are not compatible for SM120. It is not just the shared memory size differences. The tensor core architectures are inherently different, which makes the MoE API not usable for SM120. We may add an SM120 CuTe DSL MoE in the future, but it will require an entire rewrite of existing kernels |
thank you for the response! Just trying to be helpful to open source community. I'll close this! |
Summary
SM120 consumer Blackwell GPUs (RTX PRO 6000, RTX 5090, RTX 5080) have 99KB shared memory, but the CuTe DSL MoE kernels hardcode
"sm_100"for the SMEM capacity lookup, which returns 227KB (the SM100/B200 capacity). This causes_compute_stages()to compute pipeline stage counts based on 2.3x more SMEM than physically available on SM120.The Bug
SMEM capacity by architecture:
The Fix
Add
get_blackwell_smem_arch()helper inblackwell/utils.pythat auto-detects SM120 vs SM100 viatorch.cuda.get_device_capability()and returns the correct architecture string. All 4 affected kernel files now use dynamic detection.Files Changed
flashinfer/fused_moe/cute_dsl/blackwell/utils.py— newget_blackwell_smem_arch()helperflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyImpact
This fix is only relevant for SM120 consumer Blackwell GPUs. SM100 data center GPUs are unaffected (the helper returns
"sm_100"for them, preserving existing behavior).On SM120,
_compute_stages()will now correctly compute 3-5 pipeline stages instead of requesting 7-12 stages that overflow 99KB SMEM. This should improve MoE GEMM throughput for users running NVFP4 models (Qwen3.5-397B, DeepSeek, etc.) on RTX PRO 6000 and RTX 5090 hardware.Testing
Tested on 4x NVIDIA RTX PRO 6000 Blackwell (SM120, 96GB GDDR7) with Qwen3.5-397B-A17B-NVFP4:
get_smem_capacity_in_bytes("sm_100")returns 232448 (227KB)get_smem_capacity_in_bytes("sm_120")returns 101376 (99KB)get_blackwell_smem_arch()correctly returns"sm_120"on this hardwareRelated
Summary by CodeRabbit