Skip to content

Commit 9878980

Browse files
authored
manager_integ_tests: added recovery test (#28)
1 parent 7b93da7 commit 9878980

File tree

3 files changed

+174
-72
lines changed

3 files changed

+174
-72
lines changed

src/lib.rs

+15-13
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ impl Manager {
4444
bind: String,
4545
store_addr: String,
4646
world_size: u64,
47-
) -> Self {
47+
) -> PyResult<Self> {
4848
py.allow_threads(move || {
49-
let runtime = Runtime::new().unwrap();
49+
let runtime = Runtime::new()?;
5050
let manager = runtime
5151
.block_on(manager::Manager::new(
5252
replica_id,
@@ -56,13 +56,13 @@ impl Manager {
5656
store_addr,
5757
world_size,
5858
))
59-
.unwrap();
59+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
6060
let handle = runtime.spawn(manager.clone().run());
61-
Self {
61+
Ok(Self {
6262
handle: handle,
6363
manager: manager,
6464
_runtime: runtime,
65-
}
65+
})
6666
})
6767
}
6868

@@ -89,7 +89,7 @@ impl ManagerClient {
8989
#[new]
9090
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
9191
py.allow_threads(move || {
92-
let runtime = Runtime::new().unwrap();
92+
let runtime = Runtime::new()?;
9393
let client = runtime
9494
.block_on(manager::manager_client_new(addr, timeout))
9595
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -193,14 +193,16 @@ fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
193193
}
194194

195195
#[pyfunction]
196-
fn lighthouse_main(py: Python<'_>) {
197-
reset_python_signals(py).unwrap();
196+
fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
197+
reset_python_signals(py)?;
198198

199199
let mut args = env::args();
200200
args.next(); // discard binary arg
201201
let opt = lighthouse::LighthouseOpt::from_iter(args);
202-
let rt = Runtime::new().unwrap();
203-
rt.block_on(lighthouse_main_async(opt)).unwrap();
202+
let rt = Runtime::new()?;
203+
rt.block_on(lighthouse_main_async(opt))
204+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
205+
Ok(())
204206
}
205207

206208
async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
@@ -223,7 +225,7 @@ impl Lighthouse {
223225
#[new]
224226
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
225227
py.allow_threads(move || {
226-
let rt = Runtime::new().unwrap();
228+
let rt = Runtime::new()?;
227229

228230
let lighthouse = rt
229231
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
@@ -232,7 +234,7 @@ impl Lighthouse {
232234
join_timeout_ms: 100,
233235
quorum_tick_ms: 100,
234236
}))
235-
.unwrap();
237+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
236238

237239
Ok(Self {
238240
handle: rt.spawn(lighthouse.clone().run()),
@@ -261,7 +263,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
261263
.show_module_names(true)
262264
.timestamp(stderrlog::Timestamp::Millisecond)
263265
.init()
264-
.unwrap();
266+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
265267

266268
m.add_class::<Manager>()?;
267269
m.add_class::<ManagerClient>()?;

torchft/manager.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
wait_for_workers=False,
141141
)
142142
self._pg = pg
143+
self._manager = None
143144

144145
if rank == 0:
145146
hostname = socket.gethostname()
@@ -148,7 +149,8 @@ def __init__(
148149
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
149150

150151
if replica_id is None:
151-
replica_id = str(uuid.uuid4())
152+
replica_id = ""
153+
replica_id = replica_id + str(uuid.uuid4())
152154
self._manager = _Manager(
153155
replica_id=replica_id,
154156
lighthouse_addr=lighthouse_addr,
@@ -180,6 +182,8 @@ def shutdown(self) -> None:
180182
Shutdown the manager and checkpoint server.
181183
"""
182184
self._ckpt_server.shutdown()
185+
if self._manager is not None:
186+
self._manager.shutdown()
183187

184188
def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
185189
"""
@@ -364,7 +368,7 @@ def _async_quorum(self) -> None:
364368
self._participating_rank = None
365369

366370
if quorum_id != self._quorum_id:
367-
logger.info(f"reconfiguring for quorum_id {quorum_id}")
371+
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
368372
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
369373
# We use the replica rank and world as we want all replicas in the PG.
370374
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
@@ -373,7 +377,7 @@ def _async_quorum(self) -> None:
373377
# See manager.rs for healing conditions
374378
if heal:
375379
self._healing = True
376-
logger.info("healing required")
380+
logger.info(f"{replica_rank}= healing required")
377381

378382
logger.info(f"fetching checkpoint server address from {address}")
379383
primary_client = ManagerClient(address, timeout=self._timeout)

torchft/manager_integ_test.py

+152-56
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from concurrent.futures import ThreadPoolExecutor, as_completed
2+
from contextlib import ExitStack
3+
from typing import Set, Tuple
24
from unittest import TestCase
35

46
import torch
@@ -24,63 +26,108 @@ def forward(self, x):
2426
return self.model(x)
2527

2628

27-
def train_loop(replica_id: int, lighthouse_address: str) -> None:
28-
store = dist.TCPStore(
29-
host_name="localhost",
30-
port=0,
31-
is_master=True,
32-
wait_for_workers=False,
33-
)
34-
35-
def load_state_dict(state_dict):
36-
m.load_state_dict(state_dict["model"])
37-
optimizer.load_state_dict(state_dict["optim"])
38-
39-
def state_dict():
40-
return {
41-
"model": m.state_dict(),
42-
"optim": optimizer.state_dict(),
43-
}
44-
45-
pg = ProcessGroupGloo()
46-
manager = Manager(
47-
pg=pg,
48-
min_replica_size=2,
49-
load_state_dict=load_state_dict,
50-
state_dict=state_dict,
51-
replica_id=str(replica_id),
52-
store_addr="localhost",
53-
store_port=store.port,
54-
rank=0,
55-
world_size=1,
56-
lighthouse_addr=lighthouse_address,
57-
port=19530 + replica_id,
58-
)
59-
m = DistributedDataParallel(manager, MyModel())
60-
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
61-
criterion = nn.CrossEntropyLoss()
62-
63-
while True:
64-
inputs = torch.rand(2, 3)
65-
labels = torch.randint(4, (2,))
66-
67-
optimizer.zero_grad()
68-
out = m(inputs)
69-
loss = criterion(out, labels)
70-
71-
loss.backward()
72-
optimizer.step()
73-
74-
# TODO: assert weights are equal across replicas
75-
76-
if manager.current_step() >= 5:
77-
break
78-
79-
manager.shutdown()
29+
class InjectedFailure(Exception):
30+
pass
31+
32+
33+
class FailureInjector:
34+
def __init__(self) -> None:
35+
self._failures: Set[int] = set()
36+
self.count = 0
37+
38+
def fail_at(self, step: int) -> "FailureInjector":
39+
self._failures.add(step)
40+
return self
41+
42+
def check(self, step: int) -> None:
43+
if step in self._failures:
44+
self.count += 1
45+
self._failures.remove(step)
46+
print(f"injecting failure {step=}")
47+
raise InjectedFailure(f"injected failure {step=}")
48+
49+
50+
def worker_manager(
51+
replica_id: int,
52+
lighthouse_address: str,
53+
failure_injector: FailureInjector,
54+
attempts: int = 3,
55+
) -> None:
56+
for i in range(attempts):
57+
try:
58+
print(f"starting worker {replica_id} attempt {i}")
59+
return train_loop(
60+
replica_id, lighthouse_address, failure_injector=failure_injector
61+
)
62+
except InjectedFailure as e:
63+
print("got injected failure", i, e)
64+
if i == attempts - 1:
65+
raise
66+
continue
67+
68+
69+
def train_loop(
70+
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
71+
) -> None:
72+
with ExitStack() as stack:
73+
store = dist.TCPStore(
74+
host_name="localhost",
75+
port=0,
76+
is_master=True,
77+
wait_for_workers=False,
78+
)
79+
80+
def load_state_dict(state_dict):
81+
m.load_state_dict(state_dict["model"])
82+
optimizer.load_state_dict(state_dict["optim"])
83+
84+
def state_dict():
85+
return {
86+
"model": m.state_dict(),
87+
"optim": optimizer.state_dict(),
88+
}
89+
90+
pg = ProcessGroupGloo()
91+
manager = Manager(
92+
pg=pg,
93+
min_replica_size=2,
94+
load_state_dict=load_state_dict,
95+
state_dict=state_dict,
96+
replica_id=str(replica_id),
97+
store_addr="localhost",
98+
store_port=store.port,
99+
rank=0,
100+
world_size=1,
101+
lighthouse_addr=lighthouse_address,
102+
port=19530 + replica_id,
103+
)
104+
stack.callback(manager.shutdown)
105+
106+
m = DistributedDataParallel(manager, MyModel())
107+
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
108+
criterion = nn.CrossEntropyLoss()
109+
110+
while True:
111+
print(f"worker {replica_id} starting step {manager.current_step()}")
112+
inputs = torch.rand(2, 3)
113+
labels = torch.randint(4, (2,))
114+
115+
optimizer.zero_grad()
116+
out = m(inputs)
117+
loss = criterion(out, labels)
118+
119+
loss.backward()
120+
optimizer.step()
121+
122+
if manager.current_step() >= 5:
123+
# return state_dict so we can check consistency
124+
return state_dict()
125+
126+
failure_injector.check(manager.current_step())
80127

81128

82129
class ManagerIntegTest(TestCase):
83-
def test_ddp(self):
130+
def test_ddp_healthy(self):
84131
lighthouse = Lighthouse(
85132
bind="[::]:0",
86133
min_replicas=2,
@@ -90,11 +137,60 @@ def test_ddp(self):
90137

91138
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
92139
for replica_id in range(num_replicas):
140+
failure_injector = FailureInjector()
141+
futures.append(
142+
executor.submit(
143+
worker_manager,
144+
replica_id,
145+
lighthouse.address(),
146+
failure_injector=failure_injector,
147+
)
148+
)
149+
150+
state_dicts = []
151+
152+
for fut in as_completed(futures):
153+
state_dicts.append(fut.result())
154+
155+
lighthouse.shutdown()
156+
157+
for state_dict in state_dicts:
158+
torch.testing.assert_close(state_dict, state_dicts[0])
159+
160+
def test_ddp_recovery(self):
161+
lighthouse = Lighthouse(
162+
bind="[::]:0",
163+
min_replicas=2,
164+
)
165+
num_replicas = 2
166+
futures = []
167+
168+
failure_injectors = [
169+
FailureInjector(),
170+
FailureInjector().fail_at(2),
171+
]
172+
173+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
174+
for replica_id, failure_injector in zip(
175+
range(num_replicas), failure_injectors
176+
):
93177
futures.append(
94-
executor.submit(train_loop, replica_id, lighthouse.address())
178+
executor.submit(
179+
worker_manager,
180+
replica_id,
181+
lighthouse.address(),
182+
failure_injector=failure_injector,
183+
)
95184
)
96185

186+
state_dicts = []
187+
97188
for fut in as_completed(futures):
98-
fut.result()
189+
state_dicts.append(fut.result())
99190

100191
lighthouse.shutdown()
192+
193+
for state_dict in state_dicts:
194+
torch.testing.assert_close(state_dict, state_dicts[0])
195+
196+
self.assertEqual(failure_injectors[1].count, 1)

0 commit comments

Comments
 (0)