Skip to content

Commit b49ff3e

Browse files
author
kip-cxj
committed
modify init pg
1 parent 0901e9f commit b49ff3e

File tree

4 files changed

+10
-25
lines changed

4 files changed

+10
-25
lines changed

checkpoint_engine/distributed/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,17 @@ def init_process_group(
106106
rank: int,
107107
world_size: int,
108108
store: torch.distributed.TCPStore,
109-
timeout: timedelta,
110109
**kwargs,
111110
):
112111
backend = kwargs.get("backend", "nccl")
113-
store_counter = kwargs.get("store_counter", "1")
114-
sub_store = torch.distributed.PrefixStore(f"prefix-{store_counter}", store)
112+
timeout = kwargs.get("timeout", timedelta(minutes=10))
113+
115114
torch.distributed.init_process_group(
116115
backend=backend,
117116
world_size=world_size,
118117
rank=rank,
119118
timeout=timeout,
120-
store=sub_store,
119+
store=store,
121120
)
122121

123122
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
@@ -244,10 +243,9 @@ def init_process_group(
244243
rank: int,
245244
world_size: int,
246245
store: torch.distributed.TCPStore,
247-
timeout: timedelta = timedelta(seconds=300),
248246
**kwargs,
249247
):
250-
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, timeout, **kwargs)
248+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs)
251249

252250

253251
def destroy_process_group(group: DistributedProcessGroup | None = None):

checkpoint_engine/distributed/hccl.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ctypes
22
from contextlib import contextmanager
3-
from datetime import timedelta
43
from typing import Any, ClassVar
54

65
import torch
@@ -233,20 +232,15 @@ def init_process_group(
233232
rank: int,
234233
world_size: int,
235234
store: torch.distributed.TCPStore,
236-
timeout: timedelta = timedelta(seconds=300),
237235
**kwargs,
238236
):
239237
assert not self.initialized, "already initialized"
240238

241-
self.host = store.host
242-
self.port = store.port + 1
243239
self.rank = rank
244240
self.world_size = world_size
245241
self.device = torch.device("npu", torch.npu.current_device())
246242

247-
self.pg = StatelessProcessGroup.create(
248-
self.host, self.port, rank, world_size, store_timeout=int(timeout.total_seconds())
249-
)
243+
self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
250244
self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
251245
self.comm = self.pyhccl.comm
252246
self.initialized = True

checkpoint_engine/distributed/nccl.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ctypes
22
from contextlib import contextmanager
3-
from datetime import timedelta
43
from typing import Any, ClassVar
54

65
import torch
@@ -136,21 +135,15 @@ def init_process_group(
136135
rank: int,
137136
world_size: int,
138137
store: torch.distributed.TCPStore,
139-
timeout: timedelta = timedelta(seconds=300),
140138
**kwargs,
141139
):
142140
assert not self.initialized, "already initialized"
143141

144-
self.host = store.host
145-
self.port = store.port + 1
146142
self.rank = rank
147143
self.world_size = world_size
148144
self.device = torch.device("cuda", torch.cuda.current_device())
149145

150-
self.pg = StatelessProcessGroup.create(
151-
self.host, self.port, rank, world_size, store_timeout=int(timeout.total_seconds())
152-
)
153-
146+
self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None)
154147
self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device)
155148
self.comm = self.pynccl.comm
156149
self.initialized = True

checkpoint_engine/ps.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,13 @@ def init_process_group(
511511
timeout: The timeout of the process group.
512512
"""
513513
self._store_counter += 1
514+
sub_store = torch.distributed.PrefixStore(f"prefix-{self._store_counter}", self._store)
514515
dist.init_process_group(
515-
rank=self._rank,
516+
backend=self.device_manager.backend,
516517
world_size=self._world_size,
517-
store=self._store,
518+
rank=self._rank,
518519
timeout=timeout,
519-
backend=self.device_manager.backend,
520-
store_counter=self._store_counter,
520+
store=sub_store,
521521
)
522522
logger.info(f"[rank{self._rank}] init process group successfully.")
523523

0 commit comments

Comments
 (0)