diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 61eafb7..7b31f44 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -916,7 +916,7 @@ def register_checkpoint( self.unregister_checkpoint(checkpoint_name) raise - def unregister_checkpoint(self, checkpoint_name: str): + def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None: """ Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint from p2p store if p2p store is initialized. @@ -930,10 +930,7 @@ def unregister_checkpoint(self, checkpoint_name: str): ) return - # TODO: currently, we just mark the shared memory pool as unused when unregistering. - # Physically releasing the shared memory pool is not supported yet. - # We may add unregister shared memory pool logic in the future if necessary. - if checkpoint_name == self._current_shared_memory_pool_user: + if checkpoint_name == self._current_shared_memory_pool_user and not force: self._current_shared_memory_pool_user = "" return @@ -943,7 +940,12 @@ def unregister_checkpoint(self, checkpoint_name: str): f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" ) - del self._memory_pool[checkpoint_name] + if checkpoint_name == self._current_shared_memory_pool_user: + self._current_shared_memory_pool_user = "" + del self._memory_pool[self.shared_memory_pool_name] + self._memory_pool[self.shared_memory_pool_name] = [] + else: + del self._memory_pool[checkpoint_name] # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018 # this works by using torch>=2.5.0 torch._C._host_emptyCache() @@ -1242,8 +1244,13 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str): if len(pool) == 0: return named_tensors, tensor_ptrs = {}, [] + register_name = ( + checkpoint_name + if checkpoint_name != self._current_shared_memory_pool_user + else self.shared_memory_pool_name + ) for idx, memory_buffer in enumerate(pool): - named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer + named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size)) self._p2p_store.register_named_tensors(named_tensors) @@ -1252,8 +1259,13 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int: pool = self._get_memory_pool(checkpoint_name) if len(pool) == 0: return 0 + unregister_name = ( + checkpoint_name + if checkpoint_name != self._current_shared_memory_pool_user + else self.shared_memory_pool_name + ) return self._p2p_store.unregister_named_tensors( - [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)] + [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)] ) def _update_per_bucket( diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py index 8daa4ac..bb698b7 100644 --- a/tests/test_pin_memory.py +++ b/tests/test_pin_memory.py @@ -28,6 +28,9 @@ def test_register_pin_memory(): checkpoint_shared1 = generate_dummy_checkpoint() checkpoint2 = generate_dummy_checkpoint() checkpoint_shared2 = generate_dummy_checkpoint() + checkpoint_shared3 = generate_dummy_checkpoint() + checkpoint_shared3["layer3.weight"] = torch.randn(4096, 2048) + checkpoint_shared3["layer3.bias"] = torch.randn(4096) ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1) ps.unregister_checkpoint("test_checkpoint1") assert "test_checkpoint1" not in ps._memory_pool @@ -60,6 +63,15 @@ def test_register_pin_memory(): assert "test_checkpoint1" not in ps._memory_pool ps.unregister_checkpoint("test_checkpoint2") assert "test_checkpoint2" not in ps._memory_pool - ps.unregister_checkpoint("test_checkpoint_shared2") + ps.unregister_checkpoint("test_checkpoint_shared2", force=True) + assert ps._current_shared_memory_pool_user == "" + assert "__shared_memory_pool__" in ps._memory_pool + ps.register_checkpoint( + "test_checkpoint_shared3", named_tensors=checkpoint_shared3, use_shared_memory_pool=True + ) + assert "test_checkpoint_shared3" not in ps._memory_pool + assert "__shared_memory_pool__" in ps._memory_pool + assert ps._current_shared_memory_pool_user == "test_checkpoint_shared3" + ps.unregister_checkpoint("test_checkpoint_shared3") assert ps._current_shared_memory_pool_user == "" assert "__shared_memory_pool__" in ps._memory_pool