3030 set_prometheus_multiproc_dir ,
3131 set_ulimit ,
3232)
33- from sglang .srt .weight_sync .utils import update_weights as sgl_update_weights
33+ from sglang .srt .checkpoint_engine .update import req_inference
34+ from checkpoint_engine .ps import ParameterServer
35+ import torch .distributed as dist
3436from torch .distributed .device_mesh import DeviceMesh
3537
3638from verl .workers .config import HFModelConfig , RolloutConfig
@@ -165,10 +167,13 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
165167 - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
166168 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
167169 """
168- if self .device_mesh ["infer_tp" ].get_local_rank () == 0 :
170+ import torch .distributed as dist
171+
172+ tp_rank = self .device_mesh ["infer_tp" ].get_local_rank ()
173+ inference_parallel_size = self .config .tensor_model_parallel_size * self .config .data_parallel_size * self .config .pipeline_model_parallel_size
174+ if tp_rank == 0 :
169175 await self ._init_server_adapter ()
170176
171- update_weights_bucket_bytes = int (self .config .update_weights_bucket_megabytes ) << 20
172177 if self .config .get ("quantization" , None ) == "fp8" :
173178 from verl .utils .sglang .sglang_fp8_utils import quant_weights_by_name
174179
@@ -181,13 +186,23 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
181186 else :
182187 weights = weights
183188
184- for params_batch in get_named_tensor_buckets (weights , update_weights_bucket_bytes ):
185- await sgl_update_weights (
186- engine = self ._engine ,
187- params_batch = params_batch ,
188- device_mesh_key = "infer_tp" ,
189- device_mesh = self .device_mesh ,
190- )
189+ named_tensors = []
190+ for idx , weight in enumerate (weights ):
191+ if idx % inference_parallel_size == tp_rank :
192+ named_tensors .append ((weight [0 ], weight [1 ].cpu ()))
193+
194+ if tp_rank == 0 :
195+ endpoint = f"http://{ self ._engine .server_args .host } :{ self ._engine .server_args .port } "
196+ else :
197+ endpoint = ""
198+ req_func = req_inference (endpoint , inference_parallel_size )
199+
200+ checkpoint_name = "checkpoint_engine"
201+ ps = ParameterServer ()
202+ ps .register_checkpoint (checkpoint_name , named_tensors = named_tensors )
203+ dist .barrier ()
204+ ps .gather_metas (checkpoint_name )
205+ ps .update (checkpoint_name , req_func )
191206
192- if self . device_mesh [ "infer_tp" ]. get_local_rank () == 0 :
207+ if tp_rank == 0 :
193208 await self ._engine .flush_cache ()
0 commit comments