@@ -270,19 +270,25 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
270270 "method" : ["init_process_group" ] * trainer_world_size ,
271271 "rank" : list (range (0 , trainer_world_size )),
272272 "trainer_world_size" : [trainer_world_size ] * trainer_world_size ,
273- "rollout_world_size" : [rollout_world_size ] * rollout_world_size ,
273+ "rollout_world_size" : [rollout_world_size ] * trainer_world_size ,
274274 "master_metadata" : [metadata [0 ]] * trainer_world_size ,
275275 }
276276 rollout_kwargs = {
277277 "method" : ["init_process_group" ] * rollout_world_size ,
278278 "rank" : list (range (trainer_world_size , trainer_world_size + rollout_world_size )),
279- "trainer_world_size" : [trainer_world_size ] * trainer_world_size ,
279+ "trainer_world_size" : [trainer_world_size ] * rollout_world_size ,
280280 "rollout_world_size" : [rollout_world_size ] * rollout_world_size ,
281281 "master_metadata" : [metadata [0 ]] * rollout_world_size ,
282282 }
283283 return trainer_kwargs , rollout_kwargs
284284
285- def init_process_group (self , rank : int , trainer_world_size : int , rollout_world_size :int , master_metadata : MasterMetadata ):
285+ def init_process_group (
286+ self ,
287+ rank : int ,
288+ trainer_world_size : int ,
289+ rollout_world_size : int ,
290+ master_metadata : MasterMetadata ,
291+ ):
286292 """Initialize the ckpt engine process group.
287293
288294 Args:
@@ -293,9 +299,8 @@ def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_s
293299 self .trainer_world_size = trainer_world_size
294300 self .rollout_world_size = rollout_world_size
295301 self .world_size = trainer_world_size + rollout_world_size
296- # unregister_memory in transfer engine is not supported on NPU,
297- # so we have to initialize ParameterServer each time
298- if get_device_name () == "npu" or not self .initialized :
302+
303+ if not self .initialized :
299304 self .parameter_server = ParameterServer (
300305 rank = rank ,
301306 world_size = self .world_size ,
@@ -304,7 +309,7 @@ def init_process_group(self, rank: int, trainer_world_size: int, rollout_world_s
304309 master_port = master_metadata .dist_port ,
305310 )
306311 self .parameter_server .receive_tensor = types .MethodType (receive_tensor , self .parameter_server )
307- if not self . initialized :
312+
308313 dist .use_backend (f"vllm_{ get_nccl_backend ()} " )
309314 self .parameter_server .init_process_group ()
310315
@@ -345,7 +350,7 @@ def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch
345350
346351 self .parameter_server .register_checkpoint (self .checkpoint_name , named_tensors = named_tensors )
347352 named_tensors = {}
348- torch . cuda .empty_cache ()
353+ get_torch_device () .empty_cache ()
349354 logger .info (f"Rank { self .rank } offload and register, time cost: { time .time () - start_time :.2f} s" )
350355
351356 self .parameter_server .gather_metas (self .checkpoint_name )
0 commit comments