Skip to content

Commit 5a26fbf

Browse files
committed
fix: merge issues fixed
1 parent 7c6054c commit 5a26fbf

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

checkpoint_engine/ps.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,8 @@ def update(
10591059
*,
10601060
timeout: timedelta = timedelta(minutes=10),
10611061
ranks: list[int] | None = None,
1062+
master_addr: str | None = None,
1063+
master_port: int | None = None,
10621064
) -> None:
10631065
"""
10641066
Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1070,21 +1072,27 @@ def update(
10701072
which is the fastest way to update weights, especially in colocated architecture.
10711073
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
10721074
which is useful in disaggregated architecture.
1075+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1076+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1077+
timeout: The timeout of the barrier operation.
10731078
"""
10741079
assert req_func is not None, "req_func is required"
10751080
ranks_group = None
10761081
try:
1077-
master_addr = os.getenv("MASTER_ADDR")
1082+
master_addr = os.getenv("MASTER_ADDR") or master_addr
10781083
assert master_addr, "master_addr is required"
10791084
if self._auto_pg:
10801085
if not dist.is_initialized():
1081-
self.init_process_group(timeout=timeout)
1086+
self.init_process_group(
1087+
timeout=timeout, master_addr=master_addr, master_port=master_port
1088+
)
10821089
manager_store = dist.distributed_c10d._get_default_store()
10831090
else:
1084-
# HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
1091+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1092+
# If master_port is provided, use master_port+1 for barrier store
10851093
manager_store = dist.TCPStore(
10861094
master_addr,
1087-
_get_master_port() + 1,
1095+
_get_master_port(master_port) + 1,
10881096
self._world_size,
10891097
timeout=timeout,
10901098
is_master=self._rank == 0,
@@ -1331,16 +1339,15 @@ def _update_per_bucket(
13311339
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
13321340
else:
13331341
buffer_b.data.copy_(h2d_buffer[: bucket.size])
1334-
brank = bcast_rank_map[receiver_rank]
1335-
dist.broadcast(buffer_b, src=brank)
1342+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
13361343
resp = socket.recv()
13371344
if resp != b"":
13381345
msg = resp.decode("utf-8")
13391346
logger.error(
13401347
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
13411348
)
13421349
ret_code.fill_(1)
1343-
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1350+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
13441351
self.device_manager.device_module.synchronize()
13451352
if ret_code.item() != 0:
13461353
# quit early if any rank failed

0 commit comments

Comments
 (0)