@@ -178,7 +178,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
178
178
Returns:
179
179
a Future that will be completed with the allreduced gradient
180
180
"""
181
- if self ._errored :
181
+ if self .errored () :
182
182
fut = torch .futures .Future ()
183
183
fut .set_result (grad )
184
184
return fut
@@ -195,38 +195,81 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
195
195
work = self ._pg .allreduce ([grad ], ReduceOp .SUM )
196
196
fut = work .get_future ()
197
197
198
- # schedule error handling and grad normalization as a continuation
198
+ # schedule grad normalization as a continuation
199
199
# on the Future
200
200
def callback (
201
201
fut : torch .futures .Future [List [torch .Tensor ]],
202
202
) -> torch .futures .Future [torch .Tensor ]:
203
203
nonlocal grad
204
204
205
- try :
206
- val = fut .value ()
207
- except Exception :
208
- logger .exception (
209
- "got exception in all reduce future -- skipping remaining"
210
- )
211
- self ._errored = True
212
- return grad
205
+ fut .value ()
213
206
214
207
grad /= self .num_participants ()
215
208
216
209
return grad
217
210
218
211
fut = fut .then (callback )
219
- self ._pending_work . append (fut )
212
+ fut = self .wrap_future (fut , grad )
220
213
return fut
221
214
222
215
except Exception as e :
223
- logger .exception ("got exception in all reduce -- skipping remaining" )
224
- self ._errored = True
216
+ logger .exception (f "got exception in all reduce -- skipping remaining: { e } " )
217
+ self .report_error ()
225
218
226
219
fut = torch .futures .Future ()
227
220
fut .set_result (grad )
228
221
return fut
229
222
223
+ def report_error (self ) -> None :
224
+ """
225
+ Report an error to the manager.
226
+
227
+ This will cause the manager to skip the current step and will be
228
+ reconfigured on the next step.
229
+
230
+ This should be called when an error occurs that leads to a corrupted
231
+ gradient that needs to be discarded.
232
+ """
233
+ self ._errored = True
234
+
235
+ def errored (self ) -> bool :
236
+ """
237
+ Get whether an error has occurred.
238
+
239
+ Returns:
240
+ whether an error has occurred
241
+ """
242
+ return self ._errored
243
+
244
+ def wrap_future (self , fut : torch .futures .Future [object ], default : object ) -> None :
245
+ """
246
+ Wrap a Future and swallow any errors that occur and report them to the manager.
247
+
248
+ If an error occurs, the Future will be completed with the default value.
249
+
250
+ Args:
251
+ fut: the Future to wrap
252
+ default: the default value to complete the Future with if an error occurs
253
+ """
254
+
255
+ # schedule error handling and grad normalization as a continuation
256
+ # on the Future
257
+ def callback (
258
+ fut : torch .futures .Future [List [torch .Tensor ]],
259
+ ) -> torch .futures .Future [torch .Tensor ]:
260
+ nonlocal default
261
+
262
+ try :
263
+ return fut .value ()
264
+ except Exception as e :
265
+ logger .exception (f"got exception in future -- skipping remaining: { e } " )
266
+ self .report_error ()
267
+ return default
268
+
269
+ fut = fut .then (callback )
270
+ self ._pending_work .append (fut )
271
+ return fut
272
+
230
273
def step (self ) -> None :
231
274
"""
232
275
.. note::
0 commit comments