@@ -553,8 +553,9 @@ def _assign_receiver_ranks(
553553 buckets_by_rdma_device [owner_rdma_device ].append ((owner_rank , bucket ))
554554
555555 buckets_matrix = list (buckets_by_rdma_device .values ())
556+ assert buckets_matrix , "buckets_matrix should not be empty"
556557
557- # select receiver ranks
558+ # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
558559 num_receivers = min (len (local_topo ), len (buckets_by_rdma_device ))
559560 receiver_list = [min (ranks ) for ranks in list (local_topo .values ())[:num_receivers ]]
560561
@@ -582,6 +583,19 @@ def _get_master_port(master_port: int | None = None) -> int:
582583 master_port = int (os .getenv ("MASTER_PORT" )) + 1
583584 return master_port
584585
586+ def _get_bcast_rank_map (world_size , ranks : list [int ] | None ) -> dict [int , int ]:
587+ """
588+ map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
589+ which are generated in self.init_process_group_for_ranks
590+ """
591+ bcast_rank_map : dict [int , int ] = {}
592+ if not ranks :
593+ bcast_rank_map = {r : r for r in range (world_size )}
594+ else :
595+ for i , r in enumerate (ranks ):
596+ bcast_rank_map [r ] = i
597+ return bcast_rank_map
598+
585599
586600class P2PStore :
587601 def __init__ (self ):
@@ -877,6 +891,7 @@ def update(
877891 If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
878892 which is useful in disaggregated architecture.
879893 """
894+ assert req_func is not None , "req_func is required"
880895 try :
881896 # if both ranks is None or [], it will use fully broadcast to update to all ranks
882897 if not ranks :
@@ -1062,19 +1077,6 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
10621077 [f"memory_pool_{ checkpoint_name } _{ idx } " for idx , _ in enumerate (pool )]
10631078 )
10641079
1065- def _get_bcast_rank_map (self , ranks : list [int ]) -> dict [int , int ]:
1066- """
1067- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
1068- which are generated in self.init_process_group_for_ranks
1069- """
1070- bcast_rank_map : dict [int , int ] = {}
1071- if not ranks :
1072- bcast_rank_map = {r : r for r in range (self ._world_size )}
1073- else :
1074- for i , r in enumerate (ranks ):
1075- bcast_rank_map [r ] = i
1076- return bcast_rank_map
1077-
10781080 def _update_per_bucket (
10791081 self ,
10801082 checkpoint_name : str ,
@@ -1084,24 +1086,15 @@ def _update_per_bucket(
10841086 logger .warning (
10851087 f"[rank{ self ._rank } ] Using _update_per_bucket, which is an experimental feature."
10861088 )
1087- assert req_func is not None
1089+ assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
1090+ assert dist .is_initialized (), "process group is not initialized"
10881091 # if both ranks is None or [], it will use fully broadcast to update to all ranks
10891092 if not ranks :
1090- if len (self ._current_global_parameter_metas ) == 0 :
1091- raise ValueError ("parameter metas is empty" )
1092-
1093- assert dist .is_initialized (), "process group is not initialized"
1094-
10951093 logger .info (f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } " )
10961094 # if ranks is set, it will use p2p to update to the ranks
10971095 else :
10981096 assert self ._p2p_store is not None , "p2p store is not initialized"
10991097 assert ranks , "ranks should be set"
1100- if len (self ._current_global_parameter_metas ) == 0 :
1101- raise ValueError ("parameter metas is empty" )
1102- assert dist .is_initialized (), (
1103- "process group is not initialized when update model per bucket p2p"
1104- )
11051098
11061099 need_update = self ._rank in ranks
11071100 logger .info (
@@ -1131,10 +1124,10 @@ def _update_per_bucket(
11311124 # p2p store need to register h2d_buffer to let other ranks read
11321125 if ranks :
11331126 h2d_buffer_name = "__h2d_buffer__"
1134- self ._p2p_store . register_named_tensors (
1135- { h2d_buffer_name : h2d_buffer }
1136- ) if h2d_buffer is not None else None
1137-
1127+ if h2d_buffer is not None and self ._p2p_store is not None :
1128+ self . _p2p_store . register_named_tensors (
1129+ { h2d_buffer_name : h2d_buffer }
1130+ )
11381131 receiver_rank_buckets : list [tuple [int , H2DBucket ]] = []
11391132 for receiver_rank , owner_rank , bucket in buckets :
11401133 if receiver_rank != self ._rank :
@@ -1160,17 +1153,15 @@ def _update_per_bucket(
11601153 socket .send_pyobj (handle )
11611154
11621155 gidx = 0
1156+ bcast_rank_map = _get_bcast_rank_map (self ._world_size , ranks )
11631157 for i in range (max_len ):
11641158 if i < len (receiver_rank_buckets ) and not disable_h2d_buffer :
1165- if not ranks :
1166- self ._copy_to_buffer (checkpoint_name , receiver_rank_buckets [i ][1 ], h2d_buffer )
1167- else :
1168- self ._copy_to_buffer (
1169- checkpoint_name ,
1170- receiver_rank_buckets [i ][1 ],
1171- h2d_buffer ,
1172- receiver_rank_buckets [i ][0 ],
1173- )
1159+ self ._copy_to_buffer (
1160+ checkpoint_name ,
1161+ receiver_rank_buckets [i ][1 ],
1162+ h2d_buffer ,
1163+ receiver_rank_buckets [i ][0 ] if ranks else None ,
1164+ )
11741165 for receiver_rank , _buckets in buckets_by_receiver_rank .items ():
11751166 if i >= len (_buckets ):
11761167 continue
@@ -1191,7 +1182,7 @@ def _update_per_bucket(
11911182 self ._copy_to_buffer (checkpoint_name , bucket , buffer_b )
11921183 else :
11931184 buffer_b .data .copy_ (h2d_buffer [: bucket .size ])
1194- brank = self . _get_bcast_rank_map ( ranks ) [receiver_rank ]
1185+ brank = bcast_rank_map [receiver_rank ]
11951186 dist .broadcast (buffer_b , src = brank )
11961187 socket .recv ()
11971188 dist .barrier ()
0 commit comments