Skip to content

Commit 5bf79dc

Browse files
committed
fix: store based barrier for all processes' synchronization
1 parent 78f325c commit 5bf79dc

File tree

1 file changed

+37
-69
lines changed

1 file changed

+37
-69
lines changed

checkpoint_engine/ps.py

Lines changed: 37 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -650,20 +650,6 @@ def _get_master_port(master_port: int | None = None) -> int:
650650
return master_port
651651

652652

653-
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
654-
"""
655-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
656-
which are generated in self.init_process_group_for_ranks
657-
"""
658-
bcast_rank_map: dict[int, int] = {}
659-
if not ranks:
660-
bcast_rank_map = {r: r for r in range(world_size)}
661-
else:
662-
for i, r in enumerate(ranks):
663-
bcast_rank_map[r] = i
664-
return bcast_rank_map
665-
666-
667653
class P2PStore:
668654
def __init__(self, device_manager: DeviceManager):
669655
from mooncake.engine import TransferEngine
@@ -965,21 +951,41 @@ def update(
965951
"""
966952
assert req_func is not None, "req_func is required"
967953
try:
954+
manager_store = dist.TCPStore(
955+
os.getenv("MASTER_ADDR"),
956+
_get_master_port() + 1,
957+
self._world_size,
958+
timeout=timedelta(minutes=10),
959+
is_master=self._rank == 0,
960+
)
968961
# if both ranks is None or [], it will use fully broadcast to update to all ranks
969962
if not ranks:
970963
if self._auto_pg and not dist.is_initialized():
971964
self.init_process_group()
972-
self._update_per_bucket(checkpoint_name, req_func)
965+
ranks_group = dist.new_group()
966+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
973967
else:
974968
if self._auto_pg:
975969
if dist.is_initialized():
976970
dist.destroy_process_group()
977971
# HACK: wait 2s to ensure destroy is finished
978972
time.sleep(2)
979-
self.init_process_group_for_ranks(ranks)
980-
if self._rank not in ranks:
981-
return
982-
self._update_per_bucket(checkpoint_name, req_func, ranks)
973+
self.init_process_group()
974+
ranks_group = dist.new_group(ranks)
975+
logger.info(
976+
f"[rank{self._rank}] default pg: {dist.group.WORLD}, ranks group: {ranks_group}"
977+
)
978+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
979+
980+
dist.distributed_c10d._store_based_barrier(
981+
rank=self._rank,
982+
store=manager_store,
983+
group_name="manager_store_barrier",
984+
rendezvous_count=self._world_size,
985+
timeout=timedelta(minutes=5),
986+
)
987+
dist.destroy_process_group(ranks_group)
988+
del ranks_group
983989
if self._auto_pg:
984990
dist.destroy_process_group()
985991

@@ -1006,7 +1012,9 @@ def zmq_handle(device_uuid: str) -> str:
10061012
self._zmq_addr_counter += 1
10071013
return socket, socket_paths
10081014

1009-
def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
1015+
def _detect_bucket_size(
1016+
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1017+
) -> tuple[int, bool]:
10101018
GiB = 1 << 30 # noqa: N806
10111019
# auto detect bucket size
10121020
tensor = torch.tensor(
@@ -1022,7 +1030,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
10221030
dtype=torch.int64,
10231031
device=self.device_manager.device_type,
10241032
)
1025-
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1033+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
10261034
tensor = tensor.cpu()
10271035
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
10281036
max_tensor_bytes = 0
@@ -1085,47 +1093,6 @@ def _copy_to_buffer(
10851093
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
10861094
self.device_manager.device_module.synchronize()
10871095

1088-
def init_process_group_for_ranks(
1089-
self,
1090-
ranks: list[int],
1091-
*,
1092-
master_port: int | None = None,
1093-
timeout: timedelta = timedelta(minutes=10),
1094-
):
1095-
"""
1096-
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1097-
1098-
Args:
1099-
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1100-
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1101-
timeout: The timeout of the process group.
1102-
"""
1103-
assert not dist.is_initialized()
1104-
assert ranks, "ranks should be set"
1105-
if self._rank not in ranks:
1106-
return
1107-
assert self._all_hosts, "all_hosts should be set"
1108-
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
1109-
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
1110-
)
1111-
rank = ranks.index(self._rank)
1112-
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
1113-
master_port = _get_master_port(master_port)
1114-
logger.info(
1115-
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
1116-
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
1117-
)
1118-
# only initialize process group and store for ranks, other nodes are not initialized
1119-
# and will not participate in this update. Since they have registered memory addresses
1120-
# to p2p_store at the beginning, update ranks can directly get the memory addresses
1121-
# from other nodes and put the weights into the buffer.
1122-
store = dist.TCPStore(
1123-
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1124-
)
1125-
dist.init_process_group(
1126-
backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
1127-
)
1128-
11291096
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
11301097
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
11311098
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1155,10 +1122,13 @@ def _update_per_bucket(
11551122
self,
11561123
checkpoint_name: str,
11571124
req_func: Callable[[list[tuple[str, str]]], None],
1125+
ranks_group: dist.ProcessGroup,
11581126
ranks: list[int] | None = None,
11591127
):
11601128
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
11611129
assert dist.is_initialized(), "process group is not initialized"
1130+
assert ranks_group is not None, "ranks_group should be set"
1131+
11621132
# if both ranks is None or [], it will use fully broadcast to update to all ranks
11631133
if not ranks:
11641134
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -1176,9 +1146,9 @@ def _update_per_bucket(
11761146
if not need_update:
11771147
return
11781148
# first execute a barrier to avoid subsequent cuda oom
1179-
dist.barrier()
1149+
dist.barrier(group=ranks_group)
11801150

1181-
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1151+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
11821152
buckets = _gen_h2d_buckets(
11831153
self._current_global_parameter_metas,
11841154
bucket_size,
@@ -1224,7 +1194,6 @@ def _update_per_bucket(
12241194
socket.send_pyobj(handle)
12251195

12261196
gidx = 0
1227-
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
12281197
for i in range(max_len):
12291198
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
12301199
self._copy_to_buffer(
@@ -1253,18 +1222,17 @@ def _update_per_bucket(
12531222
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
12541223
else:
12551224
buffer_b.data.copy_(h2d_buffer[: bucket.size])
1256-
brank = bcast_rank_map[receiver_rank]
1257-
dist.broadcast(buffer_b, src=brank)
1225+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
12581226
socket.recv()
1259-
dist.barrier()
1227+
dist.barrier(group=ranks_group)
12601228
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
12611229
gidx += 1
12621230

12631231
socket.recv()
12641232
socket.send_pyobj(None)
12651233
socket.recv()
12661234
req_thread.join()
1267-
dist.barrier()
1235+
dist.barrier(group=ranks_group)
12681236
socket.close()
12691237
if ranks and h2d_buffer is not None:
12701238
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

0 commit comments

Comments
 (0)