@@ -557,6 +557,8 @@ def _get_micro_kernel(
557557 input_scales_are_reciprocal : bool = False ,
558558 fast_math : bool = True ,
559559 share_input_across_experts : bool = False ,
560+ share_expert_scales : bool = False ,
561+ single_token : bool = False ,
560562 mac_override : int | None = None ,
561563 activation : str = "silu" ,
562564):
@@ -588,6 +590,8 @@ def _get_micro_kernel(
588590 input_scales_are_reciprocal ,
589591 fast_math ,
590592 share_input_across_experts ,
593+ share_expert_scales ,
594+ single_token ,
591595 activation ,
592596 )
593597 cached = _MICRO_KERNEL_CACHE .get (cache_key )
@@ -607,6 +611,8 @@ def _get_micro_kernel(
607611 fast_math = fast_math ,
608612 activation = activation ,
609613 share_input_across_experts = share_input_across_experts ,
614+ share_expert_scales = share_expert_scales ,
615+ single_token = single_token ,
610616 )
611617
612618 is_gated = activation == "silu"
@@ -815,6 +821,7 @@ def launch_sm120_static_moe(
815821 # the m=1 relu2 shared-input micro optimization only applies when every
816822 # expert sees the same FC1-input global scale.
817823 input_gs_is_shared = input_gs .numel () == 1
824+ down_input_scale_is_shared = down_input_scale .numel () == 1
818825
819826 # Broadcast scalar scales to per-expert [E] tensors
820827 input_gs = _expand_to_experts (input_gs , num_experts )
@@ -828,19 +835,24 @@ def launch_sm120_static_moe(
828835
829836 sm_count = get_num_sm (torch .device ("cuda" ))
830837 base_mac = min (get_max_active_clusters (1 ), sm_count )
838+ tuned_static_mac = _lookup_mac_ladder (_STATIC_MAC_LADDER , routed_rows )
839+ static_mac = min (tuned_static_mac or base_mac , base_mac )
840+ if not use_micro and routed_rows < 40 :
841+ static_mac = min (static_mac , 64 )
831842
832843 if use_micro :
833844 assert flat_ids .numel () <= workspace .compact_topk_ids .numel (), (
834845 f"compact_topk_ids buffer too small: "
835846 f"{ workspace .compact_topk_ids .numel ()} < { flat_ids .numel ()} "
836847 )
837- compact_ids = workspace .compact_topk_ids [: flat_ids .numel ()]
838- if num_tokens == 1 :
839- # A single token's top-k is already a dense unique expert set,
840- # so we can build the compact local-id mapping on the host
841- # without launching the Triton compaction kernel. The micro
842- # kernel still reads weight_expert_ids the same way it does
843- # for m>1; it just sees a pre-filled workspace.
848+ # Single-token ReLU2 is non-gated, so the micro kernel can launch on
849+ # the routed expert ids directly. Gated SiLU still goes through the
850+ # compact id buffer so the kernel can map compact launch ids back to
851+ # the physical gate/up weight experts.
852+ if num_tokens == 1 and activation == "relu2" :
853+ launch_ids = flat_ids
854+ elif num_tokens == 1 :
855+ compact_ids = workspace .compact_topk_ids [: flat_ids .numel ()]
844856 compact_ids .copy_ (
845857 torch .arange (
846858 flat_ids .numel (),
@@ -852,7 +864,9 @@ def launch_sm120_static_moe(
852864 flat_ids .to (torch .int32 )
853865 )
854866 workspace .active_expert_count .fill_ (flat_ids .numel ())
867+ launch_ids = compact_ids
855868 else :
869+ compact_ids = workspace .compact_topk_ids [: flat_ids .numel ()]
856870 from .triton_compact import compact_topk_ids as _triton_compact_topk_ids
857871
858872 _triton_compact_topk_ids (
@@ -861,23 +875,24 @@ def launch_sm120_static_moe(
861875 workspace .weight_expert_ids ,
862876 workspace .active_expert_count ,
863877 )
864- launch_ids = compact_ids
878+ launch_ids = compact_ids
865879 # Select micro MAC: min of tuned ladder, work tiles, and hardware limit.
866- # The hardware cap (base_mac) prevents deadlocks on GPUs with fewer SMs
867- # than the profiled tuning target.
868880 micro_work_tiles = max (1 , routed_rows * max (1 , (n + 128 - 1 ) // 128 ))
869881 tuned_mac = _lookup_mac_ladder (_MICRO_MAC_LADDER , routed_rows )
870882 micro_mac = min (tuned_mac or base_mac , micro_work_tiles , base_mac )
871- # For m=1 relu2 with a shared FC1-input scale, all experts see the
872- # same quantized activation — quantize once and share the packed
873- # buffer slot across all K top-k pairs. Env override lets us flip
874- # this off without a code change if a regression surfaces .
883+ # For m=1 ReLU2 with a shared FC1-input scale, all experts see the
884+ # same quantized activation. Match FI main's synchronization model:
885+ # one CTA writes a shared packed slot, then the resident-grid barrier
886+ # below makes it visible before all CTAs read it for FC1 .
875887 share_input_across_experts = (
876888 activation == "relu2"
877889 and num_tokens == 1
878890 and input_gs_is_shared
879891 and os .environ .get ("FLASHINFER_B12X_MICRO_SHARE_INPUT" , "1" ) != "0"
880892 )
893+ share_expert_scales = (
894+ activation == "relu2" and input_gs_is_shared and down_input_scale_is_shared
895+ )
881896 compiled , mac = _get_micro_kernel (
882897 workspace .state_E ,
883898 num_experts ,
@@ -886,17 +901,16 @@ def launch_sm120_static_moe(
886901 n ,
887902 top_k ,
888903 workspace .max_rows ,
889- topk_ids_dtype = torch . int32 ,
904+ topk_ids_dtype = launch_ids . dtype ,
890905 input_scales_are_reciprocal = input_scales_are_reciprocal ,
891906 fast_math = fast_math ,
892907 share_input_across_experts = share_input_across_experts ,
908+ share_expert_scales = share_expert_scales ,
909+ single_token = num_tokens == 1 ,
893910 mac_override = micro_mac ,
894911 activation = activation ,
895912 )
896913 else :
897- # Static path — use hardware default MAC (same as main).
898- # MAC tuning for the static kernel is deferred to a follow-up
899- # to avoid changing behavior for existing static workloads.
900914 compiled , mac = _get_static_kernel (
901915 workspace .state_E ,
902916 num_experts ,
@@ -908,6 +922,7 @@ def launch_sm120_static_moe(
908922 topk_ids_dtype = torch .int32 ,
909923 input_scales_are_reciprocal = input_scales_are_reciprocal ,
910924 fast_math = fast_math ,
925+ mac_override = static_mac ,
911926 activation = activation ,
912927 )
913928 launch_ids = flat_ids
0 commit comments