Skip to content

Commit 9d4efa5

Browse files
committed
feat: inplace pin and normal pin compatible
1 parent 52ea453 commit 9d4efa5

File tree

1 file changed

+54
-14
lines changed

1 file changed

+54
-14
lines changed

checkpoint_engine/ps.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class ParameterMeta(BaseModel):
9393
name: str
9494
dtype: _TorchDtype
9595
shape: _TorchSize
96+
manually_aligned: bool = True
9697

9798

9899
class BucketRange(NamedTuple):
@@ -141,7 +142,11 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
141142
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
142143
ret = []
143144
for meta in metas:
144-
size = _align_size(meta.dtype, meta.shape)
145+
size = (
146+
_align_size(meta.dtype, meta.shape)
147+
if meta.manually_aligned
148+
else meta.dtype.itemsize * meta.shape.numel()
149+
)
145150
ret.append(
146151
{
147152
"name": meta.name,
@@ -462,12 +467,8 @@ def _register_checkpoint(
462467
if not files and not named_tensors:
463468
return []
464469
memory_buffers: list[MemoryBuffer] = []
465-
inplace_pin = all(
466-
file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
467-
for file in files or []
468-
)
469-
if inplace_pin:
470470

471+
def inplace_pin_memory(files: list[str]) -> list[MemoryBuffer]:
471472
def _pin(t: torch.Tensor):
472473
"""
473474
Pin the memory of tensor in-place.
@@ -494,6 +495,7 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
494495
n = int.from_bytes(n, byteorder="little", signed=False)
495496
start_pos = n + flag_size
496497

498+
os.remove(file_path)
497499
time.sleep(3)
498500
header_tensor = t[flag_size:start_pos]
499501
header = json.loads(header_tensor.numpy().tobytes())
@@ -506,7 +508,10 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
506508
assert offset == start, f"offset {offset} should be equal to start {start}"
507509
metas.append(
508510
ParameterMeta(
509-
name=name, dtype=_getdtype(meta["dtype"]), shape=torch.Size(meta["shape"])
511+
name=name,
512+
dtype=_getdtype(meta["dtype"]),
513+
shape=torch.Size(meta["shape"]),
514+
manually_aligned=False,
510515
)
511516
)
512517
offset = end
@@ -518,13 +523,24 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
518523
_pin(buffer)
519524
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
520525

526+
local_memory_buffers: list[MemoryBuffer] = []
527+
lock = threading.Lock()
528+
idx = 0
521529
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
522530
futures = [executor.submit(_inplace_pin_memory, file) for file in files]
523531
for future in concurrent.futures.as_completed(futures):
524532
memory_buffer = future.result()
525-
memory_buffers.append(memory_buffer)
533+
with lock:
534+
local_memory_buffers.append(memory_buffer)
535+
logger.info(
536+
f"[rank{rank}] register pin_memory for file in /dev/shm {idx + 1}/{len(files)} finished"
537+
)
538+
idx += 1
539+
return local_memory_buffers
526540

527-
else:
541+
def normal_pin_memory(
542+
files: list[str], named_tensors: dict[str, torch.Tensor]
543+
) -> list[MemoryBuffer]:
528544
parameters = _load_checkpoint(files)
529545
if named_tensors:
530546
parameters.update(named_tensors)
@@ -534,7 +550,8 @@ class MemoryBucket(BaseModel):
534550
size: int
535551
metas: list[ParameterMeta]
536552

537-
buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
553+
buckets: list[MemoryBucket] = []
554+
buckets.append(MemoryBucket(size=0, metas=[]))
538555
for name, tensor in sorted(parameters.items()):
539556
size = _align_size(tensor.dtype, tensor.shape)
540557
if buckets[-1].size + size > bucket_size:
@@ -545,7 +562,7 @@ class MemoryBucket(BaseModel):
545562
)
546563
buckets[-1].size += size
547564

548-
memory_buffers = [
565+
local_memory_buffers = [
549566
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
550567
for bucket in buckets
551568
]
@@ -568,7 +585,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
568585
assert buffer.numel() == buckets[idx].size, (
569586
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
570587
)
571-
memory_buffers[idx].buffer = buffer
588+
local_memory_buffers[idx].buffer = buffer
572589
logger.info(
573590
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
574591
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
@@ -585,6 +602,20 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
585602
offset += size
586603
for future in concurrent.futures.as_completed(new_futures):
587604
future.result()
605+
return local_memory_buffers
606+
607+
files_to_inplace_pin = [
608+
file
609+
for file in files
610+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
611+
]
612+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
613+
if files_to_normal_pin or named_tensors:
614+
memory_buffers.extend(
615+
normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors)
616+
)
617+
if files_to_inplace_pin:
618+
memory_buffers.extend(inplace_pin_memory(files_to_inplace_pin))
588619

589620
return memory_buffers
590621

@@ -634,7 +665,11 @@ def _gen_h2d_buckets(
634665
for idx, metas in enumerate(items.memory_buffer_metas_list):
635666
start_offset, offset = 0, 0
636667
for meta in metas.metas:
637-
s = _align_size(meta.dtype, meta.shape)
668+
s = (
669+
_align_size(meta.dtype, meta.shape)
670+
if meta.manually_aligned
671+
else meta.dtype.itemsize * meta.shape.numel()
672+
)
638673
if buckets[-1][1].size + s > bucket_size:
639674
if offset - start_offset > 0:
640675
buckets[-1][1].ranges.append(
@@ -1106,7 +1141,12 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11061141
for items in self._current_global_parameter_metas.values():
11071142
for metas_list in items.memory_buffer_metas_list:
11081143
for meta in metas_list.metas:
1109-
max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1144+
max_tensor_bytes = max(
1145+
max_tensor_bytes,
1146+
_align_size(meta.dtype, meta.shape)
1147+
if meta.manually_aligned
1148+
else meta.dtype.itemsize * meta.shape.numel(),
1149+
)
11101150
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
11111151
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
11121152
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")

0 commit comments

Comments
 (0)