Skip to content

Commit 992dbba

Browse files
committed
fix: resolve PR issues
1 parent ebf793c commit 992dbba

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

checkpoint_engine/ps.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -487,12 +487,18 @@ class MemoryBucket(BaseModel):
487487
def register_pin_memory(
488488
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
489489
) -> tuple[int, torch.Tensor]:
490-
buffer = (
491-
torch.empty(size, dtype=torch.uint8, pin_memory=True)
492-
if not shared_pin_memory
493-
else shared_pin_memory[idx].buffer
494-
)
495-
return idx, buffer
490+
if shared_pin_memory:
491+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492+
assert idx < len(shared_pin_memory), (
493+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
494+
)
495+
assert shared_pin_memory[idx].size == size, (
496+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
497+
)
498+
return idx, shared_pin_memory[idx].buffer
499+
else:
500+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
501+
return idx, buffer
496502

497503
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
498504
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
@@ -803,8 +809,7 @@ def __init__(
803809
self._zmq_addr_counter = 0
804810

805811
self.shared_memory_pool_name = "__shared_memory_pool__"
806-
# this dict stores all currently registered checkpoints
807-
# dict key is checkpoint_name, value is whether use shared memory pool
812+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
808813
self._current_shared_memory_pool_user: str = ""
809814
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
810815
self._memory_pool[self.shared_memory_pool_name] = []
@@ -823,16 +828,16 @@ def __init__(
823828
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
824829

825830
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
826-
if (
827-
checkpoint_name not in self._memory_pool
828-
and checkpoint_name != self._current_shared_memory_pool_user
829-
):
830-
raise RuntimeError(f"checkpoint {checkpoint_name} not registered in memory pool")
831-
return (
832-
self._memory_pool[checkpoint_name]
833-
if checkpoint_name != self._current_shared_memory_pool_user
834-
else self._memory_pool[self.shared_memory_pool_name]
835-
)
831+
if checkpoint_name == self._current_shared_memory_pool_user:
832+
if not self._memory_pool[self.shared_memory_pool_name]:
833+
raise RuntimeError(
834+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
835+
)
836+
return self._memory_pool[self.shared_memory_pool_name]
837+
elif checkpoint_name in self._memory_pool:
838+
return self._memory_pool[checkpoint_name]
839+
else:
840+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
836841

837842
def _logger_rank0(self, msg: str):
838843
if self._local_rank == 0:
@@ -866,20 +871,20 @@ def register_checkpoint(
866871
checkpoint_name: The name of the checkpoint.
867872
files: The safetensors files to register.
868873
named_tensors: The named tensors to register.
874+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
875+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
876+
cannot accommodate checkpoints with different memory requirements.
869877
"""
870878
try:
871879
if use_shared_memory_pool:
872880
logger.info(
873881
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
874882
)
875-
if self._current_shared_memory_pool_user:
876-
logger.error(
877-
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
878-
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
879-
f"This registration may cause unexpected conflicts."
880-
)
881-
return
882-
883+
assert self._current_shared_memory_pool_user == "", (
884+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
885+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
886+
f"This registration may cause unexpected conflicts."
887+
)
883888
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
884889
files=files or [],
885890
named_tensors=named_tensors or {},

0 commit comments

Comments
 (0)