Skip to content

Commit 63c82c1

Browse files
committed
process_group: added ManagedProcessGroup
1 parent 4e86676 commit 63c82c1

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

torchft/process_group.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,34 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
420420
return _DummyWork(tensors)
421421

422422

423+
class ManagedProcessGroup(ErrorSwallowingProcessGroupWrapper):
424+
"""
425+
This is a wrapper around any ProcessGroup that is managed by a torchft
426+
Manager.
427+
"""
428+
429+
def __init__(self, manager: "Manager") -> None:
430+
super().__init__(manager._pg)
431+
432+
self._manager = manager
433+
434+
def report_error(self, e: Exception) -> None:
435+
"""
436+
Report an error to this process group. This will cause all future
437+
operations to be skipped until the process group is reconfigured via
438+
``configure``.
439+
440+
Args:
441+
e: exception to report
442+
"""
443+
super().report_error(e)
444+
445+
self._manager.report_error()
446+
447+
def size(self) -> int:
448+
return self._manager.num_participants()
449+
450+
423451
class _BabyWork(Work):
424452
def __init__(
425453
self,

torchft/process_group_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from torch._C._distributed_c10d import _resolve_process_group
1616
from torch.distributed import _functional_collectives, ReduceOp, TCPStore
1717
from torch.distributed.device_mesh import init_device_mesh
18+
from torchft.manager import Manager
1819

1920
from torchft.process_group import (
2021
_DummyWork,
2122
_ErrorSwallowingWork,
2223
ErrorSwallowingProcessGroupWrapper,
2324
extend_device_mesh,
25+
ManagedProcessGroup,
2426
ProcessGroup,
2527
ProcessGroupBabyGloo,
2628
ProcessGroupBabyNCCL,
@@ -231,3 +233,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
231233
work.wait()
232234
fut = work.get_future()
233235
fut.wait()
236+
237+
def test_managed_process_group(self) -> None:
238+
manager = Mock(spec=Manager)
239+
manager._pg = ProcessGroupDummy(0, 1)
240+
pg = ManagedProcessGroup(manager)
241+
manager.num_participants.return_value = 123
242+
243+
self.assertEqual(pg.size(), 123)
244+
245+
err = RuntimeError("test")
246+
pg.report_error(err)
247+
self.assertEqual(pg.error(), err)
248+
self.assertEqual(manager.report_error.call_count, 1)

0 commit comments

Comments
 (0)