|
1 | 1 | import argparse |
2 | 2 | import concurrent.futures |
3 | 3 | import ctypes |
| 4 | +import json |
4 | 5 | import os |
5 | 6 | import pickle |
6 | 7 | import random |
|
18 | 19 | import zmq |
19 | 20 | from loguru import logger |
20 | 21 | from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema |
21 | | -from safetensors.torch import safe_open |
| 22 | +from safetensors.torch import _getdtype, safe_open |
22 | 23 | from torch.multiprocessing.reductions import reduce_tensor |
23 | 24 |
|
24 | 25 | from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid |
@@ -460,64 +461,131 @@ def _register_checkpoint( |
460 | 461 | ) |
461 | 462 | if not files and not named_tensors: |
462 | 463 | return [] |
463 | | - parameters = _load_checkpoint(files) |
464 | | - if named_tensors: |
465 | | - parameters.update(named_tensors) |
466 | | - bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) |
| 464 | + memory_buffers: list[MemoryBuffer] = [] |
| 465 | + inplace_pin = all( |
| 466 | + file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 |
| 467 | + for file in files or [] |
| 468 | + ) |
| 469 | + if inplace_pin: |
| 470 | + |
| 471 | + def _pin(t: torch.Tensor): |
| 472 | + """ |
| 473 | + Pin the memory of tensor in-place. |
| 474 | + See: https://github.com/pytorch/pytorch/issues/32167 |
| 475 | + """ |
| 476 | + cudart = torch.cuda.cudart() |
| 477 | + r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) |
| 478 | + assert r == 0, f"pin memory error, error code: {r.value}" |
| 479 | + |
| 480 | + def _inplace_pin_memory(file_path: str) -> MemoryBuffer: |
| 481 | + # TODO: should only support /dev/shm? but we found files in disk also work? |
| 482 | + size = os.stat(file_path).st_size |
| 483 | + t = torch.from_file(file_path, True, size, dtype=torch.uint8) |
| 484 | + |
| 485 | + # safetensors format see https://huggingface.co/docs/safetensors/en/index#format. |
| 486 | + # We load the safetensors file as bytes, then parse the header manually to get parameter metas. |
| 487 | + # and the actual tensor data is in the remaining bytes. |
| 488 | + # We pin the remaining bytes as the buffer, making pinning faster. |
| 489 | + flag_size = 8 |
| 490 | + with open(file_path, "rb") as f: |
| 491 | + n = bytearray(flag_size) |
| 492 | + data = f.readinto(n) |
| 493 | + assert data == flag_size, f"data {data} should be equal to flag_size {flag_size}" |
| 494 | + n = int.from_bytes(n, byteorder="little", signed=False) |
| 495 | + start_pos = n + flag_size |
| 496 | + |
| 497 | + time.sleep(3) |
| 498 | + header_tensor = t[flag_size:start_pos] |
| 499 | + header = json.loads(header_tensor.numpy().tobytes()) |
| 500 | + |
| 501 | + metas: list[ParameterMeta] = [] |
| 502 | + offset = 0 |
| 503 | + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): |
| 504 | + start, end = meta["data_offsets"] |
| 505 | + # safetensors format ensures offsets are aligned |
| 506 | + assert offset == start, f"offset {offset} should be equal to start {start}" |
| 507 | + metas.append( |
| 508 | + ParameterMeta( |
| 509 | + name=name, dtype=_getdtype(meta["dtype"]), shape=torch.Size(meta["shape"]) |
| 510 | + ) |
| 511 | + ) |
| 512 | + offset = end |
467 | 513 |
|
468 | | - class MemoryBucket(BaseModel): |
469 | | - size: int |
470 | | - metas: list[ParameterMeta] |
471 | | - |
472 | | - buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] |
473 | | - for name, tensor in sorted(parameters.items()): |
474 | | - size = _align_size(tensor.dtype, tensor.shape) |
475 | | - if buckets[-1].size + size > bucket_size: |
476 | | - assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" |
477 | | - buckets.append(MemoryBucket(size=0, metas=[])) |
478 | | - buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype)) |
479 | | - buckets[-1].size += size |
480 | | - |
481 | | - memory_buffers = [ |
482 | | - MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) |
483 | | - for bucket in buckets |
484 | | - ] |
| 514 | + buffer = t[start_pos:] |
| 515 | + assert offset == buffer.nbytes, ( |
| 516 | + f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" |
| 517 | + ) |
| 518 | + _pin(buffer) |
| 519 | + return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) |
485 | 520 |
|
486 | | - def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: |
487 | | - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) |
488 | | - return idx, buffer |
| 521 | + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
| 522 | + futures = [executor.submit(_inplace_pin_memory, file) for file in files] |
| 523 | + for future in concurrent.futures.as_completed(futures): |
| 524 | + memory_buffer = future.result() |
| 525 | + memory_buffers.append(memory_buffer) |
489 | 526 |
|
490 | | - def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): |
491 | | - buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) |
| 527 | + else: |
| 528 | + parameters = _load_checkpoint(files) |
| 529 | + if named_tensors: |
| 530 | + parameters.update(named_tensors) |
| 531 | + bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) |
| 532 | + |
| 533 | + class MemoryBucket(BaseModel): |
| 534 | + size: int |
| 535 | + metas: list[ParameterMeta] |
| 536 | + |
| 537 | + buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] |
| 538 | + for name, tensor in sorted(parameters.items()): |
| 539 | + size = _align_size(tensor.dtype, tensor.shape) |
| 540 | + if buckets[-1].size + size > bucket_size: |
| 541 | + assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" |
| 542 | + buckets.append(MemoryBucket(size=0, metas=[])) |
| 543 | + buckets[-1].metas.append( |
| 544 | + ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype) |
| 545 | + ) |
| 546 | + buckets[-1].size += size |
492 | 547 |
|
493 | | - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
494 | | - futures = [ |
495 | | - executor.submit(register_pin_memory, idx, bucket.size) |
496 | | - for idx, bucket in enumerate(buckets) |
| 548 | + memory_buffers = [ |
| 549 | + MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) |
| 550 | + for bucket in buckets |
497 | 551 | ] |
498 | | - new_futures = [] |
499 | | - for future in concurrent.futures.as_completed(futures): |
500 | | - idx, buffer = future.result() |
501 | | - assert buffer.numel() == buckets[idx].size, ( |
502 | | - f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" |
503 | | - ) |
504 | | - memory_buffers[idx].buffer = buffer |
505 | | - logger.info( |
506 | | - f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " |
507 | | - f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" |
508 | | - ) |
509 | | - offset = 0 |
510 | | - for meta in buckets[idx].metas: |
511 | | - name = meta.name |
512 | | - tensor = parameters[name] |
513 | | - size = _align_size(tensor.dtype, tensor.shape) |
514 | | - assert size == _align_size(meta.dtype, meta.shape), ( |
515 | | - f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}" |
| 552 | + |
| 553 | + def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: |
| 554 | + buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) |
| 555 | + return idx, buffer |
| 556 | + |
| 557 | + def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): |
| 558 | + buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) |
| 559 | + |
| 560 | + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
| 561 | + futures = [ |
| 562 | + executor.submit(register_pin_memory, idx, bucket.size) |
| 563 | + for idx, bucket in enumerate(buckets) |
| 564 | + ] |
| 565 | + new_futures = [] |
| 566 | + for future in concurrent.futures.as_completed(futures): |
| 567 | + idx, buffer = future.result() |
| 568 | + assert buffer.numel() == buckets[idx].size, ( |
| 569 | + f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" |
| 570 | + ) |
| 571 | + memory_buffers[idx].buffer = buffer |
| 572 | + logger.info( |
| 573 | + f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " |
| 574 | + f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" |
516 | 575 | ) |
517 | | - new_futures.append(executor.submit(register_tensor, buffer, offset, tensor)) |
518 | | - offset += size |
519 | | - for future in concurrent.futures.as_completed(new_futures): |
520 | | - future.result() |
| 576 | + offset = 0 |
| 577 | + for meta in buckets[idx].metas: |
| 578 | + name = meta.name |
| 579 | + tensor = parameters[name] |
| 580 | + size = _align_size(tensor.dtype, tensor.shape) |
| 581 | + assert size == _align_size(meta.dtype, meta.shape), ( |
| 582 | + f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}" |
| 583 | + ) |
| 584 | + new_futures.append(executor.submit(register_tensor, buffer, offset, tensor)) |
| 585 | + offset += size |
| 586 | + for future in concurrent.futures.as_completed(new_futures): |
| 587 | + future.result() |
| 588 | + |
521 | 589 | return memory_buffers |
522 | 590 |
|
523 | 591 |
|
|
0 commit comments