Skip to content

Commit da0b802

Browse files
committed
Refactor local_sgd integration tests
1 parent 87290f5 commit da0b802

File tree

2 files changed

+320
-297
lines changed

2 files changed

+320
-297
lines changed

torchft/local_sgd_integ_test.py

+320
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
import copy
2+
import logging
3+
import threading
4+
import time
5+
from concurrent.futures import as_completed, ThreadPoolExecutor
6+
from contextlib import contextmanager, ExitStack
7+
from dataclasses import dataclass, field
8+
from datetime import timedelta
9+
from typing import Any, Dict, Generator, List, Protocol, Set, Tuple
10+
from unittest import TestCase
11+
12+
import torch
13+
from torch import nn, optim
14+
15+
from torchft.local_sgd import DiLoCo, LocalSGD
16+
from torchft.manager import Manager
17+
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
18+
from torchft.process_group import ProcessGroupGloo
19+
from torchft.torchft import Lighthouse
20+
21+
logger: logging.Logger = logging.getLogger(__name__)
22+
23+
24+
def local_sgd_train_loop(
25+
rank: int,
26+
store_port: int,
27+
runner: Runner,
28+
) -> Dict[str, Dict[str, object]]:
29+
with ExitStack() as stack:
30+
31+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
32+
m.load_state_dict(state_dict["model"])
33+
optimizer.load_state_dict(state_dict["optim"])
34+
35+
def state_dict() -> Dict[str, Dict[str, object]]:
36+
return {
37+
"model": m.state_dict(),
38+
"optim": optimizer.state_dict(),
39+
}
40+
41+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
42+
43+
pg = ProcessGroupGloo()
44+
manager = Manager(
45+
pg=pg,
46+
min_replica_size=2,
47+
load_state_dict=load_state_dict,
48+
state_dict=state_dict,
49+
replica_id=str(runner.replica_id),
50+
store_addr="localhost",
51+
store_port=store_port,
52+
rank=rank,
53+
world_size=runner.world_size,
54+
lighthouse_addr=runner.lighthouse_address,
55+
port=19530 + runner.replica_id,
56+
# pyre-fixme[6]: Incompatible parameter type
57+
**runner.manager_args,
58+
)
59+
stack.callback(lambda: manager.shutdown(wait=False))
60+
61+
m: nn.Module = MyModel()
62+
optimizer: optim.Optimizer = optim.Adam(m.parameters())
63+
criterion = nn.CrossEntropyLoss()
64+
65+
with LocalSGD(manager, m, optimizer, sync_every=2):
66+
while True:
67+
inputs = torch.rand(2, 3)
68+
labels = torch.randint(4, (2,))
69+
70+
optimizer.zero_grad()
71+
out = m(inputs)
72+
loss = criterion(out, labels)
73+
74+
loss.backward()
75+
76+
optimizer.step()
77+
78+
if manager.current_step() >= 4:
79+
break
80+
81+
runner.failure_injector.check(rank, manager.current_step())
82+
83+
# return state_dict so we can check consistency
84+
return state_dict()
85+
86+
87+
def diloco_train_loop(
88+
rank: int,
89+
store_port: int,
90+
runner: Runner,
91+
) -> Dict[str, Dict[str, object]]:
92+
with ExitStack() as stack:
93+
# Declare the model and optimizers
94+
m: nn.Module = MyModel()
95+
model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
96+
m.load_state_dict(model_state_dict)
97+
98+
# Setup optimizers
99+
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
100+
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
101+
)
102+
outer_optimizer: optim.Optimizer = torch.optim.SGD(
103+
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
104+
)
105+
106+
# pyre-ignore[53]
107+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
108+
m.load_state_dict(state_dict["model"])
109+
# TODO: make this cleaner so we don't have to save this
110+
diloco._backup_parameters = state_dict["backup_params"]
111+
inner_optimizer.load_state_dict(state_dict["inner_optim"])
112+
outer_optimizer.load_state_dict(state_dict["outer_optim"])
113+
114+
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
115+
return {
116+
"model": m.state_dict(),
117+
"backup_params": copy.deepcopy(diloco._backup_parameters),
118+
"inner_optim": inner_optimizer.state_dict(),
119+
"outer_optim": outer_optimizer.state_dict(),
120+
}
121+
122+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
123+
124+
pg = ProcessGroupGloo()
125+
manager = Manager(
126+
pg=pg,
127+
min_replica_size=2,
128+
use_async_quorum=False,
129+
load_state_dict=load_state_dict,
130+
state_dict=state_dict,
131+
replica_id=str(runner.replica_id),
132+
store_addr="localhost",
133+
store_port=store_port,
134+
rank=rank,
135+
world_size=runner.world_size,
136+
lighthouse_addr=runner.lighthouse_address,
137+
port=19530 + runner.replica_id,
138+
# pyre-fixme[6]: Incompatible parameter type
139+
**runner.manager_args,
140+
)
141+
stack.callback(manager.shutdown)
142+
143+
criterion = nn.CrossEntropyLoss()
144+
all_state_dicts = {}
145+
with DiLoCo(
146+
manager, m, inner_optimizer, outer_optimizer, sync_every=2
147+
) as diloco:
148+
while True:
149+
inputs = torch.rand(2, 3)
150+
labels = torch.randint(4, (2,))
151+
152+
out = m(inputs)
153+
loss = criterion(out, labels)
154+
155+
inner_optimizer.zero_grad()
156+
loss.backward()
157+
inner_optimizer.step()
158+
manager_step_str = str(manager.current_step())
159+
all_state_dicts[manager_step_str] = state_dict()
160+
161+
# after 4 model updates then break
162+
if manager.current_step() >= 4:
163+
break
164+
165+
runner.failure_injector.check(rank, manager.current_step())
166+
167+
# return state_dict so we can check consistency
168+
return all_state_dicts
169+
170+
171+
class ManagerIntegTest(TestCase):
172+
def test_local_sgd_recovery(self) -> None:
173+
lighthouse = Lighthouse(
174+
bind="[::]:0",
175+
min_replicas=2,
176+
)
177+
num_replicas = 2
178+
futures = []
179+
180+
failure_injectors = [
181+
FailureInjector(),
182+
FailureInjector().fail_at(0, 2),
183+
]
184+
185+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
186+
for replica_id, failure_injector in zip(
187+
range(num_replicas), failure_injectors
188+
):
189+
runner = Runner(
190+
replica_id=replica_id,
191+
lighthouse_address=lighthouse.address(),
192+
failure_injector=failure_injector,
193+
train_loop=local_sgd_train_loop,
194+
manager_args={
195+
"use_async_quorum": False,
196+
},
197+
)
198+
futures.append(executor.submit(runner.run_replica))
199+
200+
state_dicts = []
201+
202+
for fut in as_completed(futures):
203+
try:
204+
state_dicts.append(fut.result())
205+
except Exception as e:
206+
print(e)
207+
raise
208+
209+
lighthouse.shutdown()
210+
211+
for state_dict in state_dicts:
212+
# LocalSGD only guarantees that the model is consistent across
213+
# replicas but uses separate optimizer states.
214+
torch.testing.assert_close(
215+
state_dict[0]["model"], state_dicts[0][0]["model"]
216+
)
217+
218+
self.assertEqual(failure_injectors[1].count, 1)
219+
220+
def test_diloco_healthy(self) -> None:
221+
lighthouse = Lighthouse(
222+
bind="[::]:0",
223+
min_replicas=2,
224+
)
225+
num_replicas = 2
226+
futures = []
227+
228+
torch.manual_seed(42)
229+
# Initialize the model so we can pass in the state_dict
230+
m: nn.Module = MyModel()
231+
232+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
233+
for replica_id in range(num_replicas):
234+
failure_injector = FailureInjector()
235+
runner = Runner(
236+
replica_id=replica_id,
237+
lighthouse_address=lighthouse.address(),
238+
failure_injector=failure_injector,
239+
train_loop=diloco_train_loop,
240+
train_loop_args={
241+
"model_state_dict": m.state_dict(),
242+
},
243+
)
244+
futures.append(executor.submit(runner.run_replica))
245+
246+
state_dicts = []
247+
248+
for fut in as_completed(futures):
249+
state_dicts.append(fut.result()[0])
250+
251+
lighthouse.shutdown()
252+
253+
for replica_group in state_dicts:
254+
for step, state_dict in replica_group.items():
255+
# inner optimizer will be different, outer optimizer and model should be the same
256+
torch.testing.assert_close(
257+
state_dict["backup_params"],
258+
state_dicts[0][str(step)]["backup_params"],
259+
)
260+
torch.testing.assert_close(
261+
state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"]
262+
)
263+
264+
def test_diloco_recovery(self) -> None:
265+
lighthouse = Lighthouse(
266+
bind="[::]:0",
267+
min_replicas=2,
268+
)
269+
num_replicas = 2
270+
futures = []
271+
272+
failure_injectors = [
273+
FailureInjector(),
274+
FailureInjector().fail_at(0, 2),
275+
]
276+
277+
torch.manual_seed(42)
278+
# Initialize the model so we can pass in the state_dict
279+
m: nn.Module = MyModel()
280+
281+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
282+
for replica_id, failure_injector in zip(
283+
range(num_replicas), failure_injectors
284+
):
285+
runner = Runner(
286+
replica_id=replica_id,
287+
lighthouse_address=lighthouse.address(),
288+
failure_injector=failure_injector,
289+
train_loop=diloco_train_loop,
290+
train_loop_args={
291+
"model_state_dict": m.state_dict(),
292+
},
293+
)
294+
futures.append(executor.submit(runner.run_replica))
295+
296+
state_dicts = []
297+
298+
for fut in as_completed(futures):
299+
try:
300+
state_dicts.append(fut.result()[0])
301+
except Exception as e:
302+
print(e)
303+
raise
304+
305+
lighthouse.shutdown()
306+
for replica_group in state_dicts:
307+
for step, state_dict in replica_group.items():
308+
str_step = str(step)
309+
if str_step in state_dicts[0]:
310+
# inner optimizer will be different, outer optimizer and model should be the same
311+
torch.testing.assert_close(
312+
state_dict["backup_params"],
313+
state_dicts[0][str_step]["backup_params"],
314+
)
315+
torch.testing.assert_close(
316+
state_dict["outer_optim"],
317+
state_dicts[0][str_step]["outer_optim"],
318+
)
319+
320+
self.assertEqual(failure_injectors[1].count, 1)

0 commit comments

Comments
 (0)