Skip to content

Commit 388dd0e

Browse files
committed
Use DispatchGmmCombineDecode operator to replace MC2 when env variable VLLM_ASCEND_ENABLE_FUSED_MC2=1
Signed-off-by: wangqiankun <[email protected]>
1 parent df7e0fe commit 388dd0e

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MoECommType(Enum):
2424
MC2 = 1
2525
ALLTOALL = 2
2626
FUSED_ALLTOALL = 3
27+
FUSED_MC2 = 4
2728

2829

2930
@contextmanager

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@
132132
# Whether to anbale dynamic EPLB
133133
"DYNAMIC_EPLB":
134134
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
135+
# Whether to anbale fused mc2(dispatch_gmm_combine_decode operator)
136+
"VLLM_ASCEND_ENABLE_FUSED_MC2":
137+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0'))),
135138
}
136139

137140
# end-env-vars-definition

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def setup_moe_comm_method(moe_config):
4646
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
4747
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
4848
moe_config)
49+
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(
50+
moe_config)
4951

5052

5153
class MoECommMethod(ABC):
@@ -315,3 +317,71 @@ def fused_experts(
315317
out=out,
316318
)
317319
return out
320+
321+
class FusedMC2CommImpl(MoECommMethod):
322+
"""This implementation is for the scenarios listed below:
323+
1. `enable_expert_parallel=True`.
324+
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
325+
3. `enable_expert_parallel=False` is not supported.
326+
327+
This implementation uses the FusedMC2 communication method, which is optimized for
328+
Communication and Computation parallelism on Ascend devices.
329+
"""
330+
331+
def _get_token_dispatcher(self):
332+
return TokenDispatcherWithMC2()
333+
334+
def _get_prepare_finalize(self):
335+
return PrepareAndFinalizeWithMC2(self.moe_config)
336+
337+
def fused_experts(
338+
self,
339+
hidden_states: torch.Tensor,
340+
w1: torch.Tensor,
341+
w2: torch.Tensor,
342+
topk_weights: torch.Tensor,
343+
topk_ids: torch.Tensor,
344+
activation: str = "silu",
345+
apply_router_weight_on_input: bool = False,
346+
use_int8_w8a8: bool = False,
347+
use_int4_w4a8: bool = False,
348+
use_int4_w4a16: bool = False,
349+
global_num_experts: Optional[int] = None,
350+
expert_map: Optional[torch.Tensor] = None,
351+
w1_scale: Optional[torch.Tensor] = None,
352+
w2_scale: Optional[torch.Tensor] = None,
353+
w1_scale_bias: torch.Tensor = None,
354+
w2_scale_bias: torch.Tensor = None,
355+
w1_offset: Optional[torch.Tensor] = None,
356+
w2_offset: Optional[torch.Tensor] = None,
357+
# For Cube/Vector parallel
358+
shared_experts: Optional[Any] = None,
359+
quantized_x_for_share: Optional[Any] = None,
360+
dynamic_scale_for_share: Optional[Any] = None,
361+
# For load balance
362+
log2phy: torch.Tensor = None,
363+
global_redundant_expert_num: int = 0,
364+
need_trans: bool = False,
365+
dynamic_eplb: bool = False,
366+
mc2_mask: torch.Tensor = None,
367+
pertoken_scale: Optional[torch.Tensor] = None):
368+
369+
output, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
370+
x=hidden_states,
371+
expert_ids=topk_ids,
372+
gmm1_permuted_weight=w1[0],
373+
gmm1_permuted_weight_scale=w1_scale[0],
374+
gmm2_weight=w2[0],
375+
gmm2_weight_scale=w2_scale[0],
376+
expert_smooth_scales=None,
377+
expert_scales=topk_weights.to(torch.float32),
378+
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
379+
ep_rank_size=self.token_dispatcher.ep_world_size,
380+
ep_rank_id=self.token_dispatcher.ep_rank_id,
381+
moe_expert_num=len(expert_map),
382+
shared_expert_num=1,
383+
shared_expert_rank_num=0,
384+
quant_mode=0,
385+
global_bs=self.token_dispatcher.global_bs
386+
)
387+
return output

vllm_ascend/ops/register_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
250250
forward_context = get_forward_context()
251251
moe_comm_type = forward_context.moe_comm_type
252252
if moe_comm_type in {
253-
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
253+
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL, MoECommType.FUSED_MC2
254254
} or forward_context.sp_enabled:
255255
return final_hidden_states
256256
else:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def apply(
235235
topk_weights = topk_weights.to(self.in_dtype)
236236

237237
moe_comm_method = get_forward_context().moe_comm_method
238+
fused_mc2_flag = get_forward_context(
239+
).moe_comm_type == MoECommType.FUSED_MC2
238240
if self.dynamic_eplb:
239241
w1 = layer.w13_weight_list
240242
w1_scale = layer.w13_weight_scale_fp32_list
@@ -244,7 +246,7 @@ def apply(
244246
w1 = [layer.w13_weight]
245247
w1_scale = [layer.w13_weight_scale_fp32]
246248
w2 = [layer.w2_weight]
247-
w2_scale = [layer.w2_weight_scale]
249+
w2_scale = [layer.w2_weight_scale_fp32 if fused_mc2_flag else layer.w2_weight_scale]
248250

249251
fused_flag = get_forward_context(
250252
).moe_comm_type == MoECommType.FUSED_ALLTOALL
@@ -284,6 +286,8 @@ def process_weights_after_loading(self, layer):
284286
layer.w13_weight_offset.data.shape[0], -1)
285287
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
286288
layer.w2_weight_scale.data.shape[0], -1)
289+
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
290+
torch.float32)
287291
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
288292
layer.w2_weight_offset.data.shape[0], -1)
289293

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,9 +1419,10 @@ def _select_moe_comm_method(self,
14191419

14201420
elif soc_version in {AscendDeviceType._910_93}:
14211421
moe_comm_type = (
1422-
MoECommType.MC2 if num_tokens <= mc2_tokens_capacity else
1423-
MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic"
1424-
and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL)
1422+
(MoECommType.FUSED_MC2 if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
1423+
else MoECommType.MC2) if num_tokens <= mc2_tokens_capacity
1424+
else MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic" and get_ep_group().world_size <= 16
1425+
else MoECommType.ALLTOALL)
14251426
else:
14261427
raise ValueError(f"Unsupported soc_version: {soc_version}")
14271428

0 commit comments

Comments
 (0)