-
Notifications
You must be signed in to change notification settings - Fork 78
feat: reuse pin_memory when registering checkpoint #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5ee3886
feat: Reuse shared pinned memory buffers
specture724 ebf793c
feat: basic test added
specture724 992dbba
fix: resolve PR issues
specture724 9b644df
fix: fix PR issues
specture724 b976022
fix: handle runtime error when getting memory pool
specture724 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -454,6 +454,7 @@ def _register_checkpoint( | |
| files: list[str], | ||
| named_tensors: dict[str, torch.Tensor], | ||
| rank: int | None = None, | ||
| shared_pin_memory: list[MemoryBuffer] | None = None, | ||
| ) -> list[MemoryBuffer]: | ||
| logger.info( | ||
| f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" | ||
|
|
@@ -483,16 +484,34 @@ class MemoryBucket(BaseModel): | |
| for bucket in buckets | ||
| ] | ||
|
|
||
| def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: | ||
| buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) | ||
| return idx, buffer | ||
| def register_pin_memory( | ||
| idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None | ||
| ) -> tuple[int, torch.Tensor]: | ||
| if shared_pin_memory: | ||
| # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one | ||
| # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time | ||
| assert idx < len(shared_pin_memory), ( | ||
| f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" | ||
| ) | ||
| assert shared_pin_memory[idx].size == size, ( | ||
| f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" | ||
| ) | ||
| return idx, shared_pin_memory[idx].buffer | ||
| else: | ||
| buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) | ||
| return idx, buffer | ||
|
|
||
| def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): | ||
| buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) | ||
|
|
||
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: | ||
| futures = [ | ||
| executor.submit(register_pin_memory, idx, bucket.size) | ||
| executor.submit( | ||
| register_pin_memory, | ||
| idx, | ||
| bucket.size, | ||
| shared_pin_memory, | ||
| ) | ||
| for idx, bucket in enumerate(buckets) | ||
| ] | ||
| new_futures = [] | ||
|
|
@@ -747,6 +766,8 @@ def batch_transfer_sync_read( | |
|
|
||
|
|
||
| class ParameterServer: | ||
| shared_memory_pool_name = "__shared_memory_pool__" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
|
|
@@ -790,7 +811,10 @@ def __init__( | |
| self._zmq_ctx = zmq.Context() | ||
| self._zmq_addr_counter = 0 | ||
|
|
||
| # stores the name of the checkpoint currently using the shared memory pool, or empty string if none | ||
| self._current_shared_memory_pool_user: str = "" | ||
| self._memory_pool: dict[str, list[MemoryBuffer]] = {} | ||
| self._memory_pool[self.shared_memory_pool_name] = [] | ||
| # dict key is owner_rank, value is a bucket metas list in owner_rank | ||
| self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {} | ||
| # NPU transfer engine initialization requires prior set_device. | ||
|
|
@@ -805,6 +829,17 @@ def __init__( | |
| self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) | ||
| self._rdma_device = None if self._p2p_store is None else self._p2p_store.device | ||
|
|
||
| def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]: | ||
| if checkpoint_name == self._current_shared_memory_pool_user: | ||
| assert self._memory_pool[self.shared_memory_pool_name], ( | ||
| f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it" | ||
| ) | ||
| return self._memory_pool[self.shared_memory_pool_name] | ||
| elif checkpoint_name in self._memory_pool: | ||
| return self._memory_pool[checkpoint_name] | ||
| else: | ||
| raise RuntimeError(f"checkpoint {checkpoint_name} is not registered") | ||
|
|
||
| def _logger_rank0(self, msg: str): | ||
| if self._local_rank == 0: | ||
| logger.info(msg) | ||
|
|
@@ -828,6 +863,7 @@ def register_checkpoint( | |
| *, | ||
| files: list[str] | None = None, | ||
| named_tensors: dict[str, torch.Tensor] | None = None, | ||
| use_shared_memory_pool: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. | ||
|
|
@@ -836,21 +872,46 @@ def register_checkpoint( | |
| checkpoint_name: The name of the checkpoint. | ||
| files: The safetensors files to register. | ||
| named_tensors: The named tensors to register. | ||
specture724 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory. | ||
| Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and | ||
| cannot accommodate checkpoints with different memory requirements. | ||
| """ | ||
| try: | ||
| assert checkpoint_name not in self._memory_pool, ( | ||
| f"checkpoint {checkpoint_name} already registered" | ||
| ) | ||
| self._memory_pool[checkpoint_name] = _register_checkpoint( | ||
| files=files or [], named_tensors=named_tensors or {}, rank=self._rank | ||
| ) | ||
| if self._p2p_store is not None: | ||
| self._register_parameters_to_p2p_store(checkpoint_name) | ||
| if use_shared_memory_pool: | ||
| logger.info( | ||
| f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool" | ||
| ) | ||
| assert self._current_shared_memory_pool_user == "", ( | ||
| f"cannot register checkpoint {checkpoint_name} to shared memory pool, " | ||
| f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. " | ||
| f"This registration may cause unexpected conflicts." | ||
| ) | ||
| # Since we set the uninitialized shared memory pool to empty list, | ||
| # we can check whether this is the first time to use shared memory pool | ||
| _is_first_time = not self._memory_pool[self.shared_memory_pool_name] | ||
| self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint( | ||
| files=files or [], | ||
| named_tensors=named_tensors or {}, | ||
| rank=self._rank, | ||
| shared_pin_memory=self._memory_pool[self.shared_memory_pool_name], | ||
| ) | ||
| self._current_shared_memory_pool_user = checkpoint_name | ||
| if self._p2p_store is not None and _is_first_time: | ||
| self._register_parameters_to_p2p_store(checkpoint_name) | ||
| else: | ||
| assert checkpoint_name not in self._memory_pool, ( | ||
| f"checkpoint {checkpoint_name} already registered" | ||
| ) | ||
| self._memory_pool[checkpoint_name] = _register_checkpoint( | ||
| files=files or [], named_tensors=named_tensors or {}, rank=self._rank | ||
| ) | ||
| if self._p2p_store is not None: | ||
| self._register_parameters_to_p2p_store(checkpoint_name) | ||
| except Exception: | ||
specture724 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.exception( | ||
| f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}" | ||
| ) | ||
| if self._p2p_store is not None: | ||
| if self._p2p_store is not None and not use_shared_memory_pool: | ||
| self._unregister_parameters_from_p2p_store(checkpoint_name) | ||
| self.unregister_checkpoint(checkpoint_name) | ||
| raise | ||
|
|
@@ -860,13 +921,28 @@ def unregister_checkpoint(self, checkpoint_name: str): | |
| Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint | ||
| from p2p store if p2p store is initialized. | ||
| """ | ||
| if checkpoint_name not in self._memory_pool: | ||
| if ( | ||
| checkpoint_name not in self._memory_pool | ||
| and checkpoint_name != self._current_shared_memory_pool_user | ||
| ): | ||
| logger.warning( | ||
| f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found" | ||
| ) | ||
| return | ||
|
|
||
| # TODO: currently, we just mark the shared memory pool as unused when unregistering. | ||
| # Physically releasing the shared memory pool is not supported yet. | ||
| # We may add unregister shared memory pool logic in the future if necessary. | ||
| if checkpoint_name == self._current_shared_memory_pool_user: | ||
| self._current_shared_memory_pool_user = "" | ||
specture724 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return | ||
|
|
||
| if self._p2p_store is not None: | ||
| num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name) | ||
| logger.info( | ||
| f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" | ||
| ) | ||
|
|
||
| del self._memory_pool[checkpoint_name] | ||
| # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018 | ||
| # this works by using torch>=2.5.0 | ||
|
|
@@ -882,14 +958,18 @@ def gather_metas(self, checkpoint_name: str): | |
| self.init_process_group() | ||
| assert dist.is_initialized(), "process group is not initialized" | ||
| metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore | ||
| try: | ||
| memory_pool = self._get_memory_pool(checkpoint_name) | ||
| except RuntimeError: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里直接忽略RuntimeError不太好,因为这是个比较general的异常。最好是类似 |
||
| memory_pool = [] | ||
| metas = DataToGather( | ||
| memory_buffer_metas_list=[ | ||
| MemoryBufferMetas( | ||
| metas=x.metas, | ||
| ptr=x.buffer.data_ptr(), | ||
| size=x.size, | ||
| ) | ||
| for x in self._memory_pool.get(checkpoint_name, []) | ||
| for x in memory_pool | ||
| ], | ||
| p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr, | ||
| host_ip=get_ip(), | ||
|
|
@@ -1095,7 +1175,7 @@ def _copy_to_buffer( | |
| remote_ptrs.append(ptrs[b.idx][0] + b.offset) | ||
| lens.append(b.size) | ||
| else: | ||
| pool = self._memory_pool[checkpoint_name][b.idx] | ||
| pool = self._get_memory_pool(checkpoint_name)[b.idx] | ||
| buffer[offset : offset + b.size].data.copy_( | ||
| pool.buffer[b.offset : b.offset + b.size], | ||
| non_blocking=True, | ||
|
|
@@ -1158,7 +1238,7 @@ def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]: | |
|
|
||
| def _register_parameters_to_p2p_store(self, checkpoint_name: str): | ||
| assert self._p2p_store is not None, "p2p store is not initialized" | ||
| pool = self._memory_pool[checkpoint_name] | ||
| pool = self._get_memory_pool(checkpoint_name) | ||
| if len(pool) == 0: | ||
| return | ||
| named_tensors, tensor_ptrs = {}, [] | ||
|
|
@@ -1169,7 +1249,7 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str): | |
|
|
||
| def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int: | ||
| assert self._p2p_store is not None, "p2p store is not initialized" | ||
| pool = self._memory_pool[checkpoint_name] | ||
| pool = self._get_memory_pool(checkpoint_name) | ||
| if len(pool) == 0: | ||
| return 0 | ||
| return self._p2p_store.unregister_named_tensors( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from checkpoint_engine.ps import ParameterServer | ||
|
|
||
|
|
||
| def generate_dummy_checkpoint() -> dict[str, torch.Tensor]: | ||
| """ | ||
| Generate dummy checkpoint data | ||
| """ | ||
| named_tensors = { | ||
| "layer1.weight": torch.randn(1024, 1024), | ||
| "layer1.bias": torch.randn(1024), | ||
| "layer2.weight": torch.randn(2048, 1024), | ||
| "layer2.bias": torch.randn(2048), | ||
| } | ||
| return named_tensors | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| def test_register_pin_memory(): | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| ps = ParameterServer() | ||
| checkpoint1 = generate_dummy_checkpoint() | ||
| checkpoint_shared1 = generate_dummy_checkpoint() | ||
| checkpoint2 = generate_dummy_checkpoint() | ||
| checkpoint_shared2 = generate_dummy_checkpoint() | ||
| ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1) | ||
| ps.unregister_checkpoint("test_checkpoint1") | ||
| assert "test_checkpoint1" not in ps._memory_pool | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True | ||
| ) | ||
| ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2) | ||
| assert "test_checkpoint_shared1" not in ps._memory_pool | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" | ||
| assert "test_checkpoint2" in ps._memory_pool | ||
| try: | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True | ||
| ) # this will fail | ||
| except AssertionError: | ||
| print("Caught expected AssertionError when registering second shared memory pool user") | ||
| assert "test_checkpoint_shared2" not in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" | ||
| ps.unregister_checkpoint("test_checkpoint_shared1") | ||
| assert ps._current_shared_memory_pool_user == "" | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True | ||
| ) | ||
| assert "test_checkpoint_shared2" not in ps._memory_pool | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2" | ||
| ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning | ||
| assert "test_checkpoint1" not in ps._memory_pool | ||
| ps.unregister_checkpoint("test_checkpoint2") | ||
| assert "test_checkpoint2" not in ps._memory_pool | ||
| ps.unregister_checkpoint("test_checkpoint_shared2") | ||
| assert ps._current_shared_memory_pool_user == "" | ||
| assert "__shared_memory_pool__" in ps._memory_pool |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.