Skip to content

Commit 52ea453

Browse files
committed
feat: inplace pin memory for safetensors in /dev/shm/
1 parent 67b0020 commit 52ea453

File tree

1 file changed

+121
-53
lines changed

1 file changed

+121
-53
lines changed

checkpoint_engine/ps.py

Lines changed: 121 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import concurrent.futures
33
import ctypes
4+
import json
45
import os
56
import pickle
67
import random
@@ -18,7 +19,7 @@
1819
import zmq
1920
from loguru import logger
2021
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
21-
from safetensors.torch import safe_open
22+
from safetensors.torch import _getdtype, safe_open
2223
from torch.multiprocessing.reductions import reduce_tensor
2324

2425
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -460,64 +461,131 @@ def _register_checkpoint(
460461
)
461462
if not files and not named_tensors:
462463
return []
463-
parameters = _load_checkpoint(files)
464-
if named_tensors:
465-
parameters.update(named_tensors)
466-
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
464+
memory_buffers: list[MemoryBuffer] = []
465+
inplace_pin = all(
466+
file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
467+
for file in files or []
468+
)
469+
if inplace_pin:
470+
471+
def _pin(t: torch.Tensor):
472+
"""
473+
Pin the memory of tensor in-place.
474+
See: https://github.com/pytorch/pytorch/issues/32167
475+
"""
476+
cudart = torch.cuda.cudart()
477+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
478+
assert r == 0, f"pin memory error, error code: {r.value}"
479+
480+
def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
481+
# TODO: should only support /dev/shm? but we found files in disk also work?
482+
size = os.stat(file_path).st_size
483+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
484+
485+
# safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
486+
# We load the safetensors file as bytes, then parse the header manually to get parameter metas.
487+
# and the actual tensor data is in the remaining bytes.
488+
# We pin the remaining bytes as the buffer, making pinning faster.
489+
flag_size = 8
490+
with open(file_path, "rb") as f:
491+
n = bytearray(flag_size)
492+
data = f.readinto(n)
493+
assert data == flag_size, f"data {data} should be equal to flag_size {flag_size}"
494+
n = int.from_bytes(n, byteorder="little", signed=False)
495+
start_pos = n + flag_size
496+
497+
time.sleep(3)
498+
header_tensor = t[flag_size:start_pos]
499+
header = json.loads(header_tensor.numpy().tobytes())
500+
501+
metas: list[ParameterMeta] = []
502+
offset = 0
503+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
504+
start, end = meta["data_offsets"]
505+
# safetensors format ensures offsets are aligned
506+
assert offset == start, f"offset {offset} should be equal to start {start}"
507+
metas.append(
508+
ParameterMeta(
509+
name=name, dtype=_getdtype(meta["dtype"]), shape=torch.Size(meta["shape"])
510+
)
511+
)
512+
offset = end
467513

468-
class MemoryBucket(BaseModel):
469-
size: int
470-
metas: list[ParameterMeta]
471-
472-
buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
473-
for name, tensor in sorted(parameters.items()):
474-
size = _align_size(tensor.dtype, tensor.shape)
475-
if buckets[-1].size + size > bucket_size:
476-
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
477-
buckets.append(MemoryBucket(size=0, metas=[]))
478-
buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
479-
buckets[-1].size += size
480-
481-
memory_buffers = [
482-
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
483-
for bucket in buckets
484-
]
514+
buffer = t[start_pos:]
515+
assert offset == buffer.nbytes, (
516+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
517+
)
518+
_pin(buffer)
519+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
485520

486-
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
487-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
488-
return idx, buffer
521+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
522+
futures = [executor.submit(_inplace_pin_memory, file) for file in files]
523+
for future in concurrent.futures.as_completed(futures):
524+
memory_buffer = future.result()
525+
memory_buffers.append(memory_buffer)
489526

490-
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
491-
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
527+
else:
528+
parameters = _load_checkpoint(files)
529+
if named_tensors:
530+
parameters.update(named_tensors)
531+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
532+
533+
class MemoryBucket(BaseModel):
534+
size: int
535+
metas: list[ParameterMeta]
536+
537+
buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
538+
for name, tensor in sorted(parameters.items()):
539+
size = _align_size(tensor.dtype, tensor.shape)
540+
if buckets[-1].size + size > bucket_size:
541+
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
542+
buckets.append(MemoryBucket(size=0, metas=[]))
543+
buckets[-1].metas.append(
544+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype)
545+
)
546+
buckets[-1].size += size
492547

493-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
494-
futures = [
495-
executor.submit(register_pin_memory, idx, bucket.size)
496-
for idx, bucket in enumerate(buckets)
548+
memory_buffers = [
549+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
550+
for bucket in buckets
497551
]
498-
new_futures = []
499-
for future in concurrent.futures.as_completed(futures):
500-
idx, buffer = future.result()
501-
assert buffer.numel() == buckets[idx].size, (
502-
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
503-
)
504-
memory_buffers[idx].buffer = buffer
505-
logger.info(
506-
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
507-
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
508-
)
509-
offset = 0
510-
for meta in buckets[idx].metas:
511-
name = meta.name
512-
tensor = parameters[name]
513-
size = _align_size(tensor.dtype, tensor.shape)
514-
assert size == _align_size(meta.dtype, meta.shape), (
515-
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
552+
553+
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
554+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
555+
return idx, buffer
556+
557+
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
558+
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
559+
560+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
561+
futures = [
562+
executor.submit(register_pin_memory, idx, bucket.size)
563+
for idx, bucket in enumerate(buckets)
564+
]
565+
new_futures = []
566+
for future in concurrent.futures.as_completed(futures):
567+
idx, buffer = future.result()
568+
assert buffer.numel() == buckets[idx].size, (
569+
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
570+
)
571+
memory_buffers[idx].buffer = buffer
572+
logger.info(
573+
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
574+
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
516575
)
517-
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
518-
offset += size
519-
for future in concurrent.futures.as_completed(new_futures):
520-
future.result()
576+
offset = 0
577+
for meta in buckets[idx].metas:
578+
name = meta.name
579+
tensor = parameters[name]
580+
size = _align_size(tensor.dtype, tensor.shape)
581+
assert size == _align_size(meta.dtype, meta.shape), (
582+
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
583+
)
584+
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
585+
offset += size
586+
for future in concurrent.futures.as_completed(new_futures):
587+
future.result()
588+
521589
return memory_buffers
522590

523591

0 commit comments

Comments
 (0)