Skip to content

Commit 9dd3e3e

Browse files
authored
Merge branch 'MoonshotAI:main' into feat/force-unregister
2 parents 2d9a832 + f69e116 commit 9dd3e3e

File tree

3 files changed

+225
-18
lines changed

3 files changed

+225
-18
lines changed

checkpoint_engine/ps.py

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import concurrent.futures
33
import ctypes
4+
import json
45
import os
56
import pickle
67
import random
@@ -18,7 +19,7 @@
1819
import zmq
1920
from loguru import logger
2021
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
21-
from safetensors.torch import safe_open
22+
from safetensors.torch import _getdtype, safe_open
2223
from torch.multiprocessing.reductions import reduce_tensor
2324

2425
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
9293
name: str
9394
dtype: _TorchDtype
9495
shape: _TorchSize
96+
aligned_size: int
9597

9698

9799
class BucketRange(NamedTuple):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140142
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
141143
ret = []
142144
for meta in metas:
143-
size = _align_size(meta.dtype, meta.shape)
145+
size = meta.aligned_size
144146
ret.append(
145147
{
146148
"name": meta.name,
@@ -422,6 +424,7 @@ class TPMeta(BaseModel):
422424
name=parameter_name,
423425
shape=meta["shape"],
424426
dtype=meta["dtype"],
427+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
425428
)
426429
tp_meta = tp_metas[parameter_name]
427430
if tp_meta.concat_dim != -1:
@@ -431,7 +434,10 @@ class TPMeta(BaseModel):
431434
shape = list(parameter_metas[name].shape)
432435
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
433436
parameter_metas[name] = ParameterMeta(
434-
name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437+
name=name,
438+
shape=torch.Size(shape),
439+
dtype=parameter_metas[name].dtype,
440+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
435441
)
436442
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
437443
# TODO: here concat is serial, which may be slow
@@ -449,18 +455,85 @@ class TPMeta(BaseModel):
449455
return parameters
450456

451457

452-
def _register_checkpoint(
453-
*,
458+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460+
"""
461+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463+
The actual tensor data is in the remaining bytes and is naturally aligned.
464+
We pin the remaining bytes as the buffer, making pinning faster.
465+
"""
466+
467+
def _pin(t: torch.Tensor):
468+
"""
469+
Pin the memory of tensor in-place.
470+
See: https://github.com/pytorch/pytorch/issues/32167
471+
"""
472+
cudart = torch.cuda.cudart()
473+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474+
assert r == 0, f"pin memory error, error code: {r}"
475+
476+
# TODO: should only support /dev/shm? but we found files in disk also work?
477+
size = os.stat(file_path).st_size
478+
flag_size = 8
479+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480+
assert t.nbytes > flag_size, (
481+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482+
)
483+
start_pos = (
484+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485+
+ flag_size
486+
)
487+
header_tensor = t[flag_size:start_pos]
488+
header = json.loads(header_tensor.numpy().tobytes())
489+
if "__metadata__" in header:
490+
header.pop("__metadata__")
491+
492+
metas: list[ParameterMeta] = []
493+
offset = 0
494+
try:
495+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496+
start, end = meta["data_offsets"]
497+
# safetensors format ensures offsets are aligned
498+
assert offset == start, f"offset {offset} should be equal to start {start}"
499+
metas.append(
500+
ParameterMeta(
501+
name=name,
502+
dtype=_getdtype(meta["dtype"]),
503+
shape=torch.Size(meta["shape"]),
504+
aligned_size=end - start,
505+
)
506+
)
507+
offset = end
508+
except Exception as e:
509+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510+
raise
511+
512+
buffer = t[start_pos:]
513+
assert offset == buffer.nbytes, (
514+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515+
)
516+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
517+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518+
os.remove(file_path)
519+
_pin(buffer)
520+
logger.info(
521+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522+
)
523+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524+
525+
memory_buffers: list[MemoryBuffer] = []
526+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528+
return memory_buffers
529+
530+
531+
def _normal_pin_memory(
454532
files: list[str],
455533
named_tensors: dict[str, torch.Tensor],
456534
rank: int | None = None,
457535
shared_pin_memory: list[MemoryBuffer] | None = None,
458536
) -> list[MemoryBuffer]:
459-
logger.info(
460-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
461-
)
462-
if not files and not named_tensors:
463-
return []
464537
parameters = _load_checkpoint(files)
465538
if named_tensors:
466539
parameters.update(named_tensors)
@@ -470,13 +543,16 @@ class MemoryBucket(BaseModel):
470543
size: int
471544
metas: list[ParameterMeta]
472545

473-
buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
546+
buckets: list[MemoryBucket] = []
547+
buckets.append(MemoryBucket(size=0, metas=[]))
474548
for name, tensor in sorted(parameters.items()):
475549
size = _align_size(tensor.dtype, tensor.shape)
476550
if buckets[-1].size + size > bucket_size:
477551
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
478552
buckets.append(MemoryBucket(size=0, metas=[]))
479-
buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
553+
buckets[-1].metas.append(
554+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555+
)
480556
buckets[-1].size += size
481557

482558
memory_buffers = [
@@ -537,6 +613,39 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
537613
offset += size
538614
for future in concurrent.futures.as_completed(new_futures):
539615
future.result()
616+
return memory_buffers
617+
618+
619+
def _register_checkpoint(
620+
*,
621+
files: list[str],
622+
named_tensors: dict[str, torch.Tensor],
623+
rank: int | None = None,
624+
shared_pin_memory: list[MemoryBuffer] | None = None,
625+
) -> list[MemoryBuffer]:
626+
logger.info(
627+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628+
)
629+
if not files and not named_tensors:
630+
return []
631+
memory_buffers: list[MemoryBuffer] = []
632+
files_to_inplace_pin = [
633+
file
634+
for file in files
635+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636+
]
637+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
638+
if files_to_normal_pin or named_tensors:
639+
memory_buffers.extend(
640+
_normal_pin_memory(
641+
files=files_to_normal_pin,
642+
named_tensors=named_tensors,
643+
rank=rank,
644+
shared_pin_memory=shared_pin_memory,
645+
)
646+
)
647+
if files_to_inplace_pin:
648+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
540649
return memory_buffers
541650

542651

@@ -585,7 +694,7 @@ def _gen_h2d_buckets(
585694
for idx, metas in enumerate(items.memory_buffer_metas_list):
586695
start_offset, offset = 0, 0
587696
for meta in metas.metas:
588-
s = _align_size(meta.dtype, meta.shape)
697+
s = meta.aligned_size
589698
if buckets[-1][1].size + s > bucket_size:
590699
if offset - start_offset > 0:
591700
buckets[-1][1].ranges.append(
@@ -867,6 +976,8 @@ def register_checkpoint(
867976
) -> None:
868977
"""
869978
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979+
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980+
Please make sure to copy the files to disks if you need to keep them.
870981
871982
Args:
872983
checkpoint_name: The name of the checkpoint.
@@ -1138,7 +1249,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11381249
for items in self._current_global_parameter_metas.values():
11391250
for metas_list in items.memory_buffer_metas_list:
11401251
for meta in metas_list.metas:
1141-
max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1252+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
11421253
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
11431254
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
11441255
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")

examples/update.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def update_weights(
100100
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
101101
uds: str | None = None,
102102
):
103-
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
104103
ps.init_process_group()
104+
dist.barrier()
105+
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
105106
check_vllm_ready(endpoint, inference_parallel_size, uds)
106107
dist.barrier()
107108
with timer("Gather metas"):
@@ -173,7 +174,9 @@ def join(
173174
args.uds,
174175
)
175176
else:
176-
if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")):
177+
if os.path.exists(
178+
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
179+
) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108
177180
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
178181
checkpoint_files = []
179182
else:

tests/test_update.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
8282
try:
8383
trigger_error(socket_paths)
8484
except RuntimeError as e:
85-
assert str(e) == "Failed to update weights due to remote errors"
85+
assert str(e) == "Some workers failed to update weights"
8686

8787

8888
def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
@@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor
9696
for name, weight in weights:
9797
if name not in named_tensors:
9898
continue
99-
assert (weight == named_tensors[name]).all()
99+
assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
100100
names_to_check[name] = True
101101

102102
def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
@@ -163,6 +163,66 @@ def run(
163163
assert proc.exitcode == 0
164164

165165

166+
def run_with_files(
167+
checker_func: callable,
168+
):
169+
rank = int(os.getenv("RANK"))
170+
ctx = get_context("spawn")
171+
queue = ctx.Queue()
172+
_device_uuid = _get_physical_gpu_id(device_manager, rank)
173+
ps = ParameterServer(auto_pg=True)
174+
_device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
175+
named_tensors = dict(gen_test_tensors(rank))
176+
177+
# Save 1/3 tensors to /dev/shm/ as .safetensors files
178+
# Save 1/3 tensors to ./tmp (disk) as .safetensors files
179+
# Keep 1/3 tensors in memory
180+
import safetensors.torch
181+
182+
files = []
183+
dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
184+
disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
185+
os.makedirs(dev_shm_dir, exist_ok=True)
186+
os.makedirs(disk_dir, exist_ok=True)
187+
tensors_items = list(named_tensors.items())
188+
tensors_in_dev_shm = named_tensors
189+
tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
190+
tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
191+
tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
192+
disk_files = [
193+
os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors")
194+
for _rank in range(get_world_size())
195+
]
196+
safetensors.torch.save_file(tensors_in_disk, disk_files[rank])
197+
time.sleep(1)
198+
files.append(disk_files[rank])
199+
dev_shm_files = [
200+
os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
201+
for _ in range(get_world_size())
202+
]
203+
safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
204+
time.sleep(1)
205+
files.append(dev_shm_files[rank])
206+
207+
checkpoint_name = "test_with_files"
208+
proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue))
209+
proc.start()
210+
ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files)
211+
ps.gather_metas(checkpoint_name)
212+
ps.update(checkpoint_name, queue.put, ranks=[])
213+
# sleep 3s to wait process group is destroyed
214+
time.sleep(3)
215+
ps.unregister_checkpoint(checkpoint_name)
216+
queue.put(None)
217+
proc.join()
218+
if rank == 0:
219+
import shutil
220+
221+
os.removedirs(dev_shm_dir)
222+
shutil.rmtree(disk_dir)
223+
assert proc.exitcode == 0
224+
225+
166226
@pytest.mark.gpu
167227
@pytest.mark.parametrize(
168228
"test_name,rank_list",
@@ -211,6 +271,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
211271
assert result.returncode == 0
212272

213273

274+
@pytest.mark.gpu
275+
def test_update_with_files(test_name: str = "test_with_files"):
276+
world_size = device_manager.device_module.device_count()
277+
assert world_size >= 2, "This test requires at least 2 GPUs."
278+
master_addr = "localhost"
279+
master_port = 25400
280+
cmd = [
281+
"torchrun",
282+
"--nproc_per_node",
283+
str(world_size),
284+
"--master_addr",
285+
master_addr,
286+
"--master_port",
287+
str(master_port),
288+
__file__,
289+
test_name,
290+
"[]",
291+
]
292+
293+
result = subprocess.run( # noqa: S603
294+
cmd,
295+
capture_output=False,
296+
text=True,
297+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
298+
shell=False,
299+
check=False,
300+
)
301+
302+
assert result.returncode == 0
303+
304+
214305
if __name__ == "__main__":
215306
run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ
216307
if not run_with_pytest:
@@ -230,5 +321,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
230321
expected_exception=RuntimeError,
231322
exception_msg="Failed to update weights due to remote errors",
232323
)
324+
elif test_type == "test_with_files":
325+
run_with_files(checker_proc)
233326
else:
234327
raise ValueError(f"Unknown TEST_TYPE: {test_type}")

0 commit comments

Comments
 (0)