@@ -93,6 +93,7 @@ class ParameterMeta(BaseModel):
9393 name : str
9494 dtype : _TorchDtype
9595 shape : _TorchSize
96+ manually_aligned : bool = True
9697
9798
9899class BucketRange (NamedTuple ):
@@ -141,7 +142,11 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
141142def _to_named_tensor (metas : list [ParameterMeta ], offset : int = 0 ) -> list [dict ]:
142143 ret = []
143144 for meta in metas :
144- size = _align_size (meta .dtype , meta .shape )
145+ size = (
146+ _align_size (meta .dtype , meta .shape )
147+ if meta .manually_aligned
148+ else meta .dtype .itemsize * meta .shape .numel ()
149+ )
145150 ret .append (
146151 {
147152 "name" : meta .name ,
@@ -462,12 +467,8 @@ def _register_checkpoint(
462467 if not files and not named_tensors :
463468 return []
464469 memory_buffers : list [MemoryBuffer ] = []
465- inplace_pin = all (
466- file .startswith ("/dev/shm/" ) and file .endswith (".safetensors" ) # noqa: S108
467- for file in files or []
468- )
469- if inplace_pin :
470470
471+ def inplace_pin_memory (files : list [str ]) -> list [MemoryBuffer ]:
471472 def _pin (t : torch .Tensor ):
472473 """
473474 Pin the memory of tensor in-place.
@@ -494,6 +495,7 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
494495 n = int .from_bytes (n , byteorder = "little" , signed = False )
495496 start_pos = n + flag_size
496497
498+ os .remove (file_path )
497499 time .sleep (3 )
498500 header_tensor = t [flag_size :start_pos ]
499501 header = json .loads (header_tensor .numpy ().tobytes ())
@@ -506,7 +508,10 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
506508 assert offset == start , f"offset { offset } should be equal to start { start } "
507509 metas .append (
508510 ParameterMeta (
509- name = name , dtype = _getdtype (meta ["dtype" ]), shape = torch .Size (meta ["shape" ])
511+ name = name ,
512+ dtype = _getdtype (meta ["dtype" ]),
513+ shape = torch .Size (meta ["shape" ]),
514+ manually_aligned = False ,
510515 )
511516 )
512517 offset = end
@@ -518,13 +523,24 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
518523 _pin (buffer )
519524 return MemoryBuffer (buffer = buffer , size = buffer .nbytes , metas = metas )
520525
526+ local_memory_buffers : list [MemoryBuffer ] = []
527+ lock = threading .Lock ()
528+ idx = 0
521529 with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
522530 futures = [executor .submit (_inplace_pin_memory , file ) for file in files ]
523531 for future in concurrent .futures .as_completed (futures ):
524532 memory_buffer = future .result ()
525- memory_buffers .append (memory_buffer )
533+ with lock :
534+ local_memory_buffers .append (memory_buffer )
535+ logger .info (
536+ f"[rank{ rank } ] register pin_memory for file in /dev/shm { idx + 1 } /{ len (files )} finished"
537+ )
538+ idx += 1
539+ return local_memory_buffers
526540
527- else :
541+ def normal_pin_memory (
542+ files : list [str ], named_tensors : dict [str , torch .Tensor ]
543+ ) -> list [MemoryBuffer ]:
528544 parameters = _load_checkpoint (files )
529545 if named_tensors :
530546 parameters .update (named_tensors )
@@ -534,7 +550,8 @@ class MemoryBucket(BaseModel):
534550 size : int
535551 metas : list [ParameterMeta ]
536552
537- buckets : list [MemoryBucket ] = [MemoryBucket (size = 0 , metas = [])]
553+ buckets : list [MemoryBucket ] = []
554+ buckets .append (MemoryBucket (size = 0 , metas = []))
538555 for name , tensor in sorted (parameters .items ()):
539556 size = _align_size (tensor .dtype , tensor .shape )
540557 if buckets [- 1 ].size + size > bucket_size :
@@ -545,7 +562,7 @@ class MemoryBucket(BaseModel):
545562 )
546563 buckets [- 1 ].size += size
547564
548- memory_buffers = [
565+ local_memory_buffers = [
549566 MemoryBuffer (buffer = torch .empty (0 ), size = bucket .size , metas = bucket .metas )
550567 for bucket in buckets
551568 ]
@@ -568,7 +585,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
568585 assert buffer .numel () == buckets [idx ].size , (
569586 f"buffer numel { buffer .numel ()} should be equal to bucket size { buckets [idx ].size } "
570587 )
571- memory_buffers [idx ].buffer = buffer
588+ local_memory_buffers [idx ].buffer = buffer
572589 logger .info (
573590 f"[rank{ rank } ] register pin_memory for bucket { idx + 1 } /{ len (buckets )} finished, "
574591 f"size { buffer .numel () / 1024 / 1024 :.2f} MiB, start to copy tensors to buffer"
@@ -585,6 +602,20 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
585602 offset += size
586603 for future in concurrent .futures .as_completed (new_futures ):
587604 future .result ()
605+ return local_memory_buffers
606+
607+ files_to_inplace_pin = [
608+ file
609+ for file in files
610+ if file .startswith ("/dev/shm/" ) and file .endswith (".safetensors" ) # noqa: S108
611+ ]
612+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin ]
613+ if files_to_normal_pin or named_tensors :
614+ memory_buffers .extend (
615+ normal_pin_memory (files = files_to_normal_pin , named_tensors = named_tensors )
616+ )
617+ if files_to_inplace_pin :
618+ memory_buffers .extend (inplace_pin_memory (files_to_inplace_pin ))
588619
589620 return memory_buffers
590621
@@ -634,7 +665,11 @@ def _gen_h2d_buckets(
634665 for idx , metas in enumerate (items .memory_buffer_metas_list ):
635666 start_offset , offset = 0 , 0
636667 for meta in metas .metas :
637- s = _align_size (meta .dtype , meta .shape )
668+ s = (
669+ _align_size (meta .dtype , meta .shape )
670+ if meta .manually_aligned
671+ else meta .dtype .itemsize * meta .shape .numel ()
672+ )
638673 if buckets [- 1 ][1 ].size + s > bucket_size :
639674 if offset - start_offset > 0 :
640675 buckets [- 1 ][1 ].ranges .append (
@@ -1106,7 +1141,12 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11061141 for items in self ._current_global_parameter_metas .values ():
11071142 for metas_list in items .memory_buffer_metas_list :
11081143 for meta in metas_list .metas :
1109- max_tensor_bytes = max (max_tensor_bytes , _align_size (meta .dtype , meta .shape ))
1144+ max_tensor_bytes = max (
1145+ max_tensor_bytes ,
1146+ _align_size (meta .dtype , meta .shape )
1147+ if meta .manually_aligned
1148+ else meta .dtype .itemsize * meta .shape .numel (),
1149+ )
11101150 free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE ) * _ALIGN_SIZE
11111151 if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer :
11121152 self ._logger_rank0 (f"[rank{ self ._rank } ] use h2d buffer" )
0 commit comments