@@ -176,6 +176,8 @@ def __init__(
176176 auto_pg : bool = True ,
177177 gpu_count : int | None = None ,
178178 mem_fraction : float | None = None ,
179+ master_addr : str | None = None ,
180+ master_port : int | None = None ,
179181 ):
180182 """
181183 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -229,6 +231,17 @@ def __init__(
229231 self ._device_uuid = _get_physical_gpu_id (self .device_manager , device_index )
230232 self ._rdma_device = None if self ._p2p_store is None else self ._p2p_store .device
231233
234+ master_addr = master_addr or os .getenv ("MASTER_ADDR" )
235+ assert master_addr , "master_addr is required"
236+ self ._store = torch .distributed .TCPStore (
237+ master_addr ,
238+ _get_master_port (master_port ),
239+ self ._world_size ,
240+ timeout = timedelta (minutes = 10 ),
241+ is_master = self ._rank == 0 ,
242+ )
243+ self ._store_counter = 0
244+
232245 def _get_memory_pool (self , checkpoint_name : str ) -> list [MemoryBuffer ]:
233246 if checkpoint_name == self ._current_shared_memory_pool_user :
234247 assert self ._memory_pool [self .shared_memory_pool_name ], (
@@ -392,7 +405,11 @@ def _unpin(t: torch.Tensor):
392405 )
393406 cudart = torch .cuda .cudart ()
394407 r = cudart .cudaHostUnregister (t .data_ptr ())
395- assert r == 0 , f"unpin memory error, error code: { r } "
408+ if r != 0 :
409+ error_msg = cudart .cudaGetErrorString (r )
410+ raise RuntimeError (
411+ f"unpin memory error, error code: { r } , error message: { error_msg } "
412+ )
396413
397414 # if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
398415 try :
@@ -408,7 +425,13 @@ def _unpin(t: torch.Tensor):
408425 del self ._memory_pool [checkpoint_name ]
409426 # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
410427 # this works by using torch>=2.5.0
411- torch ._C ._host_emptyCache ()
428+ if self .device_manager .device_type == "cuda" :
429+ torch ._C ._host_emptyCache ()
430+ else :
431+ # torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
432+ import gc
433+
434+ gc .collect ()
412435
413436 def gather_metas (self , checkpoint_name : str ):
414437 """
@@ -478,8 +501,6 @@ def gather_metas(self, checkpoint_name: str):
478501 def init_process_group (
479502 self ,
480503 * ,
481- master_addr : str | None = None ,
482- master_port : int | None = None ,
483504 timeout : timedelta = timedelta (minutes = 10 ),
484505 ):
485506 """
@@ -489,21 +510,18 @@ def init_process_group(
489510 master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
490511 timeout: The timeout of the process group.
491512 """
492- master_addr = master_addr or os .getenv ("MASTER_ADDR" )
493- assert master_addr , "master_addr is required"
513+ self ._store_counter += 1
494514 dist .init_process_group (
495- host = master_addr ,
496- port = _get_master_port (master_port ),
497515 rank = self ._rank ,
498516 world_size = self ._world_size ,
517+ store = self ._store ,
499518 timeout = timeout ,
500519 backend = self .device_manager .backend ,
520+ store_counter = self ._store_counter ,
501521 )
502522 logger .info (f"[rank{ self ._rank } ] init process group successfully." )
503523
504- def store_based_barrier (
505- self , store : torch .distributed .TCPStore , timeout : timedelta = timedelta (minutes = 5 )
506- ) -> None :
524+ def store_based_barrier (self , timeout : timedelta = timedelta (minutes = 5 )) -> None :
507525 """
508526 Perform a store-based barrier synchronization across all ranks.
509527
@@ -516,7 +534,7 @@ def store_based_barrier(
516534 """
517535 torch .distributed .distributed_c10d ._store_based_barrier (
518536 rank = self ._rank ,
519- store = store ,
537+ store = self . _store ,
520538 group_name = "parameter_server_barrier" ,
521539 rendezvous_count = self ._world_size ,
522540 timeout = timeout ,
@@ -529,8 +547,6 @@ def update(
529547 * ,
530548 timeout : timedelta = timedelta (minutes = 10 ),
531549 ranks : list [int ] | None = None ,
532- master_addr : str | None = None ,
533- master_port : int | None = None ,
534550 ) -> None :
535551 """
536552 Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -551,25 +567,12 @@ def update(
551567 assert req_func is not None , "req_func is required"
552568 ranks_group = None
553569 try :
554- master_addr = os .getenv ("MASTER_ADDR" ) or master_addr
555- assert master_addr , "master_addr is required"
556570 if self ._auto_pg and not dist .is_initialized ():
557- self .init_process_group (
558- timeout = timeout , master_addr = master_addr , master_port = master_port
559- )
560- # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
561- # If master_port is provided, use master_port+1 for barrier store
562- manager_store = torch .distributed .TCPStore (
563- master_addr ,
564- _get_master_port (master_port ) + 1 ,
565- self ._world_size ,
566- timeout = timeout ,
567- is_master = self ._rank == 0 ,
568- )
571+ self .init_process_group (timeout = timeout )
569572 # if ranks is None or [], it will use fully broadcast to update to all ranks
570573 ranks_group = dist .new_group (ranks ) if ranks else None
571574 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
572- self .store_based_barrier (manager_store )
575+ self .store_based_barrier ()
573576 except Exception as e :
574577 logger .exception (
575578 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } "
@@ -580,7 +583,6 @@ def update(
580583 dist .destroy_process_group (ranks_group )
581584 if self ._auto_pg and dist .is_initialized ():
582585 dist .destroy_process_group ()
583- del manager_store
584586 self .device_manager .device_module .empty_cache ()
585587 logger .info (
586588 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
0 commit comments