@@ -650,20 +650,6 @@ def _get_master_port(master_port: int | None = None) -> int:
650650 return master_port
651651
652652
653- def _get_bcast_rank_map (world_size : int , ranks : list [int ] | None ) -> dict [int , int ]:
654- """
655- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
656- which are generated in self.init_process_group_for_ranks
657- """
658- bcast_rank_map : dict [int , int ] = {}
659- if not ranks :
660- bcast_rank_map = {r : r for r in range (world_size )}
661- else :
662- for i , r in enumerate (ranks ):
663- bcast_rank_map [r ] = i
664- return bcast_rank_map
665-
666-
667653class P2PStore :
668654 def __init__ (self , device_manager : DeviceManager ):
669655 from mooncake .engine import TransferEngine
@@ -965,21 +951,41 @@ def update(
965951 """
966952 assert req_func is not None , "req_func is required"
967953 try :
954+ manager_store = dist .TCPStore (
955+ os .getenv ("MASTER_ADDR" ),
956+ _get_master_port () + 1 ,
957+ self ._world_size ,
958+ timeout = timedelta (minutes = 10 ),
959+ is_master = self ._rank == 0 ,
960+ )
968961 # if both ranks is None or [], it will use fully broadcast to update to all ranks
969962 if not ranks :
970963 if self ._auto_pg and not dist .is_initialized ():
971964 self .init_process_group ()
972- self ._update_per_bucket (checkpoint_name , req_func )
965+ ranks_group = dist .new_group ()
966+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
973967 else :
974968 if self ._auto_pg :
975969 if dist .is_initialized ():
976970 dist .destroy_process_group ()
977971 # HACK: wait 2s to ensure destroy is finished
978972 time .sleep (2 )
979- self .init_process_group_for_ranks (ranks )
980- if self ._rank not in ranks :
981- return
982- self ._update_per_bucket (checkpoint_name , req_func , ranks )
973+ self .init_process_group ()
974+ ranks_group = dist .new_group (ranks )
975+ logger .info (
976+ f"[rank{ self ._rank } ] default pg: { dist .group .WORLD } , ranks group: { ranks_group } "
977+ )
978+ self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
979+
980+ dist .distributed_c10d ._store_based_barrier (
981+ rank = self ._rank ,
982+ store = manager_store ,
983+ group_name = "manager_store_barrier" ,
984+ rendezvous_count = self ._world_size ,
985+ timeout = timedelta (minutes = 5 ),
986+ )
987+ dist .destroy_process_group (ranks_group )
988+ del ranks_group
983989 if self ._auto_pg :
984990 dist .destroy_process_group ()
985991
@@ -1006,7 +1012,9 @@ def zmq_handle(device_uuid: str) -> str:
10061012 self ._zmq_addr_counter += 1
10071013 return socket , socket_paths
10081014
1009- def _detect_bucket_size (self , * , disable_h2d_buffer : bool = False ) -> tuple [int , bool ]:
1015+ def _detect_bucket_size (
1016+ self , ranks_group : dist .ProcessGroup , * , disable_h2d_buffer : bool = False
1017+ ) -> tuple [int , bool ]:
10101018 GiB = 1 << 30 # noqa: N806
10111019 # auto detect bucket size
10121020 tensor = torch .tensor (
@@ -1022,7 +1030,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
10221030 dtype = torch .int64 ,
10231031 device = self .device_manager .device_type ,
10241032 )
1025- dist .all_reduce (tensor , op = dist .ReduceOp .MIN )
1033+ dist .all_reduce (tensor , op = dist .ReduceOp .MIN , group = ranks_group )
10261034 tensor = tensor .cpu ()
10271035 free_bytes , self ._zmq_addr_counter = tensor [0 ].item (), - tensor [1 ].item ()
10281036 max_tensor_bytes = 0
@@ -1085,47 +1093,6 @@ def _copy_to_buffer(
10851093 self ._p2p_store .batch_transfer_sync_read (target_addr , buf_ptrs , remote_ptrs , lens )
10861094 self .device_manager .device_module .synchronize ()
10871095
1088- def init_process_group_for_ranks (
1089- self ,
1090- ranks : list [int ],
1091- * ,
1092- master_port : int | None = None ,
1093- timeout : timedelta = timedelta (minutes = 10 ),
1094- ):
1095- """
1096- Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1097-
1098- Args:
1099- ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1100- master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1101- timeout: The timeout of the process group.
1102- """
1103- assert not dist .is_initialized ()
1104- assert ranks , "ranks should be set"
1105- if self ._rank not in ranks :
1106- return
1107- assert self ._all_hosts , "all_hosts should be set"
1108- assert len (self ._all_hosts ) == self ._world_size // self ._gpu_count , (
1109- f"world_size { self ._world_size } should be equal to all_hosts { len (self ._all_hosts )} "
1110- )
1111- rank = ranks .index (self ._rank )
1112- master_addr = self ._all_hosts [ranks [0 ] // self ._gpu_count ]
1113- master_port = _get_master_port (master_port )
1114- logger .info (
1115- f"[rank{ self ._rank } ] start to init process group as virtual_rank { rank } , "
1116- f"master_addr { master_addr } , master_port { master_port } , world_size { len (ranks )} , "
1117- )
1118- # only initialize process group and store for ranks, other nodes are not initialized
1119- # and will not participate in this update. Since they have registered memory addresses
1120- # to p2p_store at the beginning, update ranks can directly get the memory addresses
1121- # from other nodes and put the weights into the buffer.
1122- store = dist .TCPStore (
1123- master_addr , master_port , len (ranks ), is_master = rank == 0 , timeout = timeout
1124- )
1125- dist .init_process_group (
1126- backend = "nccl" , world_size = len (ranks ), rank = rank , timeout = timeout , store = store
1127- )
1128-
11291096 def _get_addr_ptrs (self , owner_rank : int ) -> tuple [str , list [tuple [int , int ]]]:
11301097 addr = self ._current_global_parameter_metas [owner_rank ].p2p_store_addr
11311098 metas_list = self ._current_global_parameter_metas [owner_rank ].memory_buffer_metas_list
@@ -1155,10 +1122,13 @@ def _update_per_bucket(
11551122 self ,
11561123 checkpoint_name : str ,
11571124 req_func : Callable [[list [tuple [str , str ]]], None ],
1125+ ranks_group : dist .ProcessGroup ,
11581126 ranks : list [int ] | None = None ,
11591127 ):
11601128 assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
11611129 assert dist .is_initialized (), "process group is not initialized"
1130+ assert ranks_group is not None , "ranks_group should be set"
1131+
11621132 # if both ranks is None or [], it will use fully broadcast to update to all ranks
11631133 if not ranks :
11641134 logger .info (f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } " )
@@ -1176,9 +1146,9 @@ def _update_per_bucket(
11761146 if not need_update :
11771147 return
11781148 # first execute a barrier to avoid subsequent cuda oom
1179- dist .barrier ()
1149+ dist .barrier (group = ranks_group )
11801150
1181- bucket_size , disable_h2d_buffer = self ._detect_bucket_size ()
1151+ bucket_size , disable_h2d_buffer = self ._detect_bucket_size (ranks_group )
11821152 buckets = _gen_h2d_buckets (
11831153 self ._current_global_parameter_metas ,
11841154 bucket_size ,
@@ -1224,7 +1194,6 @@ def _update_per_bucket(
12241194 socket .send_pyobj (handle )
12251195
12261196 gidx = 0
1227- bcast_rank_map = _get_bcast_rank_map (self ._world_size , ranks )
12281197 for i in range (max_len ):
12291198 if i < len (receiver_rank_buckets ) and not disable_h2d_buffer :
12301199 self ._copy_to_buffer (
@@ -1253,18 +1222,17 @@ def _update_per_bucket(
12531222 self ._copy_to_buffer (checkpoint_name , bucket , buffer_b )
12541223 else :
12551224 buffer_b .data .copy_ (h2d_buffer [: bucket .size ])
1256- brank = bcast_rank_map [receiver_rank ]
1257- dist .broadcast (buffer_b , src = brank )
1225+ dist .broadcast (buffer_b , src = receiver_rank , group = ranks_group )
12581226 socket .recv ()
1259- dist .barrier ()
1227+ dist .barrier (group = ranks_group )
12601228 socket .send_pyobj (_to_named_tensor (bucket .items , gidx % 2 * bucket_size ))
12611229 gidx += 1
12621230
12631231 socket .recv ()
12641232 socket .send_pyobj (None )
12651233 socket .recv ()
12661234 req_thread .join ()
1267- dist .barrier ()
1235+ dist .barrier (group = ranks_group )
12681236 socket .close ()
12691237 if ranks and h2d_buffer is not None :
12701238 self ._p2p_store .unregister_named_tensors ([h2d_buffer_name ])
0 commit comments