perf(moe): optimize SM120 b12x MoE short decode#3193
perf(moe): optimize SM120 b12x MoE short decode#3193lukealonso wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughAdds compile-time flags Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Dispatch
participant MicroKernel
participant StaticKernel
Client->>Dispatch: request routed compute (num_tokens, activation, flags)
Dispatch->>Dispatch: decide use_micro, single_token, share_expert_scales, static_mac
alt use micro-kernel
Dispatch->>MicroKernel: compile/launch (topk_ids_dtype, single_token, share_expert_scales, micro_mac)
MicroKernel-->>Dispatch: completion/results
else use static kernel
Dispatch->>StaticKernel: compile/launch (mac_override=static_mac)
StaticKernel-->>Dispatch: completion/results
end
Dispatch->>Client: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request replaces static MoE cutover constants with dynamic environment variable lookups and introduces optimizations for relu2 single-token workloads, including shared expert scales and specialized MAC capping. Feedback focuses on critical synchronization issues in the micro kernel where skipping grid barriers or redundant quantization logic could cause race conditions. Further improvements are recommended for the dispatch logic to cache environment lookups more effectively and handle potential parsing errors for non-numeric configuration values.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py (1)
1019-1025:⚠️ Potential issue | 🔴 CriticalRestore a grid-wide barrier in the shared-input single-token path.
Lines 1019-1025 and 1147-1160 remove the only inter-CTA rendezvous for
share_input_across_experts. CTAs withbidz >= total_pairsskip route/pack entirely but can still enter phase 2 and read slot 0 before any writer CTA finishes, and lagging CTAs can still be zeroingscatter_outputwhile others start atomics. That makes the bs1 ReLU2 fast path racey.Suggested fix
- if cutlass.const_expr(not self.share_input_across_experts): - self._resident_grid_barrier( - barrier_count, - barrier_epoch, - Int32(gdim_z), - is_cta_leader, - ) + self._resident_grid_barrier( + barrier_count, + barrier_epoch, + Int32(gdim_z), + is_cta_leader, + ) … - if cutlass.const_expr( - self.share_input_across_experts and self.single_token and not self.is_gated - ): - cute.arch.sync_threads() - _threadfence() - cute.arch.fence_proxy("async.global") - cute.arch.sync_threads() - else: - self._resident_grid_barrier( - barrier_count, - barrier_epoch, - Int32(gdim_z), - is_cta_leader, - ) + self._resident_grid_barrier( + barrier_count, + barrier_epoch, + Int32(gdim_z), + is_cta_leader, + )Also applies to: 1147-1160
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py` around lines 1019 - 1025, Restore the grid-wide rendezvous by calling self._resident_grid_barrier in the shared-input single-token path: replace the current conditional that uses cutlass.const_expr(not self.share_input_across_experts) with cutlass.const_expr(self.share_input_across_experts) so the barrier (invoking self._resident_grid_barrier(barrier_count, barrier_epoch, Int32(gdim_z), is_cta_leader)) runs when share_input_across_experts is enabled; make the same change at the second analogous site handling the single-token/shared-input fast path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 918-920: The current code forcibly caps static_mac with static_mac
= min(static_mac, 64) when use_micro is false and routed_rows < 40, which
unintentionally affects routed_rows 21–39; change the guard so the cap only
applies at the original cutover (<=20 routed rows) or when an explicit override
is requested: replace the condition routed_rows < 40 with routed_rows <= 20 (or
routed_rows < 21), keeping the same check for use_micro and using
_get_impl_mac("static", routed_rows=routed_rows) and static_mac to locate the
code to update.
---
Outside diff comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py`:
- Around line 1019-1025: Restore the grid-wide rendezvous by calling
self._resident_grid_barrier in the shared-input single-token path: replace the
current conditional that uses cutlass.const_expr(not
self.share_input_across_experts) with
cutlass.const_expr(self.share_input_across_experts) so the barrier (invoking
self._resident_grid_barrier(barrier_count, barrier_epoch, Int32(gdim_z),
is_cta_leader)) runs when share_input_across_experts is enabled; make the same
change at the second analogous site handling the single-token/shared-input fast
path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 048f41ee-ebe4-4a21-979b-4d894879186c
📒 Files selected for processing (3)
flashinfer/fused_moe/cute_dsl/b12x_moe.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
|
/bot stop |
|
The GitLab CI pipeline #49645982 has been cancelled. |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
838-841:⚠️ Potential issue | 🟠 MajorStatic launches still get the tuned MAC ladder.
This changes normal static-kernel behavior, not just the micro path. The new
static_maccalculation plus therouted_rows < 40clamp feed directly into_get_static_kernel(...)for every non-micro launch, which contradicts the PR objective of keeping the tuned MAC ladder micro-only and preserving existing static decode behavior.Suggested fix
- tuned_static_mac = _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) - static_mac = min(tuned_static_mac or base_mac, base_mac) - if not use_micro and routed_rows < 40: - static_mac = min(static_mac, 64) @@ compiled, mac = _get_static_kernel( workspace.state_E, num_experts, num_tokens, k, n, top_k, workspace.max_rows, topk_ids_dtype=torch.int32, input_scales_are_reciprocal=input_scales_are_reciprocal, fast_math=fast_math, - mac_override=static_mac, activation=activation, )Also applies to: 920-931
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 838 - 841, The current code applies the tuned MAC ladder to all non-micro launches by computing tuned_static_mac via _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) and then using it to set static_mac before calling _get_static_kernel, which breaks static-kernel behavior; only apply the tuned ladder and the routed_rows < 40 clamp when use_micro is True. Fix by gating the _lookup_mac_ladder call and the routed_rows clamp behind use_micro (e.g., when use_micro: compute tuned_static_mac and set static_mac = min(tuned_static_mac or base_mac, base_mac) and apply the routed_rows < 40 => static_mac = min(static_mac, 64); otherwise set static_mac = base_mac), and make the same change for the duplicate block that affects lines around the other occurrence (the logic feeding _get_static_kernel).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 838-841: The current code applies the tuned MAC ladder to all
non-micro launches by computing tuned_static_mac via
_lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) and then using it to set
static_mac before calling _get_static_kernel, which breaks static-kernel
behavior; only apply the tuned ladder and the routed_rows < 40 clamp when
use_micro is True. Fix by gating the _lookup_mac_ladder call and the routed_rows
clamp behind use_micro (e.g., when use_micro: compute tuned_static_mac and set
static_mac = min(tuned_static_mac or base_mac, base_mac) and apply the
routed_rows < 40 => static_mac = min(static_mac, 64); otherwise set static_mac =
base_mac), and make the same change for the duplicate block that affects lines
around the other occurrence (the logic feeding _get_static_kernel).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4a6ee59c-ffb1-47f9-b485-64836edbf4a8
📒 Files selected for processing (2)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
838-841:⚠️ Potential issue | 🟠 MajorKeep the tuned MAC ladder micro-only.
This block now retunes the static backend too.
tuned_static_macchanges the normal static residency for every_STATIC_MAC_LADDERbucket, and the extrarouted_rows < 40cap still hits the regular static path fortop_k == 1rows 21–39. That contradicts the PR goal of preserving existing static decode behavior.Suggested fix
- tuned_static_mac = _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) - static_mac = min(tuned_static_mac or base_mac, base_mac) - if not use_micro and routed_rows < 40: - static_mac = min(static_mac, 64) + static_mac = base_mac🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 838 - 841, The current assignment applies tuned_static_mac to the regular static path; change the logic so tuned_static_mac (from _lookup_mac_ladder and _STATIC_MAC_LADDER) is only used for micro backends: when use_micro is true set static_mac = min(tuned_static_mac or base_mac, base_mac), otherwise set static_mac = base_mac so existing static decode behavior is preserved; keep the existing routed_rows < 40 cap (the min(static_mac, 64)) intact for the non-micro path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 838-841: The current assignment applies tuned_static_mac to the
regular static path; change the logic so tuned_static_mac (from
_lookup_mac_ladder and _STATIC_MAC_LADDER) is only used for micro backends: when
use_micro is true set static_mac = min(tuned_static_mac or base_mac, base_mac),
otherwise set static_mac = base_mac so existing static decode behavior is
preserved; keep the existing routed_rows < 40 cap (the min(static_mac, 64))
intact for the non-micro path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2b5ad14c-8904-424f-9887-ea51ee497f8b
📒 Files selected for processing (2)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
Synchronize the SM120 b12x MoE implementation from the upstream b12x kernels, including the short-decode dispatch and micro-kernel fixes.
|
/bot run |
Synchronize the SM120 b12x MoE implementation from the upstream b12x
kernels, including the short-decode dispatch and micro-kernel fixes.
📌 Description
Synchronizes the SM120 B12x MoE implementation with the upstream b12x kernel changes.
This PR updates the short-decode path for B12x fused MoE, including:
expert scales.
behavior.
The goal is to bring over the upstream short-decode fixes and performance improvements without
reintroducing the reverted post-863 micro-kernel changes.
🔍 Related Issues
N/A
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the
following items are complete.
✅ Pre-commit Checks
method).
🧪 Tests
Validated with:
python -m pytest -q tests/moe/test_b12x_fused_moe.py
Result:
90 passed, 1 warning
Also ran a perf smoke test:
b12x_fused_moe relu2 bs1 topk22: median 0.030 ms
Reviewer Notes
Please focus review on the SM120 MoE short-decode path, especially:
behavior unchanged unless explicitly overridden.