Skip to content

Commit 3b80669

Browse files
committed
feat: support configurable mem fraction
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
1 parent 1d18b03 commit 3b80669

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def batch_transfer_sync_read(
592592

593593
class ParameterServer:
594594
def __init__(
595-
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
595+
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False, mem_fraction: float | None = None
596596
):
597597
"""
598598
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -608,13 +608,18 @@ def __init__(
608608
self._auto_pg = auto_pg
609609
self._all_hosts = []
610610
self._global_device_uuids: list[str] = []
611+
self._mem_fraction = mem_fraction or 0.9
611612

612613
assert self._rank is not None and self._rank >= 0, self._rank
613614
assert self._world_size and self._world_size > 0, self._world_size
614615
assert (self._gpu_count is not None and
615616
self._gpu_count > 0 and
616617
self._gpu_count <= torch.cuda.device_count()
617618
), self._gpu_count
619+
assert (self._mem_fraction is not None and
620+
self._mem_fraction > 0 and
621+
self._mem_fraction <= 1
622+
), self._mem_fraction
618623

619624
self._zmq_ctx = zmq.Context()
620625
self._zmq_addr_counter = 0
@@ -836,8 +841,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
836841
# auto detect bucket size
837842
tensor = torch.tensor(
838843
[
839-
# 90% of current cuda free memory bytes
840-
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
844+
# proportion of current cuda free memory bytes
845+
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
841846
# we use negative value to reuse allreduce min operation
842847
# for getting the max value of zmq_addr_counter in all ranks
843848
-self._zmq_addr_counter,

0 commit comments

Comments
 (0)