Skip to content

Commit 93f3fa9

Browse files
committed
misc
1 parent c380f0c commit 93f3fa9

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

checkpoint_engine/ps.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,17 @@ def _pin(t: torch.Tensor):
522522
)
523523
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524524

525-
local_memory_buffers: list[MemoryBuffer] = []
525+
memory_buffers: list[MemoryBuffer] = []
526526
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527-
local_memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528-
return local_memory_buffers
527+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528+
return memory_buffers
529529

530530

531531
def _normal_pin_memory(
532532
files: list[str],
533533
named_tensors: dict[str, torch.Tensor],
534534
rank: int | None = None,
535+
shared_pin_memory: list[MemoryBuffer] | None = None,
535536
) -> list[MemoryBuffer]:
536537
parameters = _load_checkpoint(files)
537538
if named_tensors:
@@ -554,27 +555,27 @@ class MemoryBucket(BaseModel):
554555
)
555556
buckets[-1].size += size
556557

557-
local_memory_buffers = [
558+
memory_buffers = [
558559
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
559560
for bucket in buckets
560561
]
561562

562-
def register_pin_memory(
563-
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
564-
) -> tuple[int, torch.Tensor]:
565-
if shared_pin_memory:
566-
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
567-
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
568-
assert idx < len(shared_pin_memory), (
569-
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
570-
)
571-
assert shared_pin_memory[idx].size == size, (
572-
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
573-
)
574-
return idx, shared_pin_memory[idx].buffer
575-
else:
576-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
577-
return idx, buffer
563+
def register_pin_memory(
564+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565+
) -> tuple[int, torch.Tensor]:
566+
if shared_pin_memory:
567+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569+
assert idx < len(shared_pin_memory), (
570+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571+
)
572+
assert shared_pin_memory[idx].size == size, (
573+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574+
)
575+
return idx, shared_pin_memory[idx].buffer
576+
else:
577+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578+
return idx, buffer
578579

579580
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
580581
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
@@ -595,7 +596,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
595596
assert buffer.numel() == buckets[idx].size, (
596597
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
597598
)
598-
local_memory_buffers[idx].buffer = buffer
599+
memory_buffers[idx].buffer = buffer
599600
logger.info(
600601
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
601602
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
@@ -612,14 +613,15 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
612613
offset += size
613614
for future in concurrent.futures.as_completed(new_futures):
614615
future.result()
615-
return local_memory_buffers
616+
return memory_buffers
616617

617618

618619
def _register_checkpoint(
619620
*,
620621
files: list[str],
621622
named_tensors: dict[str, torch.Tensor],
622623
rank: int | None = None,
624+
shared_pin_memory: list[MemoryBuffer] | None = None,
623625
) -> list[MemoryBuffer]:
624626
logger.info(
625627
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
@@ -635,7 +637,12 @@ def _register_checkpoint(
635637
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
636638
if files_to_normal_pin or named_tensors:
637639
memory_buffers.extend(
638-
_normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors, rank=rank)
640+
_normal_pin_memory(
641+
files=files_to_normal_pin,
642+
named_tensors=named_tensors,
643+
rank=rank,
644+
shared_pin_memory=shared_pin_memory,
645+
)
639646
)
640647
if files_to_inplace_pin:
641648
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
@@ -986,8 +993,8 @@ def register_checkpoint(
986993
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
987994
)
988995
assert self._current_shared_memory_pool_user == "", (
989-
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
990-
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
996+
f"cannot register checkpoint '{checkpoint_name}' to shared memory pool, "
997+
f"since checkpoint '{self._current_shared_memory_pool_user}' is already using shared memory pool. "
991998
f"This registration may cause unexpected conflicts."
992999
)
9931000
# Since we set the uninitialized shared memory pool to empty list,

0 commit comments

Comments
 (0)