20
20
import threading
21
21
from abc import ABC
22
22
from datetime import timedelta
23
- from typing import Callable , List , Optional , Tuple , Type
23
+ from typing import Callable , List , Optional , Tuple , Type , TYPE_CHECKING
24
24
25
25
import torch
26
26
import torch .distributed as dist
44
44
45
45
from torch .futures import Future
46
46
47
+ if TYPE_CHECKING :
48
+ from torchft .manager import Manager
49
+
47
50
logger = logging .getLogger (__name__ )
48
51
49
52
# TODO: use non strings which are cheaper
@@ -177,18 +180,25 @@ def unregister(self) -> None:
177
180
"""
178
181
dist .destroy_process_group (self )
179
182
183
+ def __repr__ (self ) -> str :
184
+ return f"{ self .__class__ .__name__ } ()"
185
+
180
186
181
187
class ProcessGroupWrapper (ProcessGroup ):
182
188
PG_CLASS : Type [BaseProcessGroup ]
183
189
"""
184
190
This is a wrapper around any ProcessGroup with a reconfiguration method.
185
191
"""
186
192
187
- def __init__ (self ) -> None :
193
+ def __init__ (self , pg : Optional [ ProcessGroup ] = None ) -> None :
188
194
super ().__init__ (0 , 1 )
189
- self ._pg = None
195
+ self ._pg = pg
190
196
191
197
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
198
+ if isinstance (self ._pg , ProcessGroup ):
199
+ self ._pg .configure (store_addr , rank , world_size )
200
+ return
201
+
192
202
if self ._pg is not None :
193
203
if hasattr (self ._pg , "abort" ):
194
204
self ._pg .abort ()
@@ -216,6 +226,12 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
216
226
def size (self ) -> int :
217
227
return self ._pg .size ()
218
228
229
+ def parent (self ) -> ProcessGroup :
230
+ return self ._pg
231
+
232
+ def __repr__ (self ) -> str :
233
+ return f"{ self .__class__ .__name__ } (pg={ self ._pg } )"
234
+
219
235
220
236
class ProcessGroupGloo (ProcessGroupWrapper ):
221
237
"""
@@ -252,7 +268,7 @@ def __init__(self, result):
252
268
self .future_ = torch .futures .Future ()
253
269
self .future_ .set_result (result )
254
270
255
- def wait (self , timeout ):
271
+ def wait (self , timeout = None ):
256
272
return True
257
273
258
274
def get_future (self ):
@@ -278,6 +294,10 @@ def __init__(self, rank: int, world: int) -> None:
278
294
self .wait_count = 0
279
295
self .get_future_count = 0
280
296
self ._work = []
297
+ self .configure_count = 0
298
+
299
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
300
+ self .configure_count += 1
281
301
282
302
def broadcast (self , tensor_list , opts ):
283
303
res = _DummyWork (tensor_list )
@@ -304,6 +324,102 @@ def getBackendName(self):
304
324
return "torchft-dummy"
305
325
306
326
327
+ class _ErrorSwallowingWork (Work ):
328
+ def __init__ (
329
+ self ,
330
+ pg : "ErrorSwallowingProcessGroup" ,
331
+ work : Work ,
332
+ default_result : object ,
333
+ ):
334
+ super ().__init__ ()
335
+
336
+ self ._pg = pg
337
+ self ._work = work
338
+ self ._default_result = default_result
339
+
340
+ def wait (self , timeout = None ) -> bool :
341
+ try :
342
+ self ._work .wait ()
343
+ except Exception as e :
344
+ self ._pg .report_error (e )
345
+
346
+ return True
347
+
348
+ def get_future (self ) -> Future :
349
+ fut = self ._work .get_future ()
350
+
351
+ # schedule error handling as a continuation on the Future
352
+ def callback (
353
+ fut : torch .futures .Future [List [torch .Tensor ]],
354
+ ) -> torch .futures .Future [torch .Tensor ]:
355
+ try :
356
+ return fut .value ()
357
+ except Exception as e :
358
+ logger .exception (f"got exception in future -- skipping remaining: { e } " )
359
+ self ._pg .report_error (e )
360
+ return self ._default_result
361
+
362
+ fut = fut .then (callback )
363
+ return fut
364
+
365
+
366
+ class ErrorSwallowingProcessGroupWrapper (ProcessGroupWrapper ):
367
+ """
368
+ This is a wrapper around any ProcessGroup that will swallow errors and
369
+ return dummy results on error.
370
+
371
+ This is intended to allow handling errors outside of the training loop to
372
+ avoid having to modify modeling code to support error handling.
373
+
374
+ After an error occurs all future operations will be skipped until the
375
+ process group is reconfigured via ``configure``.
376
+ """
377
+
378
+ def __init__ (self , pg : ProcessGroup ) -> None :
379
+ super ().__init__ (pg )
380
+
381
+ self ._error = None
382
+
383
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
384
+ self ._error = None
385
+
386
+ super ().configure (store_addr , rank , world_size )
387
+
388
+ def report_error (self , e : Exception ) -> None :
389
+ """
390
+ Report an error to this process group. This will cause all future
391
+ operations to be skipped until the process group is reconfigured via
392
+ ``configure``.
393
+
394
+ Args:
395
+ e: exception to report
396
+ """
397
+ self ._error = e
398
+
399
+ def error (self ) -> Optional [Exception ]:
400
+ """
401
+ Returns the error that was reported to this process group.
402
+
403
+ Returns:
404
+ exception that was reported
405
+ """
406
+ return self ._error
407
+
408
+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
409
+ if self ._error is not None :
410
+ return _DummyWork (tensors )
411
+
412
+ try :
413
+ return _ErrorSwallowingWork (
414
+ self ,
415
+ super ().allreduce (tensors , opts ),
416
+ tensors ,
417
+ )
418
+ except Exception as e :
419
+ self .report_error (e )
420
+ return _DummyWork (tensors )
421
+
422
+
307
423
class _BabyWork (Work ):
308
424
def __init__ (
309
425
self ,
0 commit comments