Skip to content

Commit 729e7e8

Browse files
committed
perf(moe): optimize SM120 b12x MoE short decode
Synchronize the SM120 b12x MoE implementation from the upstream b12x kernels, including the short-decode dispatch and micro-kernel fixes.
1 parent f7acd25 commit 729e7e8

2 files changed

Lines changed: 111 additions & 66 deletions

File tree

flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ def _get_micro_kernel(
557557
input_scales_are_reciprocal: bool = False,
558558
fast_math: bool = True,
559559
share_input_across_experts: bool = False,
560+
share_expert_scales: bool = False,
561+
single_token: bool = False,
560562
mac_override: int | None = None,
561563
activation: str = "silu",
562564
):
@@ -588,6 +590,8 @@ def _get_micro_kernel(
588590
input_scales_are_reciprocal,
589591
fast_math,
590592
share_input_across_experts,
593+
share_expert_scales,
594+
single_token,
591595
activation,
592596
)
593597
cached = _MICRO_KERNEL_CACHE.get(cache_key)
@@ -607,6 +611,8 @@ def _get_micro_kernel(
607611
fast_math=fast_math,
608612
activation=activation,
609613
share_input_across_experts=share_input_across_experts,
614+
share_expert_scales=share_expert_scales,
615+
single_token=single_token,
610616
)
611617

612618
is_gated = activation == "silu"
@@ -815,6 +821,7 @@ def launch_sm120_static_moe(
815821
# the m=1 relu2 shared-input micro optimization only applies when every
816822
# expert sees the same FC1-input global scale.
817823
input_gs_is_shared = input_gs.numel() == 1
824+
down_input_scale_is_shared = down_input_scale.numel() == 1
818825

819826
# Broadcast scalar scales to per-expert [E] tensors
820827
input_gs = _expand_to_experts(input_gs, num_experts)
@@ -828,19 +835,24 @@ def launch_sm120_static_moe(
828835

829836
sm_count = get_num_sm(torch.device("cuda"))
830837
base_mac = min(get_max_active_clusters(1), sm_count)
838+
tuned_static_mac = _lookup_mac_ladder(_STATIC_MAC_LADDER, routed_rows)
839+
static_mac = min(tuned_static_mac or base_mac, base_mac)
840+
if not use_micro and routed_rows < 40:
841+
static_mac = min(static_mac, 64)
831842

832843
if use_micro:
833844
assert flat_ids.numel() <= workspace.compact_topk_ids.numel(), (
834845
f"compact_topk_ids buffer too small: "
835846
f"{workspace.compact_topk_ids.numel()} < {flat_ids.numel()}"
836847
)
837-
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]
838-
if num_tokens == 1:
839-
# A single token's top-k is already a dense unique expert set,
840-
# so we can build the compact local-id mapping on the host
841-
# without launching the Triton compaction kernel. The micro
842-
# kernel still reads weight_expert_ids the same way it does
843-
# for m>1; it just sees a pre-filled workspace.
848+
# Single-token ReLU2 is non-gated, so the micro kernel can launch on
849+
# the routed expert ids directly. Gated SiLU still goes through the
850+
# compact id buffer so the kernel can map compact launch ids back to
851+
# the physical gate/up weight experts.
852+
if num_tokens == 1 and activation == "relu2":
853+
launch_ids = flat_ids
854+
elif num_tokens == 1:
855+
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]
844856
compact_ids.copy_(
845857
torch.arange(
846858
flat_ids.numel(),
@@ -852,7 +864,9 @@ def launch_sm120_static_moe(
852864
flat_ids.to(torch.int32)
853865
)
854866
workspace.active_expert_count.fill_(flat_ids.numel())
867+
launch_ids = compact_ids
855868
else:
869+
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]
856870
from .triton_compact import compact_topk_ids as _triton_compact_topk_ids
857871

858872
_triton_compact_topk_ids(
@@ -861,23 +875,24 @@ def launch_sm120_static_moe(
861875
workspace.weight_expert_ids,
862876
workspace.active_expert_count,
863877
)
864-
launch_ids = compact_ids
878+
launch_ids = compact_ids
865879
# Select micro MAC: min of tuned ladder, work tiles, and hardware limit.
866-
# The hardware cap (base_mac) prevents deadlocks on GPUs with fewer SMs
867-
# than the profiled tuning target.
868880
micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
869881
tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows)
870882
micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac)
871-
# For m=1 relu2 with a shared FC1-input scale, all experts see the
872-
# same quantized activation — quantize once and share the packed
873-
# buffer slot across all K top-k pairs. Env override lets us flip
874-
# this off without a code change if a regression surfaces.
883+
# For m=1 ReLU2 with a shared FC1-input scale, all experts see the
884+
# same quantized activation. Match FI main's synchronization model:
885+
# one CTA writes a shared packed slot, then the resident-grid barrier
886+
# below makes it visible before all CTAs read it for FC1.
875887
share_input_across_experts = (
876888
activation == "relu2"
877889
and num_tokens == 1
878890
and input_gs_is_shared
879891
and os.environ.get("FLASHINFER_B12X_MICRO_SHARE_INPUT", "1") != "0"
880892
)
893+
share_expert_scales = (
894+
activation == "relu2" and input_gs_is_shared and down_input_scale_is_shared
895+
)
881896
compiled, mac = _get_micro_kernel(
882897
workspace.state_E,
883898
num_experts,
@@ -886,17 +901,16 @@ def launch_sm120_static_moe(
886901
n,
887902
top_k,
888903
workspace.max_rows,
889-
topk_ids_dtype=torch.int32,
904+
topk_ids_dtype=launch_ids.dtype,
890905
input_scales_are_reciprocal=input_scales_are_reciprocal,
891906
fast_math=fast_math,
892907
share_input_across_experts=share_input_across_experts,
908+
share_expert_scales=share_expert_scales,
909+
single_token=num_tokens == 1,
893910
mac_override=micro_mac,
894911
activation=activation,
895912
)
896913
else:
897-
# Static path — use hardware default MAC (same as main).
898-
# MAC tuning for the static kernel is deferred to a follow-up
899-
# to avoid changing behavior for existing static workloads.
900914
compiled, mac = _get_static_kernel(
901915
workspace.state_E,
902916
num_experts,
@@ -908,6 +922,7 @@ def launch_sm120_static_moe(
908922
topk_ids_dtype=torch.int32,
909923
input_scales_are_reciprocal=input_scales_are_reciprocal,
910924
fast_math=fast_math,
925+
mac_override=static_mac,
911926
activation=activation,
912927
)
913928
launch_ids = flat_ids

0 commit comments

Comments
 (0)