@@ -789,15 +789,17 @@ def _get_master_port(master_port: int | None = None) -> int:
789789 if master_port is None :
790790 # HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
791791 # TODO: check whether master_port is available or use a more elegant way
792- master_port = int (os .getenv ("MASTER_PORT" )) + 1
792+ master_port_str = os .getenv ("MASTER_PORT" )
793+ assert master_port_str , "MASTER_PORT is required if no master_port is provided."
794+ master_port = int (master_port_str ) + 1
793795 return master_port
794796
795797
796798class P2PStore :
797799 def __init__ (self , device_manager : DeviceManager ):
798800 from mooncake .engine import TransferEngine
799801
800- self .rank = int (os .getenv ( "RANK" ))
802+ self .rank = int (os .environ [ "RANK" ]) # ENV RANK is required
801803 gpu_count = device_manager .device_module .device_count ()
802804 local_rank = self .rank % gpu_count
803805 device_type = device_manager .device_type
@@ -887,8 +889,8 @@ def __init__(
887889 Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
888890 mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
889891 """
890- self ._rank = rank or int (os .environ . get ( "RANK" , None ) )
891- self ._world_size = world_size or int (os .environ . get ( "WORLD_SIZE" , None ) )
892+ self ._rank = rank or int (os .environ [ "RANK" ] )
893+ self ._world_size = world_size or int (os .environ [ "WORLD_SIZE" ] )
892894 self .device_manager = DeviceManager ()
893895 self ._gpu_count = gpu_count or self .device_manager .device_module .device_count ()
894896 self ._local_rank = self ._rank % self ._gpu_count
@@ -897,7 +899,7 @@ def __init__(
897899 self ._global_device_uuids : list [str ] = []
898900 self ._local_rdma_devices : dict [str , set [int ]] = defaultdict (set )
899901 self ._remote_rdma_devices : dict [str , set [int ]] = defaultdict (set )
900- self ._mem_fraction = mem_fraction or 0.9
902+ self ._mem_fraction = mem_fraction or float ( os . getenv ( "PS_MEM_FRACTION" , " 0.9" ))
901903
902904 assert self ._rank is not None and self ._rank >= 0 , self ._rank
903905 assert self ._world_size and self ._world_size > 0 , self ._world_size
@@ -1352,7 +1354,7 @@ def _detect_bucket_size(
13521354 f"max_tensor_bytes { max_tensor_bytes } should be less than free_bytes { free_bytes } "
13531355 )
13541356 disable_h2d_buffer = True
1355- max_bytes = int (os .getenv ("PS_MAX_BUCKET_SIZE_GB" , 8 )) * GiB
1357+ max_bytes = int (float ( os .getenv ("PS_MAX_BUCKET_SIZE_GB" , "8" )) * GiB )
13561358 bucket_size = min (max (max_bytes , max_tensor_bytes ), free_bytes )
13571359 logger .info (f"[rank{ self ._rank } ] auto detect bucket size { bucket_size / GiB :.2f} GiB" )
13581360 return bucket_size , disable_h2d_buffer
0 commit comments