Skip to content

Commit 5bc2a69

Browse files
committed
process_group: make ManagedProcessGroup use wrap_future
1 parent f82a1a2 commit 5bc2a69

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

torchft/process_group.py

+44-11
Original file line numberDiff line numberDiff line change
@@ -436,29 +436,62 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
436436
return _DummyWork(tensors)
437437

438438

439-
class ManagedProcessGroup(ErrorSwallowingProcessGroupWrapper):
439+
class _ManagedWork(Work):
440+
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
441+
super().__init__()
442+
443+
self._manager = manager
444+
self._work = work
445+
self._default_result = default_result
446+
447+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
448+
try:
449+
if timeout is not None:
450+
self._work.wait(timeout)
451+
else:
452+
self._work.wait()
453+
except Exception as e:
454+
self._manager.report_error(e)
455+
456+
return True
457+
458+
def get_future(self) -> Future[object]:
459+
return self._manager.wrap_future(self._work.get_future(), self._default_result)
460+
461+
462+
class ManagedProcessGroup(ProcessGroupWrapper):
440463
"""
441464
This is a wrapper around any ProcessGroup that is managed by a torchft
442465
Manager.
466+
467+
This uses the ProcessGroup that is configured in the Manager. The world size
468+
is dynamic and will report the number of active particpants in the quorum to
469+
the model.
470+
471+
Any errors will be asynchronously reported to the manager and only successes
472+
will be returned to the caller.
443473
"""
444474

445475
def __init__(self, manager: "Manager") -> None:
446476
super().__init__(manager._pg)
447477

448478
self._manager = manager
449479

450-
def report_error(self, e: Exception) -> None:
451-
"""
452-
Report an error to this process group. This will cause all future
453-
operations to be skipped until the process group is reconfigured via
454-
``configure``.
480+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
481+
if self._manager.errored() is not None:
482+
return _DummyWork(tensors)
455483

456-
Args:
457-
e: exception to report
458-
"""
459-
super().report_error(e)
484+
try:
485+
work = super().allreduce(tensors, opts)
486+
except Exception as e:
487+
self._manager.report_error(e)
488+
return _DummyWork(tensors)
460489

461-
self._manager.report_error(e)
490+
return _ManagedWork(
491+
self._manager,
492+
work,
493+
tensors,
494+
)
462495

463496
def size(self) -> int:
464497
return self._manager.num_participants()

torchft/process_group_test.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ProcessGroupWrapper,
3030
_DummyWork,
3131
_ErrorSwallowingWork,
32+
_ManagedWork,
3233
extend_device_mesh,
3334
)
3435

@@ -238,13 +239,19 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
238239

239240
def test_managed_process_group(self) -> None:
240241
manager = Mock(spec=Manager)
242+
manager.errored.return_value = None
241243
manager._pg = ProcessGroupDummy(0, 1)
242244
pg = ManagedProcessGroup(manager)
243245
manager.num_participants.return_value = 123
244246

245247
self.assertEqual(pg.size(), 123)
246248

247-
err = RuntimeError("test")
248-
pg.report_error(err)
249-
self.assertEqual(pg.error(), err)
250-
self.assertEqual(manager.report_error.call_count, 1)
249+
t = torch.zeros(10)
250+
work = pg.allreduce([t], ReduceOp.SUM)
251+
self.assertIsInstance(work, _ManagedWork)
252+
work.wait()
253+
fut = work.get_future()
254+
fut.wait()
255+
256+
self.assertEqual(manager.report_error.call_count, 0)
257+
self.assertEqual(manager.wrap_future.call_count, 1)

0 commit comments

Comments
 (0)