@@ -592,7 +592,7 @@ def batch_transfer_sync_read(
592592
593593class ParameterServer :
594594 def __init__ (
595- self , * , rank : int | None = None , world_size : int | None = None , auto_pg : bool = False
595+ self , * , rank : int | None = None , world_size : int | None = None , auto_pg : bool = False , mem_fraction : float | None = None
596596 ):
597597 """
598598 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -608,13 +608,18 @@ def __init__(
608608 self ._auto_pg = auto_pg
609609 self ._all_hosts = []
610610 self ._global_device_uuids : list [str ] = []
611+ self ._mem_fraction = mem_fraction or 0.9
611612
612613 assert self ._rank is not None and self ._rank >= 0 , self ._rank
613614 assert self ._world_size and self ._world_size > 0 , self ._world_size
614615 assert (self ._gpu_count is not None and
615616 self ._gpu_count > 0 and
616617 self ._gpu_count <= torch .cuda .device_count ()
617618 ), self ._gpu_count
619+ assert (self ._mem_fraction is not None and
620+ self ._mem_fraction > 0 and
621+ self ._mem_fraction <= 1
622+ ), self ._mem_fraction
618623
619624 self ._zmq_ctx = zmq .Context ()
620625 self ._zmq_addr_counter = 0
@@ -836,8 +841,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
836841 # auto detect bucket size
837842 tensor = torch .tensor (
838843 [
839- # 90% of current cuda free memory bytes
840- int (float (torch .cuda .mem_get_info ()[0 ]) * 0.9 ),
844+ # proportion of current cuda free memory bytes
845+ int (float (torch .cuda .mem_get_info ()[0 ]) * self . _mem_fraction ),
841846 # we use negative value to reuse allreduce min operation
842847 # for getting the max value of zmq_addr_counter in all ranks
843848 - self ._zmq_addr_counter ,
0 commit comments