Skip to content

Commit 5b68376

Browse files
committed
misc: fix pr issues
1 parent 40e03b3 commit 5b68376

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

checkpoint_engine/ps.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,9 @@ def init_process_group(
939939
)
940940
logger.info(f"[rank{self._rank}] init process group successfully.")
941941

942-
def store_based_barrier(self, store: dist.TCPStore) -> None:
942+
def store_based_barrier(
943+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
944+
) -> None:
943945
"""
944946
Perform a store-based barrier synchronization across all ranks.
945947
@@ -955,14 +957,15 @@ def store_based_barrier(self, store: dist.TCPStore) -> None:
955957
store=store,
956958
group_name="parameter_server_barrier",
957959
rendezvous_count=self._world_size,
958-
timeout=timedelta(minutes=5),
960+
timeout=timeout,
959961
)
960962

961963
def update(
962964
self,
963965
checkpoint_name: str,
964966
req_func: Callable[[list[tuple[str, str]]], None],
965967
*,
968+
timeout: timedelta = timedelta(minutes=10),
966969
ranks: list[int] | None = None,
967970
) -> None:
968971
"""
@@ -981,28 +984,23 @@ def update(
981984
try:
982985
master_addr = os.getenv("MASTER_ADDR")
983986
assert master_addr, "master_addr is required"
984-
985-
# HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
986-
manager_store = dist.TCPStore(
987-
master_addr,
988-
_get_master_port() + 1,
989-
self._world_size,
990-
timeout=timedelta(minutes=10),
991-
is_master=self._rank == 0,
992-
)
993-
994-
if self._auto_pg and not dist.is_initialized():
995-
self.init_process_group()
996-
987+
if self._auto_pg:
988+
if not dist.is_initialized():
989+
self.init_process_group(timeout=timeout)
990+
manager_store = dist.distributed_c10d._get_default_store()
991+
else:
992+
# HACK: MASTER_PORT+2 for barrier store, _get_master_port() returns MASTER_PORT+1
993+
manager_store = dist.TCPStore(
994+
master_addr,
995+
_get_master_port() + 1,
996+
self._world_size,
997+
timeout=timeout,
998+
is_master=self._rank == 0,
999+
)
9971000
# if both ranks is None or [], it will use fully broadcast to update to all ranks
9981001
ranks_group = dist.new_group(ranks if ranks else None)
999-
if not ranks:
1000-
self._update_per_bucket(checkpoint_name, req_func, ranks_group)
1001-
else:
1002-
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1003-
1002+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
10041003
self.store_based_barrier(manager_store)
1005-
10061004
except Exception as e:
10071005
logger.exception(
10081006
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"

0 commit comments

Comments
 (0)