Skip to content

Commit b38d248

Browse files
authored
add output for fused_moe interface (#137)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent ad547c1 commit b38d248

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

vllm_xpu_kernels/fused_moe_interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def xpu_fused_moe(hidden_states,
122122
num_experts,
123123
ep_rank=0,
124124
ep_size=1,
125+
output=None,
125126
is_fp8=False,
126127
is_int4=False,
127128
is_mxfp4=False):
@@ -147,8 +148,11 @@ def xpu_fused_moe(hidden_states,
147148
is_int4: bool
148149
is_mxfp4: bool
149150
'''
150-
151-
output = torch.empty_like(hidden_states)
151+
if output is None:
152+
output = torch.empty_like(hidden_states)
153+
else:
154+
assert output.shape == hidden_states.shape, \
155+
"output shape must be the same as hidden_states shape"
152156
inter_size = list(w13.shape)[-2] // 2
153157

154158
assert w13.is_contiguous() and w2.is_contiguous()

0 commit comments

Comments
 (0)