@@ -169,30 +169,42 @@ def apply_monolithic(
169169 scale = routed_scaling_factor ,
170170 )
171171
172- moe_expand = torch .empty (
173- (M * top_k , N ), dtype = hidden_states .dtype , device = hidden_states .device
174- ) # [M, top_k, N], float
175- expert_m = torch .zeros (
176- global_num_experts , dtype = torch .int32 , device = hidden_states .device
177- ) # [E]
178- sorted_tokens_num_lod = torch .zeros (
179- global_num_experts + 1 , dtype = torch .int32 , device = hidden_states .device
180- ) # [E+1]
181- sorted_tokens_idx = torch .zeros (
182- M * top_k , dtype = torch .int32 , device = hidden_states .device
183- )
172+ if M * top_k > 768 :
173+ moe_expand = torch .empty (
174+ (M * top_k , N ), dtype = hidden_states .dtype , device = hidden_states .device
175+ ) # [M, top_k, N], float
176+ expert_m = torch .zeros (
177+ global_num_experts , dtype = torch .int32 , device = hidden_states .device
178+ ) # [E]
179+ sorted_tokens_num_lod = torch .zeros (
180+ global_num_experts + 1 , dtype = torch .int32 , device = hidden_states .device
181+ ) # [E+1]
182+ sorted_tokens_idx = torch .zeros (
183+ M * top_k , dtype = torch .int32 , device = hidden_states .device
184+ )
184185
185- torch .ops ._C .gen_block_statistic (topk_ids , block_statistic )
186+ torch .ops ._C .gen_block_statistic (topk_ids , block_statistic )
186187
187- torch .ops ._C .moe_pre_sorted (
188- x = hidden_states ,
189- topk_index = topk_ids ,
190- block_statistic = block_statistic ,
191- moe_expand = moe_expand ,
192- moe_index = sorted_tokens_idx ,
193- expert_m = expert_m ,
194- sorted_tokens_num_lod = sorted_tokens_num_lod ,
195- )
188+ torch .ops ._C .moe_pre_sorted (
189+ x = hidden_states ,
190+ topk_index = topk_ids ,
191+ block_statistic = block_statistic ,
192+ moe_expand = moe_expand ,
193+ moe_index = sorted_tokens_idx ,
194+ expert_m = expert_m ,
195+ sorted_tokens_num_lod = sorted_tokens_num_lod ,
196+ )
197+ del expert_m
198+ else :
199+ sorted_tokens_idx , sorted_tokens_num_lod , moe_expand = (
200+ torch .ops .xspeedgate_ops .moe_pre_small (
201+ topk_ids ,
202+ global_num_experts ,
203+ index_have_neg = False ,
204+ sort_mode = True ,
205+ x = hidden_states ,
206+ )
207+ )
196208
197209 y = torch .empty (
198210 M ,
@@ -261,7 +273,7 @@ def apply_monolithic(
261273 # sort_mode=False,
262274 act = None ,
263275 )
264- del x_q , x_scale , sorted_tokens_num_lod , expert_m
276+ del x_q , x_scale , sorted_tokens_num_lod
265277
266278 dequant_scale = torch .ones ([M , top_k ], dtype = torch .float32 , device = out .device )
267279 output = torch .empty (
0 commit comments