From b6ff9808655cc11405dd5ebe09456df4dc3ef7b7 Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Wed, 17 Dec 2025 15:21:37 +0800 Subject: [PATCH 1/4] update weighes by checkpoint_engine in sglang --- .../rollout/sglang_rollout/sglang_rollout.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 63d3b0c36af..2268e66ed36 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -30,7 +30,9 @@ set_prometheus_multiproc_dir, set_ulimit, ) -from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights +from sglang.srt.checkpoint_engine.update import req_inference +from checkpoint_engine.ps import ParameterServer +import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from verl.workers.config import HFModelConfig, RolloutConfig @@ -165,10 +167,13 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ - if self.device_mesh["infer_tp"].get_local_rank() == 0: + import torch.distributed as dist + + tp_rank = self.device_mesh["infer_tp"].get_local_rank() + inference_parallel_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size * self.config.pipeline_model_parallel_size + if tp_rank == 0: await self._init_server_adapter() - update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name @@ -181,13 +186,23 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None else: weights = weights - for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): - await sgl_update_weights( - engine=self._engine, - params_batch=params_batch, - device_mesh_key="infer_tp", - device_mesh=self.device_mesh, - ) + named_tensors = [] + for idx, weight in enumerate(weights): + if idx % inference_parallel_size == tp_rank: + named_tensors.append((weight[0], weight[1].cpu())) + + if tp_rank == 0: + endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}" + else: + endpoint = "" + req_func = req_inference(endpoint, inference_parallel_size) + + checkpoint_name = "checkpoint_engine" + ps = ParameterServer() + ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + dist.barrier() + ps.gather_metas(checkpoint_name) + ps.update(checkpoint_name, req_func) - if self.device_mesh["infer_tp"].get_local_rank() == 0: + if tp_rank == 0: await self._engine.flush_cache() From 4755ae3c8811feefc5380a3ec0fe07a4ff834e3f Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Thu, 18 Dec 2025 19:11:59 +0800 Subject: [PATCH 2/4] add checkpoint engine for update weight --- verl/trainer/config/rollout/rollout.yaml | 3 + verl/workers/config/rollout.py | 2 + verl/workers/fsdp_workers.py | 55 ++++++++++++++++--- verl/workers/megatron_workers.py | 41 ++++++++++++-- .../rollout/sglang_rollout/sglang_rollout.py | 40 +++++++------- .../rollout/vllm_rollout/vllm_async_server.py | 2 + .../rollout/vllm_rollout/vllm_rollout.py | 23 ++++++++ 7 files changed, 130 insertions(+), 36 deletions(-) diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 968d9e11277..b2a01111705 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -281,6 +281,9 @@ skip_tokenizer_init: True # When enabled (True), the rollout will record the routing decisions. enable_rollout_routing_replay: False +# Whether to checkpoint_engine for update weights +# When enabled (True), parameters sync between trainer and rollout through checkpoint_engine. +enable_checkpoint_engine: False # profile the rollout model in `generate_sequence` profiler: diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index bea1bd4520d..b0367a8cbe9 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig): enable_rollout_routing_replay: bool = False + enable_checkpoint_engine: bool = False + def __post_init__(self): """Validate the rollout config""" if self.expert_parallel_size > 1: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index b7d89134d72..a7076c627bc 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -20,8 +20,9 @@ import logging import os import warnings +from collections.abc import Callable from dataclasses import asdict -from typing import Any, Optional +from typing import Any, Generator, Optional import numpy as np import psutil @@ -577,6 +578,25 @@ def _build_model_optimizer( return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + def update_weighs_by_checkpoint_engine( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + req_func: Callable[[list[tuple[str, str]]], None] + ): + named_tensors = {} + for tensor_idx, (name, tensor) in enumerate(weights): + if tensor_idx % self.world_size == self.rank: + named_tensors[name] = tensor + + checkpoint_name = f"checkpoint_engine" + self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + dist.barrier() + self.parameter_server.gather_metas(checkpoint_name) + self.parameter_server.update(checkpoint_name, req_func) + self.parameter_server.unregister_checkpoint(checkpoint_name) + + def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -588,10 +608,10 @@ def _build_rollout(self, trust_remote_code=False): # 2. build rollout device mesh infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size infer_pp = self.config.rollout.pipeline_model_parallel_size - infer_world_size = infer_tp * infer_pp - dp = self.world_size // infer_world_size - assert self.world_size % infer_world_size == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + self.infer_world_size = infer_tp * infer_pp + dp = self.world_size // self.infer_world_size + assert self.world_size % self.infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}" ) rollout_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] @@ -700,10 +720,14 @@ async def rollout_mode(self): set_expandable_segments(False) + if self.config.rollout.enable_checkpoint_engine: + device = "cpu" + else: + device = get_device_id() # used when fsdp2 set cpu_offload_policy + if peft_config is not None and self.base_sync_done: per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case else: - device = get_device_id() # used when fsdp2 set cpu_offload_policy per_tensor_param = ( (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in params.items() @@ -718,10 +742,18 @@ async def rollout_mode(self): (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in base_model_params.items() ) - await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + else: + await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) del base_model_params, per_tensor_base_params - - await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + else: + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) log_gpu_memory_usage("After update_weights", logger=logger) del params, per_tensor_param aggressive_empty_cache(force_sync=True) @@ -863,6 +895,11 @@ def init_model(self): checkpoint_config=checkpoint_contents, ) + if self.config.rollout.enable_checkpoint_engine: + from checkpoint_engine.ps import ParameterServer + + self.parameter_server = ParameterServer(auto_pg=False) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index db2e3fb1b97..e3c23a79e8f 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,11 +19,12 @@ import logging import os import time -from typing import Any, Optional +from typing import Any, Generator, Optional import psutil import torch import torch.distributed +from collections.abc import Callable from codetiming import Timer from omegaconf import DictConfig, OmegaConf @@ -483,6 +484,24 @@ def _build_model_optimizer( return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config + def update_weighs_by_checkpoint_engine( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + req_func: Callable[[list[tuple[str, str]]], None] + ): + named_tensors = {} + for tensor_idx, (name, tensor) in enumerate(weights): + if tensor_idx % self.world_size == self.rank: + named_tensors[name] = tensor.to("cpu", non_blocking=True) + + checkpoint_name = f"checkpoint_engine" + self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + torch.distributed.barrier() + self.parameter_server.gather_metas(checkpoint_name) + self.parameter_server.update(checkpoint_name, req_func) + self.parameter_server.unregister_checkpoint(checkpoint_name) + def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -500,10 +519,10 @@ def _build_rollout(self, trust_remote_code=False): # 2. build rollout device mesh infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size infer_pp = self.config.rollout.pipeline_model_parallel_size - infer_world_size = infer_tp * infer_pp - dp = self.world_size // infer_world_size - assert self.world_size % infer_world_size == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + self.infer_world_size = infer_tp * infer_pp + dp = self.world_size // self.infer_world_size + assert self.world_size % self.infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}" ) rollout_device_mesh = init_device_mesh( get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] @@ -661,6 +680,11 @@ def init_model(self): if not self.config.actor.megatron.use_mbridge: self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + if self.config.rollout.enable_checkpoint_engine: + from checkpoint_engine.ps import ParameterServer + + self.parameter_server = ParameterServer(auto_pg=False) + get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) @@ -689,7 +713,12 @@ async def rollout_mode(self): if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) - await self.rollout.update_weights(per_tensor_param) + + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + else: + await self.rollout.update_weights(per_tensor_param) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor.actor_module) aggressive_empty_cache(force_sync=True) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 2268e66ed36..3172461d8e0 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -23,6 +23,7 @@ import ray import sglang.srt.entrypoints.engine import torch +from collections.abc import Callable from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( assert_pkg_version, @@ -30,9 +31,8 @@ set_prometheus_multiproc_dir, set_ulimit, ) +from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights from sglang.srt.checkpoint_engine.update import req_inference -from checkpoint_engine.ps import ParameterServer -import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from verl.workers.config import HFModelConfig, RolloutConfig @@ -167,13 +167,10 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ - import torch.distributed as dist - - tp_rank = self.device_mesh["infer_tp"].get_local_rank() - inference_parallel_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size * self.config.pipeline_model_parallel_size - if tp_rank == 0: + if self.device_mesh["infer_tp"].get_local_rank() == 0: await self._init_server_adapter() + update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name @@ -186,23 +183,24 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None else: weights = weights - named_tensors = [] - for idx, weight in enumerate(weights): - if idx % inference_parallel_size == tp_rank: - named_tensors.append((weight[0], weight[1].cpu())) + for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): + await sgl_update_weights( + engine=self._engine, + params_batch=params_batch, + device_mesh_key="infer_tp", + device_mesh=self.device_mesh, + ) + + if self.device_mesh["infer_tp"].get_local_rank() == 0: + await self._engine.flush_cache() - if tp_rank == 0: + async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]: + if self.device_mesh["infer_tp"].get_local_rank() == 0: + await self._init_server_adapter() endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}" else: endpoint = "" - req_func = req_inference(endpoint, inference_parallel_size) - checkpoint_name = "checkpoint_engine" - ps = ParameterServer() - ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors) - dist.barrier() - ps.gather_metas(checkpoint_name) - ps.update(checkpoint_name, req_func) + req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size) - if tp_rank == 0: - await self._engine.flush_cache() + return req_func \ No newline at end of file diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index e0c292b8397..c0f0e4dc1bd 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -290,6 +290,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "override_generation_config": json.dumps(override_generation_config), "quantization": quantization, "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None, + "worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" if self.config.enable_checkpoint_engine else None, **engine_kwargs, } @@ -690,6 +691,7 @@ async def launch_servers(self): soft=False, ), name=name, + runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} if self.config.enable_checkpoint_engine else None, ).remote( config=self.config, model_config=self.model_config, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 42a1cd96885..35ff2897166 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -31,6 +31,7 @@ import logging import os from dataclasses import asdict +from collections.abc import Callable from types import MethodType from typing import Any, Generator @@ -133,6 +134,13 @@ def __init__( else: self.sleep_level = VLLM_SLEEP_LEVEL + rank = int(os.environ["RANK"]) + local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) + rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size + self.replica_rank = rank // rollout_world_size + self.rollout_rank = rank % rollout_world_size + self.node_rank = self.rollout_rank // local_world_size + def _init_zeromq(self) -> str: tensor_parallel_size = self.config.tensor_model_parallel_size @@ -262,6 +270,21 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None logger.info("Loading standard weights (non-FP8, async)") model.load_weights(weights) + async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]: + from checkpoint_engine.ps import request_inference_to_update + rank = int(os.getenv("RANK", None)) + src = rank // inference_parallel_size * inference_parallel_size + + server_actor = ray.get_actor(f"vllm_server_{self.replica_rank}_{self.node_rank}") + server_address, server_port = await server_actor.get_server_address.remote() + def req_func(socket_paths: list[tuple[str, str]]): + if rank == src: + request_inference_to_update( + f"http://{server_address}:{server_port}/collective_rpc", + dict(socket_paths[src : src + inference_parallel_size]), + ) + return req_func + def generate_sequences(self, prompts: DataProto) -> DataProto: """Batch generate sequences in sync mode.""" raise NotImplementedError From 94e228b4f3ee4962a38ba94551aad0232e5fc2b5 Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Wed, 24 Dec 2025 17:49:59 +0800 Subject: [PATCH 3/4] fix pre_commit --- .../config/_generated_ppo_megatron_trainer.yaml | 1 + verl/trainer/config/_generated_ppo_trainer.yaml | 1 + verl/workers/fsdp_workers.py | 13 +++++++------ verl/workers/megatron_workers.py | 6 +++--- .../sglang_rollout/async_sglang_server.py | 14 +++++++++++--- .../rollout/sglang_rollout/sglang_rollout.py | 6 +++--- .../rollout/vllm_rollout/vllm_async_server.py | 8 ++++++-- .../workers/rollout/vllm_rollout/vllm_rollout.py | 16 ++-------------- 8 files changed, 34 insertions(+), 31 deletions(-) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index a117c0f332f..02aee6fc551 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -272,6 +272,7 @@ actor_rollout_ref: skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true enable_rollout_routing_replay: false + enable_checkpoint_engine: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 833ebb70d5b..14428e7bb41 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -261,6 +261,7 @@ actor_rollout_ref: skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true enable_rollout_routing_replay: false + enable_checkpoint_engine: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index a7076c627bc..58648463e0d 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -581,14 +581,14 @@ def _build_model_optimizer( def update_weighs_by_checkpoint_engine( self, weights: Generator[tuple[str, torch.Tensor], None, None], - req_func: Callable[[list[tuple[str, str]]], None] + req_func: Callable[[list[tuple[str, str]]], None], ): named_tensors = {} for tensor_idx, (name, tensor) in enumerate(weights): if tensor_idx % self.world_size == self.rank: named_tensors[name] = tensor - checkpoint_name = f"checkpoint_engine" + checkpoint_name = "checkpoint_engine" self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) named_tensors = {} dist.barrier() @@ -596,7 +596,6 @@ def update_weighs_by_checkpoint_engine( self.parameter_server.update(checkpoint_name, req_func) self.parameter_server.unregister_checkpoint(checkpoint_name) - def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -744,16 +743,18 @@ async def rollout_mode(self): ) if self.config.rollout.enable_checkpoint_engine: req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) - self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + self.update_weighs_by_checkpoint_engine(per_tensor_base_params, req_func) else: await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) del base_model_params, per_tensor_base_params - + if self.config.rollout.enable_checkpoint_engine: req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) else: - await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + await self.rollout.update_weights( + per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done + ) log_gpu_memory_usage("After update_weights", logger=logger) del params, per_tensor_param aggressive_empty_cache(force_sync=True) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index e3c23a79e8f..233d77d8963 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,12 +19,12 @@ import logging import os import time +from collections.abc import Callable from typing import Any, Generator, Optional import psutil import torch import torch.distributed -from collections.abc import Callable from codetiming import Timer from omegaconf import DictConfig, OmegaConf @@ -487,14 +487,14 @@ def _build_model_optimizer( def update_weighs_by_checkpoint_engine( self, weights: Generator[tuple[str, torch.Tensor], None, None], - req_func: Callable[[list[tuple[str, str]]], None] + req_func: Callable[[list[tuple[str, str]]], None], ): named_tensors = {} for tensor_idx, (name, tensor) in enumerate(weights): if tensor_idx % self.world_size == self.rank: named_tensors[name] = tensor.to("cpu", non_blocking=True) - checkpoint_name = f"checkpoint_engine" + checkpoint_name = "checkpoint_engine" self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) named_tensors = {} torch.distributed.barrier() diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index e78700d9f7a..3f31a4db2d1 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -77,8 +77,10 @@ def __init__( cuda_visible_devices: str, ): print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}") - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - assert torch.cuda.is_available(), "SGLang http server should run on GPU node" + os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_avilable else "ASCEND_RT_VISIBLE_DEVICES"] = ( + cuda_visible_devices + ) + assert torch.cuda.is_available() or torch.npu.is_available(), "SGLang http server should run on GPU/NPU node" self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) @@ -337,7 +339,13 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" + if torch.cuda.is_available() + else "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1" + } + }, name=name, ).remote( config=self.config, diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 3172461d8e0..34fe5b274cb 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -18,12 +18,13 @@ import logging import multiprocessing as mp import os +from collections.abc import Callable from typing import Generator import ray import sglang.srt.entrypoints.engine import torch -from collections.abc import Callable +from sglang.srt.checkpoint_engine.update import req_inference from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( assert_pkg_version, @@ -32,7 +33,6 @@ set_ulimit, ) from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights -from sglang.srt.checkpoint_engine.update import req_inference from torch.distributed.device_mesh import DeviceMesh from verl.workers.config import HFModelConfig, RolloutConfig @@ -203,4 +203,4 @@ async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Call req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size) - return req_func \ No newline at end of file + return req_func diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index c0f0e4dc1bd..8c2fa30b11f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -290,7 +290,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "override_generation_config": json.dumps(override_generation_config), "quantization": quantization, "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None, - "worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" if self.config.enable_checkpoint_engine else None, + "worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" + if self.config.enable_checkpoint_engine + else None, **engine_kwargs, } @@ -691,7 +693,9 @@ async def launch_servers(self): soft=False, ), name=name, - runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} if self.config.enable_checkpoint_engine else None, + runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} + if self.config.enable_checkpoint_engine + else None, ).remote( config=self.config, model_config=self.model_config, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 35ff2897166..8161880a83d 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -30,8 +30,8 @@ import getpass import logging import os -from dataclasses import asdict from collections.abc import Callable +from dataclasses import asdict from types import MethodType from typing import Any, Generator @@ -271,19 +271,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None model.load_weights(weights) async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]: - from checkpoint_engine.ps import request_inference_to_update - rank = int(os.getenv("RANK", None)) - src = rank // inference_parallel_size * inference_parallel_size - - server_actor = ray.get_actor(f"vllm_server_{self.replica_rank}_{self.node_rank}") - server_address, server_port = await server_actor.get_server_address.remote() - def req_func(socket_paths: list[tuple[str, str]]): - if rank == src: - request_inference_to_update( - f"http://{server_address}:{server_port}/collective_rpc", - dict(socket_paths[src : src + inference_parallel_size]), - ) - return req_func + raise NotImplementedError def generate_sequences(self, prompts: DataProto) -> DataProto: """Batch generate sequences in sync mode.""" From e905a11201a2b971f08e296eae4fa76ce313df82 Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Thu, 25 Dec 2025 14:40:19 +0800 Subject: [PATCH 4/4] fix --- verl/workers/rollout/sglang_rollout/async_sglang_server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 3f31a4db2d1..dd00280f696 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -77,7 +77,7 @@ def __init__( cuda_visible_devices: str, ): print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}") - os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_avilable else "ASCEND_RT_VISIBLE_DEVICES"] = ( + os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_available() else "ASCEND_RT_VISIBLE_DEVICES"] = ( cuda_visible_devices ) assert torch.cuda.is_available() or torch.npu.is_available(), "SGLang http server should run on GPU/NPU node" @@ -341,9 +341,8 @@ async def launch_servers(self): ), runtime_env={ "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" - if torch.cuda.is_available() - else "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1" + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", } }, name=name,