Skip to content

Commit b49bb85

Browse files
authored
[Kernel] Add moe_pre_small kernel for compressed_tensors_moe (#333)
Signed-off-by: youzeyu <youzeyu@baidu.com>
1 parent 25f58b7 commit b49bb85

1 file changed

Lines changed: 35 additions & 23 deletions

File tree

vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -169,30 +169,42 @@ def apply_monolithic(
169169
scale=routed_scaling_factor,
170170
)
171171

172-
moe_expand = torch.empty(
173-
(M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device
174-
) # [M, top_k, N], float
175-
expert_m = torch.zeros(
176-
global_num_experts, dtype=torch.int32, device=hidden_states.device
177-
) # [E]
178-
sorted_tokens_num_lod = torch.zeros(
179-
global_num_experts + 1, dtype=torch.int32, device=hidden_states.device
180-
) # [E+1]
181-
sorted_tokens_idx = torch.zeros(
182-
M * top_k, dtype=torch.int32, device=hidden_states.device
183-
)
172+
if M * top_k > 768:
173+
moe_expand = torch.empty(
174+
(M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device
175+
) # [M, top_k, N], float
176+
expert_m = torch.zeros(
177+
global_num_experts, dtype=torch.int32, device=hidden_states.device
178+
) # [E]
179+
sorted_tokens_num_lod = torch.zeros(
180+
global_num_experts + 1, dtype=torch.int32, device=hidden_states.device
181+
) # [E+1]
182+
sorted_tokens_idx = torch.zeros(
183+
M * top_k, dtype=torch.int32, device=hidden_states.device
184+
)
184185

185-
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
186+
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
186187

187-
torch.ops._C.moe_pre_sorted(
188-
x=hidden_states,
189-
topk_index=topk_ids,
190-
block_statistic=block_statistic,
191-
moe_expand=moe_expand,
192-
moe_index=sorted_tokens_idx,
193-
expert_m=expert_m,
194-
sorted_tokens_num_lod=sorted_tokens_num_lod,
195-
)
188+
torch.ops._C.moe_pre_sorted(
189+
x=hidden_states,
190+
topk_index=topk_ids,
191+
block_statistic=block_statistic,
192+
moe_expand=moe_expand,
193+
moe_index=sorted_tokens_idx,
194+
expert_m=expert_m,
195+
sorted_tokens_num_lod=sorted_tokens_num_lod,
196+
)
197+
del expert_m
198+
else:
199+
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
200+
torch.ops.xspeedgate_ops.moe_pre_small(
201+
topk_ids,
202+
global_num_experts,
203+
index_have_neg=False,
204+
sort_mode=True,
205+
x=hidden_states,
206+
)
207+
)
196208

197209
y = torch.empty(
198210
M,
@@ -261,7 +273,7 @@ def apply_monolithic(
261273
# sort_mode=False,
262274
act=None,
263275
)
264-
del x_q, x_scale, sorted_tokens_num_lod, expert_m
276+
del x_q, x_scale, sorted_tokens_num_lod
265277

266278
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
267279
output = torch.empty(

0 commit comments

Comments
 (0)