@@ -488,6 +488,7 @@ def register_pin_memory(
488488 idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
489489 ) -> tuple [int , torch .Tensor ]:
490490 if shared_pin_memory :
491+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
491492 # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492493 assert idx < len (shared_pin_memory ), (
493494 f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
@@ -765,6 +766,8 @@ def batch_transfer_sync_read(
765766
766767
767768class ParameterServer :
769+ shared_memory_pool_name = "__shared_memory_pool__"
770+
768771 def __init__ (
769772 self ,
770773 * ,
@@ -808,7 +811,6 @@ def __init__(
808811 self ._zmq_ctx = zmq .Context ()
809812 self ._zmq_addr_counter = 0
810813
811- self .shared_memory_pool_name = "__shared_memory_pool__"
812814 # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
813815 self ._current_shared_memory_pool_user : str = ""
814816 self ._memory_pool : dict [str , list [MemoryBuffer ]] = {}
@@ -829,10 +831,9 @@ def __init__(
829831
830832 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
831833 if checkpoint_name == self ._current_shared_memory_pool_user :
832- if not self ._memory_pool [self .shared_memory_pool_name ]:
833- raise RuntimeError (
834- f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
835- )
834+ assert self ._memory_pool [self .shared_memory_pool_name ], (
835+ f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
836+ )
836837 return self ._memory_pool [self .shared_memory_pool_name ]
837838 elif checkpoint_name in self ._memory_pool :
838839 return self ._memory_pool [checkpoint_name ]
@@ -885,27 +886,32 @@ def register_checkpoint(
885886 f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
886887 f"This registration may cause unexpected conflicts."
887888 )
889+ # Since we set the uninitialized shared memory pool to empty list,
890+ # we can check whether this is the first time to use shared memory pool
891+ _is_first_time = not self ._memory_pool [self .shared_memory_pool_name ]
888892 self ._memory_pool [self .shared_memory_pool_name ] = _register_checkpoint (
889893 files = files or [],
890894 named_tensors = named_tensors or {},
891895 rank = self ._rank ,
892896 shared_pin_memory = self ._memory_pool [self .shared_memory_pool_name ],
893897 )
894898 self ._current_shared_memory_pool_user = checkpoint_name
899+ if self ._p2p_store is not None and _is_first_time :
900+ self ._register_parameters_to_p2p_store (checkpoint_name )
895901 else :
896902 assert checkpoint_name not in self ._memory_pool , (
897903 f"checkpoint { checkpoint_name } already registered"
898904 )
899905 self ._memory_pool [checkpoint_name ] = _register_checkpoint (
900906 files = files or [], named_tensors = named_tensors or {}, rank = self ._rank
901907 )
902- if self ._p2p_store is not None :
903- self ._register_parameters_to_p2p_store (checkpoint_name )
908+ if self ._p2p_store is not None :
909+ self ._register_parameters_to_p2p_store (checkpoint_name )
904910 except Exception :
905911 logger .exception (
906912 f"[rank{ self ._rank } ] fail to register checkpoint { checkpoint_name } with files { files } "
907913 )
908- if self ._p2p_store is not None :
914+ if self ._p2p_store is not None and not use_shared_memory_pool :
909915 self ._unregister_parameters_from_p2p_store (checkpoint_name )
910916 self .unregister_checkpoint (checkpoint_name )
911917 raise
@@ -920,20 +926,19 @@ def unregister_checkpoint(self, checkpoint_name: str):
920926 and checkpoint_name != self ._current_shared_memory_pool_user
921927 ):
922928 logger .warning (
923- f"[rank{ self ._rank } ] unregister checkpoint failed, checkpoint name { checkpoint_name } not found"
929+ f"[rank{ self ._rank } ] unregister checkpoint name { checkpoint_name } not found"
924930 )
925931 return
932+
933+ if checkpoint_name == self ._current_shared_memory_pool_user :
934+ self ._current_shared_memory_pool_user = ""
935+ return
936+
926937 if self ._p2p_store is not None :
927938 num_unregistered = self ._unregister_parameters_from_p2p_store (checkpoint_name )
928939 logger .info (
929940 f"[rank{ self ._rank } ] unregister { num_unregistered } parameters from p2p store for checkpoint { checkpoint_name } "
930941 )
931- if checkpoint_name == self ._current_shared_memory_pool_user :
932- logger .info (
933- f"[rank{ self ._rank } ] unregister shared memory pool from p2p store, skip unregistering from memory pool"
934- )
935- self ._current_shared_memory_pool_user = ""
936- return
937942
938943 del self ._memory_pool [checkpoint_name ]
939944 # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
0 commit comments