Skip to content

Commit 5a80a2a

Browse files
author
kip-cxj
committed
add checkpoint engine for update weight
1 parent b6ff980 commit 5a80a2a

File tree

7 files changed

+128
-35
lines changed

7 files changed

+128
-35
lines changed

verl/trainer/config/rollout/rollout.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ skip_tokenizer_init: True
281281
# When enabled (True), the rollout will record the routing decisions.
282282
enable_rollout_routing_replay: False
283283

284+
# Whether to checkpoint_engine for update weights
285+
# When enabled (True), parameters sync between trainer and rollout through checkpoint_engine.
286+
enable_checkpoint_engine: False
284287

285288
# profile the rollout model in `generate_sequence`
286289
profiler:

verl/workers/config/rollout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig):
207207

208208
enable_rollout_routing_replay: bool = False
209209

210+
enable_checkpoint_engine: bool = False
211+
210212
def __post_init__(self):
211213
"""Validate the rollout config"""
212214
if self.expert_parallel_size > 1:

verl/workers/fsdp_workers.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import logging
2121
import os
2222
import warnings
23+
from collections.abc import Callable
2324
from dataclasses import asdict
24-
from typing import Any, Optional
25+
from typing import Any, Generator, Optional
2526

2627
import numpy as np
2728
import psutil
@@ -577,6 +578,25 @@ def _build_model_optimizer(
577578

578579
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
579580

581+
def update_weighs_by_checkpoint_engine(
582+
self,
583+
weights: Generator[tuple[str, torch.Tensor], None, None],
584+
req_func: Callable[[list[tuple[str, str]]], None]
585+
):
586+
named_tensors = {}
587+
for tensor_idx, (name, tensor) in enumerate(weights):
588+
if tensor_idx % self.world_size == self.rank:
589+
named_tensors[name] = tensor
590+
591+
checkpoint_name = f"checkpoint_engine"
592+
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
593+
named_tensors = {}
594+
dist.barrier()
595+
self.parameter_server.gather_metas(checkpoint_name)
596+
self.parameter_server.update(checkpoint_name, req_func)
597+
self.parameter_server.unregister_checkpoint(checkpoint_name)
598+
599+
580600
def _build_rollout(self, trust_remote_code=False):
581601
from torch.distributed.device_mesh import init_device_mesh
582602

@@ -588,10 +608,10 @@ def _build_rollout(self, trust_remote_code=False):
588608
# 2. build rollout device mesh
589609
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
590610
infer_pp = self.config.rollout.pipeline_model_parallel_size
591-
infer_world_size = infer_tp * infer_pp
592-
dp = self.world_size // infer_world_size
593-
assert self.world_size % infer_world_size == 0, (
594-
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
611+
self.infer_world_size = infer_tp * infer_pp
612+
dp = self.world_size // self.infer_world_size
613+
assert self.world_size % self.infer_world_size == 0, (
614+
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}"
595615
)
596616
rollout_device_mesh = init_device_mesh(
597617
device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
@@ -700,10 +720,14 @@ async def rollout_mode(self):
700720

701721
set_expandable_segments(False)
702722

723+
if self.config.rollout.enable_checkpoint_engine:
724+
device = "cpu"
725+
else:
726+
device = get_device_id() # used when fsdp2 set cpu_offload_policy
727+
703728
if peft_config is not None and self.base_sync_done:
704729
per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case
705730
else:
706-
device = get_device_id() # used when fsdp2 set cpu_offload_policy
707731
per_tensor_param = (
708732
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
709733
for name, param in params.items()
@@ -718,10 +742,18 @@ async def rollout_mode(self):
718742
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
719743
for name, param in base_model_params.items()
720744
)
721-
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
745+
if self.config.rollout.enable_checkpoint_engine:
746+
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
747+
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
748+
else:
749+
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
722750
del base_model_params, per_tensor_base_params
723-
724-
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
751+
752+
if self.config.rollout.enable_checkpoint_engine:
753+
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
754+
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
755+
else:
756+
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
725757
log_gpu_memory_usage("After update_weights", logger=logger)
726758
del params, per_tensor_param
727759
aggressive_empty_cache(force_sync=True)
@@ -863,6 +895,11 @@ def init_model(self):
863895
checkpoint_config=checkpoint_contents,
864896
)
865897

898+
if self.config.rollout.enable_checkpoint_engine:
899+
from checkpoint_engine.ps import ParameterServer
900+
901+
self.parameter_server = ParameterServer(auto_pg=False)
902+
866903
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
867904
@DistProfiler.annotate(color="red", role="actor_update")
868905
def update_actor(self, data: DataProto):

verl/workers/megatron_workers.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,24 @@ def _build_model_optimizer(
483483

484484
return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config
485485

486+
def update_weighs_by_checkpoint_engine(
487+
self,
488+
weights: Generator[tuple[str, torch.Tensor], None, None],
489+
req_func: Callable[[list[tuple[str, str]]], None]
490+
):
491+
named_tensors = {}
492+
for tensor_idx, (name, tensor) in enumerate(weights):
493+
if tensor_idx % self.world_size == self.rank:
494+
named_tensors[name] = tensor.to("cpu", non_blocking=True)
495+
496+
checkpoint_name = f"checkpoint_engine"
497+
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
498+
named_tensors = {}
499+
dist.barrier()
500+
self.parameter_server.gather_metas(checkpoint_name)
501+
self.parameter_server.update(checkpoint_name, req_func)
502+
self.parameter_server.unregister_checkpoint(checkpoint_name)
503+
486504
def _build_rollout(self, trust_remote_code=False):
487505
from torch.distributed.device_mesh import init_device_mesh
488506

@@ -500,10 +518,10 @@ def _build_rollout(self, trust_remote_code=False):
500518
# 2. build rollout device mesh
501519
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
502520
infer_pp = self.config.rollout.pipeline_model_parallel_size
503-
infer_world_size = infer_tp * infer_pp
504-
dp = self.world_size // infer_world_size
505-
assert self.world_size % infer_world_size == 0, (
506-
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
521+
self.infer_world_size = infer_tp * infer_pp
522+
dp = self.world_size // self.infer_world_size
523+
assert self.world_size % self.infer_world_size == 0, (
524+
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}"
507525
)
508526
rollout_device_mesh = init_device_mesh(
509527
get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
@@ -661,6 +679,11 @@ def init_model(self):
661679
if not self.config.actor.megatron.use_mbridge:
662680
self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
663681

682+
if self.config.rollout.enable_checkpoint_engine:
683+
from checkpoint_engine.ps import ParameterServer
684+
685+
self.parameter_server = ParameterServer(auto_pg=False)
686+
664687
get_torch_device().empty_cache()
665688
log_gpu_memory_usage("After init_model finish", logger=logger)
666689

@@ -689,7 +712,12 @@ async def rollout_mode(self):
689712

690713
if self.config.rollout.free_cache_engine:
691714
await self.rollout.resume(tags=["weights"])
692-
await self.rollout.update_weights(per_tensor_param)
715+
716+
if self.config.rollout.enable_checkpoint_engine:
717+
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
718+
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
719+
else:
720+
await self.rollout.update_weights(per_tensor_param)
693721
if self._is_offload_param:
694722
offload_megatron_model_to_cpu(self.actor.actor_module)
695723
aggressive_empty_cache(force_sync=True)

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
import ray
2424
import sglang.srt.entrypoints.engine
2525
import torch
26+
from collections.abc import Callable
2627
from sglang.srt.server_args import ServerArgs
2728
from sglang.srt.utils import (
2829
assert_pkg_version,
2930
is_cuda,
3031
set_prometheus_multiproc_dir,
3132
set_ulimit,
3233
)
34+
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
3335
from sglang.srt.checkpoint_engine.update import req_inference
34-
from checkpoint_engine.ps import ParameterServer
35-
import torch.distributed as dist
3636
from torch.distributed.device_mesh import DeviceMesh
3737

3838
from verl.workers.config import HFModelConfig, RolloutConfig
@@ -167,13 +167,10 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
167167
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
168168
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
169169
"""
170-
import torch.distributed as dist
171-
172-
tp_rank = self.device_mesh["infer_tp"].get_local_rank()
173-
inference_parallel_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size * self.config.pipeline_model_parallel_size
174-
if tp_rank == 0:
170+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
175171
await self._init_server_adapter()
176172

173+
update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20
177174
if self.config.get("quantization", None) == "fp8":
178175
from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name
179176

@@ -186,23 +183,24 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
186183
else:
187184
weights = weights
188185

189-
named_tensors = []
190-
for idx, weight in enumerate(weights):
191-
if idx % inference_parallel_size == tp_rank:
192-
named_tensors.append((weight[0], weight[1].cpu()))
186+
for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):
187+
await sgl_update_weights(
188+
engine=self._engine,
189+
params_batch=params_batch,
190+
device_mesh_key="infer_tp",
191+
device_mesh=self.device_mesh,
192+
)
193+
194+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
195+
await self._engine.flush_cache()
193196

194-
if tp_rank == 0:
197+
async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]:
198+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
199+
await self._init_server_adapter()
195200
endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}"
196201
else:
197202
endpoint = ""
198-
req_func = req_inference(endpoint, inference_parallel_size)
199203

200-
checkpoint_name = "checkpoint_engine"
201-
ps = ParameterServer()
202-
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
203-
dist.barrier()
204-
ps.gather_metas(checkpoint_name)
205-
ps.update(checkpoint_name, req_func)
204+
req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size)
206205

207-
if tp_rank == 0:
208-
await self._engine.flush_cache()
206+
return req_func

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
290290
"override_generation_config": json.dumps(override_generation_config),
291291
"quantization": quantization,
292292
"hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None,
293+
"worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" if self.config.enable_checkpoint_engine else None,
293294
**engine_kwargs,
294295
}
295296

@@ -690,6 +691,7 @@ async def launch_servers(self):
690691
soft=False,
691692
),
692693
name=name,
694+
runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} if self.config.enable_checkpoint_engine else None,
693695
).remote(
694696
config=self.config,
695697
model_config=self.model_config,

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import logging
3232
import os
3333
from dataclasses import asdict
34+
from collections.abc import Callable
3435
from types import MethodType
3536
from typing import Any, Generator
3637

@@ -133,6 +134,13 @@ def __init__(
133134
else:
134135
self.sleep_level = VLLM_SLEEP_LEVEL
135136

137+
rank = int(os.environ["RANK"])
138+
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
139+
rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size
140+
self.replica_rank = rank // rollout_world_size
141+
self.rollout_rank = rank % rollout_world_size
142+
self.node_rank = self.rollout_rank // local_world_size
143+
136144
def _init_zeromq(self) -> str:
137145
tensor_parallel_size = self.config.tensor_model_parallel_size
138146

@@ -262,6 +270,21 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
262270
logger.info("Loading standard weights (non-FP8, async)")
263271
model.load_weights(weights)
264272

273+
async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]:
274+
from checkpoint_engine.ps import request_inference_to_update
275+
rank = int(os.getenv("RANK", None))
276+
src = rank // inference_parallel_size * inference_parallel_size
277+
278+
server_actor = ray.get_actor(f"vllm_server_{self.replica_rank}_{self.node_rank}")
279+
server_address, server_port = await server_actor.get_server_address.remote()
280+
def req_func(socket_paths: list[tuple[str, str]]):
281+
if rank == src:
282+
request_inference_to_update(
283+
f"http://{server_address}:{server_port}/collective_rpc",
284+
dict(socket_paths[src : src + inference_parallel_size]),
285+
)
286+
return req_func
287+
265288
def generate_sequences(self, prompts: DataProto) -> DataProto:
266289
"""Batch generate sequences in sync mode."""
267290
raise NotImplementedError

0 commit comments

Comments
 (0)