Skip to content

Commit 63ee40c

Browse files
authored
manager: expand API to include errors, participant information and numeric test (#19)
* manager: added participant information * manager: error reporting APIs and numerics test
1 parent 4ad6f1b commit 63ee40c

File tree

2 files changed

+132
-23
lines changed

2 files changed

+132
-23
lines changed

torchft/manager.py

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import uuid
3333
from concurrent.futures import ThreadPoolExecutor
3434
from datetime import timedelta
35-
from typing import Dict, List, Optional
35+
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
3636

3737
import torch
3838
from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work
@@ -42,6 +42,9 @@
4242
# pyre-fixme[21]: can't find rust module
4343
from torchft.torchft import Manager as _Manager, ManagerClient
4444

45+
if TYPE_CHECKING:
46+
from torchft.process_group import ProcessGroup
47+
4548
logger: logging.Logger = logging.getLogger(__name__)
4649

4750
MANAGER_ADDR_KEY: str = "manager_addr"
@@ -58,9 +61,9 @@ class Manager:
5861

5962
def __init__(
6063
self,
61-
pg,
62-
load_state_dict,
63-
state_dict,
64+
pg: "ProcessGroup",
65+
load_state_dict: Callable[[object], None],
66+
state_dict: Callable[[], object],
6467
min_replica_size: int,
6568
port: int = MANAGER_DEFAULT_PORT,
6669
use_async_quorum: bool = True,
@@ -175,15 +178,14 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
175178
Returns:
176179
a Future that will be completed with the allreduced gradient
177180
"""
178-
if self._errored:
181+
if self.errored():
179182
fut = torch.futures.Future()
180183
fut.set_result(grad)
181184
return fut
182185

183186
self._quorum_future.result()
184187

185-
if self._healing:
186-
assert self._use_async_quorum
188+
if not self.is_participating():
187189
grad.zero_()
188190

189191
# TODO: increase timeout when waiting when healing
@@ -193,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
193195
work = self._pg.allreduce([grad], ReduceOp.SUM)
194196
fut = work.get_future()
195197

196-
# schedule error handling and grad normalization as a continuation
198+
# schedule grad normalization as a continuation
197199
# on the Future
198200
def callback(
199201
fut: torch.futures.Future[List[torch.Tensor]],
200202
) -> torch.futures.Future[torch.Tensor]:
201203
nonlocal grad
202204

203-
try:
204-
val = fut.value()
205-
except Exception:
206-
logger.exception(
207-
"got exception in all reduce future -- skipping remaining"
208-
)
209-
self._errored = True
210-
return grad
205+
fut.value()
211206

212-
grad /= self._participating_replicas
207+
grad /= self.num_participants()
213208

214209
return grad
215210

216211
fut = fut.then(callback)
217-
self._pending_work.append(fut)
212+
fut = self.wrap_future(fut, grad)
218213
return fut
219214

220215
except Exception as e:
221-
logger.exception("got exception in all reduce -- skipping remaining")
222-
self._errored = True
216+
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
217+
self.report_error()
223218

224219
fut = torch.futures.Future()
225220
fut.set_result(grad)
226221
return fut
227222

223+
def report_error(self) -> None:
224+
"""
225+
Report an error to the manager.
226+
227+
This will cause the manager to skip the current step and will be
228+
reconfigured on the next step.
229+
230+
This should be called when an error occurs that leads to a corrupted
231+
gradient that needs to be discarded.
232+
"""
233+
self._errored = True
234+
235+
def errored(self) -> bool:
236+
"""
237+
Get whether an error has occurred.
238+
239+
Returns:
240+
whether an error has occurred
241+
"""
242+
return self._errored
243+
244+
def wrap_future(self, fut: torch.futures.Future[object], default: object) -> None:
245+
"""
246+
Wrap a Future and swallow any errors that occur and report them to the manager.
247+
248+
If an error occurs, the Future will be completed with the default value.
249+
250+
Args:
251+
fut: the Future to wrap
252+
default: the default value to complete the Future with if an error occurs
253+
"""
254+
255+
# schedule error handling and grad normalization as a continuation
256+
# on the Future
257+
def callback(
258+
fut: torch.futures.Future[List[torch.Tensor]],
259+
) -> torch.futures.Future[torch.Tensor]:
260+
nonlocal default
261+
262+
try:
263+
return fut.value()
264+
except Exception as e:
265+
logger.exception(f"got exception in future -- skipping remaining: {e}")
266+
self.report_error()
267+
return default
268+
269+
fut = fut.then(callback)
270+
self._pending_work.append(fut)
271+
return fut
272+
228273
def step(self) -> None:
229274
"""
230275
.. note::
@@ -411,3 +456,26 @@ def batches_committed(self) -> int:
411456
the total number of batches committed
412457
"""
413458
return self._batches_committed
459+
460+
def num_participants(self) -> int:
461+
"""
462+
Get the number of participants in the current quorum.
463+
464+
This is the number of replicas participating in the current step.
465+
466+
Returns:
467+
the number of participants in the current quorum
468+
"""
469+
return self._participating_replicas
470+
471+
def is_participating(self) -> bool:
472+
"""
473+
Get whether this replica is participating in the current quorum.
474+
475+
Returns:
476+
whether this replica is participating in the current quorum
477+
"""
478+
if self._healing:
479+
assert self._use_async_quorum
480+
return False
481+
return True

torchft/manager_test.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from unittest import TestCase
8-
from unittest.mock import patch, create_autospec, MagicMock
8+
from unittest.mock import create_autospec, MagicMock, patch
99

1010
import torch
1111
from torch.distributed import TCPStore
12+
from torchft.manager import Manager, MANAGER_ADDR_KEY
13+
from torchft.process_group import _DummyWork, ProcessGroup
1214

1315
from torchft.torchft import ManagerClient
14-
from torchft.manager import Manager, MANAGER_ADDR_KEY
15-
from torchft.process_group import ProcessGroup
1616

1717

1818
class TestManager(TestCase):
@@ -129,6 +129,8 @@ def test_quorum_heal_sync(self, client_mock) -> None:
129129
manager.step()
130130
manager.allreduce_grad(torch.tensor([1.0])).wait()
131131
self.assertFalse(manager._healing)
132+
self.assertTrue(manager.is_participating())
133+
self.assertEqual(manager.num_participants(), 2)
132134
self.assertTrue(manager.should_commit())
133135

134136
self.assertEqual(manager._quorum_id, 123)
@@ -164,6 +166,8 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
164166
manager.step()
165167
manager._quorum_future.result()
166168
self.assertTrue(manager._healing)
169+
self.assertFalse(manager.is_participating())
170+
self.assertEqual(manager.num_participants(), 1)
167171

168172
grad = torch.tensor([1.0])
169173
manager.allreduce_grad(grad).wait()
@@ -307,3 +311,40 @@ def test_allreduce_error(self, client_mock) -> None:
307311
manager.step()
308312
manager.allreduce_grad(torch.tensor([1.0])).wait()
309313
self.assertTrue(manager.should_commit())
314+
315+
@patch("torchft.manager.ManagerClient", autospec=True)
316+
def test_manager_report_error(self, client_mock) -> None:
317+
manager = self._create_manager()
318+
319+
self.assertFalse(manager.errored())
320+
manager.report_error()
321+
self.assertTrue(manager.errored())
322+
323+
@patch("torchft.manager.ManagerClient", autospec=True)
324+
def test_manager_wrap_future(self, client_mock) -> None:
325+
manager = self._create_manager()
326+
327+
self.assertFalse(manager.errored())
328+
329+
fut = torch.futures.Future()
330+
wrapped_fut = manager.wrap_future(fut, 2)
331+
332+
fut.set_exception(RuntimeError("injected failure"))
333+
334+
self.assertEqual(wrapped_fut.value(), 2)
335+
self.assertTrue(manager.errored())
336+
self.assertEqual(manager._pending_work, [wrapped_fut])
337+
338+
@patch("torchft.manager.ManagerClient", autospec=True)
339+
def test_manager_numerics(self, client_mock) -> None:
340+
manager = self._create_manager()
341+
342+
manager._quorum_future = MagicMock()
343+
manager._participating_replicas = 5
344+
self.assertEqual(manager.num_participants(), 5)
345+
manager._pg.allreduce.return_value = _DummyWork(None)
346+
347+
fut = torch.futures.Future()
348+
fut = manager.allreduce_grad(torch.tensor([1.0]))
349+
result = fut.value()
350+
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))

0 commit comments

Comments
 (0)