-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[vllm] chore: fix mc2 used in vllm_ascend on A2 npu #5560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
c309b67
a417cb7
b964f78
c6f17d5
4147264
22838f3
607c1d7
569f1e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Copyright 2025 The Qwen Team and The HuggingFace Inc. team | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| from functools import wraps | ||
| from verl.utils.device import is_torch_npu_available | ||
|
|
||
|
|
||
| def vllm_ascend_select_moe_comm_method_wrapper(fn): | ||
| @wraps(fn) | ||
| def wrapper(self, num_tokens, with_prefill): | ||
| moe_comm_method = fn(self, num_tokens, with_prefill) | ||
| from vllm_ascend.ascend_forward_context import MoECommType | ||
| from vllm_ascend.utils import get_ascend_soc_version, AscendSocVersion | ||
|
||
| soc_version = get_ascend_soc_version() | ||
|
|
||
| # AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now. | ||
| if soc_version in {AscendSocVersion.A2} and moe_comm_method == MoECommType.MC2: | ||
| quant_type = getattr(self.vllm_config.model_config.hf_config, | ||
| 'moe_quantize', None) | ||
| # Currently, w4a8_dynamic does not support allgatherep | ||
| if quant_type == "w4a8_dynamic": | ||
| moe_comm_method = MoECommType.ALLTOALL | ||
| else: | ||
| moe_comm_method = MoECommType.ALLGATHER | ||
|
|
||
| if with_prefill: | ||
|
Comment on lines
+39
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from vllm_ascend.utils import enable_sp | ||
| if enable_sp(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| moe_comm_method = MoECommType.ALLGATHER | ||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| moe_comm_method = MoECommType.NAIVE_MULTICAST | ||
|
|
||
| return moe_comm_method | ||
|
|
||
| return wrapper | ||
|
Comment on lines
+46
to
+50
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a broad try:
forward_context = get_forward_context()
forward_context.mmrs_fusion = False
except AssertionError as e:
# Log the error or handle it more specifically if it's an expected condition.
# For example, if forward_context is not available in certain setups.
# logging.warning(f"Could not set mmrs_fusion: {e}")
pass |
||
|
|
||
| def vllm_ascend_matmul_and_reduce_wrapper(fn): | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| from vllm_ascend.utils import get_ascend_soc_version, AscendSocVersion | ||
| soc_version = get_ascend_soc_version() | ||
| #AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now. | ||
| if soc_version in {AscendSocVersion.A2}: | ||
| from vllm.forward_context import get_forward_context | ||
| try: | ||
| forward_context = get_forward_context() | ||
| forward_context.mmrs_fusion = False | ||
| except AssertionError: | ||
| # forward_context.mmrs_fusion will be false in matmul_and_reduce func. | ||
| pass | ||
| return fn(self, *args, **kwargs) | ||
|
|
||
|
Comment on lines
+61
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nested function soc_version_raw = torch_npu.npu.get_soc_version()
if 220 <= soc_version_raw <= 225:
soc_version = AscendSocVersion.A2
elif 250 <= soc_version_raw <= 255:
soc_version = AscendSocVersion.A3
else:
soc_version = AscendSocVersion.UNDEFINED |
||
| return wrapper | ||
|
|
||
|
|
||
| def check_vllm_ascend_before_server_launch(): | ||
| import torch_npu | ||
| from vllm_ascend.utils import AscendSocVersion | ||
| def get_ascend_soc_version_local(): | ||
| soc_version = torch_npu.npu.get_soc_version() | ||
| if 220 <= soc_version <= 225: | ||
| _ascend_soc_version = AscendSocVersion.A2 | ||
| elif 250 <= soc_version <= 255: | ||
| _ascend_soc_version = AscendSocVersion.A3 | ||
| else: | ||
| _ascend_soc_version = AscendSocVersion.UNDEFINED | ||
| return _ascend_soc_version | ||
| soc_version = get_ascend_soc_version_local() | ||
| if soc_version in {AscendSocVersion.A2}: | ||
| VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE = bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))) | ||
| if VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE: | ||
| raise AssertionError( | ||
| "AscendSocVersion.A2 is not support VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE in \ | ||
| Single-card multi-process scenario now. " | ||
| ) | ||
|
|
||
|
|
||
| if is_torch_npu_available(check_device=False): | ||
| from vllm_ascend.worker.model_runner_v1 import NPUModelRunner | ||
| from vllm_ascend.ops.linear_op import SequenceRowParallelOp | ||
| NPUModelRunner._select_moe_comm_method = vllm_ascend_select_moe_comm_method_wrapper( | ||
| NPUModelRunner._select_moe_comm_method | ||
| ) | ||
| SequenceRowParallelOp.matmul_and_reduce = vllm_ascend_matmul_and_reduce_wrapper( | ||
| SequenceRowParallelOp.matmul_and_reduce | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment "AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now." is vague. "now" implies a temporary state, but it's better to state the current limitation clearly without temporal ambiguity. Please clarify if this is a known, permanent limitation or if there's a specific version or condition under which it might change.