Skip to content

Commit 7c6054c

Browse files
committed
misc: fix pr issues
1 parent f9b5a0f commit 7c6054c

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

checkpoint_engine/ps.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,9 @@ def init_process_group(
10311031
)
10321032
logger.info(f"[rank{self._rank}] init process group successfully.")
10331033

1034-
def store_based_barrier(self, store: dist.TCPStore) -> None:
1034+
def store_based_barrier(
1035+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1036+
) -> None:
10351037
"""
10361038
Perform a store-based barrier synchronization across all ranks.
10371039
@@ -1047,14 +1049,15 @@ def store_based_barrier(self, store: dist.TCPStore) -> None:
10471049
store=store,
10481050
group_name="parameter_server_barrier",
10491051
rendezvous_count=self._world_size,
1050-
timeout=timedelta(minutes=5),
1052+
timeout=timeout,
10511053
)
10521054

10531055
def update(
10541056
self,
10551057
checkpoint_name: str,
10561058
req_func: Callable[[list[tuple[str, str]]], None],
10571059
*,
1060+
timeout: timedelta = timedelta(minutes=10),
10581061
ranks: list[int] | None = None,
10591062
) -> None:
10601063
"""
@@ -1073,28 +1076,23 @@ def update(
10731076
try:
10741077
master_addr = os.getenv("MASTER_ADDR")
10751078
assert master_addr, "master_addr is required"
1076-
1077-
# HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
1078-
manager_store = dist.TCPStore(
1079-
master_addr,
1080-
_get_master_port() + 1,
1081-
self._world_size,
1082-
timeout=timedelta(minutes=10),
1083-
is_master=self._rank == 0,
1084-
)
1085-
1086-
if self._auto_pg and not dist.is_initialized():
1087-
self.init_process_group()
1088-
1079+
if self._auto_pg:
1080+
if not dist.is_initialized():
1081+
self.init_process_group(timeout=timeout)
1082+
manager_store = dist.distributed_c10d._get_default_store()
1083+
else:
1084+
# HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
1085+
manager_store = dist.TCPStore(
1086+
master_addr,
1087+
_get_master_port() + 1,
1088+
self._world_size,
1089+
timeout=timeout,
1090+
is_master=self._rank == 0,
1091+
)
10891092
# if both ranks is None or [], it will use fully broadcast to update to all ranks
10901093
ranks_group = dist.new_group(ranks if ranks else None)
1091-
if not ranks:
1092-
self._update_per_bucket(checkpoint_name, req_func, ranks_group)
1093-
else:
1094-
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1095-
1094+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
10961095
self.store_based_barrier(manager_store)
1097-
10981096
except Exception as e:
10991097
logger.exception(
11001098
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"

0 commit comments

Comments
 (0)