2828import logging
2929import os
3030import socket
31- import time
3231import uuid
3332from concurrent .futures import ThreadPoolExecutor
3433from datetime import timedelta
34+ from enum import Enum
3535from typing import TYPE_CHECKING , Callable , Dict , List , Optional , TypeVar , cast
3636
3737import torch
5454T = TypeVar ("T" )
5555
5656
57+ class WorldSizeMode (Enum ):
58+ """
59+ This controls the numerics for the job when doing allreduces across replicas
60+ when the world size is larger than ``min_replica_size``. The world size will
61+ never be smaller than ``min_replica_size``.
62+
63+ DYNAMIC:
64+ The world size will dynamical increase to use all available
65+ replicas and normalize the gradient by the world size.
66+ FIXED_WITH_SPARES:
67+ The number of active replicas is ``min_replica_size`` and any spares
68+ will contribute zero gradients.
69+ """
70+
71+ DYNAMIC = 0
72+ FIXED_WITH_SPARES = 1
73+
74+
5775class Manager :
5876 """
5977 Manager manages the full fault tolerant training loop.
@@ -73,6 +91,7 @@ def __init__(
7391 timeout : timedelta = timedelta (seconds = 60 ),
7492 rank : Optional [int ] = None ,
7593 world_size : Optional [int ] = None ,
94+ world_size_mode : WorldSizeMode = WorldSizeMode .DYNAMIC ,
7695 store_addr : Optional [str ] = None ,
7796 store_port : Optional [int ] = None ,
7897 lighthouse_addr : Optional [str ] = None ,
@@ -98,6 +117,7 @@ def __init__(
98117 self ._pending_state_dict : Optional [Dict [str , object ]] = None
99118 self ._use_async_quorum = use_async_quorum
100119 self ._timeout = timeout
120+ self ._world_size_mode = world_size_mode
101121
102122 store_addr = store_addr or os .environ ["MASTER_ADDR" ]
103123 store_port = store_port or int (os .environ ["MASTER_PORT" ])
@@ -150,12 +170,13 @@ def __init__(
150170 self ._quorum_id = - 1
151171 self ._errored = False
152172 self ._healing = False
153- self ._participating_replicas = 0
154173 self ._pending_work : List [torch .futures .Future [object ]] = []
155174 self ._batches_committed = 0
156175
157176 # first step is 1
158177 self ._should_step = True
178+ self ._participating_rank : Optional [int ] = None
179+ self ._participating_world_size : int = 0
159180
160181 def shutdown (self ) -> None :
161182 """
@@ -287,7 +308,7 @@ def step(self) -> None:
287308
288309 if self ._should_step :
289310 self ._step += 1
290- self ._batches_committed += self ._participating_replicas
311+ self ._batches_committed += self .num_participants ()
291312
292313 self ._errored = False
293314 self ._healing = False
@@ -311,25 +332,45 @@ def _async_quorum(self) -> None:
311332 (
312333 quorum_id ,
313334 replica_rank ,
314- replica_world ,
335+ replica_world_size ,
315336 address ,
316337 store_address ,
317338 max_step ,
318- num_max ,
339+ max_rank ,
340+ max_world_size ,
319341 heal ,
320342 ) = self ._client .quorum (
321343 rank = self ._rank ,
322344 step = self ._step ,
323345 checkpoint_server_addr = self ._ckpt_server .address (),
324346 )
325- self ._participating_replicas = (
326- num_max if self ._use_async_quorum else replica_world
347+
348+ # When using async quorum we need to take the recovered workers.
349+ # When not using async quorum we need to take the max world size as all
350+ # workers will be healthy.
351+ self ._participating_rank , self ._participating_world_size = (
352+ (max_rank , max_world_size )
353+ if self ._use_async_quorum
354+ else (replica_rank , replica_world_size )
327355 )
328356
357+ # For fixed with spares we need to ensure that we don't have more
358+ # participating replicas than the min replica size.
359+ if self ._world_size_mode == WorldSizeMode .FIXED_WITH_SPARES :
360+ self ._participating_world_size = min (
361+ self ._participating_world_size , self ._min_replica_size
362+ )
363+ if (
364+ self ._participating_rank is not None
365+ and self ._participating_rank >= self ._min_replica_size
366+ ):
367+ self ._participating_rank = None
368+
329369 if quorum_id != self ._quorum_id :
330370 logger .info (f"reconfiguring for quorum_id { quorum_id } " )
331371 store_prefixed_addr = f"{ store_address } /torchft/{ quorum_id } /{ self ._rank } "
332- self ._pg .configure (store_prefixed_addr , replica_rank , replica_world )
372+ # We use the replica rank and world as we want all replicas in the PG.
373+ self ._pg .configure (store_prefixed_addr , replica_rank , replica_world_size )
333374 self ._quorum_id = quorum_id
334375
335376 # See manager.rs for healing conditions
@@ -396,7 +437,7 @@ def should_commit(self) -> bool:
396437 if self ._healing :
397438 self ._apply_pending_state_dict ()
398439
399- enough_replicas = self ._participating_replicas >= self ._min_replica_size
440+ enough_replicas = self .num_participants () >= self ._min_replica_size
400441 local_should_commit = enough_replicas and not self ._errored
401442 should_commit = self ._client .should_commit (
402443 self ._rank , self ._step , local_should_commit
@@ -469,7 +510,8 @@ def num_participants(self) -> int:
469510 Returns:
470511 the number of participants in the current quorum
471512 """
472- return self ._participating_replicas
513+ assert self ._participating_world_size >= 0 , "internal error"
514+ return self ._participating_world_size
473515
474516 def is_participating (self ) -> bool :
475517 """
@@ -478,6 +520,8 @@ def is_participating(self) -> bool:
478520 Returns:
479521 whether this replica is participating in the current quorum
480522 """
523+ if self ._participating_rank is None :
524+ return False
481525 if self ._healing :
482526 assert self ._use_async_quorum
483527 return False
0 commit comments