@@ -488,6 +488,7 @@ def register_pin_memory(
488488 idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
489489 ) -> tuple [int , torch .Tensor ]:
490490 if shared_pin_memory :
491+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
491492 # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
492493 assert idx < len (shared_pin_memory ), (
493494 f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
@@ -756,6 +757,8 @@ def batch_transfer_sync_read(
756757
757758
758759class ParameterServer :
760+ shared_memory_pool_name = "__shared_memory_pool__"
761+
759762 def __init__ (
760763 self ,
761764 * ,
@@ -799,7 +802,6 @@ def __init__(
799802 self ._zmq_ctx = zmq .Context ()
800803 self ._zmq_addr_counter = 0
801804
802- self .shared_memory_pool_name = "__shared_memory_pool__"
803805 # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
804806 self ._current_shared_memory_pool_user : str = ""
805807 self ._memory_pool : dict [str , list [MemoryBuffer ]] = {}
@@ -819,10 +821,9 @@ def __init__(
819821
820822 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
821823 if checkpoint_name == self ._current_shared_memory_pool_user :
822- if not self ._memory_pool [self .shared_memory_pool_name ]:
823- raise RuntimeError (
824- f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
825- )
824+ assert self ._memory_pool [self .shared_memory_pool_name ], (
825+ f"shared memory pool is not initialized, but checkpoint { checkpoint_name } is using it"
826+ )
826827 return self ._memory_pool [self .shared_memory_pool_name ]
827828 elif checkpoint_name in self ._memory_pool :
828829 return self ._memory_pool [checkpoint_name ]
@@ -875,27 +876,32 @@ def register_checkpoint(
875876 f"since checkpoint { self ._current_shared_memory_pool_user } is already using shared memory pool. "
876877 f"This registration may cause unexpected conflicts."
877878 )
879+ # Since we set the uninitialized shared memory pool to empty list,
880+ # we can check whether this is the first time to use shared memory pool
881+ _is_first_time = not self ._memory_pool [self .shared_memory_pool_name ]
878882 self ._memory_pool [self .shared_memory_pool_name ] = _register_checkpoint (
879883 files = files or [],
880884 named_tensors = named_tensors or {},
881885 rank = self ._rank ,
882886 shared_pin_memory = self ._memory_pool [self .shared_memory_pool_name ],
883887 )
884888 self ._current_shared_memory_pool_user = checkpoint_name
889+ if self ._p2p_store is not None and _is_first_time :
890+ self ._register_parameters_to_p2p_store (checkpoint_name )
885891 else :
886892 assert checkpoint_name not in self ._memory_pool , (
887893 f"checkpoint { checkpoint_name } already registered"
888894 )
889895 self ._memory_pool [checkpoint_name ] = _register_checkpoint (
890896 files = files or [], named_tensors = named_tensors or {}, rank = self ._rank
891897 )
892- if self ._p2p_store is not None :
893- self ._register_parameters_to_p2p_store (checkpoint_name )
898+ if self ._p2p_store is not None :
899+ self ._register_parameters_to_p2p_store (checkpoint_name )
894900 except Exception :
895901 logger .exception (
896902 f"[rank{ self ._rank } ] fail to register checkpoint { checkpoint_name } with files { files } "
897903 )
898- if self ._p2p_store is not None :
904+ if self ._p2p_store is not None and not use_shared_memory_pool :
899905 self ._unregister_parameters_from_p2p_store (checkpoint_name )
900906 self .unregister_checkpoint (checkpoint_name )
901907 raise
@@ -910,20 +916,19 @@ def unregister_checkpoint(self, checkpoint_name: str):
910916 and checkpoint_name != self ._current_shared_memory_pool_user
911917 ):
912918 logger .warning (
913- f"[rank{ self ._rank } ] unregister checkpoint failed, checkpoint name { checkpoint_name } not found"
919+ f"[rank{ self ._rank } ] unregister checkpoint name { checkpoint_name } not found"
914920 )
915921 return
922+
923+ if checkpoint_name == self ._current_shared_memory_pool_user :
924+ self ._current_shared_memory_pool_user = ""
925+ return
926+
916927 if self ._p2p_store is not None :
917928 num_unregistered = self ._unregister_parameters_from_p2p_store (checkpoint_name )
918929 logger .info (
919930 f"[rank{ self ._rank } ] unregister { num_unregistered } parameters from p2p store for checkpoint { checkpoint_name } "
920931 )
921- if checkpoint_name == self ._current_shared_memory_pool_user :
922- logger .info (
923- f"[rank{ self ._rank } ] unregister shared memory pool from p2p store, skip unregistering from memory pool"
924- )
925- self ._current_shared_memory_pool_user = ""
926- return
927932
928933 del self ._memory_pool [checkpoint_name ]
929934 # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
0 commit comments