Skip to content

Commit 9b644df

Browse files
committed
fix: fix PR issues
1 parent 992dbba commit 9b644df

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

checkpoint_engine/ps.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def register_pin_memory(
488488
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
489489
) -> tuple[int, torch.Tensor]:
490490
if shared_pin_memory:
491+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
491492
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492493
assert idx < len(shared_pin_memory), (
493494
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
@@ -765,6 +766,8 @@ def batch_transfer_sync_read(
765766

766767

767768
class ParameterServer:
769+
shared_memory_pool_name = "__shared_memory_pool__"
770+
768771
def __init__(
769772
self,
770773
*,
@@ -808,7 +811,6 @@ def __init__(
808811
self._zmq_ctx = zmq.Context()
809812
self._zmq_addr_counter = 0
810813

811-
self.shared_memory_pool_name = "__shared_memory_pool__"
812814
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
813815
self._current_shared_memory_pool_user: str = ""
814816
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
@@ -829,10 +831,9 @@ def __init__(
829831

830832
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
831833
if checkpoint_name == self._current_shared_memory_pool_user:
832-
if not self._memory_pool[self.shared_memory_pool_name]:
833-
raise RuntimeError(
834-
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
835-
)
834+
assert self._memory_pool[self.shared_memory_pool_name], (
835+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
836+
)
836837
return self._memory_pool[self.shared_memory_pool_name]
837838
elif checkpoint_name in self._memory_pool:
838839
return self._memory_pool[checkpoint_name]
@@ -885,27 +886,32 @@ def register_checkpoint(
885886
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
886887
f"This registration may cause unexpected conflicts."
887888
)
889+
# Since we set the uninitialized shared memory pool to empty list,
890+
# we can check whether this is the first time to use shared memory pool
891+
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
888892
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
889893
files=files or [],
890894
named_tensors=named_tensors or {},
891895
rank=self._rank,
892896
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
893897
)
894898
self._current_shared_memory_pool_user = checkpoint_name
899+
if self._p2p_store is not None and _is_first_time:
900+
self._register_parameters_to_p2p_store(checkpoint_name)
895901
else:
896902
assert checkpoint_name not in self._memory_pool, (
897903
f"checkpoint {checkpoint_name} already registered"
898904
)
899905
self._memory_pool[checkpoint_name] = _register_checkpoint(
900906
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
901907
)
902-
if self._p2p_store is not None:
903-
self._register_parameters_to_p2p_store(checkpoint_name)
908+
if self._p2p_store is not None:
909+
self._register_parameters_to_p2p_store(checkpoint_name)
904910
except Exception:
905911
logger.exception(
906912
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
907913
)
908-
if self._p2p_store is not None:
914+
if self._p2p_store is not None and not use_shared_memory_pool:
909915
self._unregister_parameters_from_p2p_store(checkpoint_name)
910916
self.unregister_checkpoint(checkpoint_name)
911917
raise
@@ -920,20 +926,19 @@ def unregister_checkpoint(self, checkpoint_name: str):
920926
and checkpoint_name != self._current_shared_memory_pool_user
921927
):
922928
logger.warning(
923-
f"[rank{self._rank}] unregister checkpoint failed, checkpoint name {checkpoint_name} not found"
929+
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
924930
)
925931
return
932+
933+
if checkpoint_name == self._current_shared_memory_pool_user:
934+
self._current_shared_memory_pool_user = ""
935+
return
936+
926937
if self._p2p_store is not None:
927938
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
928939
logger.info(
929940
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
930941
)
931-
if checkpoint_name == self._current_shared_memory_pool_user:
932-
logger.info(
933-
f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool"
934-
)
935-
self._current_shared_memory_pool_user = ""
936-
return
937942

938943
del self._memory_pool[checkpoint_name]
939944
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018

tests/test_pin_memory.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ def test_register_pin_memory():
3939
assert "__shared_memory_pool__" in ps._memory_pool
4040
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
4141
assert "test_checkpoint2" in ps._memory_pool
42-
ps.register_checkpoint(
43-
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
44-
) # this will fail
42+
try:
43+
ps.register_checkpoint(
44+
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
45+
) # this will fail
46+
except AssertionError:
47+
print("Caught expected AssertionError when registering second shared memory pool user")
4548
assert "test_checkpoint_shared2" not in ps._memory_pool
4649
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
4750
ps.unregister_checkpoint("test_checkpoint_shared1")
@@ -53,7 +56,7 @@ def test_register_pin_memory():
5356
assert "test_checkpoint_shared2" not in ps._memory_pool
5457
assert "__shared_memory_pool__" in ps._memory_pool
5558
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
56-
ps.unregister_checkpoint("test_checkpoint1")
59+
ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning
5760
assert "test_checkpoint1" not in ps._memory_pool
5861
ps.unregister_checkpoint("test_checkpoint2")
5962
assert "test_checkpoint2" not in ps._memory_pool

0 commit comments

Comments
 (0)