22import torch
33
44try :
5+ from . import _C # noqa: F401
56 from . import _xpu_C # noqa: F401
67 FUSEDMOE_UNAVAILABLE_REASON = None
78 FUSEDMOE_AVAILABLE = True
@@ -57,6 +58,10 @@ def xpu_fused_moe(hidden_states, w13, w13_bias, w2, w2_bias, topk_weights,
5758 topk_ids , n_experts_per_token , activation , num_experts ):
5859
5960 output = torch .zeros_like (hidden_states )
61+ if w13 .is_contiguous ():
62+ # transpose and replace original data once
63+ w13 .data = w13 .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
64+ w2 .data = w2 .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
6065
6166 # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
6267 num_rows , hidden_size = list (hidden_states .shape )
@@ -106,6 +111,8 @@ def config_ws(name, size):
106111 workspace = torch .zeros (map_offset ,
107112 dtype = torch .uint8 ,
108113 device = hidden_states .device )
114+ if topk_ids .dtype == torch .int32 :
115+ topk_ids = topk_ids .to (torch .int64 )
109116 torch .ops ._xpu_C .fused_moe (output = output ,
110117 input = hidden_states ,
111118 token_selected_experts = topk_ids ,
@@ -143,13 +150,12 @@ def config_ws(name, size):
143150 w2_bias = w2_bias .repeat_interleave (expert_token_count ,
144151 dim = 0 ).float ()
145152 expert_token_count = expert_token_count .cpu ()
146-
147153 gemm1_output = torch .empty ((num_moe_inputs , 2 * inter_size ),
148154 dtype = hidden_states .dtype ,
149155 device = hidden_states .device )
150156
151157 ########### gemm1 ##################
152- input_B = w13 . transpose ( - 1 , - 2 ). contiguous (). transpose ( - 1 , - 2 )
158+ input_B = w13
153159
154160 torch .ops ._xpu_C .cutlass_grouped_gemm (
155161 ptr_A = gemm1_input ,
@@ -163,13 +169,21 @@ def config_ws(name, size):
163169 groups = num_experts_per_node )
164170
165171 # act
166- gate , up_ = torch .split (gemm1_output , inter_size , dim = 1 )
167- act = torch .nn .SiLU ()
168- act_output = act (gate ) * up_
172+ act_output = torch .empty ((num_moe_inputs , inter_size ),
173+ dtype = gemm1_output .dtype ,
174+ device = gemm1_output .device )
175+ if activation == "silu" :
176+ torch .ops ._C .silu_and_mul (act_output , gemm1_output )
177+ elif activation == "gelu" :
178+ torch .ops ._C .gelu_and_mul (act_output , gemm1_output )
179+ elif activation == "swigluoai" :
180+ torch .ops ._C .swigluoai_and_mul (act_output , gemm1_output , 1.702 , 7.0 )
181+ else :
182+ raise ValueError (f"Unsupported FusedMoe activation: { activation } ." )
169183
170184 ########### gemm2 ##################
171185 input_A = act_output .contiguous ()
172- input_B = w2 . transpose ( - 1 , - 2 ). contiguous (). transpose ( - 1 , - 2 )
186+ input_B = w2
173187 gemm2_output = torch .empty ((num_moe_inputs , hidden_size ),
174188 dtype = hidden_states .dtype ,
175189 device = hidden_states .device )
@@ -184,23 +198,23 @@ def config_ws(name, size):
184198 K = inter_size ,
185199 groups = num_experts_per_node )
186200
187- topk_weights = topk_weights .view (- 1 , 1 )
188201 expert_cache = output
189202
190- for expert_id , end_idx in enumerate (expert_first_token_offset ):
191- start_idx = 0 if expert_id == 0 else expert_first_token_offset [
192- expert_id - 1 ]
203+ iter_for_weight_apply = expert_first_token_offset [1 :]
204+ for expert_id , end_idx in enumerate (iter_for_weight_apply ):
205+ start_idx = 0 if expert_id == 0 else iter_for_weight_apply [expert_id -
206+ 1 ]
193207 if start_idx == end_idx :
194208 continue
195209
196- exp_token_idxs = permuted_row_to_unpermuted_row [
197- start_idx :end_idx ] % num_rows
210+ exp_token_idxs = permuted_row_to_unpermuted_row [start_idx :end_idx ]
211+ scores_token_ids = exp_token_idxs % num_rows
212+ scores_k_slot = exp_token_idxs // num_rows
213+ scores = topk_weights [scores_token_ids , scores_k_slot ]
198214 expert_out = gemm2_output [start_idx :end_idx ]
199- expert_out .mul_ (
200- topk_weights [permuted_row_to_unpermuted_row [start_idx :end_idx ] %
201- num_rows ])
215+ expert_out .mul_ (scores .view (- 1 , 1 ))
202216 expert_cache .scatter_reduce_ (0 ,
203- exp_token_idxs .view (- 1 , 1 ).repeat (
217+ scores_token_ids .view (- 1 , 1 ).repeat (
204218 1 , hidden_size ),
205219 expert_out ,
206220 reduce = 'sum' )
0 commit comments