@@ -182,7 +182,6 @@ def _manager_state_dict() -> Dict[str, T]:
182
182
self ._batches_committed = 0
183
183
184
184
# first step is 1
185
- self ._should_step = True
186
185
self ._participating_rank : Optional [int ] = None
187
186
self ._participating_world_size : int = 0
188
187
@@ -218,8 +217,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
218
217
fut .set_result (grad )
219
218
return fut
220
219
221
- assert self ._quorum_future is not None , "must call step before allreduce_grad"
222
- self ._quorum_future .result ()
220
+ self .wait_quorum ()
223
221
224
222
if not self .is_participating ():
225
223
grad .zero_ ()
@@ -315,21 +313,28 @@ def callback(
315
313
self ._pending_work .append (cast (torch .futures .Future [object ], fut ))
316
314
return fut
317
315
318
- def start_step (self ) -> None :
316
+ def start_quorum (self , allow_heal : bool = True ) -> None :
319
317
"""
320
318
.. note::
321
319
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
322
320
323
321
Computes a new quorum (potentially asynchronously) and readies the
324
322
manager for a new step.
325
323
326
- Must be called before the forwards pass of each step for best
324
+ It's best practice to call this before the forwards pass of each step for
327
325
performance as computing quorum may take some time.
326
+
327
+ If allow_heal is set, the manager will attempt to heal either
328
+ synchronously before returning or asynchronously prior to any network
329
+ calls.
330
+
331
+ Args:
332
+ allow_heal: whether to allow healing at the beginning of the step
328
333
"""
329
334
330
- if self . _should_step :
331
- self ._step += 1
332
- self ._batches_committed += self . num_participants ()
335
+ # wait for previous quorum to complete
336
+ if self ._quorum_future is not None :
337
+ self ._quorum_future . result ()
333
338
334
339
self ._errored = None
335
340
self ._healing = False
@@ -338,9 +343,9 @@ def start_step(self) -> None:
338
343
# TODO: we should really be wrapping this whole section in a try-except
339
344
# block to allow gracefully recovering from issues in PG setup and quorum.
340
345
341
- self ._quorum_future = self ._executor .submit (self ._async_quorum )
346
+ self ._quorum_future = self ._executor .submit (self ._async_quorum , allow_heal )
342
347
if not self ._use_async_quorum :
343
- self ._quorum_future . result ()
348
+ self .wait_quorum ()
344
349
345
350
if self ._healing :
346
351
# eagerly apply pending state_dict so we can run the forwards pass
@@ -350,7 +355,18 @@ def start_step(self) -> None:
350
355
# and don't need to zero_grad
351
356
self ._healing = False
352
357
353
- def _async_quorum (self ) -> None :
358
+ def wait_quorum (self ) -> None :
359
+ """
360
+ Wait for the quorum to complete.
361
+
362
+ ProcessGroup will be in a healthy state after this returns.
363
+ """
364
+ assert (
365
+ self ._quorum_future is not None
366
+ ), "must call start_quorum before wait_quorum"
367
+ self ._quorum_future .result ()
368
+
369
+ def _async_quorum (self , allow_heal : bool ) -> None :
354
370
(
355
371
quorum_id ,
356
372
replica_rank ,
@@ -372,7 +388,7 @@ def _async_quorum(self) -> None:
372
388
# workers will be healthy.
373
389
self ._participating_rank , self ._participating_world_size = (
374
390
(max_rank , max_world_size )
375
- if self ._use_async_quorum
391
+ if self ._use_async_quorum or not allow_heal
376
392
else (replica_rank , replica_world_size )
377
393
)
378
394
@@ -397,7 +413,7 @@ def _async_quorum(self) -> None:
397
413
self ._quorum_id = quorum_id
398
414
399
415
# See manager.rs for healing conditions
400
- if heal :
416
+ if heal and allow_heal :
401
417
self ._healing = True
402
418
self ._logger .info (
403
419
f"healing required, fetching checkpoint server address from { address = } { max_step = } "
@@ -475,7 +491,9 @@ def should_commit(self) -> bool:
475
491
self ._ckpt_server .disallow_checkpoint ()
476
492
477
493
# decide whether we're in a healthy state to increase the step count
478
- self ._should_step = should_commit
494
+ if should_commit :
495
+ self ._step += 1
496
+ self ._batches_committed += self .num_participants ()
479
497
480
498
return should_commit
481
499
0 commit comments