@@ -522,16 +522,17 @@ def _pin(t: torch.Tensor):
522522 )
523523 return MemoryBuffer (buffer = buffer , size = buffer .nbytes , metas = metas )
524524
525- local_memory_buffers : list [MemoryBuffer ] = []
525+ memory_buffers : list [MemoryBuffer ] = []
526526 with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
527- local_memory_buffers = list (executor .map (_parse_and_pin_from_safetensors , files ))
528- return local_memory_buffers
527+ memory_buffers = list (executor .map (_parse_and_pin_from_safetensors , files ))
528+ return memory_buffers
529529
530530
531531def _normal_pin_memory (
532532 files : list [str ],
533533 named_tensors : dict [str , torch .Tensor ],
534534 rank : int | None = None ,
535+ shared_pin_memory : list [MemoryBuffer ] | None = None ,
535536) -> list [MemoryBuffer ]:
536537 parameters = _load_checkpoint (files )
537538 if named_tensors :
@@ -554,27 +555,27 @@ class MemoryBucket(BaseModel):
554555 )
555556 buckets [- 1 ].size += size
556557
557- local_memory_buffers = [
558+ memory_buffers = [
558559 MemoryBuffer (buffer = torch .empty (0 ), size = bucket .size , metas = bucket .metas )
559560 for bucket in buckets
560561 ]
561562
562- def register_pin_memory (
563- idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
564- ) -> tuple [int , torch .Tensor ]:
565- if shared_pin_memory :
566- # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
567- # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
568- assert idx < len (shared_pin_memory ), (
569- f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
570- )
571- assert shared_pin_memory [idx ].size == size , (
572- f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
573- )
574- return idx , shared_pin_memory [idx ].buffer
575- else :
576- buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
577- return idx , buffer
563+ def register_pin_memory (
564+ idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
565+ ) -> tuple [int , torch .Tensor ]:
566+ if shared_pin_memory :
567+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569+ assert idx < len (shared_pin_memory ), (
570+ f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
571+ )
572+ assert shared_pin_memory [idx ].size == size , (
573+ f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
574+ )
575+ return idx , shared_pin_memory [idx ].buffer
576+ else :
577+ buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
578+ return idx , buffer
578579
579580 def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
580581 buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
@@ -595,7 +596,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
595596 assert buffer .numel () == buckets [idx ].size , (
596597 f"buffer numel { buffer .numel ()} should be equal to bucket size { buckets [idx ].size } "
597598 )
598- local_memory_buffers [idx ].buffer = buffer
599+ memory_buffers [idx ].buffer = buffer
599600 logger .info (
600601 f"[rank{ rank } ] register pin_memory for bucket { idx + 1 } /{ len (buckets )} finished, "
601602 f"size { buffer .numel () / 1024 / 1024 :.2f} MiB, start to copy tensors to buffer"
@@ -612,14 +613,15 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
612613 offset += size
613614 for future in concurrent .futures .as_completed (new_futures ):
614615 future .result ()
615- return local_memory_buffers
616+ return memory_buffers
616617
617618
618619def _register_checkpoint (
619620 * ,
620621 files : list [str ],
621622 named_tensors : dict [str , torch .Tensor ],
622623 rank : int | None = None ,
624+ shared_pin_memory : list [MemoryBuffer ] | None = None ,
623625) -> list [MemoryBuffer ]:
624626 logger .info (
625627 f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
@@ -635,7 +637,12 @@ def _register_checkpoint(
635637 files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin ]
636638 if files_to_normal_pin or named_tensors :
637639 memory_buffers .extend (
638- _normal_pin_memory (files = files_to_normal_pin , named_tensors = named_tensors , rank = rank )
640+ _normal_pin_memory (
641+ files = files_to_normal_pin ,
642+ named_tensors = named_tensors ,
643+ rank = rank ,
644+ shared_pin_memory = shared_pin_memory ,
645+ )
639646 )
640647 if files_to_inplace_pin :
641648 memory_buffers .extend (_inplace_pin_memory (files_to_inplace_pin , rank = rank ))
@@ -986,8 +993,8 @@ def register_checkpoint(
986993 f"[rank{ self ._rank } ] checkpoint { checkpoint_name } use shared memory pool"
987994 )
988995 assert self ._current_shared_memory_pool_user == "" , (
989- f"cannot register checkpoint { checkpoint_name } to shared memory pool, "
990- f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
996+ f"cannot register checkpoint ' { checkpoint_name } ' to shared memory pool, "
997+ f"since checkpoint ' { self ._current_shared_memory_pool_user } ' is already using shared memory pool. "
991998 f"This registration may cause unexpected conflicts."
992999 )
9931000 # Since we set the uninitialized shared memory pool to empty list,
0 commit comments