Skip to content

Commit 927f8b4

Browse files
authored
Refactor local_sgd integration tests (#96)
1 parent 118d1a2 commit 927f8b4

File tree

2 files changed

+316
-297
lines changed

2 files changed

+316
-297
lines changed

torchft/local_sgd_integ_test.py

+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import copy
2+
import logging
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from contextlib import ExitStack
5+
from typing import Any, Dict
6+
from unittest import TestCase
7+
8+
import torch
9+
from torch import nn, optim
10+
11+
from torchft.local_sgd import DiLoCo, LocalSGD
12+
from torchft.manager import Manager
13+
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
14+
from torchft.process_group import ProcessGroupGloo
15+
from torchft.torchft import Lighthouse
16+
17+
logger: logging.Logger = logging.getLogger(__name__)
18+
19+
20+
def local_sgd_train_loop(
21+
rank: int,
22+
store_port: int,
23+
runner: Runner,
24+
) -> Dict[str, Dict[str, object]]:
25+
with ExitStack() as stack:
26+
27+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
28+
m.load_state_dict(state_dict["model"])
29+
optimizer.load_state_dict(state_dict["optim"])
30+
31+
def state_dict() -> Dict[str, Dict[str, object]]:
32+
return {
33+
"model": m.state_dict(),
34+
"optim": optimizer.state_dict(),
35+
}
36+
37+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
38+
39+
pg = ProcessGroupGloo()
40+
manager = Manager(
41+
pg=pg,
42+
min_replica_size=2,
43+
load_state_dict=load_state_dict,
44+
state_dict=state_dict,
45+
replica_id=str(runner.replica_id),
46+
store_addr="localhost",
47+
store_port=store_port,
48+
rank=rank,
49+
world_size=runner.world_size,
50+
lighthouse_addr=runner.lighthouse_address,
51+
port=19530 + runner.replica_id,
52+
# pyre-fixme[6]: Incompatible parameter type
53+
**runner.manager_args,
54+
)
55+
stack.callback(lambda: manager.shutdown(wait=False))
56+
57+
m: nn.Module = MyModel()
58+
optimizer: optim.Optimizer = optim.Adam(m.parameters())
59+
criterion = nn.CrossEntropyLoss()
60+
61+
with LocalSGD(manager, m, optimizer, sync_every=2):
62+
while True:
63+
inputs = torch.rand(2, 3)
64+
labels = torch.randint(4, (2,))
65+
66+
optimizer.zero_grad()
67+
out = m(inputs)
68+
loss = criterion(out, labels)
69+
70+
loss.backward()
71+
72+
optimizer.step()
73+
74+
if manager.current_step() >= 4:
75+
break
76+
77+
runner.failure_injector.check(rank, manager.current_step())
78+
79+
# return state_dict so we can check consistency
80+
return state_dict()
81+
82+
83+
def diloco_train_loop(
84+
rank: int,
85+
store_port: int,
86+
runner: Runner,
87+
) -> Dict[str, Dict[str, object]]:
88+
with ExitStack() as stack:
89+
# Declare the model and optimizers
90+
m: nn.Module = MyModel()
91+
model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
92+
m.load_state_dict(model_state_dict)
93+
94+
# Setup optimizers
95+
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
96+
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
97+
)
98+
outer_optimizer: optim.Optimizer = torch.optim.SGD(
99+
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
100+
)
101+
102+
# pyre-ignore[53]
103+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
104+
m.load_state_dict(state_dict["model"])
105+
# TODO: make this cleaner so we don't have to save this
106+
diloco._backup_parameters = state_dict["backup_params"]
107+
inner_optimizer.load_state_dict(state_dict["inner_optim"])
108+
outer_optimizer.load_state_dict(state_dict["outer_optim"])
109+
110+
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
111+
return {
112+
"model": m.state_dict(),
113+
"backup_params": copy.deepcopy(diloco._backup_parameters),
114+
"inner_optim": inner_optimizer.state_dict(),
115+
"outer_optim": outer_optimizer.state_dict(),
116+
}
117+
118+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
119+
120+
pg = ProcessGroupGloo()
121+
manager = Manager(
122+
pg=pg,
123+
min_replica_size=2,
124+
use_async_quorum=False,
125+
load_state_dict=load_state_dict,
126+
state_dict=state_dict,
127+
replica_id=str(runner.replica_id),
128+
store_addr="localhost",
129+
store_port=store_port,
130+
rank=rank,
131+
world_size=runner.world_size,
132+
lighthouse_addr=runner.lighthouse_address,
133+
port=19530 + runner.replica_id,
134+
# pyre-fixme[6]: Incompatible parameter type
135+
**runner.manager_args,
136+
)
137+
stack.callback(manager.shutdown)
138+
139+
criterion = nn.CrossEntropyLoss()
140+
all_state_dicts = {}
141+
with DiLoCo(
142+
manager, m, inner_optimizer, outer_optimizer, sync_every=2
143+
) as diloco:
144+
while True:
145+
inputs = torch.rand(2, 3)
146+
labels = torch.randint(4, (2,))
147+
148+
out = m(inputs)
149+
loss = criterion(out, labels)
150+
151+
inner_optimizer.zero_grad()
152+
loss.backward()
153+
inner_optimizer.step()
154+
manager_step_str = str(manager.current_step())
155+
all_state_dicts[manager_step_str] = state_dict()
156+
157+
# after 4 model updates then break
158+
if manager.current_step() >= 4:
159+
break
160+
161+
runner.failure_injector.check(rank, manager.current_step())
162+
163+
# return state_dict so we can check consistency
164+
return all_state_dicts
165+
166+
167+
class ManagerIntegTest(TestCase):
168+
def test_local_sgd_recovery(self) -> None:
169+
lighthouse = Lighthouse(
170+
bind="[::]:0",
171+
min_replicas=2,
172+
)
173+
num_replicas = 2
174+
futures = []
175+
176+
failure_injectors = [
177+
FailureInjector(),
178+
FailureInjector().fail_at(0, 2),
179+
]
180+
181+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
182+
for replica_id, failure_injector in zip(
183+
range(num_replicas), failure_injectors
184+
):
185+
runner = Runner(
186+
replica_id=replica_id,
187+
lighthouse_address=lighthouse.address(),
188+
failure_injector=failure_injector,
189+
train_loop=local_sgd_train_loop,
190+
manager_args={
191+
"use_async_quorum": False,
192+
},
193+
)
194+
futures.append(executor.submit(runner.run_replica))
195+
196+
state_dicts = []
197+
198+
for fut in as_completed(futures):
199+
try:
200+
state_dicts.append(fut.result())
201+
except Exception as e:
202+
print(e)
203+
raise
204+
205+
lighthouse.shutdown()
206+
207+
for state_dict in state_dicts:
208+
# LocalSGD only guarantees that the model is consistent across
209+
# replicas but uses separate optimizer states.
210+
torch.testing.assert_close(
211+
state_dict[0]["model"], state_dicts[0][0]["model"]
212+
)
213+
214+
self.assertEqual(failure_injectors[1].count, 1)
215+
216+
def test_diloco_healthy(self) -> None:
217+
lighthouse = Lighthouse(
218+
bind="[::]:0",
219+
min_replicas=2,
220+
)
221+
num_replicas = 2
222+
futures = []
223+
224+
torch.manual_seed(42)
225+
# Initialize the model so we can pass in the state_dict
226+
m: nn.Module = MyModel()
227+
228+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
229+
for replica_id in range(num_replicas):
230+
failure_injector = FailureInjector()
231+
runner = Runner(
232+
replica_id=replica_id,
233+
lighthouse_address=lighthouse.address(),
234+
failure_injector=failure_injector,
235+
train_loop=diloco_train_loop,
236+
train_loop_args={
237+
"model_state_dict": m.state_dict(),
238+
},
239+
)
240+
futures.append(executor.submit(runner.run_replica))
241+
242+
state_dicts = []
243+
244+
for fut in as_completed(futures):
245+
state_dicts.append(fut.result()[0])
246+
247+
lighthouse.shutdown()
248+
249+
for replica_group in state_dicts:
250+
for step, state_dict in replica_group.items():
251+
# inner optimizer will be different, outer optimizer and model should be the same
252+
torch.testing.assert_close(
253+
state_dict["backup_params"],
254+
state_dicts[0][str(step)]["backup_params"],
255+
)
256+
torch.testing.assert_close(
257+
state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"]
258+
)
259+
260+
def test_diloco_recovery(self) -> None:
261+
lighthouse = Lighthouse(
262+
bind="[::]:0",
263+
min_replicas=2,
264+
)
265+
num_replicas = 2
266+
futures = []
267+
268+
failure_injectors = [
269+
FailureInjector(),
270+
FailureInjector().fail_at(0, 2),
271+
]
272+
273+
torch.manual_seed(42)
274+
# Initialize the model so we can pass in the state_dict
275+
m: nn.Module = MyModel()
276+
277+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
278+
for replica_id, failure_injector in zip(
279+
range(num_replicas), failure_injectors
280+
):
281+
runner = Runner(
282+
replica_id=replica_id,
283+
lighthouse_address=lighthouse.address(),
284+
failure_injector=failure_injector,
285+
train_loop=diloco_train_loop,
286+
train_loop_args={
287+
"model_state_dict": m.state_dict(),
288+
},
289+
)
290+
futures.append(executor.submit(runner.run_replica))
291+
292+
state_dicts = []
293+
294+
for fut in as_completed(futures):
295+
try:
296+
state_dicts.append(fut.result()[0])
297+
except Exception as e:
298+
print(e)
299+
raise
300+
301+
lighthouse.shutdown()
302+
for replica_group in state_dicts:
303+
for step, state_dict in replica_group.items():
304+
str_step = str(step)
305+
if str_step in state_dicts[0]:
306+
# inner optimizer will be different, outer optimizer and model should be the same
307+
torch.testing.assert_close(
308+
state_dict["backup_params"],
309+
state_dicts[0][str_step]["backup_params"],
310+
)
311+
torch.testing.assert_close(
312+
state_dict["outer_optim"],
313+
state_dicts[0][str_step]["outer_optim"],
314+
)
315+
316+
self.assertEqual(failure_injectors[1].count, 1)

0 commit comments

Comments
 (0)