@@ -436,29 +436,62 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
436
436
return _DummyWork (tensors )
437
437
438
438
439
- class ManagedProcessGroup (ErrorSwallowingProcessGroupWrapper ):
439
+ class _ManagedWork (Work ):
440
+ def __init__ (self , manager : "Manager" , work : Work , default_result : object ) -> None :
441
+ super ().__init__ ()
442
+
443
+ self ._manager = manager
444
+ self ._work = work
445
+ self ._default_result = default_result
446
+
447
+ def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
448
+ try :
449
+ if timeout is not None :
450
+ self ._work .wait (timeout )
451
+ else :
452
+ self ._work .wait ()
453
+ except Exception as e :
454
+ self ._manager .report_error (e )
455
+
456
+ return True
457
+
458
+ def get_future (self ) -> Future [object ]:
459
+ return self ._manager .wrap_future (self ._work .get_future (), self ._default_result )
460
+
461
+
462
+ class ManagedProcessGroup (ProcessGroupWrapper ):
440
463
"""
441
464
This is a wrapper around any ProcessGroup that is managed by a torchft
442
465
Manager.
466
+
467
+ This uses the ProcessGroup that is configured in the Manager. The world size
468
+ is dynamic and will report the number of active particpants in the quorum to
469
+ the model.
470
+
471
+ Any errors will be asynchronously reported to the manager and only successes
472
+ will be returned to the caller.
443
473
"""
444
474
445
475
def __init__ (self , manager : "Manager" ) -> None :
446
476
super ().__init__ (manager ._pg )
447
477
448
478
self ._manager = manager
449
479
450
- def report_error (self , e : Exception ) -> None :
451
- """
452
- Report an error to this process group. This will cause all future
453
- operations to be skipped until the process group is reconfigured via
454
- ``configure``.
480
+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
481
+ if self ._manager .errored () is not None :
482
+ return _DummyWork (tensors )
455
483
456
- Args:
457
- e: exception to report
458
- """
459
- super ().report_error (e )
484
+ try :
485
+ work = super ().allreduce (tensors , opts )
486
+ except Exception as e :
487
+ self ._manager .report_error (e )
488
+ return _DummyWork (tensors )
460
489
461
- self ._manager .report_error (e )
490
+ return _ManagedWork (
491
+ self ._manager ,
492
+ work ,
493
+ tensors ,
494
+ )
462
495
463
496
def size (self ) -> int :
464
497
return self ._manager .num_participants ()
0 commit comments