Skip to content

Commit 61c475a

Browse files
committed
update
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
1 parent 3b80669 commit 61c475a

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

checkpoint_engine/ps.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -592,18 +592,25 @@ def batch_transfer_sync_read(
592592

593593
class ParameterServer:
594594
def __init__(
595-
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False, mem_fraction: float | None = None
595+
self,
596+
*,
597+
rank: int | None = None,
598+
world_size: int | None = None,
599+
auto_pg: bool = False,
600+
gpu_count: int | None = None,
601+
mem_fraction: float | None = None,
596602
):
597603
"""
598604
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
599605
600606
Args:
601607
auto_pg: Whether to automatically initialize the process group.
602608
Notice that if auto_pg is True, will destroy the process group after update.
609+
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
603610
"""
604611
self._rank = rank or int(os.environ.get("RANK", None))
605612
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
606-
self._gpu_count = self._world_size or torch.cuda.device_count()
613+
self._gpu_count = gpu_count or torch.cuda.device_count()
607614
self._local_rank = self._rank % self._gpu_count
608615
self._auto_pg = auto_pg
609616
self._all_hosts = []
@@ -612,13 +619,13 @@ def __init__(
612619

613620
assert self._rank is not None and self._rank >= 0, self._rank
614621
assert self._world_size and self._world_size > 0, self._world_size
615-
assert (self._gpu_count is not None and
616-
self._gpu_count > 0 and
617-
self._gpu_count <= torch.cuda.device_count()
622+
assert (
623+
self._gpu_count is not None
624+
and self._gpu_count > 0
625+
and self._gpu_count <= torch.cuda.device_count()
618626
), self._gpu_count
619-
assert (self._mem_fraction is not None and
620-
self._mem_fraction > 0 and
621-
self._mem_fraction <= 1
627+
assert (
628+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
622629
), self._mem_fraction
623630

624631
self._zmq_ctx = zmq.Context()

0 commit comments

Comments
 (0)