@@ -1027,7 +1027,7 @@ def register_checkpoint(
10271027 self .unregister_checkpoint (checkpoint_name )
10281028 raise
10291029
1030- def unregister_checkpoint (self , checkpoint_name : str ) :
1030+ def unregister_checkpoint (self , checkpoint_name : str , force : bool = False ) -> None :
10311031 """
10321032 Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
10331033 from p2p store if p2p store is initialized.
@@ -1041,10 +1041,7 @@ def unregister_checkpoint(self, checkpoint_name: str):
10411041 )
10421042 return
10431043
1044- # TODO: currently, we just mark the shared memory pool as unused when unregistering.
1045- # Physically releasing the shared memory pool is not supported yet.
1046- # We may add unregister shared memory pool logic in the future if necessary.
1047- if checkpoint_name == self ._current_shared_memory_pool_user :
1044+ if checkpoint_name == self ._current_shared_memory_pool_user and not force :
10481045 self ._current_shared_memory_pool_user = ""
10491046 return
10501047
@@ -1054,7 +1051,12 @@ def unregister_checkpoint(self, checkpoint_name: str):
10541051 f"[rank{ self ._rank } ] unregister { num_unregistered } parameters from p2p store for checkpoint { checkpoint_name } "
10551052 )
10561053
1057- del self ._memory_pool [checkpoint_name ]
1054+ if checkpoint_name == self ._current_shared_memory_pool_user :
1055+ self ._current_shared_memory_pool_user = ""
1056+ del self ._memory_pool [self .shared_memory_pool_name ]
1057+ self ._memory_pool [self .shared_memory_pool_name ] = []
1058+ else :
1059+ del self ._memory_pool [checkpoint_name ]
10581060 # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
10591061 # this works by using torch>=2.5.0
10601062 torch ._C ._host_emptyCache ()
@@ -1353,8 +1355,13 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str):
13531355 if len (pool ) == 0 :
13541356 return
13551357 named_tensors , tensor_ptrs = {}, []
1358+ register_name = (
1359+ checkpoint_name
1360+ if checkpoint_name != self ._current_shared_memory_pool_user
1361+ else self .shared_memory_pool_name
1362+ )
13561363 for idx , memory_buffer in enumerate (pool ):
1357- named_tensors [f"memory_pool_{ checkpoint_name } _{ idx } " ] = memory_buffer .buffer
1364+ named_tensors [f"memory_pool_{ register_name } _{ idx } " ] = memory_buffer .buffer
13581365 tensor_ptrs .append ((memory_buffer .buffer .data_ptr (), memory_buffer .size ))
13591366 self ._p2p_store .register_named_tensors (named_tensors )
13601367
@@ -1363,8 +1370,13 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
13631370 pool = self ._get_memory_pool (checkpoint_name )
13641371 if len (pool ) == 0 :
13651372 return 0
1373+ unregister_name = (
1374+ checkpoint_name
1375+ if checkpoint_name != self ._current_shared_memory_pool_user
1376+ else self .shared_memory_pool_name
1377+ )
13661378 return self ._p2p_store .unregister_named_tensors (
1367- [f"memory_pool_{ checkpoint_name } _{ idx } " for idx , _ in enumerate (pool )]
1379+ [f"memory_pool_{ unregister_name } _{ idx } " for idx , _ in enumerate (pool )]
13681380 )
13691381
13701382 def _update_per_bucket (
0 commit comments