@@ -1031,7 +1031,9 @@ def init_process_group(
10311031 )
10321032 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
10331033
1034- def store_based_barrier (self , store : dist .TCPStore ) -> None :
1034+ def store_based_barrier (
1035+ self , store : dist .TCPStore , timeout : timedelta = timedelta (minutes = 5 )
1036+ ) -> None :
10351037 """
10361038 Perform a store-based barrier synchronization across all ranks.
10371039
@@ -1047,14 +1049,15 @@ def store_based_barrier(self, store: dist.TCPStore) -> None:
10471049 store = store ,
10481050 group_name = "parameter_server_barrier" ,
10491051 rendezvous_count = self ._world_size ,
1050- timeout = timedelta ( minutes = 5 ) ,
1052+ timeout = timeout ,
10511053 )
10521054
10531055 def update (
10541056 self ,
10551057 checkpoint_name : str ,
10561058 req_func : Callable [[list [tuple [str , str ]]], None ],
10571059 * ,
1060+ timeout : timedelta = timedelta (minutes = 10 ),
10581061 ranks : list [int ] | None = None ,
10591062 ) -> None :
10601063 """
@@ -1073,28 +1076,23 @@ def update(
10731076 try :
10741077 master_addr = os .getenv ("MASTER_ADDR" )
10751078 assert master_addr , "master_addr is required"
1076-
1077- # HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
1078- manager_store = dist . TCPStore (
1079- master_addr ,
1080- _get_master_port () + 1 ,
1081- self . _world_size ,
1082- timeout = timedelta ( minutes = 10 ),
1083- is_master = self . _rank == 0 ,
1084- )
1085-
1086- if self . _auto_pg and not dist . is_initialized ():
1087- self .init_process_group ()
1088-
1079+ if self . _auto_pg :
1080+ if not dist . is_initialized ():
1081+ self . init_process_group ( timeout = timeout )
1082+ manager_store = dist . distributed_c10d . _get_default_store ()
1083+ else :
1084+ # HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
1085+ manager_store = dist . TCPStore (
1086+ master_addr ,
1087+ _get_master_port () + 1 ,
1088+ self . _world_size ,
1089+ timeout = timeout ,
1090+ is_master = self ._rank == 0 ,
1091+ )
10891092 # if both ranks is None or [], it will use fully broadcast to update to all ranks
10901093 ranks_group = dist .new_group (ranks if ranks else None )
1091- if not ranks :
1092- self ._update_per_bucket (checkpoint_name , req_func , ranks_group )
1093- else :
1094- self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
1095-
1094+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
10961095 self .store_based_barrier (manager_store )
1097-
10981096 except Exception as e :
10991097 logger .exception (
11001098 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
0 commit comments