Skip to content

Commit bf73f42

Browse files
committed
fix: resolve pr comment issues
1 parent 5916eb9 commit bf73f42

File tree

1 file changed

+30
-39
lines changed

1 file changed

+30
-39
lines changed

checkpoint_engine/ps.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,9 @@ def _assign_receiver_ranks(
553553
buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
554554

555555
buckets_matrix = list(buckets_by_rdma_device.values())
556+
assert buckets_matrix, "buckets_matrix should not be empty"
556557

557-
# select receiver ranks
558+
# Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
558559
num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
559560
receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
560561

@@ -582,6 +583,19 @@ def _get_master_port(master_port: int | None = None) -> int:
582583
master_port = int(os.getenv("MASTER_PORT")) + 1
583584
return master_port
584585

586+
def _get_bcast_rank_map(world_size, ranks: list[int] | None) -> dict[int, int]:
587+
"""
588+
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
589+
which are generated in self.init_process_group_for_ranks
590+
"""
591+
bcast_rank_map: dict[int, int] = {}
592+
if not ranks:
593+
bcast_rank_map = {r: r for r in range(world_size)}
594+
else:
595+
for i, r in enumerate(ranks):
596+
bcast_rank_map[r] = i
597+
return bcast_rank_map
598+
585599

586600
class P2PStore:
587601
def __init__(self):
@@ -877,6 +891,7 @@ def update(
877891
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
878892
which is useful in disaggregated architecture.
879893
"""
894+
assert req_func is not None, "req_func is required"
880895
try:
881896
# if both ranks is None or [], it will use fully broadcast to update to all ranks
882897
if not ranks:
@@ -1062,19 +1077,6 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
10621077
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
10631078
)
10641079

1065-
def _get_bcast_rank_map(self, ranks: list[int]) -> dict[int, int]:
1066-
"""
1067-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
1068-
which are generated in self.init_process_group_for_ranks
1069-
"""
1070-
bcast_rank_map: dict[int, int] = {}
1071-
if not ranks:
1072-
bcast_rank_map = {r: r for r in range(self._world_size)}
1073-
else:
1074-
for i, r in enumerate(ranks):
1075-
bcast_rank_map[r] = i
1076-
return bcast_rank_map
1077-
10781080
def _update_per_bucket(
10791081
self,
10801082
checkpoint_name: str,
@@ -1084,24 +1086,15 @@ def _update_per_bucket(
10841086
logger.warning(
10851087
f"[rank{self._rank}] Using _update_per_bucket, which is an experimental feature."
10861088
)
1087-
assert req_func is not None
1089+
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1090+
assert dist.is_initialized(), "process group is not initialized"
10881091
# if both ranks is None or [], it will use fully broadcast to update to all ranks
10891092
if not ranks:
1090-
if len(self._current_global_parameter_metas) == 0:
1091-
raise ValueError("parameter metas is empty")
1092-
1093-
assert dist.is_initialized(), "process group is not initialized"
1094-
10951093
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
10961094
# if ranks is set, it will use p2p to update to the ranks
10971095
else:
10981096
assert self._p2p_store is not None, "p2p store is not initialized"
10991097
assert ranks, "ranks should be set"
1100-
if len(self._current_global_parameter_metas) == 0:
1101-
raise ValueError("parameter metas is empty")
1102-
assert dist.is_initialized(), (
1103-
"process group is not initialized when update model per bucket p2p"
1104-
)
11051098

11061099
need_update = self._rank in ranks
11071100
logger.info(
@@ -1131,10 +1124,10 @@ def _update_per_bucket(
11311124
# p2p store need to register h2d_buffer to let other ranks read
11321125
if ranks:
11331126
h2d_buffer_name = "__h2d_buffer__"
1134-
self._p2p_store.register_named_tensors(
1135-
{h2d_buffer_name: h2d_buffer}
1136-
) if h2d_buffer is not None else None
1137-
1127+
if h2d_buffer is not None and self._p2p_store is not None:
1128+
self._p2p_store.register_named_tensors(
1129+
{h2d_buffer_name: h2d_buffer}
1130+
)
11381131
receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
11391132
for receiver_rank, owner_rank, bucket in buckets:
11401133
if receiver_rank != self._rank:
@@ -1160,17 +1153,15 @@ def _update_per_bucket(
11601153
socket.send_pyobj(handle)
11611154

11621155
gidx = 0
1156+
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
11631157
for i in range(max_len):
11641158
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1165-
if not ranks:
1166-
self._copy_to_buffer(checkpoint_name, receiver_rank_buckets[i][1], h2d_buffer)
1167-
else:
1168-
self._copy_to_buffer(
1169-
checkpoint_name,
1170-
receiver_rank_buckets[i][1],
1171-
h2d_buffer,
1172-
receiver_rank_buckets[i][0],
1173-
)
1159+
self._copy_to_buffer(
1160+
checkpoint_name,
1161+
receiver_rank_buckets[i][1],
1162+
h2d_buffer,
1163+
receiver_rank_buckets[i][0] if ranks else None,
1164+
)
11741165
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
11751166
if i >= len(_buckets):
11761167
continue
@@ -1191,7 +1182,7 @@ def _update_per_bucket(
11911182
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
11921183
else:
11931184
buffer_b.data.copy_(h2d_buffer[: bucket.size])
1194-
brank = self._get_bcast_rank_map(ranks)[receiver_rank]
1185+
brank = bcast_rank_map[receiver_rank]
11951186
dist.broadcast(buffer_b, src=brank)
11961187
socket.recv()
11971188
dist.barrier()

0 commit comments

Comments
 (0)