11import argparse
22import concurrent .futures
33import ctypes
4+ import json
45import os
56import pickle
67import random
1819import zmq
1920from loguru import logger
2021from pydantic import BaseModel , PlainSerializer , PlainValidator , WithJsonSchema
21- from safetensors .torch import safe_open
22+ from safetensors .torch import _getdtype , safe_open
2223from torch .multiprocessing .reductions import reduce_tensor
2324
2425from checkpoint_engine .device_utils import DeviceManager , get_ip , npu_generate_uuid
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
9293 name : str
9394 dtype : _TorchDtype
9495 shape : _TorchSize
96+ aligned_size : int
9597
9698
9799class BucketRange (NamedTuple ):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140142def _to_named_tensor (metas : list [ParameterMeta ], offset : int = 0 ) -> list [dict ]:
141143 ret = []
142144 for meta in metas :
143- size = _align_size ( meta .dtype , meta . shape )
145+ size = meta .aligned_size
144146 ret .append (
145147 {
146148 "name" : meta .name ,
@@ -422,6 +424,7 @@ class TPMeta(BaseModel):
422424 name = parameter_name ,
423425 shape = meta ["shape" ],
424426 dtype = meta ["dtype" ],
427+ aligned_size = _align_size (meta ["dtype" ], meta ["shape" ]),
425428 )
426429 tp_meta = tp_metas [parameter_name ]
427430 if tp_meta .concat_dim != - 1 :
@@ -431,7 +434,10 @@ class TPMeta(BaseModel):
431434 shape = list (parameter_metas [name ].shape )
432435 shape [tp_meta .concat_dim ] = shape [tp_meta .concat_dim ] * tp_meta .size
433436 parameter_metas [name ] = ParameterMeta (
434- name = name , shape = torch .Size (shape ), dtype = parameter_metas [name ].dtype
437+ name = name ,
438+ shape = torch .Size (shape ),
439+ dtype = parameter_metas [name ].dtype ,
440+ aligned_size = _align_size (parameter_metas [name ].dtype , torch .Size (shape )),
435441 )
436442 weights_in_cpu = [parameters_with_tp [name ][key ] for key in sorted (parameters_with_tp [name ])]
437443 # TODO: here concat is serial, which may be slow
@@ -449,18 +455,85 @@ class TPMeta(BaseModel):
449455 return parameters
450456
451457
452- def _register_checkpoint (
453- * ,
458+ def _inplace_pin_memory (files : list [str ], rank : int | None = None ) -> list [MemoryBuffer ]:
459+ def _parse_and_pin_from_safetensors (file_path : str ) -> MemoryBuffer :
460+ """
461+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463+ The actual tensor data is in the remaining bytes and is naturally aligned.
464+ We pin the remaining bytes as the buffer, making pinning faster.
465+ """
466+
467+ def _pin (t : torch .Tensor ):
468+ """
469+ Pin the memory of tensor in-place.
470+ See: https://github.com/pytorch/pytorch/issues/32167
471+ """
472+ cudart = torch .cuda .cudart ()
473+ r = cudart .cudaHostRegister (t .data_ptr (), t .numel () * t .element_size (), 0 )
474+ assert r == 0 , f"pin memory error, error code: { r } "
475+
476+ # TODO: should only support /dev/shm? but we found files in disk also work?
477+ size = os .stat (file_path ).st_size
478+ flag_size = 8
479+ t = torch .from_file (file_path , True , size , dtype = torch .uint8 )
480+ assert t .nbytes > flag_size , (
481+ f"tensor nbytes { t .nbytes } should be greater than flag_size { flag_size } "
482+ )
483+ start_pos = (
484+ int .from_bytes (t [0 :flag_size ].numpy ().tobytes (), byteorder = "little" , signed = False )
485+ + flag_size
486+ )
487+ header_tensor = t [flag_size :start_pos ]
488+ header = json .loads (header_tensor .numpy ().tobytes ())
489+ if "__metadata__" in header :
490+ header .pop ("__metadata__" )
491+
492+ metas : list [ParameterMeta ] = []
493+ offset = 0
494+ try :
495+ for name , meta in sorted (header .items (), key = lambda x : x [1 ]["data_offsets" ]):
496+ start , end = meta ["data_offsets" ]
497+ # safetensors format ensures offsets are aligned
498+ assert offset == start , f"offset { offset } should be equal to start { start } "
499+ metas .append (
500+ ParameterMeta (
501+ name = name ,
502+ dtype = _getdtype (meta ["dtype" ]),
503+ shape = torch .Size (meta ["shape" ]),
504+ aligned_size = end - start ,
505+ )
506+ )
507+ offset = end
508+ except Exception as e :
509+ logger .error (f"fail to parse safetensors header from { file_path } : { e } " )
510+ raise
511+
512+ buffer = t [start_pos :]
513+ assert offset == buffer .nbytes , (
514+ f"offset { offset } should be equal to buffer.nbytes { buffer .nbytes } "
515+ )
516+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518+ os .remove (file_path )
519+ _pin (buffer )
520+ logger .info (
521+ f"[rank{ rank } ] inplace pin memory for file { file_path } finished, size { buffer .nbytes / 1024 / 1024 :.2f} MiB"
522+ )
523+ return MemoryBuffer (buffer = buffer , size = buffer .nbytes , metas = metas )
524+
525+ memory_buffers : list [MemoryBuffer ] = []
526+ with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
527+ memory_buffers = list (executor .map (_parse_and_pin_from_safetensors , files ))
528+ return memory_buffers
529+
530+
531+ def _normal_pin_memory (
454532 files : list [str ],
455533 named_tensors : dict [str , torch .Tensor ],
456534 rank : int | None = None ,
457535 shared_pin_memory : list [MemoryBuffer ] | None = None ,
458536) -> list [MemoryBuffer ]:
459- logger .info (
460- f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
461- )
462- if not files and not named_tensors :
463- return []
464537 parameters = _load_checkpoint (files )
465538 if named_tensors :
466539 parameters .update (named_tensors )
@@ -470,13 +543,16 @@ class MemoryBucket(BaseModel):
470543 size : int
471544 metas : list [ParameterMeta ]
472545
473- buckets : list [MemoryBucket ] = [MemoryBucket (size = 0 , metas = [])]
546+ buckets : list [MemoryBucket ] = []
547+ buckets .append (MemoryBucket (size = 0 , metas = []))
474548 for name , tensor in sorted (parameters .items ()):
475549 size = _align_size (tensor .dtype , tensor .shape )
476550 if buckets [- 1 ].size + size > bucket_size :
477551 assert buckets [- 1 ], f"buckets[{ len (buckets ) - 1 } ] should not be empty"
478552 buckets .append (MemoryBucket (size = 0 , metas = []))
479- buckets [- 1 ].metas .append (ParameterMeta (name = name , shape = tensor .shape , dtype = tensor .dtype ))
553+ buckets [- 1 ].metas .append (
554+ ParameterMeta (name = name , shape = tensor .shape , dtype = tensor .dtype , aligned_size = size )
555+ )
480556 buckets [- 1 ].size += size
481557
482558 memory_buffers = [
@@ -537,6 +613,39 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
537613 offset += size
538614 for future in concurrent .futures .as_completed (new_futures ):
539615 future .result ()
616+ return memory_buffers
617+
618+
619+ def _register_checkpoint (
620+ * ,
621+ files : list [str ],
622+ named_tensors : dict [str , torch .Tensor ],
623+ rank : int | None = None ,
624+ shared_pin_memory : list [MemoryBuffer ] | None = None ,
625+ ) -> list [MemoryBuffer ]:
626+ logger .info (
627+ f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
628+ )
629+ if not files and not named_tensors :
630+ return []
631+ memory_buffers : list [MemoryBuffer ] = []
632+ files_to_inplace_pin = [
633+ file
634+ for file in files
635+ if file .startswith ("/dev/shm/" ) and file .endswith (".safetensors" ) # noqa: S108
636+ ]
637+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin ]
638+ if files_to_normal_pin or named_tensors :
639+ memory_buffers .extend (
640+ _normal_pin_memory (
641+ files = files_to_normal_pin ,
642+ named_tensors = named_tensors ,
643+ rank = rank ,
644+ shared_pin_memory = shared_pin_memory ,
645+ )
646+ )
647+ if files_to_inplace_pin :
648+ memory_buffers .extend (_inplace_pin_memory (files_to_inplace_pin , rank = rank ))
540649 return memory_buffers
541650
542651
@@ -585,7 +694,7 @@ def _gen_h2d_buckets(
585694 for idx , metas in enumerate (items .memory_buffer_metas_list ):
586695 start_offset , offset = 0 , 0
587696 for meta in metas .metas :
588- s = _align_size ( meta .dtype , meta . shape )
697+ s = meta .aligned_size
589698 if buckets [- 1 ][1 ].size + s > bucket_size :
590699 if offset - start_offset > 0 :
591700 buckets [- 1 ][1 ].ranges .append (
@@ -867,6 +976,8 @@ def register_checkpoint(
867976 ) -> None :
868977 """
869978 Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979+ Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980+ Please make sure to copy the files to disks if you need to keep them.
870981
871982 Args:
872983 checkpoint_name: The name of the checkpoint.
@@ -1138,7 +1249,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11381249 for items in self ._current_global_parameter_metas .values ():
11391250 for metas_list in items .memory_buffer_metas_list :
11401251 for meta in metas_list .metas :
1141- max_tensor_bytes = max (max_tensor_bytes , _align_size ( meta .dtype , meta . shape ) )
1252+ max_tensor_bytes = max (max_tensor_bytes , meta .aligned_size )
11421253 free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE ) * _ALIGN_SIZE
11431254 if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer :
11441255 self ._logger_rank0 (f"[rank{ self ._rank } ] use h2d buffer" )
0 commit comments