Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MoECommType(Enum):
MC2 = 1
ALLTOALL = 2
FUSED_ALLTOALL = 3
FUSED_MC2 = 4


@contextmanager
Expand Down Expand Up @@ -257,9 +258,12 @@ def select_moe_comm_method(num_tokens: int,
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb)
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
else MoECommType.FUSED_ALLTOALL
if fused_all2all_enable else MoECommType.ALLTOALL)
fused_mc2_enable = quant_type == "w8a8_dynamic" and \
envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and (not dynamic_eplb)
moe_comm_type = (
(MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2) if
num_tokens <= mc2_tokens_capacity else MoECommType.FUSED_ALLTOALL
if fused_all2all_enable else MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to anbale fused mc2(dispatch_gmm_combine_decode operator)
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0'))),
}

# end-env-vars-definition
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)
Expand Down Expand Up @@ -533,7 +533,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
Expand Down
69 changes: 69 additions & 0 deletions vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)


class MoECommMethod(ABC):
Expand Down Expand Up @@ -306,3 +307,71 @@ def fused_experts(
out=out,
)
return out


class FusedMC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `class MC2CommImpl` can be used.
2. `VLLM_ASCEND_ENABLE_FUSED_MC2` is enabled.
3. `w8a8_dynamic` quantization is used.
This implementation uses the `dispatch_gmm_combine_decode` operator, which is a fused
operator for MoE decoding that combines communication and computation for optimization
on Ascend devices.
"""

def _get_token_dispatcher(self):
return TokenDispatcherWithMC2()

def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config)

def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):

assert w1_scale is not None, "w1_scale should not be None"
assert w2_scale is not None, "w2_scale should not be None"
assert expert_map is not None, "expert_map should not be None"
output, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
gmm1_permuted_weight_scale=w1_scale[0],
gmm2_weight=w2[0],
gmm2_weight_scale=w2_scale[0],
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id,
moe_expert_num=len(expert_map),
global_bs=self.token_dispatcher.global_bs)
return output
3 changes: 2 additions & 1 deletion vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL
MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL,
MoECommType.FUSED_MC2
} or forward_context.sp_enabled:
return final_hidden_states
else:
Expand Down
9 changes: 8 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def apply(
topk_weights = topk_weights.to(self.in_dtype)

moe_comm_method = get_forward_context().moe_comm_method
fused_mc2_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_MC2
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
Expand All @@ -244,7 +246,10 @@ def apply(
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
w2_scale = [
layer.w2_weight_scale_fp32
if fused_mc2_flag else layer.w2_weight_scale
]

fused_flag = get_forward_context(
).moe_comm_type == MoECommType.FUSED_ALLTOALL
Expand Down Expand Up @@ -284,6 +289,8 @@ def process_weights_after_loading(self, layer):
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
torch.float32)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)

Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def _skip_all_reduce_acorss_dp_group(self) -> bool:
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
potential_max_num_tokens, self.vllm_config) == MoECommType.MC2
potential_max_num_tokens,
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}

def _sync_metadata_across_dp(
self, num_tokens: int,
Expand Down Expand Up @@ -2188,7 +2189,8 @@ def _dummy_sampler_run(
def profile_run(self) -> None:
mc2_tokens_capacity = get_mc2_tokens_capacity()
if self.max_num_tokens > mc2_tokens_capacity and \
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2:
select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) in {
MoECommType.MC2, MoECommType.FUSED_MC2}:
self._dummy_run(mc2_tokens_capacity,
with_prefill=True,
is_profile=True)
Expand Down
Loading