@@ -305,6 +305,17 @@ def __init__(self, args, model, weights, *, model_name, quantization_config, voc
305305 self .param_info_buckets = get_param_info_buckets (self .args , self .model )
306306 self .weight_version = 0
307307
308+ # create the group within megatron.
309+ for start_rank in range (0 , dist .get_world_size (), self .args .rollout_num_gpus_per_engine ):
310+ end_rank = start_rank + self .args .rollout_num_gpus_per_engine
311+ group_ranks = list (range (start_rank , end_rank ))
312+ new_group = dist .new_group (ranks = group_ranks , backend = "gloo" )
313+ if dist .get_rank () in group_ranks :
314+ self ._ipc_gather_group = new_group
315+ self ._ipc_gather_src = start_rank
316+
317+ self ._model_update_groups = None
318+
308319 def connect_rollout_engines (self , rollout_engines , rollout_engine_lock ):
309320 self .rollout_engines = rollout_engines
310321 colocate_engine_nums = (
@@ -322,6 +333,11 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
322333 )
323334 self ._group_name = "slime"
324335 if self ._is_distributed_src_rank :
336+ if self ._model_update_groups is not None :
337+ disconnect_rollout_engines_from_distributed (
338+ self .args , self ._group_name , self ._model_update_groups , self .distributed_rollout_engines
339+ )
340+
325341 self ._model_update_groups = connect_rollout_engines_from_distributed (
326342 self .args , self ._group_name , self .distributed_rollout_engines
327343 )
@@ -331,13 +347,7 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
331347 start_rank = i * self .args .rollout_num_gpus_per_engine
332348 end_rank = (i + 1 ) * self .args .rollout_num_gpus_per_engine
333349 group_ranks = list (range (start_rank , end_rank ))
334- new_group = dist .new_group (
335- ranks = group_ranks ,
336- backend = "gloo" ,
337- )
338350 if dist .get_rank () in group_ranks :
339- self ._ipc_gather_src = start_rank
340- self ._ipc_gather_group = new_group
341351 self ._ipc_engine = engine
342352
343353 @torch .no_grad ()
@@ -496,6 +506,7 @@ def __init__(self, args, model, weights, *, model_name, quantization_config, voc
496506 self .vocab_size = vocab_size
497507 self .quantization_config = quantization_config
498508 self .weight_version = 0
509+ self ._model_update_groups = None
499510
500511 def connect_rollout_engines (self , rollout_engines , rollout_engine_lock ):
501512 self .rollout_engines = rollout_engines
@@ -512,6 +523,10 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
512523 self ._group_name = f"slime-pp_{ pp_rank } "
513524
514525 if self ._is_pp_src_rank :
526+ if self ._model_update_groups is not None :
527+ disconnect_rollout_engines_from_distributed (
528+ self .args , self ._group_name , self ._model_update_groups , self .rollout_engines
529+ )
515530 self ._model_update_groups = connect_rollout_engines_from_distributed (
516531 self .args , self ._group_name , rollout_engines
517532 )
@@ -670,6 +685,12 @@ def connect_rollout_engines_from_distributed(args, group_name, rollout_engines):
670685 return model_update_groups
671686
672687
688+ def disconnect_rollout_engines_from_distributed (args , group_name , model_update_groups , rollout_engines ):
689+ refs = [engine .destroy_weights_update_group .remote (group_name ) for engine in rollout_engines ]
690+ dist .destroy_process_group (model_update_groups )
691+ ray .get (refs )
692+
693+
673694def update_weights_from_distributed (args , group_name , group , weight_version , rollout_engines , converted_named_tensors ):
674695 refs = [
675696 engine .update_weights_from_distributed .remote (
0 commit comments