@@ -487,12 +487,18 @@ class MemoryBucket(BaseModel):
487487 def register_pin_memory (
488488 idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
489489 ) -> tuple [int , torch .Tensor ]:
490- buffer = (
491- torch .empty (size , dtype = torch .uint8 , pin_memory = True )
492- if not shared_pin_memory
493- else shared_pin_memory [idx ].buffer
494- )
495- return idx , buffer
490+ if shared_pin_memory :
491+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492+ assert idx < len (shared_pin_memory ), (
493+ f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
494+ )
495+ assert shared_pin_memory [idx ].size == size , (
496+ f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
497+ )
498+ return idx , shared_pin_memory [idx ].buffer
499+ else :
500+ buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
501+ return idx , buffer
496502
497503 def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
498504 buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
@@ -803,8 +809,7 @@ def __init__(
803809 self ._zmq_addr_counter = 0
804810
805811 self .shared_memory_pool_name = "__shared_memory_pool__"
806- # this dict stores all currently registered checkpoints
807- # dict key is checkpoint_name, value is whether use shared memory pool
812+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
808813 self ._current_shared_memory_pool_user : str = ""
809814 self ._memory_pool : dict [str , list [MemoryBuffer ]] = {}
810815 self ._memory_pool [self .shared_memory_pool_name ] = []
@@ -823,16 +828,16 @@ def __init__(
823828 self ._rdma_device = None if self ._p2p_store is None else self ._p2p_store .device
824829
825830 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
826- if (
827- checkpoint_name not in self ._memory_pool
828- and checkpoint_name != self . _current_shared_memory_pool_user
829- ):
830- raise RuntimeError ( f"checkpoint { checkpoint_name } not registered in memory pool" )
831- return (
832- self ._memory_pool [ checkpoint_name ]
833- if checkpoint_name != self ._current_shared_memory_pool_user
834- else self . _memory_pool [ self . shared_memory_pool_name ]
835- )
831+ if checkpoint_name == self . _current_shared_memory_pool_user :
832+ if not self ._memory_pool [ self . shared_memory_pool_name ]:
833+ raise RuntimeError (
834+ f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
835+ )
836+ return self . _memory_pool [ self . shared_memory_pool_name ]
837+ elif checkpoint_name in self ._memory_pool :
838+ return self ._memory_pool [ checkpoint_name ]
839+ else :
840+ raise RuntimeError ( f"checkpoint { checkpoint_name } is not registered" )
836841
837842 def _logger_rank0 (self , msg : str ):
838843 if self ._local_rank == 0 :
@@ -866,20 +871,20 @@ def register_checkpoint(
866871 checkpoint_name: The name of the checkpoint.
867872 files: The safetensors files to register.
868873 named_tensors: The named tensors to register.
874+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
875+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
876+ cannot accommodate checkpoints with different memory requirements.
869877 """
870878 try :
871879 if use_shared_memory_pool :
872880 logger .info (
873881 f"[rank{ self ._rank } ] checkpoint { checkpoint_name } use shared memory pool"
874882 )
875- if self ._current_shared_memory_pool_user :
876- logger .error (
877- f"cannot register checkpoint { checkpoint_name } to shared memory pool, "
878- f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
879- f"This registration may cause unexpected conflicts."
880- )
881- return
882-
883+ assert self ._current_shared_memory_pool_user == "" , (
884+ f"cannot register checkpoint { checkpoint_name } to shared memory pool, "
885+ f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
886+ f"This registration may cause unexpected conflicts."
887+ )
883888 self ._memory_pool [self .shared_memory_pool_name ] = _register_checkpoint (
884889 files = files or [],
885890 named_tensors = named_tensors or {},
0 commit comments