diff --git a/verl/utils/vllm/__init__.py b/verl/utils/vllm/__init__.py index 00aa7bdb642..2221fa8e2de 100644 --- a/verl/utils/vllm/__init__.py +++ b/verl/utils/vllm/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .npu_vllm_patch import check_vllm_ascend_before_server_launch from .utils import TensorLoRARequest, VLLMHijack, is_version_ge # The contents of vllm/patch.py should not be imported here, because the contents of @@ -23,4 +24,5 @@ "TensorLoRARequest", "VLLMHijack", "is_version_ge", + "check_vllm_ascend_before_server_launch", ] diff --git a/verl/utils/vllm/npu_vllm_patch.py b/verl/utils/vllm/npu_vllm_patch.py new file mode 100644 index 00000000000..256e7695ae2 --- /dev/null +++ b/verl/utils/vllm/npu_vllm_patch.py @@ -0,0 +1,107 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import wraps + +from verl.utils.device import is_torch_npu_available + + +def vllm_ascend_select_moe_comm_method_wrapper(fn): + @wraps(fn) + def wrapper(self, num_tokens, with_prefill): + moe_comm_method = fn(self, num_tokens, with_prefill) + from vllm_ascend.ascend_forward_context import MoECommType + from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version + + soc_version = get_ascend_soc_version() + + # AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now. + if soc_version in {AscendSocVersion.A2} and moe_comm_method == MoECommType.MC2: + quant_type = getattr(self.vllm_config.model_config.hf_config, "moe_quantize", None) + # Currently, w4a8_dynamic does not support allgatherep + if quant_type == "w4a8_dynamic": + moe_comm_method = MoECommType.ALLTOALL + else: + moe_comm_method = MoECommType.ALLGATHER + + if with_prefill: + from vllm_ascend.utils import enable_sp + + if enable_sp(): + moe_comm_method = MoECommType.ALLGATHER + else: + moe_comm_method = MoECommType.NAIVE_MULTICAST + + return moe_comm_method + + return wrapper + + +def vllm_ascend_matmul_and_reduce_wrapper(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version + + soc_version = get_ascend_soc_version() + # AscendSocVersion.A2 is not support MC2 in Single-card multi-process scenario now. + if soc_version in {AscendSocVersion.A2}: + from vllm.forward_context import get_forward_context + + try: + forward_context = get_forward_context() + forward_context.mmrs_fusion = False + except AssertionError: + # forward_context.mmrs_fusion will be false in matmul_and_reduce func. + pass + return fn(self, *args, **kwargs) + + return wrapper + + +def check_vllm_ascend_before_server_launch(): + import torch_npu + from vllm_ascend.utils import AscendSocVersion + + def get_ascend_soc_version_local(): + soc_version = torch_npu.npu.get_soc_version() + if 220 <= soc_version <= 225: + _ascend_soc_version = AscendSocVersion.A2 + elif 250 <= soc_version <= 255: + _ascend_soc_version = AscendSocVersion.A3 + else: + _ascend_soc_version = AscendSocVersion.UNDEFINED + return _ascend_soc_version + + soc_version = get_ascend_soc_version_local() + if soc_version in {AscendSocVersion.A2}: + VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE = bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", "0"))) + if VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE: + raise AssertionError( + "AscendSocVersion.A2 is not support VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE in \ + Single-card multi-process scenario now. " + ) + + +if is_torch_npu_available(check_device=False): + from vllm_ascend.ops.linear_op import SequenceRowParallelOp + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + NPUModelRunner._select_moe_comm_method = vllm_ascend_select_moe_comm_method_wrapper( + NPUModelRunner._select_moe_comm_method + ) + SequenceRowParallelOp.matmul_and_reduce = vllm_ascend_matmul_and_reduce_wrapper( + SequenceRowParallelOp.matmul_and_reduce + ) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 1e30a01beda..bcd092b61fb 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -36,7 +36,7 @@ from vllm.v1.engine.async_llm import AsyncLLM from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import get_resource_name, get_visible_devices_keyword +from verl.utils.device import get_resource_name, get_visible_devices_keyword, is_torch_npu_available from verl.utils.net_utils import get_free_port, is_valid_ipv6_address from verl.utils.profiler import DistProfiler, build_vllm_profiler_args from verl.utils.tokenizer import normalize_token_ids @@ -239,7 +239,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quantization = self.config.quantization hf_overrides = {} + if is_torch_npu_available(check_device=False): + from verl.utils.vllm.npu_vllm_patch import check_vllm_ascend_before_server_launch + check_vllm_ascend_before_server_launch() # Handle QAT (Quantization-Aware Training) configuration qat_config_dict = getattr(self.config, "qat", {}) or {} if qat_config_dict.get("enable", False):