@@ -1059,6 +1059,8 @@ def update(
10591059 * ,
10601060 timeout : timedelta = timedelta (minutes = 10 ),
10611061 ranks : list [int ] | None = None ,
1062+ master_addr : str | None = None ,
1063+ master_port : int | None = None ,
10621064 ) -> None :
10631065 """
10641066 Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1070,21 +1072,27 @@ def update(
10701072 which is the fastest way to update weights, especially in colocated architecture.
10711073 If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
10721074 which is useful in disaggregated architecture.
1075+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1076+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1077+ timeout: The timeout of the barrier operation.
10731078 """
10741079 assert req_func is not None , "req_func is required"
10751080 ranks_group = None
10761081 try :
1077- master_addr = os .getenv ("MASTER_ADDR" )
1082+ master_addr = os .getenv ("MASTER_ADDR" ) or master_addr
10781083 assert master_addr , "master_addr is required"
10791084 if self ._auto_pg :
10801085 if not dist .is_initialized ():
1081- self .init_process_group (timeout = timeout )
1086+ self .init_process_group (
1087+ timeout = timeout , master_addr = master_addr , master_port = master_port
1088+ )
10821089 manager_store = dist .distributed_c10d ._get_default_store ()
10831090 else :
1084- # HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
1091+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1092+ # If master_port is provided, use master_port+1 for barrier store
10851093 manager_store = dist .TCPStore (
10861094 master_addr ,
1087- _get_master_port () + 1 ,
1095+ _get_master_port (master_port ) + 1 ,
10881096 self ._world_size ,
10891097 timeout = timeout ,
10901098 is_master = self ._rank == 0 ,
@@ -1331,16 +1339,15 @@ def _update_per_bucket(
13311339 self ._copy_to_buffer (checkpoint_name , bucket , buffer_b )
13321340 else :
13331341 buffer_b .data .copy_ (h2d_buffer [: bucket .size ])
1334- brank = bcast_rank_map [receiver_rank ]
1335- dist .broadcast (buffer_b , src = brank )
1342+ dist .broadcast (buffer_b , src = receiver_rank , group = ranks_group )
13361343 resp = socket .recv ()
13371344 if resp != b"" :
13381345 msg = resp .decode ("utf-8" )
13391346 logger .error (
13401347 f"[rank{ self ._rank } ] receive error response from rank { receiver_rank } for bucket { gidx } in checkpoint { checkpoint_name } : { msg } "
13411348 )
13421349 ret_code .fill_ (1 )
1343- dist .all_reduce (ret_code , op = dist .ReduceOp .SUM )
1350+ dist .all_reduce (ret_code , op = dist .ReduceOp .SUM , group = ranks_group )
13441351 self .device_manager .device_module .synchronize ()
13451352 if ret_code .item () != 0 :
13461353 # quit early if any rank failed
0 commit comments