@@ -939,7 +939,9 @@ def init_process_group(
939939 )
940940 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
941941
942- def store_based_barrier (self , store : dist .TCPStore ) -> None :
942+ def store_based_barrier (
943+ self , store : dist .TCPStore , timeout : timedelta = timedelta (minutes = 5 )
944+ ) -> None :
943945 """
944946 Perform a store-based barrier synchronization across all ranks.
945947
@@ -955,14 +957,15 @@ def store_based_barrier(self, store: dist.TCPStore) -> None:
955957 store = store ,
956958 group_name = "parameter_server_barrier" ,
957959 rendezvous_count = self ._world_size ,
958- timeout = timedelta ( minutes = 5 ) ,
960+ timeout = timeout ,
959961 )
960962
961963 def update (
962964 self ,
963965 checkpoint_name : str ,
964966 req_func : Callable [[list [tuple [str , str ]]], None ],
965967 * ,
968+ timeout : timedelta = timedelta (minutes = 10 ),
966969 ranks : list [int ] | None = None ,
967970 ) -> None :
968971 """
@@ -981,28 +984,23 @@ def update(
981984 try :
982985 master_addr = os .getenv ("MASTER_ADDR" )
983986 assert master_addr , "master_addr is required"
984-
985- # HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
986- manager_store = dist . TCPStore (
987- master_addr ,
988- _get_master_port () + 1 ,
989- self . _world_size ,
990- timeout = timedelta ( minutes = 10 ),
991- is_master = self . _rank == 0 ,
992- )
993-
994- if self . _auto_pg and not dist . is_initialized ():
995- self .init_process_group ()
996-
987+ if self . _auto_pg :
988+ if not dist . is_initialized ():
989+ self . init_process_group ( timeout = timeout )
990+ manager_store = dist . distributed_c10d . _get_default_store ()
991+ else :
992+ # HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
993+ manager_store = dist . TCPStore (
994+ master_addr ,
995+ _get_master_port () + 1 ,
996+ self . _world_size ,
997+ timeout = timeout ,
998+ is_master = self ._rank == 0 ,
999+ )
9971000 # if both ranks is None or [], it will use fully broadcast to update to all ranks
9981001 ranks_group = dist .new_group (ranks if ranks else None )
999- if not ranks :
1000- self ._update_per_bucket (checkpoint_name , req_func , ranks_group )
1001- else :
1002- self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
1003-
1002+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
10041003 self .store_based_barrier (manager_store )
1005-
10061004 except Exception as e :
10071005 logger .exception (
10081006 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
0 commit comments