Skip to content

Commit 2555243

Browse files
kip-cxjyexin
authored andcommitted
add checkpoint engine for one step off policy
1 parent 249c083 commit 2555243

File tree

11 files changed

+292
-42
lines changed

11 files changed

+292
-42
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright 2025 Meituan Ltd. and/or its affiliates
3+
# Copyright 2025 Huawei Ltd. and/or its affiliates
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
import os
19+
import time
20+
21+
import httpx
22+
import torch
23+
import torch.distributed
24+
from checkpoint_engine.ps import ParameterServer, request_inference_to_update
25+
from omegaconf import DictConfig, OmegaConf
26+
27+
from verl.single_controller.base import Worker
28+
from verl.single_controller.base.decorator import Dispatch, register
29+
from verl.utils.device import (
30+
get_device_name,
31+
)
32+
33+
logger = logging.getLogger(__file__)
34+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
35+
36+
device_name = get_device_name()
37+
38+
39+
class CkptEngineWorker(Worker):
40+
def __init__(self, rank_offset, ps_world_size, inference_parallel_size):
41+
super().__init__()
42+
rank = self.rank + rank_offset
43+
self.ps_rank = rank
44+
self.ps_rank_offset = rank_offset
45+
self.ps_world_size = ps_world_size
46+
self.inference_parallel_size = inference_parallel_size
47+
self.ps = ParameterServer(rank=rank, world_size=ps_world_size)
48+
self.index = 0
49+
50+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
51+
def init_process_group(self):
52+
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020-61050"
53+
self.ps.init_process_group(device_index=0, master_port=60010)
54+
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]
55+
56+
def check_vllm_ready(self, uds: str | None = None):
57+
if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size:
58+
return
59+
retry_num = 0
60+
transport = None
61+
if uds is not None:
62+
transport = httpx.HTTPTransport(uds=uds)
63+
while True:
64+
try:
65+
response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10)
66+
response.raise_for_status()
67+
break
68+
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
69+
retry_num += 1
70+
logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}")
71+
time.sleep(5)
72+
73+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
74+
def set_server_addresses(self, server_addresses: list[str]):
75+
# todo support multiple api server
76+
self.endpoint = f"http://{server_addresses[0]}"
77+
self.check_vllm_ready()
78+
79+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
80+
def sync_rollout_weights_by_ckpt_engine(self):
81+
rank = self.rank
82+
src = rank // self.inference_parallel_size * self.inference_parallel_size
83+
84+
def req_func(socket_paths: list[tuple[str, str]]) -> None:
85+
if rank == src:
86+
request_inference_to_update(
87+
url=f"{self.endpoint}/collective_rpc",
88+
socket_paths=dict(socket_paths),
89+
)
90+
91+
checkpoint_name = f"sync_{self.index}"
92+
self.ps.register_checkpoint(checkpoint_name=checkpoint_name)
93+
self.ps.gather_metas(checkpoint_name)
94+
ranks = list(range(self.ps_rank_offset, self.ps_world_size))
95+
self.ps.update(checkpoint_name, req_func, ranks=ranks)

recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ actor_rollout_ref:
2020
free_cache_engine: False
2121
# Must be enabled! Otherwise, log_probs cannot be calculated.
2222
calculate_log_probs: True
23+
engine_kwargs:
24+
vllm:
25+
worker_extension_cls: checkpoint_engine.worker.VllmColocateWorkerExtension
2326

2427
# Only then will the use of log probs be correct.
2528
# And it can be used in conjunction with other rollout_correction algorithms.
2629
algorithm:
2730
rollout_correction:
28-
bypass_mode: True
31+
bypass_mode: True

recipe/one_step_off_policy/fsdp_workers.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
import torch
2020
import torch.distributed
21+
from checkpoint_engine.ps import ParameterServer
2122
from omegaconf import DictConfig
2223
from ray.util.collective import collective
2324
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2425

25-
from recipe.one_step_off_policy.distributed_util import vllm_stateless_init_process_group
2626
from verl.single_controller.base.decorator import Dispatch, register
2727
from verl.utils.device import (
2828
get_device_name,
@@ -53,17 +53,6 @@ class DetachSync(AsyncActorRolloutRefWorker):
5353
def _get_actor_params(self):
5454
pass
5555

56-
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
57-
def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size):
58-
rank = torch.distributed.get_rank() + rank_offset
59-
self._weight_sync_group = vllm_stateless_init_process_group(
60-
master_address,
61-
master_port,
62-
rank,
63-
world_size,
64-
get_torch_device().current_device(),
65-
)
66-
6756
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
6857
def sync_rollout_weights(self):
6958
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
@@ -127,6 +116,61 @@ async def update_weights(self, inference_engine, params):
127116

128117

129118
class DetachActorWorker(DetachSync):
119+
def __init__(self, config: DictConfig, role: str, **kwargs):
120+
ActorRolloutRefWorker.__init__(self, config, role)
121+
122+
if role == "actor":
123+
self.ps_rank_offset = kwargs.get("rank_offset", self.rank)
124+
self.ps_world_size = kwargs.get("ps_world_size", self.world_size)
125+
self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size)
126+
127+
self.index = 0
128+
129+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
130+
def init_process_group(self):
131+
os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020"
132+
self.ps.init_process_group(device_index=0, master_port=60010)
133+
del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"]
134+
135+
def split_tensors(self) -> dict[str, torch.Tensor]:
136+
assert self._is_actor and not self.config.hybrid_engine
137+
assert hasattr(self, "_weights_info") and self._weights_info is not None
138+
139+
if self._is_actor and self._is_offload_param:
140+
load_fsdp_model_to_gpu(self.actor_module_fsdp)
141+
params = self._get_actor_params()
142+
143+
named_tensors = {}
144+
145+
world_size = self.world_size
146+
rank = self.rank
147+
148+
weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size
149+
for index, (key, _, _) in enumerate(self._weights_info):
150+
assert key in params
151+
tensor = params[key].full_tensor()
152+
if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank:
153+
named_tensors[key] = tensor.to("cpu", non_blocking=True)
154+
155+
get_torch_device().synchronize()
156+
157+
return named_tensors
158+
159+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
160+
def sync_rollout_weights_by_ckpt_engine(self):
161+
def req_func(socket_paths: list[tuple[str, str]]):
162+
return
163+
164+
named_tensors = self.split_tensors()
165+
166+
checkpoint_name = f"sync_{self.index}"
167+
168+
self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors)
169+
self.ps.gather_metas(checkpoint_name)
170+
ranks = list(range(self.ps_rank_offset, self.ps_world_size))
171+
172+
self.ps.update(checkpoint_name, req_func, ranks=ranks)
173+
130174
def _get_actor_params(self):
131175
assert self._is_actor
132176
params = self.actor_module_fsdp.state_dict()
@@ -159,8 +203,7 @@ def get_actor_weights_info(self):
159203

160204

161205
class DetachAsyncRolloutWorker(DetachSync):
162-
def __init__(self, config: DictConfig, role: str):
163-
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
206+
def __init__(self, config: DictConfig, role: str, **kwargs):
164207
ActorRolloutRefWorker.__init__(self, config, role)
165208

166209
@register(dispatch_mode=Dispatch.ONE_TO_ALL)

recipe/one_step_off_policy/main_ppo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from verl.utils.config import validate_config
3333
from verl.utils.device import auto_set_ascend_device_name
3434

35+
from .ckpt_engine_worker import CkptEngineWorker
36+
3537

3638
def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
3739
"""
@@ -69,6 +71,14 @@ def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
6971
resource_pool_spec["rollout_pool"] = rollout_pool
7072
mapping[Role.Rollout] = "rollout_pool"
7173

74+
if Role.CkptEngine in roles:
75+
assert config.rollout.n_gpus_per_node > 0, "ckpt_engine config.rollout.n_gpus_per_node must be greater than 0"
76+
assert config.rollout.nnodes > 0, "ckpt_engine config.rollout.nnodes must be greater than 0"
77+
# the same as rollout pool
78+
ckpt_engine_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
79+
resource_pool_spec["ckpt_engine_pool"] = ckpt_engine_pool
80+
mapping[Role.CkptEngine] = "ckpt_engine_pool"
81+
7282
return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
7383

7484

@@ -111,6 +121,7 @@ def create_role_worker_mapping(config):
111121
Role.Actor: ray.remote(DetachActorWorker),
112122
Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
113123
Role.Critic: ray.remote(CriticWorker),
124+
Role.CkptEngine: ray.remote(CkptEngineWorker),
114125
}
115126

116127
if config.reward_model.enable:
@@ -140,6 +151,9 @@ def run(self, config):
140151

141152
from verl.utils.fs import copy_to_local
142153

154+
if os.environ.get("ASCEND_RT_VISIBLE_DEVICES", None) is not None:
155+
del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
156+
143157
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
144158

145159
pprint(OmegaConf.to_container(config, resolve=True))

recipe/one_step_off_policy/ray_trainer.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def __init__(
127127
if config.algorithm.use_kl_in_reward:
128128
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
129129

130+
self.rank_offset = config.trainer.n_gpus_per_node * config.trainer.nnodes
131+
self.ps_world_size = self.rank_offset + config.rollout.n_gpus_per_node * config.rollout.nnodes
132+
130133
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
131134

132135
def _validate(self):
@@ -149,7 +152,8 @@ def init_workers(self):
149152
self._init_async_rollout_manager()
150153

151154
def _init_resource_pools(self):
152-
self.resource_pool_manager.create_resource_pool()
155+
additional = {"ckpt_engine_pool": {"CPU": 1, "NPU": 0.2}, "rollout_pool": {"CPU": 1, "NPU": 0.8}}
156+
self.resource_pool_manager.create_resource_pool(additional=additional)
153157

154158
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
155159

@@ -158,6 +162,19 @@ def _create_worker_classes(self):
158162
self._create_critic_class()
159163
self._create_reference_policy_class()
160164
self._create_reward_model_class()
165+
self._create_ckpt_engine_class()
166+
167+
def _create_ckpt_engine_class(self):
168+
# create ckpt engine
169+
if True:
170+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.CkptEngine)
171+
ckpt_engine_cls = RayClassWithInitArgs(
172+
cls=self.role_worker_mapping[Role.CkptEngine],
173+
rank_offset=self.rank_offset,
174+
ps_world_size=self.ps_world_size,
175+
inference_parallel_size=self.config.actor_rollout_ref.rollout.tensor_model_parallel_size,
176+
)
177+
self.resource_pool_to_cls[resource_pool][str(Role.CkptEngine)] = ckpt_engine_cls
161178

162179
def _create_actor_rollout_classes(self):
163180
for role in [Role.Actor, Role.Rollout]:
@@ -166,6 +183,8 @@ def _create_actor_rollout_classes(self):
166183
cls=self.role_worker_mapping[role],
167184
config=self.config.actor_rollout_ref,
168185
role=str(role),
186+
rank_offset=self.rank_offset,
187+
ps_world_size=self.ps_world_size,
169188
)
170189
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
171190

@@ -249,26 +268,15 @@ def _init_models(self):
249268
self.rollout_wg = self.all_wg[str(Role.Rollout)]
250269
self.actor_wg.init_model()
251270
self.rollout_wg.init_model()
271+
self.ckpt_engine_wg = self.all_wg[str(Role.CkptEngine)]
252272
self.actor_rollout_wg = self.actor_wg
253273
weights_info = self.actor_wg.get_actor_weights_info()[0]
254274
self.rollout_wg.set_actor_weights_info(weights_info)
255275
self._create_weight_sync_group()
256276

257277
def _create_weight_sync_group(self):
258-
# TODO: NPU support
259-
from verl.utils.device import get_nccl_backend
260-
261-
actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
262-
n_workers = len(actor_rollout_workers)
263-
264-
# Create Ray collective group for fallback communication
265-
collective.create_collective_group(
266-
actor_rollout_workers,
267-
n_workers,
268-
list(range(0, n_workers)),
269-
backend=get_nccl_backend(),
270-
group_name="actor_rollout",
271-
)
278+
self.actor_wg.init_process_group()
279+
ray.get(self.ckpt_engine_wg.init_process_group())
272280

273281
def _init_async_rollout_manager(self):
274282
# create async rollout manager and request scheduler
@@ -286,9 +294,11 @@ def _init_async_rollout_manager(self):
286294
config=self.config, worker_group=self.rollout_wg, rm_resource_pool=rm_resource_pool
287295
)
288296

297+
ray.get(self.ckpt_engine_wg.set_server_addresses(self.async_rollout_manager.server_addresses))
298+
289299
def sync_rollout_weights(self):
290-
self.actor_wg.sync_rollout_weights()
291-
ray.get(self.rollout_wg.sync_rollout_weights())
300+
self.actor_wg.sync_rollout_weights_by_ckpt_engine()
301+
ray.get(self.ckpt_engine_wg.sync_rollout_weights_by_ckpt_engine())
292302

293303
def _create_continuous_iterator(self):
294304
"""

0 commit comments

Comments
 (0)