Skip to content

Commit f9b5a0f

Browse files
committed
refactor: rewrite update process management logic
1 parent 0c8d3f2 commit f9b5a0f

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

checkpoint_engine/ps.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,25 @@ def init_process_group(
10311031
)
10321032
logger.info(f"[rank{self._rank}] init process group successfully.")
10331033

1034+
def store_based_barrier(self, store: dist.TCPStore) -> None:
1035+
"""
1036+
Perform a store-based barrier synchronization across all ranks.
1037+
1038+
This barrier uses a TCP store directly rather than a process group,
1039+
allowing all ranks to synchronize regardless of which process group
1040+
they belong to.
1041+
1042+
Args:
1043+
store: The TCPStore instance to use for synchronization.
1044+
"""
1045+
dist.distributed_c10d._store_based_barrier(
1046+
rank=self._rank,
1047+
store=store,
1048+
group_name="parameter_server_barrier",
1049+
rendezvous_count=self._world_size,
1050+
timeout=timedelta(minutes=5),
1051+
)
1052+
10341053
def update(
10351054
self,
10361055
checkpoint_name: str,
@@ -1050,52 +1069,42 @@ def update(
10501069
which is useful in disaggregated architecture.
10511070
"""
10521071
assert req_func is not None, "req_func is required"
1072+
ranks_group = None
10531073
try:
1074+
master_addr = os.getenv("MASTER_ADDR")
1075+
assert master_addr, "master_addr is required"
1076+
1077+
# HACK: MASTER_PORT+1 for main process group, MASTER_PORT+2 for barrier store
10541078
manager_store = dist.TCPStore(
1055-
os.getenv("MASTER_ADDR"),
1079+
master_addr,
10561080
_get_master_port() + 1,
10571081
self._world_size,
10581082
timeout=timedelta(minutes=10),
10591083
is_master=self._rank == 0,
10601084
)
1085+
1086+
if self._auto_pg and not dist.is_initialized():
1087+
self.init_process_group()
1088+
10611089
# if both ranks is None or [], it will use fully broadcast to update to all ranks
1090+
ranks_group = dist.new_group(ranks if ranks else None)
10621091
if not ranks:
1063-
if self._auto_pg and not dist.is_initialized():
1064-
self.init_process_group()
1065-
ranks_group = dist.new_group()
1066-
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1092+
self._update_per_bucket(checkpoint_name, req_func, ranks_group)
10671093
else:
1068-
if self._auto_pg:
1069-
if dist.is_initialized():
1070-
dist.destroy_process_group()
1071-
# HACK: wait 2s to ensure destroy is finished
1072-
time.sleep(2)
1073-
self.init_process_group()
1074-
ranks_group = dist.new_group(ranks)
1075-
logger.info(
1076-
f"[rank{self._rank}] default pg: {dist.group.WORLD}, ranks group: {ranks_group}"
1077-
)
10781094
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
10791095

1080-
dist.distributed_c10d._store_based_barrier(
1081-
rank=self._rank,
1082-
store=manager_store,
1083-
group_name="manager_store_barrier",
1084-
rendezvous_count=self._world_size,
1085-
timeout=timedelta(minutes=5),
1086-
)
1087-
dist.destroy_process_group(ranks_group)
1088-
del ranks_group
1096+
self.store_based_barrier(manager_store)
10891097

10901098
except Exception as e:
10911099
logger.exception(
10921100
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
10931101
)
10941102
raise
10951103
finally:
1096-
if self._auto_pg and (not ranks or self._rank in ranks):
1104+
if ranks_group:
1105+
dist.destroy_process_group(ranks_group)
1106+
if self._auto_pg and dist.is_initialized():
10971107
dist.destroy_process_group()
1098-
10991108
self.device_manager.device_module.empty_cache()
11001109
logger.info(
11011110
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -1228,7 +1237,6 @@ def _update_per_bucket(
12281237
):
12291238
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
12301239
assert dist.is_initialized(), "process group is not initialized"
1231-
assert ranks_group is not None, "ranks_group should be set"
12321240

12331241
# if both ranks is None or [], it will use fully broadcast to update to all ranks
12341242
if not ranks:

0 commit comments

Comments
 (0)