Skip to content

Commit 9c13c5f

Browse files
committed
manager: error reporting APIs and numerics test
1 parent bbaf95e commit 9c13c5f

File tree

2 files changed

+94
-14
lines changed

2 files changed

+94
-14
lines changed

torchft/manager.py

+56-13
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
178178
Returns:
179179
a Future that will be completed with the allreduced gradient
180180
"""
181-
if self._errored:
181+
if self.errored():
182182
fut = torch.futures.Future()
183183
fut.set_result(grad)
184184
return fut
@@ -195,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
195195
work = self._pg.allreduce([grad], ReduceOp.SUM)
196196
fut = work.get_future()
197197

198-
# schedule error handling and grad normalization as a continuation
198+
# schedule grad normalization as a continuation
199199
# on the Future
200200
def callback(
201201
fut: torch.futures.Future[List[torch.Tensor]],
202202
) -> torch.futures.Future[torch.Tensor]:
203203
nonlocal grad
204204

205-
try:
206-
val = fut.value()
207-
except Exception:
208-
logger.exception(
209-
"got exception in all reduce future -- skipping remaining"
210-
)
211-
self._errored = True
212-
return grad
205+
fut.value()
213206

214207
grad /= self.num_participants()
215208

216209
return grad
217210

218211
fut = fut.then(callback)
219-
self._pending_work.append(fut)
212+
fut = self.wrap_future(fut, grad)
220213
return fut
221214

222215
except Exception as e:
223-
logger.exception("got exception in all reduce -- skipping remaining")
224-
self._errored = True
216+
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
217+
self.report_error()
225218

226219
fut = torch.futures.Future()
227220
fut.set_result(grad)
228221
return fut
229222

223+
def report_error(self) -> None:
224+
"""
225+
Report an error to the manager.
226+
227+
This will cause the manager to skip the current step and will be
228+
reconfigured on the next step.
229+
230+
This should be called when an error occurs that leads to a corrupted
231+
gradient that needs to be discarded.
232+
"""
233+
self._errored = True
234+
235+
def errored(self) -> bool:
236+
"""
237+
Get whether an error has occurred.
238+
239+
Returns:
240+
whether an error has occurred
241+
"""
242+
return self._errored
243+
244+
def wrap_future(self, fut: torch.futures.Future[object], default: object) -> None:
245+
"""
246+
Wrap a Future and swallow any errors that occur and report them to the manager.
247+
248+
If an error occurs, the Future will be completed with the default value.
249+
250+
Args:
251+
fut: the Future to wrap
252+
default: the default value to complete the Future with if an error occurs
253+
"""
254+
255+
# schedule error handling and grad normalization as a continuation
256+
# on the Future
257+
def callback(
258+
fut: torch.futures.Future[List[torch.Tensor]],
259+
) -> torch.futures.Future[torch.Tensor]:
260+
nonlocal default
261+
262+
try:
263+
return fut.value()
264+
except Exception as e:
265+
logger.exception(f"got exception in future -- skipping remaining: {e}")
266+
self.report_error()
267+
return default
268+
269+
fut = fut.then(callback)
270+
self._pending_work.append(fut)
271+
return fut
272+
230273
def step(self) -> None:
231274
"""
232275
.. note::

torchft/manager_test.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from torch.distributed import TCPStore
1212
from torchft.manager import Manager, MANAGER_ADDR_KEY
13-
from torchft.process_group import ProcessGroup
13+
from torchft.process_group import _DummyWork, ProcessGroup
1414

1515
from torchft.torchft import ManagerClient
1616

@@ -311,3 +311,40 @@ def test_allreduce_error(self, client_mock) -> None:
311311
manager.step()
312312
manager.allreduce_grad(torch.tensor([1.0])).wait()
313313
self.assertTrue(manager.should_commit())
314+
315+
@patch("torchft.manager.ManagerClient", autospec=True)
316+
def test_manager_report_error(self, client_mock) -> None:
317+
manager = self._create_manager()
318+
319+
self.assertFalse(manager.errored())
320+
manager.report_error()
321+
self.assertTrue(manager.errored())
322+
323+
@patch("torchft.manager.ManagerClient", autospec=True)
324+
def test_manager_wrap_future(self, client_mock) -> None:
325+
manager = self._create_manager()
326+
327+
self.assertFalse(manager.errored())
328+
329+
fut = torch.futures.Future()
330+
wrapped_fut = manager.wrap_future(fut, 2)
331+
332+
fut.set_exception(RuntimeError("injected failure"))
333+
334+
self.assertEqual(wrapped_fut.value(), 2)
335+
self.assertTrue(manager.errored())
336+
self.assertEqual(manager._pending_work, [wrapped_fut])
337+
338+
@patch("torchft.manager.ManagerClient", autospec=True)
339+
def test_manager_numerics(self, client_mock) -> None:
340+
manager = self._create_manager()
341+
342+
manager._quorum_future = MagicMock()
343+
manager._participating_replicas = 5
344+
self.assertEqual(manager.num_participants(), 5)
345+
manager._pg.allreduce.return_value = _DummyWork(None)
346+
347+
fut = torch.futures.Future()
348+
fut = manager.allreduce_grad(torch.tensor([1.0]))
349+
result = fut.value()
350+
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))

0 commit comments

Comments
 (0)