28
28
import logging
29
29
import os
30
30
import socket
31
- import time
32
31
import uuid
33
32
from concurrent .futures import ThreadPoolExecutor
34
33
from datetime import timedelta
34
+ from enum import Enum
35
35
from typing import TYPE_CHECKING , Callable , Dict , List , Optional , TypeVar , cast
36
36
37
37
import torch
54
54
T = TypeVar ("T" )
55
55
56
56
57
+ class WorldSizeMode (Enum ):
58
+ """
59
+ This controls the numerics for the job when doing allreduces across replicas
60
+ when the world size is larger than ``min_replica_size``. The world size will
61
+ never be smaller than ``min_replica_size``.
62
+
63
+ DYNAMIC:
64
+ The world size will dynamical increase to use all available
65
+ replicas and normalize the gradient by the world size.
66
+ FIXED_WITH_SPARES:
67
+ The number of active replicas is ``min_replica_size`` and any spares
68
+ will contribute zero gradients.
69
+ """
70
+
71
+ DYNAMIC = 0
72
+ FIXED_WITH_SPARES = 1
73
+
74
+
57
75
class Manager :
58
76
"""
59
77
Manager manages the full fault tolerant training loop.
@@ -73,6 +91,7 @@ def __init__(
73
91
timeout : timedelta = timedelta (seconds = 60 ),
74
92
rank : Optional [int ] = None ,
75
93
world_size : Optional [int ] = None ,
94
+ world_size_mode : WorldSizeMode = WorldSizeMode .DYNAMIC ,
76
95
store_addr : Optional [str ] = None ,
77
96
store_port : Optional [int ] = None ,
78
97
lighthouse_addr : Optional [str ] = None ,
@@ -98,6 +117,7 @@ def __init__(
98
117
self ._pending_state_dict : Optional [Dict [str , object ]] = None
99
118
self ._use_async_quorum = use_async_quorum
100
119
self ._timeout = timeout
120
+ self ._world_size_mode = world_size_mode
101
121
102
122
store_addr = store_addr or os .environ ["MASTER_ADDR" ]
103
123
store_port = store_port or int (os .environ ["MASTER_PORT" ])
@@ -150,12 +170,13 @@ def __init__(
150
170
self ._quorum_id = - 1
151
171
self ._errored = False
152
172
self ._healing = False
153
- self ._participating_replicas = 0
154
173
self ._pending_work : List [torch .futures .Future [object ]] = []
155
174
self ._batches_committed = 0
156
175
157
176
# first step is 1
158
177
self ._should_step = True
178
+ self ._participating_rank : Optional [int ] = None
179
+ self ._participating_world_size : int = 0
159
180
160
181
def shutdown (self ) -> None :
161
182
"""
@@ -287,7 +308,7 @@ def step(self) -> None:
287
308
288
309
if self ._should_step :
289
310
self ._step += 1
290
- self ._batches_committed += self ._participating_replicas
311
+ self ._batches_committed += self .num_participants ()
291
312
292
313
self ._errored = False
293
314
self ._healing = False
@@ -311,25 +332,45 @@ def _async_quorum(self) -> None:
311
332
(
312
333
quorum_id ,
313
334
replica_rank ,
314
- replica_world ,
335
+ replica_world_size ,
315
336
address ,
316
337
store_address ,
317
338
max_step ,
318
- num_max ,
339
+ max_rank ,
340
+ max_world_size ,
319
341
heal ,
320
342
) = self ._client .quorum (
321
343
rank = self ._rank ,
322
344
step = self ._step ,
323
345
checkpoint_server_addr = self ._ckpt_server .address (),
324
346
)
325
- self ._participating_replicas = (
326
- num_max if self ._use_async_quorum else replica_world
347
+
348
+ # When using async quorum we need to take the recovered workers.
349
+ # When not using async quorum we need to take the max world size as all
350
+ # workers will be healthy.
351
+ self ._participating_rank , self ._participating_world_size = (
352
+ (max_rank , max_world_size )
353
+ if self ._use_async_quorum
354
+ else (replica_rank , replica_world_size )
327
355
)
328
356
357
+ # For fixed with spares we need to ensure that we don't have more
358
+ # participating replicas than the min replica size.
359
+ if self ._world_size_mode == WorldSizeMode .FIXED_WITH_SPARES :
360
+ self ._participating_world_size = min (
361
+ self ._participating_world_size , self ._min_replica_size
362
+ )
363
+ if (
364
+ self ._participating_rank is not None
365
+ and self ._participating_rank >= self ._min_replica_size
366
+ ):
367
+ self ._participating_rank = None
368
+
329
369
if quorum_id != self ._quorum_id :
330
370
logger .info (f"reconfiguring for quorum_id { quorum_id } " )
331
371
store_prefixed_addr = f"{ store_address } /torchft/{ quorum_id } /{ self ._rank } "
332
- self ._pg .configure (store_prefixed_addr , replica_rank , replica_world )
372
+ # We use the replica rank and world as we want all replicas in the PG.
373
+ self ._pg .configure (store_prefixed_addr , replica_rank , replica_world_size )
333
374
self ._quorum_id = quorum_id
334
375
335
376
# See manager.rs for healing conditions
@@ -396,7 +437,7 @@ def should_commit(self) -> bool:
396
437
if self ._healing :
397
438
self ._apply_pending_state_dict ()
398
439
399
- enough_replicas = self ._participating_replicas >= self ._min_replica_size
440
+ enough_replicas = self .num_participants () >= self ._min_replica_size
400
441
local_should_commit = enough_replicas and not self ._errored
401
442
should_commit = self ._client .should_commit (
402
443
self ._rank , self ._step , local_should_commit
@@ -469,7 +510,8 @@ def num_participants(self) -> int:
469
510
Returns:
470
511
the number of participants in the current quorum
471
512
"""
472
- return self ._participating_replicas
513
+ assert self ._participating_world_size >= 0 , "internal error"
514
+ return self ._participating_world_size
473
515
474
516
def is_participating (self ) -> bool :
475
517
"""
@@ -478,6 +520,8 @@ def is_participating(self) -> bool:
478
520
Returns:
479
521
whether this replica is participating in the current quorum
480
522
"""
523
+ if self ._participating_rank is None :
524
+ return False
481
525
if self ._healing :
482
526
assert self ._use_async_quorum
483
527
return False
0 commit comments