Skip to content

Commit b6ff980

Browse files
author
kip-cxj
committed
update weighes by checkpoint_engine in sglang
1 parent a0e8e44 commit b6ff980

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
set_prometheus_multiproc_dir,
3131
set_ulimit,
3232
)
33-
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
33+
from sglang.srt.checkpoint_engine.update import req_inference
34+
from checkpoint_engine.ps import ParameterServer
35+
import torch.distributed as dist
3436
from torch.distributed.device_mesh import DeviceMesh
3537

3638
from verl.workers.config import HFModelConfig, RolloutConfig
@@ -165,10 +167,13 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
165167
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
166168
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
167169
"""
168-
if self.device_mesh["infer_tp"].get_local_rank() == 0:
170+
import torch.distributed as dist
171+
172+
tp_rank = self.device_mesh["infer_tp"].get_local_rank()
173+
inference_parallel_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size * self.config.pipeline_model_parallel_size
174+
if tp_rank == 0:
169175
await self._init_server_adapter()
170176

171-
update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20
172177
if self.config.get("quantization", None) == "fp8":
173178
from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name
174179

@@ -181,13 +186,23 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
181186
else:
182187
weights = weights
183188

184-
for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):
185-
await sgl_update_weights(
186-
engine=self._engine,
187-
params_batch=params_batch,
188-
device_mesh_key="infer_tp",
189-
device_mesh=self.device_mesh,
190-
)
189+
named_tensors = []
190+
for idx, weight in enumerate(weights):
191+
if idx % inference_parallel_size == tp_rank:
192+
named_tensors.append((weight[0], weight[1].cpu()))
193+
194+
if tp_rank == 0:
195+
endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}"
196+
else:
197+
endpoint = ""
198+
req_func = req_inference(endpoint, inference_parallel_size)
199+
200+
checkpoint_name = "checkpoint_engine"
201+
ps = ParameterServer()
202+
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
203+
dist.barrier()
204+
ps.gather_metas(checkpoint_name)
205+
ps.update(checkpoint_name, req_func)
191206

192-
if self.device_mesh["infer_tp"].get_local_rank() == 0:
207+
if tp_rank == 0:
193208
await self._engine.flush_cache()

0 commit comments

Comments
 (0)