Skip to content

Commit 377e7eb

Browse files
remove xpu_fused_moe weights handling (vllm-project#163)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent bded432 commit 377e7eb

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

tests/fused_moe/test_fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ def test_fused_moe(m, n, k, e, topk, dtype, w_dtype, has_bias):
231231
flat_expert_weights, flat_expert_indices, topk,
232232
"silu", e)
233233

234+
w13.data = w13.transpose(-1, -2).contiguous()
235+
w2.data = w2.transpose(-1, -2).contiguous()
236+
234237
output = xpu_fused_moe(hidden_states=a,
235238
w13=w13,
236239
w13_scales=w13_scales,
@@ -562,6 +565,9 @@ def test_fused_moe_ep(m, n, k, e, topk, ep_rank, ep_size, dtype, w_dtype,
562565
expert_start_id = e * ep_rank
563566
expert_end_id = expert_start_id + e
564567

568+
w13.data = w13.transpose(-1, -2).contiguous()
569+
w2.data = w2.transpose(-1, -2).contiguous()
570+
565571
output = xpu_fused_moe(hidden_states=a,
566572
w13=w13[expert_start_id:expert_end_id],
567573
w13_scales=w13_scales[expert_start_id:expert_end_id]

vllm_xpu_kernels/fused_moe_interface.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,17 @@ def xpu_fused_moe(hidden_states,
154154
else:
155155
assert output.shape == hidden_states.shape, \
156156
"output shape must be the same as hidden_states shape"
157-
inter_size = list(w13.shape)[-2] // 2
158-
159-
assert w13.is_contiguous() and w2.is_contiguous()
160157

161158
# 4bits support [E, N, K]
162159
# other types [E, K, N]
163160
if not is_int4 and not is_mxfp4:
164-
if not hasattr(w13, 'xpu_fused_moe'):
165-
w13.data = w13.transpose(-1, -2).contiguous()
166-
w2.data = w2.transpose(-1, -2).contiguous()
167-
w13.xpu_fused_moe = True
168-
w13.inter_size = inter_size
169-
else:
170-
inter_size = w13.inter_size
161+
inter_size = list(w13.shape)[-1] // 2
162+
else:
163+
inter_size = list(w13.shape)[-2] // 2
164+
165+
assert w13.is_contiguous() and w2.is_contiguous()
171166

167+
# FIXME: move this to vllm
172168
if is_int4 and not hasattr(w13, 'xpu_fused_moe'):
173169
w13_tmp = torch.empty_like(w13)
174170
w2_tmp = torch.empty_like(w2)
@@ -257,7 +253,7 @@ def xpu_fused_moe(hidden_states,
257253
torch.ops._C.silu_and_mul(act_output, gemm1_output)
258254
elif activation == "gelu":
259255
torch.ops._C.gelu_and_mul(act_output, gemm1_output)
260-
elif activation == "swigluoai":
256+
elif activation == "swigluoai" or ("SWIGLUOAI" in str(activation)):
261257
torch.ops._C.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
262258
else:
263259
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")

0 commit comments

Comments
 (0)