@@ -174,8 +174,10 @@ async def receive_tensor(
174174
175175@dataclass
176176class MasterMetadata :
177- ip : str
178- port : int
177+ zmq_ip : str
178+ zmq_port : int
179+ dist_ip : str
180+ dist_port : int
179181
180182
181183class BroadcastOperation :
@@ -231,17 +233,11 @@ class KIMICheckpointEngine(CheckpointEngine):
231233
232234 def __init__ (
233235 self ,
234- train_world_size : int ,
235- rollout_world_size : int ,
236236 bucket_size : int ,
237237 rebuild_group : bool = False ,
238238 is_master : bool = False ,
239239 rollout_dtype : torch .dtype = torch .bfloat16 ,
240240 ) -> None :
241- self .train_world_size = train_world_size
242- self .rollout_world_size = rollout_world_size
243- self .world_size = train_world_size + rollout_world_size
244-
245241 self .bucket_size = bucket_size
246242 self .rebuild_group = rebuild_group
247243 self .rollout_dtype = rollout_dtype
@@ -254,39 +250,65 @@ def prepare(self) -> MasterMetadata:
254250 self .ip = ray .util .get_node_ip_address ().strip ("[]" )
255251 self .listen_port , _ = get_free_port (self .ip )
256252
257- return MasterMetadata (ip = self .ip , port = self .listen_port ) if self .is_master else None
253+ return (
254+ MasterMetadata (zmq_ip = None , zmq_port = None , dist_ip = self .ip , dist_port = self .listen_port )
255+ if self .is_master
256+ else None
257+ )
258258
259- def finish (self ):
259+ def finalize (self ):
260260 """Destroy the ckpt engine process group if rebuild_group is True."""
261261 if self .rebuild_group :
262262 dist .destroy_process_group ()
263263 self .rank = None
264264 self .world_size = None
265265 self .initialized = False
266266
267- def init_process_group (self , rank : int , world_size : int , master_metadata : MasterMetadata ):
267+ @classmethod
268+ def build_topology (cls , trainer_world_size : int , rollout_world_size : int , metadata : list [dict ]):
269+ trainer_kwargs = {
270+ "method" : ["init_process_group" ] * trainer_world_size ,
271+ "rank" : list (range (0 , trainer_world_size )),
272+ "trainer_world_size" : [trainer_world_size ] * trainer_world_size ,
273+ "rollout_world_size" : [rollout_world_size ] * rollout_world_size ,
274+ "master_metadata" : [metadata [0 ]] * trainer_world_size ,
275+ }
276+ rollout_kwargs = {
277+ "method" : ["init_process_group" ] * rollout_world_size ,
278+ "rank" : list (range (trainer_world_size , trainer_world_size + rollout_world_size )),
279+ "trainer_world_size" : [trainer_world_size ] * trainer_world_size ,
280+ "rollout_world_size" : [rollout_world_size ] * rollout_world_size ,
281+ "master_metadata" : [metadata [0 ]] * rollout_world_size ,
282+ }
283+ return trainer_kwargs , rollout_kwargs
284+
285+ def init_process_group (self , rank : int , trainer_world_size : int , rollout_world_size :int , master_metadata : MasterMetadata ):
268286 """Initialize the ckpt engine process group.
269287
270288 Args:
271289 rank (int): The rank of the current process.
272290 world_size (int): The total number of processes.
273291 """
274292 self .rank = rank
293+ self .trainer_world_size = trainer_world_size
294+ self .rollout_world_size = rollout_world_size
295+ self .world_size = trainer_world_size + rollout_world_size
275296 # unregister_memory in transfer engine is not supported on NPU,
276297 # so we have to initialize ParameterServer each time
277298 if get_device_name () == "npu" or not self .initialized :
278- self .parameter_server = ParameterServer (rank = rank , world_size = world_size , auto_pg = False , custom_dist = True )
279- self .parameter_server .receive_tensor = types .MethodType (receive_tensor , self .parameter_server )
280- if not self .initialized :
281- dist .init_process_group (
282- host = master_metadata .ip ,
283- port = master_metadata .port ,
299+ self .parameter_server = ParameterServer (
284300 rank = rank ,
285- world_size = world_size ,
286- backend = get_nccl_backend (),
301+ world_size = self .world_size ,
302+ auto_pg = False ,
303+ master_addr = master_metadata .dist_ip ,
304+ master_port = master_metadata .dist_port ,
287305 )
306+ self .parameter_server .receive_tensor = types .MethodType (receive_tensor , self .parameter_server )
307+ if not self .initialized :
308+ dist .use_backend (f"vllm_{ get_nccl_backend ()} " )
309+ self .parameter_server .init_process_group ()
288310
289- self .rollout_ranks = list (range (self .train_world_size , world_size ))
311+ self .rollout_ranks = list (range (self .trainer_world_size , self . world_size ))
290312 self .rollout_group = dist .new_group (self .rollout_ranks )
291313 self .initialized = True
292314
@@ -304,7 +326,7 @@ def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch
304326 start_time = time .time ()
305327 named_tensors = {}
306328 for named_tensors_gpu in ckpt_get_named_tensor_buckets (
307- weights , self .bucket_size , self .train_world_size , self .rank , self .rollout_dtype
329+ weights , self .bucket_size , self .trainer_world_size , self .rank , self .rollout_dtype
308330 ):
309331 with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
310332 futures = [
0 commit comments