Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4e86676

Browse files
authoredDec 4, 2024··
process_group: wrapper updates and ErrorSwallowingProcessGroup (#21)
1 parent 03f5350 commit 4e86676

File tree

3 files changed

+158
-6
lines changed

3 files changed

+158
-6
lines changed
 

‎torchft/manager.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ def wrap_future(self, fut: torch.futures.Future[object], default: object) -> Non
252252
default: the default value to complete the Future with if an error occurs
253253
"""
254254

255-
# schedule error handling and grad normalization as a continuation
256-
# on the Future
255+
# schedule error handling as a continuation on the Future
257256
def callback(
258257
fut: torch.futures.Future[List[torch.Tensor]],
259258
) -> torch.futures.Future[torch.Tensor]:

‎torchft/process_group.py

+120-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from abc import ABC
2222
from datetime import timedelta
23-
from typing import Callable, List, Optional, Tuple, Type
23+
from typing import Callable, List, Optional, Tuple, Type, TYPE_CHECKING
2424

2525
import torch
2626
import torch.distributed as dist
@@ -44,6 +44,9 @@
4444

4545
from torch.futures import Future
4646

47+
if TYPE_CHECKING:
48+
from torchft.manager import Manager
49+
4750
logger = logging.getLogger(__name__)
4851

4952
# TODO: use non strings which are cheaper
@@ -177,18 +180,25 @@ def unregister(self) -> None:
177180
"""
178181
dist.destroy_process_group(self)
179182

183+
def __repr__(self) -> str:
184+
return f"{self.__class__.__name__}()"
185+
180186

181187
class ProcessGroupWrapper(ProcessGroup):
182188
PG_CLASS: Type[BaseProcessGroup]
183189
"""
184190
This is a wrapper around any ProcessGroup with a reconfiguration method.
185191
"""
186192

187-
def __init__(self) -> None:
193+
def __init__(self, pg: Optional[ProcessGroup] = None) -> None:
188194
super().__init__(0, 1)
189-
self._pg = None
195+
self._pg = pg
190196

191197
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
198+
if isinstance(self._pg, ProcessGroup):
199+
self._pg.configure(store_addr, rank, world_size)
200+
return
201+
192202
if self._pg is not None:
193203
if hasattr(self._pg, "abort"):
194204
self._pg.abort()
@@ -216,6 +226,12 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
216226
def size(self) -> int:
217227
return self._pg.size()
218228

229+
def parent(self) -> ProcessGroup:
230+
return self._pg
231+
232+
def __repr__(self) -> str:
233+
return f"{self.__class__.__name__}(pg={self._pg})"
234+
219235

220236
class ProcessGroupGloo(ProcessGroupWrapper):
221237
"""
@@ -252,7 +268,7 @@ def __init__(self, result):
252268
self.future_ = torch.futures.Future()
253269
self.future_.set_result(result)
254270

255-
def wait(self, timeout):
271+
def wait(self, timeout=None):
256272
return True
257273

258274
def get_future(self):
@@ -278,6 +294,10 @@ def __init__(self, rank: int, world: int) -> None:
278294
self.wait_count = 0
279295
self.get_future_count = 0
280296
self._work = []
297+
self.configure_count = 0
298+
299+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
300+
self.configure_count += 1
281301

282302
def broadcast(self, tensor_list, opts):
283303
res = _DummyWork(tensor_list)
@@ -304,6 +324,102 @@ def getBackendName(self):
304324
return "torchft-dummy"
305325

306326

327+
class _ErrorSwallowingWork(Work):
328+
def __init__(
329+
self,
330+
pg: "ErrorSwallowingProcessGroup",
331+
work: Work,
332+
default_result: object,
333+
):
334+
super().__init__()
335+
336+
self._pg = pg
337+
self._work = work
338+
self._default_result = default_result
339+
340+
def wait(self, timeout=None) -> bool:
341+
try:
342+
self._work.wait()
343+
except Exception as e:
344+
self._pg.report_error(e)
345+
346+
return True
347+
348+
def get_future(self) -> Future:
349+
fut = self._work.get_future()
350+
351+
# schedule error handling as a continuation on the Future
352+
def callback(
353+
fut: torch.futures.Future[List[torch.Tensor]],
354+
) -> torch.futures.Future[torch.Tensor]:
355+
try:
356+
return fut.value()
357+
except Exception as e:
358+
logger.exception(f"got exception in future -- skipping remaining: {e}")
359+
self._pg.report_error(e)
360+
return self._default_result
361+
362+
fut = fut.then(callback)
363+
return fut
364+
365+
366+
class ErrorSwallowingProcessGroupWrapper(ProcessGroupWrapper):
367+
"""
368+
This is a wrapper around any ProcessGroup that will swallow errors and
369+
return dummy results on error.
370+
371+
This is intended to allow handling errors outside of the training loop to
372+
avoid having to modify modeling code to support error handling.
373+
374+
After an error occurs all future operations will be skipped until the
375+
process group is reconfigured via ``configure``.
376+
"""
377+
378+
def __init__(self, pg: ProcessGroup) -> None:
379+
super().__init__(pg)
380+
381+
self._error = None
382+
383+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
384+
self._error = None
385+
386+
super().configure(store_addr, rank, world_size)
387+
388+
def report_error(self, e: Exception) -> None:
389+
"""
390+
Report an error to this process group. This will cause all future
391+
operations to be skipped until the process group is reconfigured via
392+
``configure``.
393+
394+
Args:
395+
e: exception to report
396+
"""
397+
self._error = e
398+
399+
def error(self) -> Optional[Exception]:
400+
"""
401+
Returns the error that was reported to this process group.
402+
403+
Returns:
404+
exception that was reported
405+
"""
406+
return self._error
407+
408+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
409+
if self._error is not None:
410+
return _DummyWork(tensors)
411+
412+
try:
413+
return _ErrorSwallowingWork(
414+
self,
415+
super().allreduce(tensors, opts),
416+
tensors,
417+
)
418+
except Exception as e:
419+
self.report_error(e)
420+
return _DummyWork(tensors)
421+
422+
307423
class _BabyWork(Work):
308424
def __init__(
309425
self,

‎torchft/process_group_test.py

+37
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
from concurrent.futures import ThreadPoolExecutor
99
from unittest import skipUnless, TestCase
10+
from unittest.mock import Mock
1011

1112
import torch
1213
import torch.distributed as dist
@@ -16,13 +17,17 @@
1617
from torch.distributed.device_mesh import init_device_mesh
1718

1819
from torchft.process_group import (
20+
_DummyWork,
21+
_ErrorSwallowingWork,
22+
ErrorSwallowingProcessGroupWrapper,
1923
extend_device_mesh,
2024
ProcessGroup,
2125
ProcessGroupBabyGloo,
2226
ProcessGroupBabyNCCL,
2327
ProcessGroupDummy,
2428
ProcessGroupGloo,
2529
ProcessGroupNCCL,
30+
ProcessGroupWrapper,
2631
)
2732

2833

@@ -194,3 +199,35 @@ def test_functional_collectives(self) -> None:
194199
_functional_collectives.all_reduce(t, "sum", pg).wait()
195200
finally:
196201
pg.unregister()
202+
203+
def test_process_group_wrapper(self) -> None:
204+
pg = ProcessGroupDummy(0, 1)
205+
wrapper = ProcessGroupWrapper(pg)
206+
self.assertIs(wrapper.parent(), pg)
207+
208+
wrapper.configure("addr", 0, 1)
209+
self.assertEqual(pg.configure_count, 1)
210+
211+
self.assertEqual(repr(wrapper), "ProcessGroupWrapper(pg=ProcessGroupDummy())")
212+
213+
def test_error_swallowing_process_group_wrapper(self) -> None:
214+
pg = ProcessGroupDummy(0, 1)
215+
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
216+
self.assertIs(wrapper.parent(), pg)
217+
218+
t = torch.zeros(10)
219+
work = wrapper.allreduce([t], ReduceOp.SUM)
220+
self.assertIsInstance(work, _ErrorSwallowingWork)
221+
work.wait()
222+
fut = work.get_future()
223+
fut.wait()
224+
225+
err = RuntimeError("test")
226+
wrapper.report_error(err)
227+
self.assertEqual(wrapper.error(), err)
228+
229+
work = wrapper.allreduce([t], ReduceOp.SUM)
230+
self.assertIsInstance(work, _DummyWork)
231+
work.wait()
232+
fut = work.get_future()
233+
fut.wait()

0 commit comments

Comments
 (0)
Please sign in to comment.