@@ -590,17 +590,18 @@ def batch_transfer_sync_read(
590590
591591
592592class ParameterServer :
593- def __init__ (self , * , auto_pg : bool = False ):
593+ def __init__ (
594+ self , * , rank : int | None = None , world_size : int | None = None , auto_pg : bool = False
595+ ):
594596 """
595597 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
596598
597599 Args:
598600 auto_pg: Whether to automatically initialize the process group.
599601 Notice that if auto_pg is True, will destroy the process group after update.
600602 """
601- self ._rank = int (os .environ .get ("RANK" , None ))
602- self ._world_size = int (os .environ .get ("WORLD_SIZE" , None ))
603- self ._master_addr = os .getenv ("MASTER_ADDR" )
603+ self ._rank = rank or int (os .environ .get ("RANK" , None ))
604+ self ._world_size = world_size or int (os .environ .get ("WORLD_SIZE" , None ))
604605 self ._gpu_count = torch .cuda .device_count ()
605606 self ._local_rank = self ._rank % self ._gpu_count
606607 self ._auto_pg = auto_pg
@@ -733,7 +734,11 @@ def gather_metas(self, checkpoint_name: str):
733734 )
734735
735736 def init_process_group (
736- self , * , master_port : int | None = None , timeout : timedelta = timedelta (minutes = 10 )
737+ self ,
738+ * ,
739+ master_addr : str | None = None ,
740+ master_port : int | None = None ,
741+ timeout : timedelta = timedelta (minutes = 10 ),
737742 ):
738743 """
739744 Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
@@ -742,8 +747,10 @@ def init_process_group(
742747 master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
743748 timeout: The timeout of the process group.
744749 """
750+ master_addr = master_addr or os .getenv ("MASTER_ADDR" )
751+ assert master_addr , "master_addr is required"
745752 store = dist .TCPStore (
746- self . _master_addr ,
753+ master_addr ,
747754 _get_master_port (master_port ),
748755 self ._world_size ,
749756 timeout = timeout ,
0 commit comments