Skip to content

Commit 7dfd396

Browse files
author
kip-cxj
committed
revert to store_based_barrier
1 parent 47a2561 commit 7dfd396

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

checkpoint_engine/distributed/nccl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def new_group(self, ranks: list[int], **kwargs) -> CommGroup:
223223
else:
224224
ranks.sort()
225225

226+
group: CommGroup = None
226227
newcomm = self.pynccl.create_newcomm(ranks)
227228
if newcomm:
228229
group = CommGroup(newcomm.value, ranks)

checkpoint_engine/ps.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,14 +553,26 @@ def update(
553553
try:
554554
master_addr = os.getenv("MASTER_ADDR") or master_addr
555555
assert master_addr, "master_addr is required"
556-
if self._auto_pg and not dist.is_initialized():
557-
self.init_process_group(
558-
timeout=timeout, master_addr=master_addr, master_port=master_port
556+
if self._auto_pg:
557+
if not dist.is_initialized():
558+
self.init_process_group(
559+
timeout=timeout, master_addr=master_addr, master_port=master_port
560+
)
561+
manager_store = torch.distributed.distributed_c10d._get_default_store()
562+
else:
563+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
564+
# If master_port is provided, use master_port+1 for barrier store
565+
manager_store = torch.distributed.TCPStore(
566+
master_addr,
567+
_get_master_port(master_port) + 1,
568+
self._world_size,
569+
timeout=timeout,
570+
is_master=self._rank == 0,
559571
)
560572
# if ranks is None or [], it will use fully broadcast to update to all ranks
561573
ranks_group = dist.new_group(ranks) if ranks else None
562574
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
563-
dist.barrier()
575+
self.store_based_barrier(manager_store)
564576
except Exception as e:
565577
logger.exception(
566578
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"

0 commit comments

Comments
 (0)