@@ -931,6 +931,15 @@ def init_process_group(
931931 )
932932 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
933933
934+ def store_based_barrier (self , store : dist .TCPStore ):
935+ dist .distributed_c10d ._store_based_barrier (
936+ rank = self ._rank ,
937+ store = store ,
938+ group_name = "parameter_server_barrier" ,
939+ rendezvous_count = self ._world_size ,
940+ timeout = timedelta (minutes = 5 ),
941+ )
942+
934943 def update (
935944 self ,
936945 checkpoint_name : str ,
@@ -958,34 +967,27 @@ def update(
958967 timeout = timedelta (minutes = 10 ),
959968 is_master = self ._rank == 0 ,
960969 )
961- # if both ranks is None or [], it will use fully broadcast to update to all ranks
962- if not ranks :
963- if self ._auto_pg and not dist .is_initialized ():
964- self .init_process_group ()
965- ranks_group = dist .new_group ()
966- self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
967- else :
968- if self ._auto_pg :
970+ if self ._auto_pg :
971+ if not ranks :
972+ if not dist .is_initialized ():
973+ self .init_process_group ()
974+ else :
969975 if dist .is_initialized ():
970976 dist .destroy_process_group ()
971977 # HACK: wait 2s to ensure destroy is finished
972978 time .sleep (2 )
973979 self .init_process_group ()
974- ranks_group = dist .new_group (ranks )
975- logger .info (
976- f"[rank{ self ._rank } ] default pg: { dist .group .WORLD } , ranks group: { ranks_group } "
977- )
980+
981+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
982+ ranks_group = dist .new_group (ranks if ranks else None )
983+ if not ranks :
984+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group )
985+ else :
978986 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
979987
980- dist .distributed_c10d ._store_based_barrier (
981- rank = self ._rank ,
982- store = manager_store ,
983- group_name = "manager_store_barrier" ,
984- rendezvous_count = self ._world_size ,
985- timeout = timedelta (minutes = 5 ),
986- )
988+ self .store_based_barrier (manager_store )
989+
987990 dist .destroy_process_group (ranks_group )
988- del ranks_group
989991 if self ._auto_pg :
990992 dist .destroy_process_group ()
991993
0 commit comments