Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions verl/utils/vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .utils import TensorLoRARequest, VLLMHijack, is_version_ge
from .npu_vllm_patch import check_vllm_ascend_before_server_launch

# The contents of vllm/patch.py should not be imported here, because the contents of
# patch.py should be imported after the vllm LLM instance is created. Therefore,
Expand All @@ -23,4 +24,5 @@
"TensorLoRARequest",
"VLLMHijack",
"is_version_ge",
"check_vllm_ascend_before_server_launch",
]
99 changes: 99 additions & 0 deletions verl/utils/vllm/npu_vllm_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment "AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now." is vague. "now" implies a temporary state, but it's better to state the current limitation clearly without temporal ambiguity. Please clarify if this is a known, permanent limitation or if there's a specific version or condition under which it might change.

# limitations under the License.
import os
from functools import wraps
from verl.utils.device import is_torch_npu_available


def vllm_ascend_select_moe_comm_method_wrapper(fn):
@wraps(fn)
def wrapper(self, num_tokens, with_prefill):
moe_comm_method = fn(self, num_tokens, with_prefill)
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import get_ascend_soc_version, AscendSocVersion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to previous comments, importing enable_sp inside the wrapper function is inefficient and can hide import errors. It should be imported at the top of the file.

from vllm_ascend.utils import enable_sp

soc_version = get_ascend_soc_version()

# AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now.
if soc_version in {AscendSocVersion.A2} and moe_comm_method == MoECommType.MC2:
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
# Currently, w4a8_dynamic does not support allgatherep
if quant_type == "w4a8_dynamic":
moe_comm_method = MoECommType.ALLTOALL
else:
moe_comm_method = MoECommType.ALLGATHER

if with_prefill:
Comment on lines +39 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing modules (get_ascend_soc_version, AscendSocVersion) inside this wrapper function is inefficient and can hide import errors. These should be imported at the top of the file.

from vllm_ascend.utils import get_ascend_soc_version, AscendSocVersion

from vllm_ascend.utils import enable_sp
if enable_sp():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment "AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now." is repeated and still vague. Please clarify the specific and current limitation without temporal ambiguity.

moe_comm_method = MoECommType.ALLGATHER
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing get_forward_context inside the wrapper function is inefficient and can hide import errors. It should be imported at the top of the file.

from vllm.forward_context import get_forward_context

moe_comm_method = MoECommType.NAIVE_MULTICAST

return moe_comm_method

return wrapper
Comment on lines +46 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Catching a broad AssertionError and silently passing (except AssertionError: pass) is a critical issue. This can mask legitimate bugs or unexpected conditions that should be handled explicitly or logged. If get_forward_context() is expected to raise an AssertionError under specific, non-critical circumstances, those conditions should be checked explicitly, or a more specific exception should be caught, and at minimum, a warning should be logged. Silently passing can lead to difficult-to-debug issues.

            try:
                forward_context = get_forward_context()
                forward_context.mmrs_fusion = False
            except AssertionError as e:
                # Log the error or handle it more specifically if it's an expected condition.
                # For example, if forward_context is not available in certain setups.
                # logging.warning(f"Could not set mmrs_fusion: {e}")
                pass


def vllm_ascend_matmul_and_reduce_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
from vllm_ascend.utils import get_ascend_soc_version, AscendSocVersion
soc_version = get_ascend_soc_version()
#AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now.
if soc_version in {AscendSocVersion.A2}:
from vllm.forward_context import get_forward_context
try:
forward_context = get_forward_context()
forward_context.mmrs_fusion = False
except AssertionError:
# forward_context.mmrs_fusion will be false in matmul_and_reduce func.
pass
return fn(self, *args, **kwargs)

Comment on lines +61 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested function get_ascend_soc_version_local() is unnecessary. Its logic can be directly integrated into check_vllm_ascend_before_server_launch() or get_ascend_soc_version from vllm_ascend.utils could be used if it provides the same functionality. Defining functions within functions adds complexity without clear benefit here.

    soc_version_raw = torch_npu.npu.get_soc_version()
    if 220 <= soc_version_raw <= 225:
        soc_version = AscendSocVersion.A2
    elif 250 <= soc_version_raw <= 255:
        soc_version = AscendSocVersion.A3
    else:
        soc_version = AscendSocVersion.UNDEFINED

return wrapper


def check_vllm_ascend_before_server_launch():
import torch_npu
from vllm_ascend.utils import AscendSocVersion
def get_ascend_soc_version_local():
soc_version = torch_npu.npu.get_soc_version()
if 220 <= soc_version <= 225:
_ascend_soc_version = AscendSocVersion.A2
elif 250 <= soc_version <= 255:
_ascend_soc_version = AscendSocVersion.A3
else:
_ascend_soc_version = AscendSocVersion.UNDEFINED
return _ascend_soc_version
soc_version = get_ascend_soc_version_local()
if soc_version in {AscendSocVersion.A2}:
VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE = bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0')))
if VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
raise AssertionError(
"AscendSocVersion.A2 is not support VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE in \
Single-card multi-process scenario now. "
)


if is_torch_npu_available(check_device=False):
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
from vllm_ascend.ops.linear_op import SequenceRowParallelOp
NPUModelRunner._select_moe_comm_method = vllm_ascend_select_moe_comm_method_wrapper(
NPUModelRunner._select_moe_comm_method
)
SequenceRowParallelOp.matmul_and_reduce = vllm_ascend_matmul_and_reduce_wrapper(
SequenceRowParallelOp.matmul_and_reduce
)
6 changes: 4 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.v1.engine.async_llm import AsyncLLM

from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_resource_name, get_visible_devices_keyword
from verl.utils.device import get_resource_name, get_visible_devices_keyword, is_torch_npu_available
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address
from verl.utils.profiler import DistProfiler, build_vllm_profiler_args
from verl.utils.tokenizer import normalize_token_ids
Expand Down Expand Up @@ -239,7 +239,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non

quantization = self.config.quantization
hf_overrides = {}

if is_torch_npu_available(check_device=False):
from verl.utils.vllm.npu_vllm_patch import check_vllm_ascend_before_server_launch
check_vllm_ascend_before_server_launch()
# Handle QAT (Quantization-Aware Training) configuration
qat_config_dict = getattr(self.config, "qat", {}) or {}
if qat_config_dict.get("enable", False):
Expand Down
Loading