Skip to content

Commit a8ceb08

Browse files
committed
[Feature]Add model-side integration for fused operator DispatchGmmCombineDecode.
This commit adds model-side integration for the previously introduced experimental AscendC fused operator DispatchGmmCombineDecode, used in MoE decoding. The operator implementation itself was added in a prior PR #4139. This change only adapts the model execution path to optionally use the fused operator. When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=1 is set, the original MC2 path composed of multiple operators (A8W8 dispatch → GMM → SwiGLU → GMM → combine) is replaced by the single fused operator DispatchGmmCombineDecode. By default, the existing multi-operator MC2 implementation is preserved. Signed-off-by: wangqiankun <[email protected]>
1 parent df7e0fe commit a8ceb08

File tree

7 files changed

+99
-10
lines changed

7 files changed

+99
-10
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MoECommType(Enum):
2424
MC2 = 1
2525
ALLTOALL = 2
2626
FUSED_ALLTOALL = 3
27+
FUSED_MC2 = 4
2728

2829

2930
@contextmanager

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@
132132
# Whether to anbale dynamic EPLB
133133
"DYNAMIC_EPLB":
134134
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
135+
# Whether to anbale fused mc2(dispatch_gmm_combine_decode operator)
136+
"VLLM_ASCEND_ENABLE_FUSED_MC2":
137+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0'))),
135138
}
136139

137140
# end-env-vars-definition

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
352352
shared_out = fc3_context.shared_experts(hidden_states)
353353
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
354354
moe_comm_type = forward_context.moe_comm_type
355-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
355+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
356356
and not shared_expert_dp_enabled():
357357
shared_out = tensor_model_parallel_all_reduce(shared_out)
358358
set_flash_common3_context(shared_out=shared_out)
@@ -527,7 +527,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
527527
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
528528
forward_context = get_forward_context()
529529
moe_comm_type = forward_context.moe_comm_type
530-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
530+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
531531
and not shared_expert_dp_enabled():
532532
shared_out = tensor_model_parallel_all_reduce(shared_out)
533533
else:

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def setup_moe_comm_method(moe_config):
4646
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
4747
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
4848
moe_config)
49+
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
4950

5051

5152
class MoECommMethod(ABC):
@@ -315,3 +316,71 @@ def fused_experts(
315316
out=out,
316317
)
317318
return out
319+
320+
321+
class FusedMC2CommImpl(MoECommMethod):
322+
"""This implementation is for the scenarios listed below:
323+
1. `class MC2CommImpl` can be used.
324+
2. `VLLM_ASCEND_ENABLE_FUSED_MC2` is enabled.
325+
3. `w8a8_dynamic` quantization is used.
326+
This implementation uses the `dispatch_gmm_combine_decode` operator, which is a fused
327+
operator for MoE decoding that combines communication and computation for optimization
328+
on Ascend devices.
329+
"""
330+
331+
def _get_token_dispatcher(self):
332+
return TokenDispatcherWithMC2()
333+
334+
def _get_prepare_finalize(self):
335+
return PrepareAndFinalizeWithMC2(self.moe_config)
336+
337+
def fused_experts(
338+
self,
339+
hidden_states: torch.Tensor,
340+
w1: torch.Tensor | list[torch.Tensor],
341+
w2: torch.Tensor | list[torch.Tensor],
342+
topk_weights: torch.Tensor,
343+
topk_ids: torch.Tensor,
344+
activation: str = "silu",
345+
apply_router_weight_on_input: bool = False,
346+
use_int8_w8a8: bool = False,
347+
use_int4_w4a8: bool = False,
348+
use_int4_w4a16: bool = False,
349+
global_num_experts: Optional[int] = None,
350+
expert_map: Optional[torch.Tensor] = None,
351+
w1_scale: Optional[list[torch.Tensor]] = None,
352+
w2_scale: Optional[list[torch.Tensor]] = None,
353+
w1_scale_bias: torch.Tensor = None,
354+
w2_scale_bias: torch.Tensor = None,
355+
w1_offset: Optional[torch.Tensor] = None,
356+
w2_offset: Optional[torch.Tensor] = None,
357+
# For Cube/Vector parallel
358+
shared_experts: Optional[Any] = None,
359+
quantized_x_for_share: Optional[Any] = None,
360+
dynamic_scale_for_share: Optional[Any] = None,
361+
# For load balance
362+
log2phy: torch.Tensor = None,
363+
global_redundant_expert_num: int = 0,
364+
need_trans: bool = False,
365+
dynamic_eplb: bool = False,
366+
mc2_mask: torch.Tensor = None,
367+
pertoken_scale: Optional[torch.Tensor] = None):
368+
369+
assert w1_scale is not None, "w1_scale should not be None"
370+
assert w2_scale is not None, "w2_scale should not be None"
371+
assert expert_map is not None, "expert_map should not be None"
372+
output, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
373+
x=hidden_states,
374+
expert_ids=topk_ids,
375+
gmm1_permuted_weight=w1[0],
376+
gmm1_permuted_weight_scale=w1_scale[0],
377+
gmm2_weight=w2[0],
378+
gmm2_weight_scale=w2_scale[0],
379+
expert_smooth_scales=None,
380+
expert_scales=topk_weights.to(torch.float32),
381+
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
382+
ep_rank_size=self.token_dispatcher.ep_world_size,
383+
ep_rank_id=self.token_dispatcher.ep_rank_id,
384+
moe_expert_num=len(expert_map),
385+
global_bs=self.token_dispatcher.global_bs)
386+
return output

vllm_ascend/ops/register_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
250250
forward_context = get_forward_context()
251251
moe_comm_type = forward_context.moe_comm_type
252252
if moe_comm_type in {
253-
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
253+
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL,
254+
MoECommType.FUSED_MC2
254255
} or forward_context.sp_enabled:
255256
return final_hidden_states
256257
else:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def apply(
235235
topk_weights = topk_weights.to(self.in_dtype)
236236

237237
moe_comm_method = get_forward_context().moe_comm_method
238+
fused_mc2_flag = get_forward_context(
239+
).moe_comm_type == MoECommType.FUSED_MC2
238240
if self.dynamic_eplb:
239241
w1 = layer.w13_weight_list
240242
w1_scale = layer.w13_weight_scale_fp32_list
@@ -244,7 +246,10 @@ def apply(
244246
w1 = [layer.w13_weight]
245247
w1_scale = [layer.w13_weight_scale_fp32]
246248
w2 = [layer.w2_weight]
247-
w2_scale = [layer.w2_weight_scale]
249+
w2_scale = [
250+
layer.w2_weight_scale_fp32
251+
if fused_mc2_flag else layer.w2_weight_scale
252+
]
248253

249254
fused_flag = get_forward_context(
250255
).moe_comm_type == MoECommType.FUSED_ALLTOALL
@@ -284,6 +289,8 @@ def process_weights_after_loading(self, layer):
284289
layer.w13_weight_offset.data.shape[0], -1)
285290
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
286291
layer.w2_weight_scale.data.shape[0], -1)
292+
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
293+
torch.float32)
287294
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
288295
layer.w2_weight_offset.data.shape[0], -1)
289296

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,9 @@ def _skip_all_reduce_acorss_dp_group(self) -> bool:
445445
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
446446
# nodes. So here we check whether recompute_scheduler_enable is True.
447447
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method(
448-
potential_max_num_tokens) == MoECommType.MC2
448+
potential_max_num_tokens) in {
449+
MoECommType.MC2, MoECommType.FUSED_MC2
450+
}
449451

450452
def _sync_metadata_across_dp(
451453
self, num_tokens: int,
@@ -1418,10 +1420,16 @@ def _select_moe_comm_method(self,
14181420
moe_comm_type = MoECommType.ALLGATHER
14191421

14201422
elif soc_version in {AscendDeviceType._910_93}:
1421-
moe_comm_type = (
1422-
MoECommType.MC2 if num_tokens <= mc2_tokens_capacity else
1423-
MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic"
1424-
and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL)
1423+
if num_tokens <= mc2_tokens_capacity:
1424+
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic":
1425+
moe_comm_type = MoECommType.FUSED_MC2
1426+
else:
1427+
moe_comm_type = MoECommType.MC2
1428+
elif quant_type == "w8a8_dynamic" and get_ep_group(
1429+
).world_size <= 16:
1430+
moe_comm_type = MoECommType.FUSED_ALLTOALL
1431+
else:
1432+
moe_comm_type = MoECommType.ALLTOALL
14251433
else:
14261434
raise ValueError(f"Unsupported soc_version: {soc_version}")
14271435

@@ -2291,7 +2299,7 @@ def profile_run(self) -> None:
22912299
# allowing vLLM to correctly estimate the maximum memory required.
22922300
mc2_tokens_capacity = get_mc2_tokens_capacity()
22932301
if self.max_num_tokens > mc2_tokens_capacity and \
2294-
self._select_moe_comm_method(mc2_tokens_capacity) == MoECommType.MC2:
2302+
self._select_moe_comm_method(mc2_tokens_capacity) in {MoECommType.MC2, MoECommType.FUSED_MC2}:
22952303
self._dummy_run(mc2_tokens_capacity, with_prefill=True)
22962304

22972305
output = None

0 commit comments

Comments
 (0)