Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def batch_transfer_sync_read(

class ParameterServer:
def __init__(
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False, mem_fraction: float | None = None
):
"""
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
Expand All @@ -603,14 +603,23 @@ def __init__(
"""
self._rank = rank or int(os.environ.get("RANK", None))
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
self._gpu_count = torch.cuda.device_count()
self._gpu_count = self._world_size or torch.cuda.device_count()
self._local_rank = self._rank % self._gpu_count
self._auto_pg = auto_pg
self._all_hosts = []
self._global_device_uuids: list[str] = []
self._mem_fraction = mem_fraction or 0.9

assert self._rank is not None and self._rank >= 0, self._rank
assert self._world_size and self._world_size > 0, self._world_size
assert (self._gpu_count is not None and
self._gpu_count > 0 and
self._gpu_count <= torch.cuda.device_count()
), self._gpu_count
assert (self._mem_fraction is not None and
self._mem_fraction > 0 and
self._mem_fraction <= 1
), self._mem_fraction

self._zmq_ctx = zmq.Context()
self._zmq_addr_counter = 0
Expand Down Expand Up @@ -832,8 +841,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
# auto detect bucket size
tensor = torch.tensor(
[
# 90% of current cuda free memory bytes
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
# proportion of current cuda free memory bytes
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
# we use negative value to reuse allreduce min operation
# for getting the max value of zmq_addr_counter in all ranks
-self._zmq_addr_counter,
Expand Down
Loading