Skip to content

Commit 3def1a2

Browse files
feat: add rank and world_size args in ParameterServer (#20)
1 parent 03ff7e7 commit 3def1a2

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

checkpoint_engine/ps.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,17 +590,18 @@ def batch_transfer_sync_read(
590590

591591

592592
class ParameterServer:
593-
def __init__(self, *, auto_pg: bool = False):
593+
def __init__(
594+
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
595+
):
594596
"""
595597
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
596598
597599
Args:
598600
auto_pg: Whether to automatically initialize the process group.
599601
Notice that if auto_pg is True, will destroy the process group after update.
600602
"""
601-
self._rank = int(os.environ.get("RANK", None))
602-
self._world_size = int(os.environ.get("WORLD_SIZE", None))
603-
self._master_addr = os.getenv("MASTER_ADDR")
603+
self._rank = rank or int(os.environ.get("RANK", None))
604+
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
604605
self._gpu_count = torch.cuda.device_count()
605606
self._local_rank = self._rank % self._gpu_count
606607
self._auto_pg = auto_pg
@@ -733,7 +734,11 @@ def gather_metas(self, checkpoint_name: str):
733734
)
734735

735736
def init_process_group(
736-
self, *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)
737+
self,
738+
*,
739+
master_addr: str | None = None,
740+
master_port: int | None = None,
741+
timeout: timedelta = timedelta(minutes=10),
737742
):
738743
"""
739744
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
@@ -742,8 +747,10 @@ def init_process_group(
742747
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
743748
timeout: The timeout of the process group.
744749
"""
750+
master_addr = master_addr or os.getenv("MASTER_ADDR")
751+
assert master_addr, "master_addr is required"
745752
store = dist.TCPStore(
746-
self._master_addr,
753+
master_addr,
747754
_get_master_port(master_port),
748755
self._world_size,
749756
timeout=timeout,

0 commit comments

Comments
 (0)