Skip to content

perf(moe): optimize SM120 b12x MoE short decode#3193

Open
lukealonso wants to merge 1 commit intoflashinfer-ai:mainfrom
lukealonso:main
Open

perf(moe): optimize SM120 b12x MoE short decode#3193
lukealonso wants to merge 1 commit intoflashinfer-ai:mainfrom
lukealonso:main

Conversation

@lukealonso
Copy link
Copy Markdown

@lukealonso lukealonso commented Apr 27, 2026

Synchronize the SM120 b12x MoE implementation from the upstream b12x
kernels, including the short-decode dispatch and micro-kernel fixes.

📌 Description

Synchronizes the SM120 B12x MoE implementation with the upstream b12x kernel changes.

This PR updates the short-decode path for B12x fused MoE, including:

  • ReLU2 single-token dispatch shortcut that avoids the Triton compaction pre-pass.
  • Micro-kernel specializations for single-token decode, shared input quantization, and shared
    expert scales.
  • Tuned micro-kernel max-active-cluster selection while preserving the existing static-kernel MAC
    behavior.
  • CUDA graph workspace sizing updated to use the same static/dynamic cutover helper as dispatch.

The goal is to bring over the upstream short-decode fixes and performance improvements without
reintroducing the reverted post-863 micro-kernel changes.

🔍 Related Issues

N/A

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the
following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred
    method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Validated with:

python -m pytest -q tests/moe/test_b12x_fused_moe.py

Result:

90 passed, 1 warning

Also ran a perf smoke test:

b12x_fused_moe relu2 bs1 topk22: median 0.030 ms

Reviewer Notes

Please focus review on the SM120 MoE short-decode path, especially:

  • ReLU2 single-token routing using flat_ids.
  • Gated/SwiGLU single-token behavior preserving compacted expert mapping.
  • The intentional choice to apply the tuned MAC ladder only to the micro path, leaving static MAC
    behavior unchanged unless explicitly overridden.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 27, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7b81688d-0d4a-49e1-bd1d-b45c7fdeba56

📥 Commits

Reviewing files that changed from the base of the PR and between 1bc2658 and 729e7e8.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py

📝 Walkthrough

Walkthrough

Adds compile-time flags share_expert_scales and single_token across dispatch and micro-kernel; changes micro/static MAC selection, single-token launch path, quantization/packing layout, and synchronization/launch-id handling.

Changes

Cohort / File(s) Summary
Dispatch / Launcher
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
Threads share_expert_scales and single_token into micro-kernel cache key and compile args; sets topk_ids_dtype=launch_ids.dtype; computes/clamps a tuned static_mac and passes mac_override for static kernels; adjusts single-token launch-id mapping to use flat_ids for activation=="relu2".
Micro-kernel implementation
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
Adds ctor flags share_expert_scales and single_token. Introduces single-token specialization (routing, phase init, compute-tile scheduling), changes pre-pack/resident-grid barrier logic to depend on shared-input flags, reworks quantization/scale indexing and packed-input ownership when scales/inputs are shared, and adapts routing metadata derivation for gated vs non-gated modes.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant Dispatch
    participant MicroKernel
    participant StaticKernel

    Client->>Dispatch: request routed compute (num_tokens, activation, flags)
    Dispatch->>Dispatch: decide use_micro, single_token, share_expert_scales, static_mac
    alt use micro-kernel
        Dispatch->>MicroKernel: compile/launch (topk_ids_dtype, single_token, share_expert_scales, micro_mac)
        MicroKernel-->>Dispatch: completion/results
    else use static kernel
        Dispatch->>StaticKernel: compile/launch (mac_override=static_mac)
        StaticKernel-->>Dispatch: completion/results
    end
    Dispatch->>Client: return outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • aleozlx
  • yzh119
  • samuellees
  • IwakuraRein
  • jiahanc
  • nv-yunzheq

Poem

🐰 Tiny kernels, bold and keen,
Single tokens hop between,
Scales shared snug in slot zero's nest,
Flat IDs leap and threads do rest,
A rabbit claps for faster zest.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically refers to performance optimization of SM120 B12x MoE short decode, which is the main focus of the changeset.
Description check ✅ Passed The description covers all required template sections with sufficient detail: motivation, specific changes, related issues, test results, and reviewer guidance.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu added the run-ci label Apr 27, 2026
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 27, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !612 has been created, and the CI pipeline #49645982 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces static MoE cutover constants with dynamic environment variable lookups and introduces optimizations for relu2 single-token workloads, including shared expert scales and specialized MAC capping. Feedback focuses on critical synchronization issues in the micro kernel where skipping grid barriers or redundant quantization logic could cause race conditions. Further improvements are recommended for the dispatch logic to cache environment lookups more effectively and handle potential parsing errors for non-numeric configuration values.

Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py (1)

1019-1025: ⚠️ Potential issue | 🔴 Critical

Restore a grid-wide barrier in the shared-input single-token path.

Lines 1019-1025 and 1147-1160 remove the only inter-CTA rendezvous for share_input_across_experts. CTAs with bidz >= total_pairs skip route/pack entirely but can still enter phase 2 and read slot 0 before any writer CTA finishes, and lagging CTAs can still be zeroing scatter_output while others start atomics. That makes the bs1 ReLU2 fast path racey.

Suggested fix
-        if cutlass.const_expr(not self.share_input_across_experts):
-            self._resident_grid_barrier(
-                barrier_count,
-                barrier_epoch,
-                Int32(gdim_z),
-                is_cta_leader,
-            )
+        self._resident_grid_barrier(
+            barrier_count,
+            barrier_epoch,
+            Int32(gdim_z),
+            is_cta_leader,
+        )-        if cutlass.const_expr(
-            self.share_input_across_experts and self.single_token and not self.is_gated
-        ):
-            cute.arch.sync_threads()
-            _threadfence()
-            cute.arch.fence_proxy("async.global")
-            cute.arch.sync_threads()
-        else:
-            self._resident_grid_barrier(
-                barrier_count,
-                barrier_epoch,
-                Int32(gdim_z),
-                is_cta_leader,
-            )
+        self._resident_grid_barrier(
+            barrier_count,
+            barrier_epoch,
+            Int32(gdim_z),
+            is_cta_leader,
+        )

Also applies to: 1147-1160

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py` around
lines 1019 - 1025, Restore the grid-wide rendezvous by calling
self._resident_grid_barrier in the shared-input single-token path: replace the
current conditional that uses cutlass.const_expr(not
self.share_input_across_experts) with
cutlass.const_expr(self.share_input_across_experts) so the barrier (invoking
self._resident_grid_barrier(barrier_count, barrier_epoch, Int32(gdim_z),
is_cta_leader)) runs when share_input_across_experts is enabled; make the same
change at the second analogous site handling the single-token/shared-input fast
path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 918-920: The current code forcibly caps static_mac with static_mac
= min(static_mac, 64) when use_micro is false and routed_rows < 40, which
unintentionally affects routed_rows 21–39; change the guard so the cap only
applies at the original cutover (<=20 routed rows) or when an explicit override
is requested: replace the condition routed_rows < 40 with routed_rows <= 20 (or
routed_rows < 21), keeping the same check for use_micro and using
_get_impl_mac("static", routed_rows=routed_rows) and static_mac to locate the
code to update.

---

Outside diff comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py`:
- Around line 1019-1025: Restore the grid-wide rendezvous by calling
self._resident_grid_barrier in the shared-input single-token path: replace the
current conditional that uses cutlass.const_expr(not
self.share_input_across_experts) with
cutlass.const_expr(self.share_input_across_experts) so the barrier (invoking
self._resident_grid_barrier(barrier_count, barrier_epoch, Int32(gdim_z),
is_cta_leader)) runs when share_input_across_experts is enabled; make the same
change at the second analogous site handling the single-token/shared-input fast
path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 048f41ee-ebe4-4a21-979b-4d894879186c

📥 Commits

Reviewing files that changed from the base of the PR and between f7acd25 and 2b77278.

📒 Files selected for processing (3)
  • flashinfer/fused_moe/cute_dsl/b12x_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py

Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 27, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #49645982 has been cancelled.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)

838-841: ⚠️ Potential issue | 🟠 Major

Static launches still get the tuned MAC ladder.

This changes normal static-kernel behavior, not just the micro path. The new static_mac calculation plus the routed_rows < 40 clamp feed directly into _get_static_kernel(...) for every non-micro launch, which contradicts the PR objective of keeping the tuned MAC ladder micro-only and preserving existing static decode behavior.

Suggested fix
-    tuned_static_mac = _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows)
-    static_mac = min(tuned_static_mac or base_mac, base_mac)
-    if not use_micro and routed_rows < 40:
-        static_mac = min(static_mac, 64)
@@
         compiled, mac = _get_static_kernel(
             workspace.state_E,
             num_experts,
             num_tokens,
             k,
             n,
             top_k,
             workspace.max_rows,
             topk_ids_dtype=torch.int32,
             input_scales_are_reciprocal=input_scales_are_reciprocal,
             fast_math=fast_math,
-            mac_override=static_mac,
             activation=activation,
         )

Also applies to: 920-931

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
838 - 841, The current code applies the tuned MAC ladder to all non-micro
launches by computing tuned_static_mac via
_lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) and then using it to set
static_mac before calling _get_static_kernel, which breaks static-kernel
behavior; only apply the tuned ladder and the routed_rows < 40 clamp when
use_micro is True. Fix by gating the _lookup_mac_ladder call and the routed_rows
clamp behind use_micro (e.g., when use_micro: compute tuned_static_mac and set
static_mac = min(tuned_static_mac or base_mac, base_mac) and apply the
routed_rows < 40 => static_mac = min(static_mac, 64); otherwise set static_mac =
base_mac), and make the same change for the duplicate block that affects lines
around the other occurrence (the logic feeding _get_static_kernel).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 838-841: The current code applies the tuned MAC ladder to all
non-micro launches by computing tuned_static_mac via
_lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows) and then using it to set
static_mac before calling _get_static_kernel, which breaks static-kernel
behavior; only apply the tuned ladder and the routed_rows < 40 clamp when
use_micro is True. Fix by gating the _lookup_mac_ladder call and the routed_rows
clamp behind use_micro (e.g., when use_micro: compute tuned_static_mac and set
static_mac = min(tuned_static_mac or base_mac, base_mac) and apply the
routed_rows < 40 => static_mac = min(static_mac, 64); otherwise set static_mac =
base_mac), and make the same change for the duplicate block that affects lines
around the other occurrence (the logic feeding _get_static_kernel).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4a6ee59c-ffb1-47f9-b485-64836edbf4a8

📥 Commits

Reviewing files that changed from the base of the PR and between 2b77278 and 2dc841b.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)

838-841: ⚠️ Potential issue | 🟠 Major

Keep the tuned MAC ladder micro-only.

This block now retunes the static backend too. tuned_static_mac changes the normal static residency for every _STATIC_MAC_LADDER bucket, and the extra routed_rows < 40 cap still hits the regular static path for top_k == 1 rows 21–39. That contradicts the PR goal of preserving existing static decode behavior.

Suggested fix
-    tuned_static_mac = _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows)
-    static_mac = min(tuned_static_mac or base_mac, base_mac)
-    if not use_micro and routed_rows < 40:
-        static_mac = min(static_mac, 64)
+    static_mac = base_mac
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
838 - 841, The current assignment applies tuned_static_mac to the regular static
path; change the logic so tuned_static_mac (from _lookup_mac_ladder and
_STATIC_MAC_LADDER) is only used for micro backends: when use_micro is true set
static_mac = min(tuned_static_mac or base_mac, base_mac), otherwise set
static_mac = base_mac so existing static decode behavior is preserved; keep the
existing routed_rows < 40 cap (the min(static_mac, 64)) intact for the non-micro
path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 838-841: The current assignment applies tuned_static_mac to the
regular static path; change the logic so tuned_static_mac (from
_lookup_mac_ladder and _STATIC_MAC_LADDER) is only used for micro backends: when
use_micro is true set static_mac = min(tuned_static_mac or base_mac, base_mac),
otherwise set static_mac = base_mac so existing static decode behavior is
preserved; keep the existing routed_rows < 40 cap (the min(static_mac, 64))
intact for the non-micro path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b5ad14c-8904-424f-9887-ea51ee497f8b

📥 Commits

Reviewing files that changed from the base of the PR and between 2dc841b and 1bc2658.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py

Synchronize the SM120 b12x MoE implementation from the
upstream b12x kernels, including the short-decode dispatch
and micro-kernel fixes.
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 28, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !612 has been updated with latest changes, and the CI pipeline #49659338 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants