Skip to content

Commit 35add0e

Browse files
committed
This commit adds model-side integration for the previously introduced experimental AscendC fused operator DispatchGmmCombineDecode, used in MoE decoding.
The operator implementation itself was added in a prior PR #4139. This change only adapts the model execution path to optionally use the fused operator. When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=1 is set, the original MC2 path composed of multiple operators (A8W8 dispatch → GMM → SwiGLU → GMM → combine) is replaced by the single fused operator DispatchGmmCombineDecode. By default, the existing multi-operator MC2 implementation is preserved. Signed-off-by: wangqiankun <[email protected]>
1 parent df7e0fe commit 35add0e

File tree

6 files changed

+85
-6
lines changed

6 files changed

+85
-6
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MoECommType(Enum):
2424
MC2 = 1
2525
ALLTOALL = 2
2626
FUSED_ALLTOALL = 3
27+
FUSED_MC2 = 4
2728

2829

2930
@contextmanager

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@
132132
# Whether to anbale dynamic EPLB
133133
"DYNAMIC_EPLB":
134134
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
135+
# Whether to anbale fused mc2(dispatch_gmm_combine_decode operator)
136+
"VLLM_ASCEND_ENABLE_FUSED_MC2":
137+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0'))),
135138
}
136139

137140
# end-env-vars-definition

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def setup_moe_comm_method(moe_config):
4646
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
4747
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
4848
moe_config)
49+
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(
50+
moe_config)
4951

5052

5153
class MoECommMethod(ABC):
@@ -315,3 +317,67 @@ def fused_experts(
315317
out=out,
316318
)
317319
return out
320+
321+
class FusedMC2CommImpl(MoECommMethod):
322+
"""This implementation is for the scenarios listed below:
323+
1. `class MC2CommImpl` can be used.
324+
2. `VLLM_ASCEND_ENABLE_FUSED_MC2` is enabled.
325+
3. `w8a8_dynamic` quantization is used.
326+
This implementation uses the `dispatch_gmm_combine_decode` operator, which is a fused
327+
operator for MoE decoding that combines communication and computation for optimization
328+
on Ascend devices.
329+
"""
330+
def _get_token_dispatcher(self):
331+
return TokenDispatcherWithMC2()
332+
333+
def _get_prepare_finalize(self):
334+
return PrepareAndFinalizeWithMC2(self.moe_config)
335+
336+
def fused_experts(
337+
self,
338+
hidden_states: torch.Tensor,
339+
w1: torch.Tensor,
340+
w2: torch.Tensor,
341+
topk_weights: torch.Tensor,
342+
topk_ids: torch.Tensor,
343+
activation: str = "silu",
344+
apply_router_weight_on_input: bool = False,
345+
use_int8_w8a8: bool = False,
346+
use_int4_w4a8: bool = False,
347+
use_int4_w4a16: bool = False,
348+
global_num_experts: Optional[int] = None,
349+
expert_map: Optional[torch.Tensor] = None,
350+
w1_scale: Optional[torch.Tensor] = None,
351+
w2_scale: Optional[torch.Tensor] = None,
352+
w1_scale_bias: torch.Tensor = None,
353+
w2_scale_bias: torch.Tensor = None,
354+
w1_offset: Optional[torch.Tensor] = None,
355+
w2_offset: Optional[torch.Tensor] = None,
356+
# For Cube/Vector parallel
357+
shared_experts: Optional[Any] = None,
358+
quantized_x_for_share: Optional[Any] = None,
359+
dynamic_scale_for_share: Optional[Any] = None,
360+
# For load balance
361+
log2phy: torch.Tensor = None,
362+
global_redundant_expert_num: int = 0,
363+
need_trans: bool = False,
364+
dynamic_eplb: bool = False,
365+
mc2_mask: torch.Tensor = None,
366+
pertoken_scale: Optional[torch.Tensor] = None):
367+
368+
output, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
369+
x=hidden_states,
370+
expert_ids=topk_ids,
371+
gmm1_permuted_weight=w1[0],
372+
gmm1_permuted_weight_scale=w1_scale[0],
373+
gmm2_weight=w2[0],
374+
gmm2_weight_scale=w2_scale[0],
375+
expert_smooth_scales=None,
376+
expert_scales=topk_weights.to(torch.float32),
377+
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
378+
ep_rank_size=self.token_dispatcher.ep_world_size,
379+
ep_rank_id=self.token_dispatcher.ep_rank_id,
380+
moe_expert_num=len(expert_map),
381+
global_bs=self.token_dispatcher.global_bs
382+
)
383+
return output

vllm_ascend/ops/register_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
250250
forward_context = get_forward_context()
251251
moe_comm_type = forward_context.moe_comm_type
252252
if moe_comm_type in {
253-
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
253+
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL, MoECommType.FUSED_MC2
254254
} or forward_context.sp_enabled:
255255
return final_hidden_states
256256
else:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def apply(
235235
topk_weights = topk_weights.to(self.in_dtype)
236236

237237
moe_comm_method = get_forward_context().moe_comm_method
238+
fused_mc2_flag = get_forward_context(
239+
).moe_comm_type == MoECommType.FUSED_MC2
238240
if self.dynamic_eplb:
239241
w1 = layer.w13_weight_list
240242
w1_scale = layer.w13_weight_scale_fp32_list
@@ -244,7 +246,7 @@ def apply(
244246
w1 = [layer.w13_weight]
245247
w1_scale = [layer.w13_weight_scale_fp32]
246248
w2 = [layer.w2_weight]
247-
w2_scale = [layer.w2_weight_scale]
249+
w2_scale = [layer.w2_weight_scale_fp32 if fused_mc2_flag else layer.w2_weight_scale]
248250

249251
fused_flag = get_forward_context(
250252
).moe_comm_type == MoECommType.FUSED_ALLTOALL
@@ -284,6 +286,8 @@ def process_weights_after_loading(self, layer):
284286
layer.w13_weight_offset.data.shape[0], -1)
285287
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
286288
layer.w2_weight_scale.data.shape[0], -1)
289+
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
290+
torch.float32)
287291
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
288292
layer.w2_weight_offset.data.shape[0], -1)
289293

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,10 +1418,15 @@ def _select_moe_comm_method(self,
14181418
moe_comm_type = MoECommType.ALLGATHER
14191419

14201420
elif soc_version in {AscendDeviceType._910_93}:
1421-
moe_comm_type = (
1422-
MoECommType.MC2 if num_tokens <= mc2_tokens_capacity else
1423-
MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic"
1424-
and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL)
1421+
if num_tokens <= mc2_tokens_capacity:
1422+
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic":
1423+
moe_comm_type = MoECommType.FUSED_MC2
1424+
else:
1425+
moe_comm_type = MoECommType.MC2
1426+
elif quant_type == "w8a8_dynamic" and get_ep_group().world_size <= 16:
1427+
moe_comm_type = MoECommType.FUSED_ALLTOALL
1428+
else:
1429+
moe_comm_type = MoECommType.ALLTOALL
14251430
else:
14261431
raise ValueError(f"Unsupported soc_version: {soc_version}")
14271432

0 commit comments

Comments
 (0)