Skip to content

Commit e9325bb

Browse files
committed
refactor: rewrite update process management logic
1 parent 5bf79dc commit e9325bb

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

checkpoint_engine/ps.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,15 @@ def init_process_group(
931931
)
932932
logger.info(f"[rank{self._rank}] init process group successfully.")
933933

934+
def store_based_barrier(self, store: dist.TCPStore):
935+
dist.distributed_c10d._store_based_barrier(
936+
rank=self._rank,
937+
store=store,
938+
group_name="parameter_server_barrier",
939+
rendezvous_count=self._world_size,
940+
timeout=timedelta(minutes=5),
941+
)
942+
934943
def update(
935944
self,
936945
checkpoint_name: str,
@@ -958,34 +967,27 @@ def update(
958967
timeout=timedelta(minutes=10),
959968
is_master=self._rank == 0,
960969
)
961-
# if both ranks is None or [], it will use fully broadcast to update to all ranks
962-
if not ranks:
963-
if self._auto_pg and not dist.is_initialized():
964-
self.init_process_group()
965-
ranks_group = dist.new_group()
966-
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
967-
else:
968-
if self._auto_pg:
970+
if self._auto_pg:
971+
if not ranks:
972+
if not dist.is_initialized():
973+
self.init_process_group()
974+
else:
969975
if dist.is_initialized():
970976
dist.destroy_process_group()
971977
# HACK: wait 2s to ensure destroy is finished
972978
time.sleep(2)
973979
self.init_process_group()
974-
ranks_group = dist.new_group(ranks)
975-
logger.info(
976-
f"[rank{self._rank}] default pg: {dist.group.WORLD}, ranks group: {ranks_group}"
977-
)
980+
981+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
982+
ranks_group = dist.new_group(ranks if ranks else None)
983+
if not ranks:
984+
self._update_per_bucket(checkpoint_name, req_func, ranks_group)
985+
else:
978986
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
979987

980-
dist.distributed_c10d._store_based_barrier(
981-
rank=self._rank,
982-
store=manager_store,
983-
group_name="manager_store_barrier",
984-
rendezvous_count=self._world_size,
985-
timeout=timedelta(minutes=5),
986-
)
988+
self.store_based_barrier(manager_store)
989+
987990
dist.destroy_process_group(ranks_group)
988-
del ranks_group
989991
if self._auto_pg:
990992
dist.destroy_process_group()
991993

0 commit comments

Comments
 (0)