3232import uuid
3333from concurrent .futures import ThreadPoolExecutor
3434from datetime import timedelta
35- from typing import Dict , List , Optional
35+ from typing import Callable , Dict , List , Optional , TYPE_CHECKING
3636
3737import torch
3838from torch .distributed import PrefixStore , ReduceOp , TCPStore , Work
4242# pyre-fixme[21]: can't find rust module
4343from torchft .torchft import Manager as _Manager , ManagerClient
4444
45+ if TYPE_CHECKING :
46+ from torchft .process_group import ProcessGroup
47+
4548logger : logging .Logger = logging .getLogger (__name__ )
4649
4750MANAGER_ADDR_KEY : str = "manager_addr"
@@ -58,9 +61,9 @@ class Manager:
5861
5962 def __init__ (
6063 self ,
61- pg ,
62- load_state_dict ,
63- state_dict ,
64+ pg : "ProcessGroup" ,
65+ load_state_dict : Callable [[ object ], None ] ,
66+ state_dict : Callable [[], object ] ,
6467 min_replica_size : int ,
6568 port : int = MANAGER_DEFAULT_PORT ,
6669 use_async_quorum : bool = True ,
@@ -175,15 +178,14 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
175178 Returns:
176179 a Future that will be completed with the allreduced gradient
177180 """
178- if self ._errored :
181+ if self .errored () :
179182 fut = torch .futures .Future ()
180183 fut .set_result (grad )
181184 return fut
182185
183186 self ._quorum_future .result ()
184187
185- if self ._healing :
186- assert self ._use_async_quorum
188+ if not self .is_participating ():
187189 grad .zero_ ()
188190
189191 # TODO: increase timeout when waiting when healing
@@ -193,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
193195 work = self ._pg .allreduce ([grad ], ReduceOp .SUM )
194196 fut = work .get_future ()
195197
196- # schedule error handling and grad normalization as a continuation
198+ # schedule grad normalization as a continuation
197199 # on the Future
198200 def callback (
199201 fut : torch .futures .Future [List [torch .Tensor ]],
200202 ) -> torch .futures .Future [torch .Tensor ]:
201203 nonlocal grad
202204
203- try :
204- val = fut .value ()
205- except Exception :
206- logger .exception (
207- "got exception in all reduce future -- skipping remaining"
208- )
209- self ._errored = True
210- return grad
205+ fut .value ()
211206
212- grad /= self ._participating_replicas
207+ grad /= self .num_participants ()
213208
214209 return grad
215210
216211 fut = fut .then (callback )
217- self ._pending_work . append (fut )
212+ fut = self .wrap_future (fut , grad )
218213 return fut
219214
220215 except Exception as e :
221- logger .exception ("got exception in all reduce -- skipping remaining" )
222- self ._errored = True
216+ logger .exception (f "got exception in all reduce -- skipping remaining: { e } " )
217+ self .report_error ()
223218
224219 fut = torch .futures .Future ()
225220 fut .set_result (grad )
226221 return fut
227222
223+ def report_error (self ) -> None :
224+ """
225+ Report an error to the manager.
226+
227+ This will cause the manager to skip the current step and will be
228+ reconfigured on the next step.
229+
230+ This should be called when an error occurs that leads to a corrupted
231+ gradient that needs to be discarded.
232+ """
233+ self ._errored = True
234+
235+ def errored (self ) -> bool :
236+ """
237+ Get whether an error has occurred.
238+
239+ Returns:
240+ whether an error has occurred
241+ """
242+ return self ._errored
243+
244+ def wrap_future (self , fut : torch .futures .Future [object ], default : object ) -> None :
245+ """
246+ Wrap a Future and swallow any errors that occur and report them to the manager.
247+
248+ If an error occurs, the Future will be completed with the default value.
249+
250+ Args:
251+ fut: the Future to wrap
252+ default: the default value to complete the Future with if an error occurs
253+ """
254+
255+ # schedule error handling and grad normalization as a continuation
256+ # on the Future
257+ def callback (
258+ fut : torch .futures .Future [List [torch .Tensor ]],
259+ ) -> torch .futures .Future [torch .Tensor ]:
260+ nonlocal default
261+
262+ try :
263+ return fut .value ()
264+ except Exception as e :
265+ logger .exception (f"got exception in future -- skipping remaining: { e } " )
266+ self .report_error ()
267+ return default
268+
269+ fut = fut .then (callback )
270+ self ._pending_work .append (fut )
271+ return fut
272+
228273 def step (self ) -> None :
229274 """
230275 .. note::
@@ -411,3 +456,26 @@ def batches_committed(self) -> int:
411456 the total number of batches committed
412457 """
413458 return self ._batches_committed
459+
460+ def num_participants (self ) -> int :
461+ """
462+ Get the number of participants in the current quorum.
463+
464+ This is the number of replicas participating in the current step.
465+
466+ Returns:
467+ the number of participants in the current quorum
468+ """
469+ return self ._participating_replicas
470+
471+ def is_participating (self ) -> bool :
472+ """
473+ Get whether this replica is participating in the current quorum.
474+
475+ Returns:
476+ whether this replica is participating in the current quorum
477+ """
478+ if self ._healing :
479+ assert self ._use_async_quorum
480+ return False
481+ return True
0 commit comments