Skip to content

Commit f69e116

Browse files
authored
feat: force unregister shared pin memory buffer supported (#62)
* feat: force unregister the shared memory pool * feat: tests added for force unregister shared memory pool
1 parent e88d462 commit f69e116

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

checkpoint_engine/ps.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def register_checkpoint(
10271027
self.unregister_checkpoint(checkpoint_name)
10281028
raise
10291029

1030-
def unregister_checkpoint(self, checkpoint_name: str):
1030+
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
10311031
"""
10321032
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
10331033
from p2p store if p2p store is initialized.
@@ -1041,10 +1041,7 @@ def unregister_checkpoint(self, checkpoint_name: str):
10411041
)
10421042
return
10431043

1044-
# TODO: currently, we just mark the shared memory pool as unused when unregistering.
1045-
# Physically releasing the shared memory pool is not supported yet.
1046-
# We may add unregister shared memory pool logic in the future if necessary.
1047-
if checkpoint_name == self._current_shared_memory_pool_user:
1044+
if checkpoint_name == self._current_shared_memory_pool_user and not force:
10481045
self._current_shared_memory_pool_user = ""
10491046
return
10501047

@@ -1054,7 +1051,12 @@ def unregister_checkpoint(self, checkpoint_name: str):
10541051
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
10551052
)
10561053

1057-
del self._memory_pool[checkpoint_name]
1054+
if checkpoint_name == self._current_shared_memory_pool_user:
1055+
self._current_shared_memory_pool_user = ""
1056+
del self._memory_pool[self.shared_memory_pool_name]
1057+
self._memory_pool[self.shared_memory_pool_name] = []
1058+
else:
1059+
del self._memory_pool[checkpoint_name]
10581060
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
10591061
# this works by using torch>=2.5.0
10601062
torch._C._host_emptyCache()
@@ -1353,8 +1355,13 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str):
13531355
if len(pool) == 0:
13541356
return
13551357
named_tensors, tensor_ptrs = {}, []
1358+
register_name = (
1359+
checkpoint_name
1360+
if checkpoint_name != self._current_shared_memory_pool_user
1361+
else self.shared_memory_pool_name
1362+
)
13561363
for idx, memory_buffer in enumerate(pool):
1357-
named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
1364+
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
13581365
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
13591366
self._p2p_store.register_named_tensors(named_tensors)
13601367

@@ -1363,8 +1370,13 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
13631370
pool = self._get_memory_pool(checkpoint_name)
13641371
if len(pool) == 0:
13651372
return 0
1373+
unregister_name = (
1374+
checkpoint_name
1375+
if checkpoint_name != self._current_shared_memory_pool_user
1376+
else self.shared_memory_pool_name
1377+
)
13661378
return self._p2p_store.unregister_named_tensors(
1367-
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1379+
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
13681380
)
13691381

13701382
def _update_per_bucket(

tests/test_pin_memory.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def test_register_pin_memory():
2828
checkpoint_shared1 = generate_dummy_checkpoint()
2929
checkpoint2 = generate_dummy_checkpoint()
3030
checkpoint_shared2 = generate_dummy_checkpoint()
31+
checkpoint_shared3 = generate_dummy_checkpoint()
32+
checkpoint_shared3["layer3.weight"] = torch.randn(4096, 2048)
33+
checkpoint_shared3["layer3.bias"] = torch.randn(4096)
3134
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
3235
ps.unregister_checkpoint("test_checkpoint1")
3336
assert "test_checkpoint1" not in ps._memory_pool
@@ -60,6 +63,15 @@ def test_register_pin_memory():
6063
assert "test_checkpoint1" not in ps._memory_pool
6164
ps.unregister_checkpoint("test_checkpoint2")
6265
assert "test_checkpoint2" not in ps._memory_pool
63-
ps.unregister_checkpoint("test_checkpoint_shared2")
66+
ps.unregister_checkpoint("test_checkpoint_shared2", force=True)
67+
assert ps._current_shared_memory_pool_user == ""
68+
assert "__shared_memory_pool__" in ps._memory_pool
69+
ps.register_checkpoint(
70+
"test_checkpoint_shared3", named_tensors=checkpoint_shared3, use_shared_memory_pool=True
71+
)
72+
assert "test_checkpoint_shared3" not in ps._memory_pool
73+
assert "__shared_memory_pool__" in ps._memory_pool
74+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared3"
75+
ps.unregister_checkpoint("test_checkpoint_shared3")
6476
assert ps._current_shared_memory_pool_user == ""
6577
assert "__shared_memory_pool__" in ps._memory_pool

0 commit comments

Comments
 (0)