File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed
Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -583,7 +583,8 @@ def _get_master_port(master_port: int | None = None) -> int:
583583 master_port = int (os .getenv ("MASTER_PORT" )) + 1
584584 return master_port
585585
586- def _get_bcast_rank_map (world_size , ranks : list [int ] | None ) -> dict [int , int ]:
586+
587+ def _get_bcast_rank_map (world_size : int , ranks : list [int ] | None ) -> dict [int , int ]:
587588 """
588589 map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
589590 which are generated in self.init_process_group_for_ranks
@@ -1125,9 +1126,7 @@ def _update_per_bucket(
11251126 if ranks :
11261127 h2d_buffer_name = "__h2d_buffer__"
11271128 if h2d_buffer is not None and self ._p2p_store is not None :
1128- self ._p2p_store .register_named_tensors (
1129- {h2d_buffer_name : h2d_buffer }
1130- )
1129+ self ._p2p_store .register_named_tensors ({h2d_buffer_name : h2d_buffer })
11311130 receiver_rank_buckets : list [tuple [int , H2DBucket ]] = []
11321131 for receiver_rank , owner_rank , bucket in buckets :
11331132 if receiver_rank != self ._rank :
You can’t perform that action at this time.
0 commit comments