Skip to content

Commit ab66c7c

Browse files
authored
manager: added FIXED_WITH_SPARES mode (#24)
1 parent 1d5464d commit ab66c7c

File tree

7 files changed

+151
-33
lines changed

7 files changed

+151
-33
lines changed

build.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// LICENSE file in the root directory of this source tree.
66

77
fn main() -> Result<(), Box<dyn std::error::Error>> {
8-
tonic_build::compile_protos("proto/torchft.proto")?;
8+
tonic_build::configure()
9+
.protoc_arg("--experimental_allow_proto3_optional")
10+
.compile_protos(&["proto/torchft.proto"], &["proto"])?;
911
Ok(())
1012
}

proto/torchft.proto

+7-4
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,14 @@ message ManagerQuorumResponse {
7878
int64 quorum_id = 1;
7979
string address = 2;
8080
string store_address = 3;
81+
// These are information for the replicas which are at the max step.
8182
int64 max_step = 4;
82-
int64 num_max = 5;
83-
int64 replica_rank = 6;
84-
int64 replica_world = 7;
85-
bool heal = 8;
83+
optional int64 max_rank = 5;
84+
int64 max_world_size = 6;
85+
// These are information for all replicas including behind replicas.
86+
int64 replica_rank = 7;
87+
int64 replica_world_size = 8;
88+
bool heal = 9;
8689
}
8790

8891
message CheckpointAddressRequest {

src/lib.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl ManagerClient {
102102
rank: i64,
103103
step: i64,
104104
checkpoint_server_addr: String,
105-
) -> PyResult<(i64, i64, i64, String, String, i64, i64, bool)> {
105+
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
106106
py.allow_threads(move || {
107107
let mut request = tonic::Request::new(ManagerQuorumRequest {
108108
rank: rank,
@@ -121,11 +121,12 @@ impl ManagerClient {
121121
Ok((
122122
resp.quorum_id,
123123
resp.replica_rank,
124-
resp.replica_world,
124+
resp.replica_world_size,
125125
resp.address,
126126
resp.store_address,
127127
resp.max_step,
128-
resp.num_max,
128+
resp.max_rank,
129+
resp.max_world_size,
129130
resp.heal,
130131
))
131132
})

src/manager.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ impl ManagerService for Arc<Manager> {
234234

235235
let primary = max_participants[rank as usize % max_participants.len()];
236236

237+
let mut max_rank = None;
238+
for (i, p) in max_participants.iter().enumerate() {
239+
if p.replica_id == self.replica_id {
240+
max_rank = Some(i as i64);
241+
break;
242+
}
243+
}
244+
237245
// Decide whether we should be healing:
238246
// 1. if we're not at the max step
239247
// 2. if everyone is at the first step and we're not the primary
@@ -251,9 +259,10 @@ impl ManagerService for Arc<Manager> {
251259
address: primary.address.clone(),
252260
store_address: primary.store_address.clone(),
253261
max_step: max_step,
254-
num_max: max_participants.len() as i64,
262+
max_rank: max_rank,
263+
max_world_size: max_participants.len() as i64,
255264
replica_rank: replica_rank as i64,
256-
replica_world: participants.len() as i64,
265+
replica_world_size: participants.len() as i64,
257266
heal: heal,
258267
};
259268

torchft/manager.py

+54-10
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
import logging
2929
import os
3030
import socket
31-
import time
3231
import uuid
3332
from concurrent.futures import ThreadPoolExecutor
3433
from datetime import timedelta
34+
from enum import Enum
3535
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
3636

3737
import torch
@@ -54,6 +54,24 @@
5454
T = TypeVar("T")
5555

5656

57+
class WorldSizeMode(Enum):
58+
"""
59+
This controls the numerics for the job when doing allreduces across replicas
60+
when the world size is larger than ``min_replica_size``. The world size will
61+
never be smaller than ``min_replica_size``.
62+
63+
DYNAMIC:
64+
The world size will dynamical increase to use all available
65+
replicas and normalize the gradient by the world size.
66+
FIXED_WITH_SPARES:
67+
The number of active replicas is ``min_replica_size`` and any spares
68+
will contribute zero gradients.
69+
"""
70+
71+
DYNAMIC = 0
72+
FIXED_WITH_SPARES = 1
73+
74+
5775
class Manager:
5876
"""
5977
Manager manages the full fault tolerant training loop.
@@ -73,6 +91,7 @@ def __init__(
7391
timeout: timedelta = timedelta(seconds=60),
7492
rank: Optional[int] = None,
7593
world_size: Optional[int] = None,
94+
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
7695
store_addr: Optional[str] = None,
7796
store_port: Optional[int] = None,
7897
lighthouse_addr: Optional[str] = None,
@@ -98,6 +117,7 @@ def __init__(
98117
self._pending_state_dict: Optional[Dict[str, object]] = None
99118
self._use_async_quorum = use_async_quorum
100119
self._timeout = timeout
120+
self._world_size_mode = world_size_mode
101121

102122
store_addr = store_addr or os.environ["MASTER_ADDR"]
103123
store_port = store_port or int(os.environ["MASTER_PORT"])
@@ -150,12 +170,13 @@ def __init__(
150170
self._quorum_id = -1
151171
self._errored = False
152172
self._healing = False
153-
self._participating_replicas = 0
154173
self._pending_work: List[torch.futures.Future[object]] = []
155174
self._batches_committed = 0
156175

157176
# first step is 1
158177
self._should_step = True
178+
self._participating_rank: Optional[int] = None
179+
self._participating_world_size: int = 0
159180

160181
def shutdown(self) -> None:
161182
"""
@@ -287,7 +308,7 @@ def step(self) -> None:
287308

288309
if self._should_step:
289310
self._step += 1
290-
self._batches_committed += self._participating_replicas
311+
self._batches_committed += self.num_participants()
291312

292313
self._errored = False
293314
self._healing = False
@@ -311,25 +332,45 @@ def _async_quorum(self) -> None:
311332
(
312333
quorum_id,
313334
replica_rank,
314-
replica_world,
335+
replica_world_size,
315336
address,
316337
store_address,
317338
max_step,
318-
num_max,
339+
max_rank,
340+
max_world_size,
319341
heal,
320342
) = self._client.quorum(
321343
rank=self._rank,
322344
step=self._step,
323345
checkpoint_server_addr=self._ckpt_server.address(),
324346
)
325-
self._participating_replicas = (
326-
num_max if self._use_async_quorum else replica_world
347+
348+
# When using async quorum we need to take the recovered workers.
349+
# When not using async quorum we need to take the max world size as all
350+
# workers will be healthy.
351+
self._participating_rank, self._participating_world_size = (
352+
(max_rank, max_world_size)
353+
if self._use_async_quorum
354+
else (replica_rank, replica_world_size)
327355
)
328356

357+
# For fixed with spares we need to ensure that we don't have more
358+
# participating replicas than the min replica size.
359+
if self._world_size_mode == WorldSizeMode.FIXED_WITH_SPARES:
360+
self._participating_world_size = min(
361+
self._participating_world_size, self._min_replica_size
362+
)
363+
if (
364+
self._participating_rank is not None
365+
and self._participating_rank >= self._min_replica_size
366+
):
367+
self._participating_rank = None
368+
329369
if quorum_id != self._quorum_id:
330370
logger.info(f"reconfiguring for quorum_id {quorum_id}")
331371
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
332-
self._pg.configure(store_prefixed_addr, replica_rank, replica_world)
372+
# We use the replica rank and world as we want all replicas in the PG.
373+
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
333374
self._quorum_id = quorum_id
334375

335376
# See manager.rs for healing conditions
@@ -396,7 +437,7 @@ def should_commit(self) -> bool:
396437
if self._healing:
397438
self._apply_pending_state_dict()
398439

399-
enough_replicas = self._participating_replicas >= self._min_replica_size
440+
enough_replicas = self.num_participants() >= self._min_replica_size
400441
local_should_commit = enough_replicas and not self._errored
401442
should_commit = self._client.should_commit(
402443
self._rank, self._step, local_should_commit
@@ -469,7 +510,8 @@ def num_participants(self) -> int:
469510
Returns:
470511
the number of participants in the current quorum
471512
"""
472-
return self._participating_replicas
513+
assert self._participating_world_size >= 0, "internal error"
514+
return self._participating_world_size
473515

474516
def is_participating(self) -> bool:
475517
"""
@@ -478,6 +520,8 @@ def is_participating(self) -> bool:
478520
Returns:
479521
whether this replica is participating in the current quorum
480522
"""
523+
if self._participating_rank is None:
524+
return False
481525
if self._healing:
482526
assert self._use_async_quorum
483527
return False

0 commit comments

Comments
 (0)