1- from __future__ import annotations
2-
31import argparse
42import concurrent .futures
53import ctypes
108import threading
119import time
1210from collections import defaultdict
11+ from collections .abc import Callable
1312from datetime import timedelta
1413from functools import lru_cache
1514from typing import TYPE_CHECKING , Annotated , Any , BinaryIO , NamedTuple
2625
2726
2827if TYPE_CHECKING :
29- from collections .abc import Callable
30-
3128 from typing_extensions import TypedDict
3229
3330 class FileMeta (TypedDict ):
@@ -152,8 +149,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
152149 return ret
153150
154151
155- def _load_checkpoint_file (file_path : str ) -> tuple [int , dict [str , tuple [FileMeta , torch .Tensor ]]]:
156- def _safetensors_load (fn : str ) -> dict [str , tuple [FileMeta , torch .Tensor ]]:
152+ def _load_checkpoint_file (file_path : str ) -> tuple [int , dict [str , tuple [" FileMeta" , torch .Tensor ]]]:
153+ def _safetensors_load (fn : str ) -> dict [str , tuple [" FileMeta" , torch .Tensor ]]:
157154 ret = {}
158155 with safe_open (fn , framework = "pt" ) as f :
159156 for name in f .keys (): # noqa: SIM118
@@ -169,7 +166,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
169166 return ret
170167
171168 # deprecated, will be removed in the future
172- def _fast_np_load (fn : str ) -> dict [str , tuple [FileMeta , torch .Tensor ]]:
169+ def _fast_np_load (fn : str ) -> dict [str , tuple [" FileMeta" , torch .Tensor ]]:
173170 """load *.np file and return memmap and related tensor meta"""
174171
175172 def parse_npy_header (fin : BinaryIO ) -> dict [str , Any ]:
@@ -654,27 +651,43 @@ def batch_transfer_sync_read(
654651
655652class ParameterServer :
656653 def __init__ (
657- self , * , rank : int | None = None , world_size : int | None = None , auto_pg : bool = False
654+ self ,
655+ * ,
656+ rank : int | None = None ,
657+ world_size : int | None = None ,
658+ auto_pg : bool = False ,
659+ gpu_count : int | None = None ,
660+ mem_fraction : float | None = None ,
658661 ):
659662 """
660663 Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
661664
662665 Args:
663666 auto_pg: Whether to automatically initialize the process group.
664667 Notice that if auto_pg is True, will destroy the process group after update.
668+ mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
665669 """
666670 self ._rank = rank or int (os .environ .get ("RANK" , None ))
667671 self ._world_size = world_size or int (os .environ .get ("WORLD_SIZE" , None ))
668- self ._gpu_count = torch .cuda .device_count ()
672+ self ._gpu_count = gpu_count or torch .cuda .device_count ()
669673 self ._local_rank = self ._rank % self ._gpu_count
670674 self ._auto_pg = auto_pg
671675 self ._all_hosts = []
672676 self ._global_device_uuids : list [str ] = []
673677 self ._local_rdma_devices : dict [str , set [int ]] = defaultdict (set )
674678 self ._remote_rdma_devices : dict [str , set [int ]] = defaultdict (set )
679+ self ._mem_fraction = mem_fraction or 0.9
675680
676681 assert self ._rank is not None and self ._rank >= 0 , self ._rank
677682 assert self ._world_size and self ._world_size > 0 , self ._world_size
683+ assert (
684+ self ._gpu_count is not None
685+ and self ._gpu_count > 0
686+ and self ._gpu_count <= torch .cuda .device_count ()
687+ ), self ._gpu_count
688+ assert (
689+ self ._mem_fraction is not None and self ._mem_fraction > 0 and self ._mem_fraction <= 1
690+ ), self ._mem_fraction
678691
679692 self ._zmq_ctx = zmq .Context ()
680693 self ._zmq_addr_counter = 0
@@ -879,6 +892,8 @@ def update(
879892 dist .destroy_process_group ()
880893 # HACK: wait 2s to ensure destroy is finished
881894 time .sleep (2 )
895+ if self ._rank not in ranks :
896+ return
882897 self .init_process_group_for_ranks (ranks )
883898 if self ._rank not in ranks :
884899 return
@@ -914,8 +929,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
914929 # auto detect bucket size
915930 tensor = torch .tensor (
916931 [
917- # 90% of current cuda free memory bytes
918- int (float (torch .cuda .mem_get_info ()[0 ]) * 0.9 ),
932+ # proportion of current cuda free memory bytes
933+ int (float (torch .cuda .mem_get_info ()[0 ]) * self . _mem_fraction ),
919934 # we use negative value to reuse allreduce min operation
920935 # for getting the max value of zmq_addr_counter in all ranks
921936 - self ._zmq_addr_counter ,
0 commit comments