diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 52229d4..698f310 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -592,7 +592,13 @@ def batch_transfer_sync_read( class ParameterServer: def __init__( - self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False + self, + *, + rank: int | None = None, + world_size: int | None = None, + auto_pg: bool = False, + gpu_count: int | None = None, + mem_fraction: float | None = None, ): """ Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set. @@ -600,17 +606,27 @@ def __init__( Args: auto_pg: Whether to automatically initialize the process group. Notice that if auto_pg is True, will destroy the process group after update. + mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation. """ self._rank = rank or int(os.environ.get("RANK", None)) self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None)) - self._gpu_count = torch.cuda.device_count() + self._gpu_count = gpu_count or torch.cuda.device_count() self._local_rank = self._rank % self._gpu_count self._auto_pg = auto_pg self._all_hosts = [] self._global_device_uuids: list[str] = [] + self._mem_fraction = mem_fraction or 0.9 assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size + assert ( + self._gpu_count is not None + and self._gpu_count > 0 + and self._gpu_count <= torch.cuda.device_count() + ), self._gpu_count + assert ( + self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1 + ), self._mem_fraction self._zmq_ctx = zmq.Context() self._zmq_addr_counter = 0 @@ -832,8 +848,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, # auto detect bucket size tensor = torch.tensor( [ - # 90% of current cuda free memory bytes - int(float(torch.cuda.mem_get_info()[0]) * 0.9), + # proportion of current cuda free memory bytes + int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction), # we use negative value to reuse allreduce min operation # for getting the max value of zmq_addr_counter in all ranks -self._zmq_addr_counter,