Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 98 additions & 18 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def _register_checkpoint(
files: list[str],
named_tensors: dict[str, torch.Tensor],
rank: int | None = None,
shared_pin_memory: list[MemoryBuffer] | None = None,
) -> list[MemoryBuffer]:
logger.info(
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
Expand Down Expand Up @@ -483,16 +484,34 @@ class MemoryBucket(BaseModel):
for bucket in buckets
]

def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer
def register_pin_memory(
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
) -> tuple[int, torch.Tensor]:
if shared_pin_memory:
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
assert idx < len(shared_pin_memory), (
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
)
assert shared_pin_memory[idx].size == size, (
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
)
return idx, shared_pin_memory[idx].buffer
else:
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer

def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)

with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = [
executor.submit(register_pin_memory, idx, bucket.size)
executor.submit(
register_pin_memory,
idx,
bucket.size,
shared_pin_memory,
)
for idx, bucket in enumerate(buckets)
]
new_futures = []
Expand Down Expand Up @@ -747,6 +766,8 @@ def batch_transfer_sync_read(


class ParameterServer:
shared_memory_pool_name = "__shared_memory_pool__"

def __init__(
self,
*,
Expand Down Expand Up @@ -790,7 +811,10 @@ def __init__(
self._zmq_ctx = zmq.Context()
self._zmq_addr_counter = 0

# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
self._current_shared_memory_pool_user: str = ""
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
self._memory_pool[self.shared_memory_pool_name] = []
# dict key is owner_rank, value is a bucket metas list in owner_rank
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
# NPU transfer engine initialization requires prior set_device.
Expand All @@ -805,6 +829,17 @@ def __init__(
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device

def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
if checkpoint_name == self._current_shared_memory_pool_user:
assert self._memory_pool[self.shared_memory_pool_name], (
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
)
return self._memory_pool[self.shared_memory_pool_name]
elif checkpoint_name in self._memory_pool:
return self._memory_pool[checkpoint_name]
else:
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
logger.info(msg)
Expand All @@ -828,6 +863,7 @@ def register_checkpoint(
*,
files: list[str] | None = None,
named_tensors: dict[str, torch.Tensor] | None = None,
use_shared_memory_pool: bool = False,
) -> None:
"""
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
Expand All @@ -836,21 +872,46 @@ def register_checkpoint(
checkpoint_name: The name of the checkpoint.
files: The safetensors files to register.
named_tensors: The named tensors to register.
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
cannot accommodate checkpoints with different memory requirements.
"""
try:
assert checkpoint_name not in self._memory_pool, (
f"checkpoint {checkpoint_name} already registered"
)
self._memory_pool[checkpoint_name] = _register_checkpoint(
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
)
if self._p2p_store is not None:
self._register_parameters_to_p2p_store(checkpoint_name)
if use_shared_memory_pool:
logger.info(
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
)
assert self._current_shared_memory_pool_user == "", (
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
f"This registration may cause unexpected conflicts."
)
# Since we set the uninitialized shared memory pool to empty list,
# we can check whether this is the first time to use shared memory pool
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
files=files or [],
named_tensors=named_tensors or {},
rank=self._rank,
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
)
self._current_shared_memory_pool_user = checkpoint_name
if self._p2p_store is not None and _is_first_time:
self._register_parameters_to_p2p_store(checkpoint_name)
else:
assert checkpoint_name not in self._memory_pool, (
f"checkpoint {checkpoint_name} already registered"
)
self._memory_pool[checkpoint_name] = _register_checkpoint(
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
)
if self._p2p_store is not None:
self._register_parameters_to_p2p_store(checkpoint_name)
except Exception:
logger.exception(
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
)
if self._p2p_store is not None:
if self._p2p_store is not None and not use_shared_memory_pool:
self._unregister_parameters_from_p2p_store(checkpoint_name)
self.unregister_checkpoint(checkpoint_name)
raise
Expand All @@ -860,13 +921,28 @@ def unregister_checkpoint(self, checkpoint_name: str):
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
from p2p store if p2p store is initialized.
"""
if checkpoint_name not in self._memory_pool:
if (
checkpoint_name not in self._memory_pool
and checkpoint_name != self._current_shared_memory_pool_user
):
logger.warning(
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
)
return

# TODO: currently, we just mark the shared memory pool as unused when unregistering.
# Physically releasing the shared memory pool is not supported yet.
# We may add unregister shared memory pool logic in the future if necessary.
if checkpoint_name == self._current_shared_memory_pool_user:
self._current_shared_memory_pool_user = ""
return

if self._p2p_store is not None:
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
logger.info(
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
)

del self._memory_pool[checkpoint_name]
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
# this works by using torch>=2.5.0
Expand All @@ -882,14 +958,18 @@ def gather_metas(self, checkpoint_name: str):
self.init_process_group()
assert dist.is_initialized(), "process group is not initialized"
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
try:
memory_pool = self._get_memory_pool(checkpoint_name)
except RuntimeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接忽略RuntimeError不太好,因为这是个比较general的异常。最好是类似dict.get的做法,参数里加个allow_not_found=True或者default=None什么的,然后返回None来判断

memory_pool = []
metas = DataToGather(
memory_buffer_metas_list=[
MemoryBufferMetas(
metas=x.metas,
ptr=x.buffer.data_ptr(),
size=x.size,
)
for x in self._memory_pool.get(checkpoint_name, [])
for x in memory_pool
],
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
host_ip=get_ip(),
Expand Down Expand Up @@ -1095,7 +1175,7 @@ def _copy_to_buffer(
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
lens.append(b.size)
else:
pool = self._memory_pool[checkpoint_name][b.idx]
pool = self._get_memory_pool(checkpoint_name)[b.idx]
buffer[offset : offset + b.size].data.copy_(
pool.buffer[b.offset : b.offset + b.size],
non_blocking=True,
Expand Down Expand Up @@ -1158,7 +1238,7 @@ def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:

def _register_parameters_to_p2p_store(self, checkpoint_name: str):
assert self._p2p_store is not None, "p2p store is not initialized"
pool = self._memory_pool[checkpoint_name]
pool = self._get_memory_pool(checkpoint_name)
if len(pool) == 0:
return
named_tensors, tensor_ptrs = {}, []
Expand All @@ -1169,7 +1249,7 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str):

def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
assert self._p2p_store is not None, "p2p store is not initialized"
pool = self._memory_pool[checkpoint_name]
pool = self._get_memory_pool(checkpoint_name)
if len(pool) == 0:
return 0
return self._p2p_store.unregister_named_tensors(
Expand Down
65 changes: 65 additions & 0 deletions tests/test_pin_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os

import pytest
import torch

from checkpoint_engine.ps import ParameterServer


def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
"""
Generate dummy checkpoint data
"""
named_tensors = {
"layer1.weight": torch.randn(1024, 1024),
"layer1.bias": torch.randn(1024),
"layer2.weight": torch.randn(2048, 1024),
"layer2.bias": torch.randn(2048),
}
return named_tensors


@pytest.mark.gpu
def test_register_pin_memory():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
ps = ParameterServer()
checkpoint1 = generate_dummy_checkpoint()
checkpoint_shared1 = generate_dummy_checkpoint()
checkpoint2 = generate_dummy_checkpoint()
checkpoint_shared2 = generate_dummy_checkpoint()
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
ps.unregister_checkpoint("test_checkpoint1")
assert "test_checkpoint1" not in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True
)
ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2)
assert "test_checkpoint_shared1" not in ps._memory_pool
assert "__shared_memory_pool__" in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
assert "test_checkpoint2" in ps._memory_pool
try:
ps.register_checkpoint(
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
) # this will fail
except AssertionError:
print("Caught expected AssertionError when registering second shared memory pool user")
assert "test_checkpoint_shared2" not in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
ps.unregister_checkpoint("test_checkpoint_shared1")
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
)
assert "test_checkpoint_shared2" not in ps._memory_pool
assert "__shared_memory_pool__" in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning
assert "test_checkpoint1" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint2")
assert "test_checkpoint2" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint_shared2")
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool