Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def register_checkpoint(
self.unregister_checkpoint(checkpoint_name)
raise

def unregister_checkpoint(self, checkpoint_name: str):
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
"""
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
from p2p store if p2p store is initialized.
Expand All @@ -930,10 +930,7 @@ def unregister_checkpoint(self, checkpoint_name: str):
)
return

# TODO: currently, we just mark the shared memory pool as unused when unregistering.
# Physically releasing the shared memory pool is not supported yet.
# We may add unregister shared memory pool logic in the future if necessary.
if checkpoint_name == self._current_shared_memory_pool_user:
if checkpoint_name == self._current_shared_memory_pool_user and not force:
self._current_shared_memory_pool_user = ""
return

Expand All @@ -943,7 +940,12 @@ def unregister_checkpoint(self, checkpoint_name: str):
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
)

del self._memory_pool[checkpoint_name]
if checkpoint_name == self._current_shared_memory_pool_user:
self._current_shared_memory_pool_user = ""
del self._memory_pool[self.shared_memory_pool_name]
self._memory_pool[self.shared_memory_pool_name] = []
else:
del self._memory_pool[checkpoint_name]
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
# this works by using torch>=2.5.0
torch._C._host_emptyCache()
Expand Down Expand Up @@ -1242,8 +1244,13 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str):
if len(pool) == 0:
return
named_tensors, tensor_ptrs = {}, []
register_name = (
checkpoint_name
if checkpoint_name != self._current_shared_memory_pool_user
else self.shared_memory_pool_name
)
for idx, memory_buffer in enumerate(pool):
named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
self._p2p_store.register_named_tensors(named_tensors)

Expand All @@ -1252,8 +1259,13 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
pool = self._get_memory_pool(checkpoint_name)
if len(pool) == 0:
return 0
unregister_name = (
checkpoint_name
if checkpoint_name != self._current_shared_memory_pool_user
else self.shared_memory_pool_name
)
return self._p2p_store.unregister_named_tensors(
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
)

def _update_per_bucket(
Expand Down
14 changes: 13 additions & 1 deletion tests/test_pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def test_register_pin_memory():
checkpoint_shared1 = generate_dummy_checkpoint()
checkpoint2 = generate_dummy_checkpoint()
checkpoint_shared2 = generate_dummy_checkpoint()
checkpoint_shared3 = generate_dummy_checkpoint()
checkpoint_shared3["layer3.weight"] = torch.randn(4096, 2048)
checkpoint_shared3["layer3.bias"] = torch.randn(4096)
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
ps.unregister_checkpoint("test_checkpoint1")
assert "test_checkpoint1" not in ps._memory_pool
Expand Down Expand Up @@ -60,6 +63,15 @@ def test_register_pin_memory():
assert "test_checkpoint1" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint2")
assert "test_checkpoint2" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint_shared2")
ps.unregister_checkpoint("test_checkpoint_shared2", force=True)
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared3", named_tensors=checkpoint_shared3, use_shared_memory_pool=True
)
Comment on lines +69 to +71
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test registers checkpoint_shared3 with different tensor shapes (4096x2048 and 4096) than checkpoint_shared2 (1024x1024, 1024, 2048x1024, 2048) after force unregistering the shared pool. However, according to the register_checkpoint docstring (line 875-877 in ps.py), "The pool's shape is fixed on first use and cannot accommodate checkpoints with different memory requirements." This test should verify that registering a checkpoint with different memory requirements after a force unregister works correctly, or it should assert/test for failure if the fixed shape restriction still applies after the shared pool is cleared.

Copilot uses AI. Check for mistakes.
assert "test_checkpoint_shared3" not in ps._memory_pool
assert "__shared_memory_pool__" in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared3"
ps.unregister_checkpoint("test_checkpoint_shared3")
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool