Skip to content

Commit 2d20085

Browse files
[FusedMoE] fix topk>1 acc issue and support different activation (vllm-project#66)
* fix topk>1 acc issue Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com> * format Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com> * Update fused_moe_interface.py * add activition types support Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com> * fix test Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com> --------- Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
1 parent 2d03d25 commit 2d20085

2 files changed

Lines changed: 31 additions & 22 deletions

File tree

tests/fused_moe/test_fused_moe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,7 @@ def test_grouped_gemm(m, n, k, e, topk, dtype, has_bias):
8282
pre_token_sum += cur_token_num
8383
ref = torch.cat(ref, dim=0)
8484

85-
try:
86-
torch.testing.assert_close(output, ref, rtol=1e-2, atol=1e-2)
87-
print("a and b close enough")
88-
except AssertionError as e:
89-
print("a and b diffs")
90-
print(e)
85+
torch.testing.assert_close(output, ref, rtol=2e-2, atol=1e-2)
9186

9287

9388
def ref_fused_moe(x, w13, w13_bias, w2, w2_bias, flat_expert_weights,

vllm_xpu_kernels/fused_moe_interface.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
try:
5+
from . import _C # noqa: F401
56
from . import _xpu_C # noqa: F401
67
FUSEDMOE_UNAVAILABLE_REASON = None
78
FUSEDMOE_AVAILABLE = True
@@ -57,6 +58,10 @@ def xpu_fused_moe(hidden_states, w13, w13_bias, w2, w2_bias, topk_weights,
5758
topk_ids, n_experts_per_token, activation, num_experts):
5859

5960
output = torch.zeros_like(hidden_states)
61+
if w13.is_contiguous():
62+
# transpose and replace original data once
63+
w13.data = w13.transpose(-1, -2).contiguous().transpose(-1, -2)
64+
w2.data = w2.transpose(-1, -2).contiguous().transpose(-1, -2)
6065

6166
# TODO: will all integrated in Cpp func. Temporary expose before gemm fusion
6267
num_rows, hidden_size = list(hidden_states.shape)
@@ -106,6 +111,8 @@ def config_ws(name, size):
106111
workspace = torch.zeros(map_offset,
107112
dtype=torch.uint8,
108113
device=hidden_states.device)
114+
if topk_ids.dtype == torch.int32:
115+
topk_ids = topk_ids.to(torch.int64)
109116
torch.ops._xpu_C.fused_moe(output=output,
110117
input=hidden_states,
111118
token_selected_experts=topk_ids,
@@ -143,13 +150,12 @@ def config_ws(name, size):
143150
w2_bias = w2_bias.repeat_interleave(expert_token_count,
144151
dim=0).float()
145152
expert_token_count = expert_token_count.cpu()
146-
147153
gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size),
148154
dtype=hidden_states.dtype,
149155
device=hidden_states.device)
150156

151157
########### gemm1 ##################
152-
input_B = w13.transpose(-1, -2).contiguous().transpose(-1, -2)
158+
input_B = w13
153159

154160
torch.ops._xpu_C.cutlass_grouped_gemm(
155161
ptr_A=gemm1_input,
@@ -163,13 +169,21 @@ def config_ws(name, size):
163169
groups=num_experts_per_node)
164170

165171
# act
166-
gate, up_ = torch.split(gemm1_output, inter_size, dim=1)
167-
act = torch.nn.SiLU()
168-
act_output = act(gate) * up_
172+
act_output = torch.empty((num_moe_inputs, inter_size),
173+
dtype=gemm1_output.dtype,
174+
device=gemm1_output.device)
175+
if activation == "silu":
176+
torch.ops._C.silu_and_mul(act_output, gemm1_output)
177+
elif activation == "gelu":
178+
torch.ops._C.gelu_and_mul(act_output, gemm1_output)
179+
elif activation == "swigluoai":
180+
torch.ops._C.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
181+
else:
182+
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
169183

170184
########### gemm2 ##################
171185
input_A = act_output.contiguous()
172-
input_B = w2.transpose(-1, -2).contiguous().transpose(-1, -2)
186+
input_B = w2
173187
gemm2_output = torch.empty((num_moe_inputs, hidden_size),
174188
dtype=hidden_states.dtype,
175189
device=hidden_states.device)
@@ -184,23 +198,23 @@ def config_ws(name, size):
184198
K=inter_size,
185199
groups=num_experts_per_node)
186200

187-
topk_weights = topk_weights.view(-1, 1)
188201
expert_cache = output
189202

190-
for expert_id, end_idx in enumerate(expert_first_token_offset):
191-
start_idx = 0 if expert_id == 0 else expert_first_token_offset[
192-
expert_id - 1]
203+
iter_for_weight_apply = expert_first_token_offset[1:]
204+
for expert_id, end_idx in enumerate(iter_for_weight_apply):
205+
start_idx = 0 if expert_id == 0 else iter_for_weight_apply[expert_id -
206+
1]
193207
if start_idx == end_idx:
194208
continue
195209

196-
exp_token_idxs = permuted_row_to_unpermuted_row[
197-
start_idx:end_idx] % num_rows
210+
exp_token_idxs = permuted_row_to_unpermuted_row[start_idx:end_idx]
211+
scores_token_ids = exp_token_idxs % num_rows
212+
scores_k_slot = exp_token_idxs // num_rows
213+
scores = topk_weights[scores_token_ids, scores_k_slot]
198214
expert_out = gemm2_output[start_idx:end_idx]
199-
expert_out.mul_(
200-
topk_weights[permuted_row_to_unpermuted_row[start_idx:end_idx] %
201-
num_rows])
215+
expert_out.mul_(scores.view(-1, 1))
202216
expert_cache.scatter_reduce_(0,
203-
exp_token_idxs.view(-1, 1).repeat(
217+
scores_token_ids.view(-1, 1).repeat(
204218
1, hidden_size),
205219
expert_out,
206220
reduce='sum')

0 commit comments

Comments
 (0)