@@ -487,12 +487,18 @@ class MemoryBucket(BaseModel):
487487 def register_pin_memory (
488488 idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
489489 ) -> tuple [int , torch .Tensor ]:
490- buffer = (
491- torch .empty (size , dtype = torch .uint8 , pin_memory = True )
492- if not shared_pin_memory
493- else shared_pin_memory [idx ].buffer
494- )
495- return idx , buffer
490+ if shared_pin_memory :
491+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492+ assert idx < len (shared_pin_memory ), (
493+ f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
494+ )
495+ assert shared_pin_memory [idx ].size == size , (
496+ f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
497+ )
498+ return idx , shared_pin_memory [idx ].buffer
499+ else :
500+ buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
501+ return idx , buffer
496502
497503 def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
498504 buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
@@ -794,8 +800,7 @@ def __init__(
794800 self ._zmq_addr_counter = 0
795801
796802 self .shared_memory_pool_name = "__shared_memory_pool__"
797- # this dict stores all currently registered checkpoints
798- # dict key is checkpoint_name, value is whether use shared memory pool
803+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
799804 self ._current_shared_memory_pool_user : str = ""
800805 self ._memory_pool : dict [str , list [MemoryBuffer ]] = {}
801806 self ._memory_pool [self .shared_memory_pool_name ] = []
@@ -813,16 +818,16 @@ def __init__(
813818 self ._rdma_device = None if self ._p2p_store is None else self ._p2p_store .device
814819
815820 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
816- if (
817- checkpoint_name not in self ._memory_pool
818- and checkpoint_name != self . _current_shared_memory_pool_user
819- ):
820- raise RuntimeError ( f"checkpoint { checkpoint_name } not registered in memory pool" )
821- return (
822- self ._memory_pool [ checkpoint_name ]
823- if checkpoint_name != self ._current_shared_memory_pool_user
824- else self . _memory_pool [ self . shared_memory_pool_name ]
825- )
821+ if checkpoint_name == self . _current_shared_memory_pool_user :
822+ if not self ._memory_pool [ self . shared_memory_pool_name ]:
823+ raise RuntimeError (
824+ f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
825+ )
826+ return self . _memory_pool [ self . shared_memory_pool_name ]
827+ elif checkpoint_name in self ._memory_pool :
828+ return self ._memory_pool [ checkpoint_name ]
829+ else :
830+ raise RuntimeError ( f"checkpoint { checkpoint_name } is not registered" )
826831
827832 def _logger_rank0 (self , msg : str ):
828833 if self ._local_rank == 0 :
@@ -856,20 +861,20 @@ def register_checkpoint(
856861 checkpoint_name: The name of the checkpoint.
857862 files: The safetensors files to register.
858863 named_tensors: The named tensors to register.
864+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
865+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
866+ cannot accommodate checkpoints with different memory requirements.
859867 """
860868 try :
861869 if use_shared_memory_pool :
862870 logger .info (
863871 f"[rank{ self ._rank } ] checkpoint { checkpoint_name } use shared memory pool"
864872 )
865- if self ._current_shared_memory_pool_user :
866- logger .error (
867- f"cannot register checkpoint { checkpoint_name } to shared memory pool, "
868- f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
869- f"This registration may cause unexpected conflicts."
870- )
871- return
872-
873+ assert self ._current_shared_memory_pool_user == "" , (
874+ f"cannot register checkpoint { checkpoint_name } to shared memory pool, "
875+ f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
876+ f"This registration may cause unexpected conflicts."
877+ )
873878 self ._memory_pool [self .shared_memory_pool_name ] = _register_checkpoint (
874879 files = files or [],
875880 named_tensors = named_tensors or {},
0 commit comments