Skip to content

Commit 141e419

Browse files
authored
process_group: add PG timeouts + automatically assign manager port (#60)
1 parent 517f300 commit 141e419

File tree

3 files changed

+125
-40
lines changed

3 files changed

+125
-40
lines changed

torchft/manager.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from torchft.process_group import ProcessGroup
4747

4848
MANAGER_ADDR_KEY: str = "manager_addr"
49-
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
49+
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
5050
REPLICA_ID_KEY: str = "replica_id"
5151

5252
T = TypeVar("T")
@@ -74,6 +74,12 @@ class Manager:
7474
"""
7575
Manager manages the full fault tolerant training loop.
7676
77+
This requires the that the TCPStore specified by the store_addr and
78+
store_port or MASTER_ADDR and MASTER_PORT environment variables to be
79+
started prior to creating this manager. If using a modern version of
80+
torchelastic this will already be the case. Otherwise, it should be started
81+
via torch.distributed.init_process_group prior to creating this manager.
82+
7783
NOTE: when saving periodic checkpoints you must save and restore the
7884
Manager's state_dict as well to avoid synchronization issues.
7985
"""
@@ -84,7 +90,6 @@ def __init__(
8490
load_state_dict: Callable[[T], None],
8591
state_dict: Callable[[], T],
8692
min_replica_size: int,
87-
port: int = MANAGER_DEFAULT_PORT,
8893
use_async_quorum: bool = True,
8994
timeout: timedelta = timedelta(seconds=60),
9095
rank: Optional[int] = None,
@@ -94,13 +99,18 @@ def __init__(
9499
store_port: Optional[int] = None,
95100
lighthouse_addr: Optional[str] = None,
96101
replica_id: Optional[str] = None,
102+
port: Optional[int] = None,
97103
) -> None:
98104
"""
99105
Args:
100106
load_state_dict: function to load the state dict when recovering
101107
state_dict: function to save the state dict with recovering
102108
min_replica_size: minimum number of replicas on each step
103-
port: if rank==0, the port to run the manager server on
109+
port: if rank==0, the port to run the manager server on.
110+
Port assignment priority:
111+
1. this argument
112+
2. TORCHFT_MANAGER_PORT env var
113+
3. arbitrary port assigned via 0
104114
use_async_quorum: whether to run the quorum asynchronously during the forward pass
105115
timeout:
106116
the default timeout for all operation, if you're using per
@@ -150,6 +160,10 @@ def _manager_state_dict() -> Dict[str, T]:
150160

151161
if rank == 0:
152162
hostname = socket.gethostname()
163+
164+
if port is None:
165+
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
166+
153167
addr = f"http://{hostname}:{port}"
154168
bind = f"[::]:{port}"
155169
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
@@ -166,7 +180,7 @@ def _manager_state_dict() -> Dict[str, T]:
166180
world_size=world_size,
167181
)
168182

169-
self._store.set(MANAGER_ADDR_KEY, addr)
183+
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
170184
self._store.set(REPLICA_ID_KEY, replica_id)
171185

172186
addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")

torchft/process_group.py

+69-25
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
"""
1818

1919
import logging
20+
import queue
2021
import threading
2122
from abc import ABC
2223
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Type
24+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
2425

2526
import torch
2627
import torch.distributed as dist
@@ -53,8 +54,23 @@
5354
_FUTURE_EXCEPTION = "fut_exception"
5455

5556

56-
def _get(queue: mp.Queue, timeout: float) -> object:
57-
v = queue.get(timeout=timeout)
57+
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
58+
"""
59+
Gets an item from a queue with a timeout. If the timeout is exceeded then
60+
a TimeoutError is raised.
61+
62+
If an exception is returned from the queue then it is raised.
63+
64+
Args:
65+
q: queue to get from
66+
timeout: timeout in seconds
67+
"""
68+
if isinstance(timeout, timedelta):
69+
timeout = timeout.total_seconds()
70+
try:
71+
v = q.get(timeout=timeout)
72+
except queue.Empty as e:
73+
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e
5874
if isinstance(v, Exception):
5975
raise v
6076
return v
@@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
95111
Every time this is called it must be provided with a unique prefixed
96112
store address. I.e. localhost:1234/my/prefix/1
97113
114+
This function will block until the underlying ProcessGroup is created.
115+
If an error occurs this will throw.
116+
98117
Args:
99118
store_addr: address of the store to use
100119
rank: rank of this process
@@ -187,7 +206,6 @@ def __repr__(self) -> str:
187206

188207

189208
class ProcessGroupWrapper(ProcessGroup):
190-
PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
191209
"""
192210
This is a wrapper around any ProcessGroup with a reconfiguration method.
193211
"""
@@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209227

210228
store = create_store_client(store_addr)
211229

212-
# TODO: set global timeout
213-
# pyre-fixme[20]: expects argument options
214-
self._pg = self.PG_CLASS(store, rank, world_size)
230+
self._pg = self._create_pg(store, rank, world_size)
231+
232+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
233+
raise NotImplementedError("not implemented")
215234

216235
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
217236
return self.parent.allreduce(tensors, opts)
@@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244263
This is a reconfigurable version of ProcessGroupGloo.
245264
"""
246265

247-
PG_CLASS: Type[BaseProcessGroup] = (
248-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249-
)
266+
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
267+
super().__init__()
268+
self._timeout = timeout
269+
270+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
271+
# pyre-fixme[16]: no attribute ProcessGroupGloo
272+
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)
250273

251274
def getBackendName(self) -> str:
252275
return "torchft-gloo"
@@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263286
abort when reconfiguring, we need to ensure this is safe.
264287
"""
265288

266-
PG_CLASS: Type[BaseProcessGroup] = (
267-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268-
)
289+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
290+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
291+
return BaseProcessGroupNCCL(store, rank, world_size)
269292

270293
def getBackendName(self) -> str:
271294
return "torchft-nccl"
@@ -546,10 +569,9 @@ class ProcessGroupBaby(ProcessGroup):
546569
547570
"""
548571

549-
PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
550572
WORK_CLASS: Type[_BabyWork] = _BabyWork
551573

552-
def __init__(self, timeout: float = 60.0) -> None:
574+
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
553575
super().__init__(0, 1)
554576

555577
self._world_size = -1
@@ -562,7 +584,10 @@ def __init__(self, timeout: float = 60.0) -> None:
562584
self._futures: Dict[int, Future[object]] = {}
563585
self._futures_lock = threading.Lock()
564586

565-
self._timeout = timeout
587+
if isinstance(timeout, timedelta):
588+
timeout = timeout.total_seconds()
589+
590+
self._timeout: float = timeout
566591

567592
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
568593
if self._p is not None:
@@ -581,7 +606,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
581606

582607
ctx = mp.get_context("spawn")
583608
self._tx = ctx.Queue()
584-
self._rx = ctx.Queue()
609+
self._rx = rx = ctx.Queue()
585610

586611
# futures need thread to fire callbacks
587612
self._future_queue = ctx.Queue()
@@ -602,6 +627,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
602627
)
603628
self._p.start()
604629

630+
# fetch the status of the PG init
631+
# if an exception was returned _get will throw
632+
assert _get(rx, self._timeout) is None
633+
634+
@classmethod
635+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
636+
"""
637+
This is a class method to avoid pickling the class.
638+
"""
639+
raise NotImplementedError("not implemented")
640+
605641
@classmethod
606642
def _worker(
607643
cls,
@@ -615,8 +651,13 @@ def _worker(
615651
try:
616652
store = create_store_client(store_addr)
617653

618-
# pyre-fixme[20]: expects argument options
619-
pg = cls.PG_CLASS(store, rank, world_size)
654+
try:
655+
pg = cls._create_pg(store, rank, world_size)
656+
except Exception as e:
657+
logger.exception(f"got exception in worker: {e}")
658+
tx.put(e)
659+
return
660+
tx.put(None)
620661

621662
work = {}
622663
next_op_id: int = 0
@@ -737,9 +778,10 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
737778
ProcessGroupBabyNCCL.
738779
"""
739780

740-
PG_CLASS: Type[BaseProcessGroup] = (
741-
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
742-
)
781+
@classmethod
782+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
783+
# pyre-fixme[16]: no attribute ProcessGroupGloo
784+
return BaseProcessGroupGloo(store, rank, world_size)
743785

744786
def getBackendName(self) -> str:
745787
return "torchft-baby-gloo"
@@ -761,11 +803,13 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
761803
tensors may leak in the current PyTorch implementation. TODO fix
762804
"""
763805

764-
PG_CLASS: Type[BaseProcessGroup] = (
765-
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
766-
)
767806
WORK_CLASS = _BabyWorkNCCL
768807

808+
@classmethod
809+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
810+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
811+
return BaseProcessGroupNCCL(store, rank, world_size)
812+
769813
def getBackendName(self) -> str:
770814
return "torchft-baby-nccl"
771815

torchft/process_group_test.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
from concurrent.futures import ThreadPoolExecutor
9+
from datetime import timedelta
910
from typing import Any, Dict, Tuple
1011
from unittest import TestCase, skipUnless
1112
from unittest.mock import Mock
@@ -122,6 +123,16 @@ def test_gloo(self) -> None:
122123
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
123124
m(torch.rand(2, 3))
124125

126+
def test_gloo_timeout(self) -> None:
127+
store = TCPStore(
128+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
129+
)
130+
131+
store_addr = f"localhost:{store.port}/prefix"
132+
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01))
133+
with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"):
134+
pg.configure(store_addr, 0, 2)
135+
125136
# pyre-fixme[56]: Pyre was not able to infer the type of argument
126137
@skipUnless(torch.cuda.is_available(), "needs CUDA")
127138
def test_nccl(self) -> None:
@@ -155,28 +166,44 @@ def test_baby_gloo(self) -> None:
155166
host_name="localhost", port=0, is_master=True, wait_for_workers=False
156167
)
157168

158-
store_addr = f"localhost:{store.port}/prefix"
169+
store_addr: str = f"localhost:{store.port}/prefix"
170+
171+
def run(rank: int) -> Tuple[torch.Tensor, Work]:
172+
a = ProcessGroupBabyGloo()
173+
a.configure(store_addr, rank, 2)
159174

160-
a = ProcessGroupBabyGloo()
161-
b = ProcessGroupBabyGloo()
175+
self.assertEqual(a.size(), 2)
162176

163-
a.configure(store_addr, 0, 2)
164-
b.configure(store_addr, 1, 2)
177+
at = torch.tensor([rank + 1])
165178

166-
self.assertEqual(a.size(), 2)
179+
a_work = a.allreduce([at], ReduceOp.SUM)
180+
return at, a_work
167181

168-
at = torch.tensor([1])
169-
bt = torch.tensor([2])
182+
with ThreadPoolExecutor(max_workers=2) as executor:
183+
a_fut = executor.submit(run, 0)
184+
b_fut = executor.submit(run, 1)
170185

171-
a_work = a.allreduce([at], ReduceOp.SUM)
172-
b_work = b.allreduce([bt], ReduceOp.SUM)
186+
at, a_work = a_fut.result()
187+
bt, b_work = b_fut.result()
173188

174189
a_work.wait()
175190
fut = b_work.get_future()
176191

177192
fut.wait()
178193

179-
torch.testing.assert_close(at, bt)
194+
torch.testing.assert_close(at, torch.tensor([3]))
195+
torch.testing.assert_close(bt, torch.tensor([3]))
196+
197+
def test_baby_gloo_timeout(self) -> None:
198+
store = TCPStore(
199+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
200+
)
201+
202+
store_addr = f"localhost:{store.port}/prefix"
203+
204+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01))
205+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
206+
a.configure(store_addr, 0, 2)
180207

181208
def test_dummy(self) -> None:
182209
pg = ProcessGroupDummy(0, 1)

0 commit comments

Comments
 (0)