Skip to content

Commit 745e4c2

Browse files
authored
Merge branch 'main' into feat/optimize_p2p
Signed-off-by: specture724 <149605198+specture724@users.noreply.github.com>
2 parents 59b8b38 + 8a60e65 commit 745e4c2

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

checkpoint_engine/ps.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import argparse
42
import concurrent.futures
53
import ctypes
@@ -10,6 +8,7 @@
108
import threading
119
import time
1210
from collections import defaultdict
11+
from collections.abc import Callable
1312
from datetime import timedelta
1413
from functools import lru_cache
1514
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
@@ -26,8 +25,6 @@
2625

2726

2827
if TYPE_CHECKING:
29-
from collections.abc import Callable
30-
3128
from typing_extensions import TypedDict
3229

3330
class FileMeta(TypedDict):
@@ -152,8 +149,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
152149
return ret
153150

154151

155-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
156-
def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
152+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
153+
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
157154
ret = {}
158155
with safe_open(fn, framework="pt") as f:
159156
for name in f.keys(): # noqa: SIM118
@@ -169,7 +166,7 @@ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
169166
return ret
170167

171168
# deprecated, will be removed in the future
172-
def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
169+
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
173170
"""load *.np file and return memmap and related tensor meta"""
174171

175172
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
@@ -654,27 +651,43 @@ def batch_transfer_sync_read(
654651

655652
class ParameterServer:
656653
def __init__(
657-
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
654+
self,
655+
*,
656+
rank: int | None = None,
657+
world_size: int | None = None,
658+
auto_pg: bool = False,
659+
gpu_count: int | None = None,
660+
mem_fraction: float | None = None,
658661
):
659662
"""
660663
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
661664
662665
Args:
663666
auto_pg: Whether to automatically initialize the process group.
664667
Notice that if auto_pg is True, will destroy the process group after update.
668+
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
665669
"""
666670
self._rank = rank or int(os.environ.get("RANK", None))
667671
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
668-
self._gpu_count = torch.cuda.device_count()
672+
self._gpu_count = gpu_count or torch.cuda.device_count()
669673
self._local_rank = self._rank % self._gpu_count
670674
self._auto_pg = auto_pg
671675
self._all_hosts = []
672676
self._global_device_uuids: list[str] = []
673677
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
674678
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
679+
self._mem_fraction = mem_fraction or 0.9
675680

676681
assert self._rank is not None and self._rank >= 0, self._rank
677682
assert self._world_size and self._world_size > 0, self._world_size
683+
assert (
684+
self._gpu_count is not None
685+
and self._gpu_count > 0
686+
and self._gpu_count <= torch.cuda.device_count()
687+
), self._gpu_count
688+
assert (
689+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
690+
), self._mem_fraction
678691

679692
self._zmq_ctx = zmq.Context()
680693
self._zmq_addr_counter = 0
@@ -879,6 +892,8 @@ def update(
879892
dist.destroy_process_group()
880893
# HACK: wait 2s to ensure destroy is finished
881894
time.sleep(2)
895+
if self._rank not in ranks:
896+
return
882897
self.init_process_group_for_ranks(ranks)
883898
if self._rank not in ranks:
884899
return
@@ -914,8 +929,8 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
914929
# auto detect bucket size
915930
tensor = torch.tensor(
916931
[
917-
# 90% of current cuda free memory bytes
918-
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
932+
# proportion of current cuda free memory bytes
933+
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
919934
# we use negative value to reuse allreduce min operation
920935
# for getting the max value of zmq_addr_counter in all ranks
921936
-self._zmq_addr_counter,

0 commit comments

Comments
 (0)