Skip to content

Commit 6980e59

Browse files
committed
fix: fix logical error when destroying ranks group
1 parent 90c6456 commit 6980e59

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

checkpoint_engine/ps.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,7 @@ def update(
969969
which is useful in disaggregated architecture.
970970
"""
971971
assert req_func is not None, "req_func is required"
972+
ranks_group = None
972973
try:
973974
master_addr = os.getenv("MASTER_ADDR")
974975
assert master_addr, "master_addr is required"
@@ -993,22 +994,23 @@ def update(
993994
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
994995

995996
self.store_based_barrier(manager_store)
996-
logger.info(
997-
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
998-
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
999-
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1000-
)
997+
1001998
except Exception as e:
1002999
logger.exception(
10031000
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
10041001
)
10051002
raise
10061003
finally:
1007-
if not ranks_group:
1004+
if ranks_group:
10081005
dist.destroy_process_group(ranks_group)
10091006
if self._auto_pg and dist.is_initialized():
10101007
dist.destroy_process_group()
10111008
self.device_manager.device_module.empty_cache()
1009+
logger.info(
1010+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1011+
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
1012+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1013+
)
10121014

10131015
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
10141016
def zmq_handle(device_uuid: str) -> str:

0 commit comments

Comments
 (0)