@@ -1031,6 +1031,25 @@ def init_process_group(
10311031 )
10321032 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
10331033
1034+ def store_based_barrier (self , store : dist .TCPStore ) -> None :
1035+ """
1036+ Perform a store-based barrier synchronization across all ranks.
1037+
1038+ This barrier uses a TCP store directly rather than a process group,
1039+ allowing all ranks to synchronize regardless of which process group
1040+ they belong to.
1041+
1042+ Args:
1043+ store: The TCPStore instance to use for synchronization.
1044+ """
1045+ dist .distributed_c10d ._store_based_barrier (
1046+ rank = self ._rank ,
1047+ store = store ,
1048+ group_name = "parameter_server_barrier" ,
1049+ rendezvous_count = self ._world_size ,
1050+ timeout = timedelta (minutes = 5 ),
1051+ )
1052+
10341053 def update (
10351054 self ,
10361055 checkpoint_name : str ,
@@ -1050,52 +1069,42 @@ def update(
10501069 which is useful in disaggregated architecture.
10511070 """
10521071 assert req_func is not None , "req_func is required"
1072+ ranks_group = None
10531073 try :
1074+ master_addr = os .getenv ("MASTER_ADDR" )
1075+ assert master_addr , "master_addr is required"
1076+
1077+ # HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
10541078 manager_store = dist .TCPStore (
1055- os . getenv ( "MASTER_ADDR" ) ,
1079+ master_addr ,
10561080 _get_master_port () + 1 ,
10571081 self ._world_size ,
10581082 timeout = timedelta (minutes = 10 ),
10591083 is_master = self ._rank == 0 ,
10601084 )
1085+
1086+ if self ._auto_pg and not dist .is_initialized ():
1087+ self .init_process_group ()
1088+
10611089 # if both ranks is None or [], it will use fully broadcast to update to all ranks
1090+ ranks_group = dist .new_group (ranks if ranks else None )
10621091 if not ranks :
1063- if self ._auto_pg and not dist .is_initialized ():
1064- self .init_process_group ()
1065- ranks_group = dist .new_group ()
1066- self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
1092+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group )
10671093 else :
1068- if self ._auto_pg :
1069- if dist .is_initialized ():
1070- dist .destroy_process_group ()
1071- # HACK: wait 2s to ensure destroy is finished
1072- time .sleep (2 )
1073- self .init_process_group ()
1074- ranks_group = dist .new_group (ranks )
1075- logger .info (
1076- f"[rank{ self ._rank } ] default pg: { dist .group .WORLD } , ranks group: { ranks_group } "
1077- )
10781094 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
10791095
1080- dist .distributed_c10d ._store_based_barrier (
1081- rank = self ._rank ,
1082- store = manager_store ,
1083- group_name = "manager_store_barrier" ,
1084- rendezvous_count = self ._world_size ,
1085- timeout = timedelta (minutes = 5 ),
1086- )
1087- dist .destroy_process_group (ranks_group )
1088- del ranks_group
1096+ self .store_based_barrier (manager_store )
10891097
10901098 except Exception as e :
10911099 logger .exception (
10921100 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
10931101 )
10941102 raise
10951103 finally :
1096- if self ._auto_pg and (not ranks or self ._rank in ranks ):
1104+ if ranks_group :
1105+ dist .destroy_process_group (ranks_group )
1106+ if self ._auto_pg and dist .is_initialized ():
10971107 dist .destroy_process_group ()
1098-
10991108 self .device_manager .device_module .empty_cache ()
11001109 logger .info (
11011110 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
@@ -1228,7 +1237,6 @@ def _update_per_bucket(
12281237 ):
12291238 assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
12301239 assert dist .is_initialized (), "process group is not initialized"
1231- assert ranks_group is not None , "ranks_group should be set"
12321240
12331241 # if both ranks is None or [], it will use fully broadcast to update to all ranks
12341242 if not ranks :
0 commit comments