Skip to content

Commit b93e062

Browse files
author
yexin
committed
remove dist_wrapper.py
1 parent 76c35e8 commit b93e062

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

checkpoint_engine/dist_wrapper.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

checkpoint_engine/distributed/base.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import importlib
77

88
import torch
9-
from torch.distributed import ReduceOp
9+
import torch.distributed as torch_dist
1010

1111

1212
class Distributed(ABC):
@@ -45,7 +45,7 @@ def all_gather_object(
4545
def all_reduce(
4646
self,
4747
tensor: torch.Tensor,
48-
op :ReduceOp,
48+
op :torch_dist.ReduceOp,
4949
group,
5050
):
5151
raise NotImplementedError
@@ -159,13 +159,14 @@ def init_process_group(
159159

160160
def destroy_process_group(group=None):
161161
if _BACKEND_INSTANCE is None:
162-
raise RuntimeError("distribute module not initialized")
162+
torch_dist.destroy_process_group(group)
163+
return
163164
_BACKEND_INSTANCE.destroy_process_group(group)
164165

165166

166167
def is_initialized() -> bool:
167168
if _BACKEND_INSTANCE is None:
168-
return False
169+
return torch_dist.is_initialized()
169170
return _BACKEND_INSTANCE.is_initialized()
170171

171172
def all_gather_object(
@@ -174,37 +175,43 @@ def all_gather_object(
174175
group=None,
175176
):
176177
if _BACKEND_INSTANCE is None:
177-
raise RuntimeError("distribute module not initialized")
178+
torch_dist.all_gather_object(object_list, obj, group)
179+
return
178180
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
179181

180182

181183
def all_reduce(
182184
tensor: torch.Tensor,
183-
op=ReduceOp.SUM,
185+
op=torch_dist.ReduceOp.SUM,
184186
group=None,
187+
**kwargs,
185188
):
186189
if _BACKEND_INSTANCE is None:
187-
raise RuntimeError("distribute module not initialized")
190+
torch_dist.all_reduce(tensor, op, group, **kwargs)
191+
return
188192
_BACKEND_INSTANCE.all_reduce(tensor, op, group)
189193

190194

191195
def broadcast(
192196
tensor: torch.Tensor,
193-
src= None,
197+
src=None,
194198
group=None,
199+
**kwargs,
195200
):
196201
if _BACKEND_INSTANCE is None:
197-
raise RuntimeError("distribute module not initialized")
202+
torch_dist.broadcast(tensor, src, group, **kwargs)
203+
return
198204
_BACKEND_INSTANCE.broadcast(tensor, src, group)
199205

200206

201-
def barrier(group=None):
207+
def barrier(group=None, **kwargs):
202208
if _BACKEND_INSTANCE is None:
203-
raise RuntimeError("distribute module not initialized")
209+
torch_dist.barrier(group, **kwargs)
210+
return
204211
_BACKEND_INSTANCE.barrier(group)
205212

206213

207-
def new_group(ranks: list[int]):
214+
def new_group(ranks: list[int], **kwargs):
208215
if _BACKEND_INSTANCE is None:
209-
raise RuntimeError("distribute module not initialized")
216+
return torch_dist.new_group(ranks, **kwargs)
210217
return _BACKEND_INSTANCE.new_group(ranks)

checkpoint_engine/ps.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
2525
from checkpoint_engine.p2p_store import P2PStore
2626
from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
27-
from checkpoint_engine.dist_wrapper import dist
27+
import checkpoint_engine.distributed as dist
2828

2929

3030
if TYPE_CHECKING:
@@ -176,6 +176,7 @@ def __init__(
176176
auto_pg: bool = True,
177177
gpu_count: int | None = None,
178178
mem_fraction: float | None = None,
179+
custom_dist: bool = False,
179180
):
180181
"""
181182
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -196,6 +197,7 @@ def __init__(
196197
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
197198
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
198199
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
200+
self._custom_dist = custom_dist
199201

200202
assert self._rank is not None and self._rank >= 0, self._rank
201203
assert self._world_size and self._world_size > 0, self._world_size
@@ -491,7 +493,7 @@ def init_process_group(
491493
"""
492494
master_addr = master_addr or os.getenv("MASTER_ADDR")
493495
assert master_addr, "master_addr is required"
494-
if dist is torch.distributed:
496+
if not self._custom_dist:
495497
store = torch.distributed.TCPStore(
496498
master_addr,
497499
_get_master_port(master_port),
@@ -518,7 +520,7 @@ def init_process_group(
518520
logger.info(f"[rank{self._rank}] init process group successfully.")
519521

520522
def store_based_barrier(
521-
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
523+
self, store, timeout: timedelta = timedelta(minutes=5)
522524
) -> None:
523525
"""
524526
Perform a store-based barrier synchronization across all ranks.
@@ -606,7 +608,7 @@ def zmq_handle(device_uuid: str) -> str:
606608
return socket, socket_paths
607609

608610
def _detect_bucket_size(
609-
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
611+
self, ranks_group, *, disable_h2d_buffer: bool = False
610612
) -> tuple[int, bool]:
611613
GiB = 1 << 30 # noqa: N806
612614
# auto detect bucket size
@@ -725,7 +727,7 @@ def _update_per_bucket(
725727
self,
726728
checkpoint_name: str,
727729
req_func: Callable[[list[tuple[str, str]]], None],
728-
ranks_group: dist.ProcessGroup | None,
730+
ranks_group,
729731
ranks: list[int] | None = None,
730732
):
731733
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"

examples/update.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from checkpoint_engine.ps import ParameterServer
1717
from checkpoint_engine.api import request_inference_to_update
18-
from checkpoint_engine.dist_wrapper import dist, setup_dist
18+
import checkpoint_engine.distributed as dist
1919

2020

2121
@contextmanager
@@ -164,11 +164,8 @@ def join(
164164
rank = int(os.getenv("RANK"))
165165
world_size = int(os.getenv("WORLD_SIZE"))
166166

167-
if args.custom_dist:
168-
setup_dist()
169-
170167
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
171-
ps = ParameterServer(auto_pg=True)
168+
ps = ParameterServer(auto_pg=True, custom_dist=args.custom_dist)
172169
if args.load_metas_file:
173170
join(
174171
ps,

0 commit comments

Comments
 (0)