Skip to content

Commit 4fc8d6f

Browse files
committed
Allow set mem_fraction from ENV, allow set float max_bucket_size_gb, update use of os.environ
1 parent 88370e2 commit 4fc8d6f

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

checkpoint_engine/ps.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -789,15 +789,17 @@ def _get_master_port(master_port: int | None = None) -> int:
789789
if master_port is None:
790790
# HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
791791
# TODO: check whether master_port is available or use a more elegant way
792-
master_port = int(os.getenv("MASTER_PORT")) + 1
792+
master_port_str = os.getenv("MASTER_PORT")
793+
assert master_port_str, "MASTER_PORT is required if no master_port is provided."
794+
master_port = int(master_port_str) + 1
793795
return master_port
794796

795797

796798
class P2PStore:
797799
def __init__(self, device_manager: DeviceManager):
798800
from mooncake.engine import TransferEngine
799801

800-
self.rank = int(os.getenv("RANK"))
802+
self.rank = int(os.environ["RANK"]) # ENV RANK is required
801803
gpu_count = device_manager.device_module.device_count()
802804
local_rank = self.rank % gpu_count
803805
device_type = device_manager.device_type
@@ -887,8 +889,8 @@ def __init__(
887889
Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
888890
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
889891
"""
890-
self._rank = rank or int(os.environ.get("RANK", None))
891-
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
892+
self._rank = rank or int(os.environ["RANK"])
893+
self._world_size = world_size or int(os.environ["WORLD_SIZE"])
892894
self.device_manager = DeviceManager()
893895
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
894896
self._local_rank = self._rank % self._gpu_count
@@ -897,7 +899,7 @@ def __init__(
897899
self._global_device_uuids: list[str] = []
898900
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
899901
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
900-
self._mem_fraction = mem_fraction or 0.9
902+
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
901903

902904
assert self._rank is not None and self._rank >= 0, self._rank
903905
assert self._world_size and self._world_size > 0, self._world_size
@@ -1352,7 +1354,7 @@ def _detect_bucket_size(
13521354
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
13531355
)
13541356
disable_h2d_buffer = True
1355-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
1357+
max_bytes = int(float(os.getenv("PS_MAX_BUCKET_SIZE_GB", "8")) * GiB)
13561358
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
13571359
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
13581360
return bucket_size, disable_h2d_buffer

0 commit comments

Comments
 (0)