Skip to content

Commit 71af910

Browse files
committed
fix: resolve PR issues
1 parent 5b3e9da commit 71af910

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

checkpoint_engine/ps.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -487,12 +487,18 @@ class MemoryBucket(BaseModel):
487487
def register_pin_memory(
488488
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
489489
) -> tuple[int, torch.Tensor]:
490-
buffer = (
491-
torch.empty(size, dtype=torch.uint8, pin_memory=True)
492-
if not shared_pin_memory
493-
else shared_pin_memory[idx].buffer
494-
)
495-
return idx, buffer
490+
if shared_pin_memory:
491+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492+
assert idx < len(shared_pin_memory), (
493+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
494+
)
495+
assert shared_pin_memory[idx].size == size, (
496+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
497+
)
498+
return idx, shared_pin_memory[idx].buffer
499+
else:
500+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
501+
return idx, buffer
496502

497503
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
498504
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
@@ -794,8 +800,7 @@ def __init__(
794800
self._zmq_addr_counter = 0
795801

796802
self.shared_memory_pool_name = "__shared_memory_pool__"
797-
# this dict stores all currently registered checkpoints
798-
# dict key is checkpoint_name, value is whether use shared memory pool
803+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
799804
self._current_shared_memory_pool_user: str = ""
800805
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
801806
self._memory_pool[self.shared_memory_pool_name] = []
@@ -813,16 +818,16 @@ def __init__(
813818
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
814819

815820
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
816-
if (
817-
checkpoint_name not in self._memory_pool
818-
and checkpoint_name != self._current_shared_memory_pool_user
819-
):
820-
raise RuntimeError(f"checkpoint {checkpoint_name} not registered in memory pool")
821-
return (
822-
self._memory_pool[checkpoint_name]
823-
if checkpoint_name != self._current_shared_memory_pool_user
824-
else self._memory_pool[self.shared_memory_pool_name]
825-
)
821+
if checkpoint_name == self._current_shared_memory_pool_user:
822+
if not self._memory_pool[self.shared_memory_pool_name]:
823+
raise RuntimeError(
824+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
825+
)
826+
return self._memory_pool[self.shared_memory_pool_name]
827+
elif checkpoint_name in self._memory_pool:
828+
return self._memory_pool[checkpoint_name]
829+
else:
830+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
826831

827832
def _logger_rank0(self, msg: str):
828833
if self._local_rank == 0:
@@ -856,20 +861,20 @@ def register_checkpoint(
856861
checkpoint_name: The name of the checkpoint.
857862
files: The safetensors files to register.
858863
named_tensors: The named tensors to register.
864+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
865+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
866+
cannot accommodate checkpoints with different memory requirements.
859867
"""
860868
try:
861869
if use_shared_memory_pool:
862870
logger.info(
863871
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
864872
)
865-
if self._current_shared_memory_pool_user:
866-
logger.error(
867-
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
868-
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
869-
f"This registration may cause unexpected conflicts."
870-
)
871-
return
872-
873+
assert self._current_shared_memory_pool_user == "", (
874+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
875+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
876+
f"This registration may cause unexpected conflicts."
877+
)
873878
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
874879
files=files or [],
875880
named_tensors=named_tensors or {},

0 commit comments

Comments
 (0)