Skip to content

Commit 3bd6a3b

Browse files
committed
[WIP] Add DiLoCo
ghstack-source-id: 2153244514c7ff795ec590804d208c9d9cd2b4ed Pull Request resolved: #76
1 parent 97ad397 commit 3bd6a3b

File tree

3 files changed

+362
-114
lines changed

3 files changed

+362
-114
lines changed

torchft/local_sgd.py

+126-56
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,22 @@
1111
This module implements a fault tolerant version of LocalSGD and related methods.
1212
"""
1313

14-
from typing import Any, Dict, List, Mapping, Optional
14+
import logging
15+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional
1516

1617
import torch
1718
from torch import nn, optim
19+
from torch.nn.parameter import Parameter
20+
from torch.optim.optimizer import Optimizer
1821

1922
from torchft.manager import Manager
2023

24+
logger: logging.Logger = logging.getLogger(__name__)
2125

22-
class LocalSGD(nn.Module):
26+
27+
class LocalSGD:
2328
"""
24-
LocalSGD is a model wrapper similar to DistributedDataParallel that
29+
LocalSGD is a context manager that
2530
implements the algorithm described in https://arxiv.org/pdf/1805.09767
2631
2732
This will synchronize the model parameters periodically in a fault tolerant
@@ -71,8 +76,8 @@ def __init__(
7176

7277
self._manager = manager
7378
self._model = model
79+
self._local_optimizer = optimizer
7480
self._local_step = 0
75-
self._started_step = False
7681
self._sync_every = sync_every
7782
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
7883

@@ -93,7 +98,30 @@ def __init__(
9398
# Need to copy the parameters to the host to be safe if we are on the first step.
9499
self._save_parameters()
95100

96-
optimizer.register_step_post_hook(self._step_post_hook)
101+
def __enter__(self):
102+
# Add optimizer hook which increments the local step counter and syncs if necessary
103+
self._opt_hook = self._local_optimizer.register_step_post_hook(
104+
self._step_post_hook
105+
)
106+
107+
# Register a forward prehook to check for quorum
108+
self._forward_pre_hook = self._model.register_forward_pre_hook(
109+
self._forward_step_pre_hook
110+
)
111+
112+
return self
113+
114+
def __exit__(self, exc_type, exc_value, traceback):
115+
# Handle any cleanup or error handling here
116+
if exc_type is not None:
117+
# If an exception occurred, restore parameters
118+
self._restore_parameters()
119+
120+
# Clean up hooks
121+
self._opt_hook.remove()
122+
self._forward_pre_hook.remove()
123+
124+
return False # Propagate exceptions
97125

98126
def _save_parameters(self) -> None:
99127
# TODO: consider running copy on a separate stream
@@ -105,71 +133,53 @@ def _restore_parameters(self) -> None:
105133
for name, p in self._model.named_parameters():
106134
p.data.copy_(self._backup_parameters[name], non_blocking=True)
107135

108-
# pyre-fixme[14]: support state_dict args
109-
def state_dict(self) -> Dict[str, object]:
110-
"""
111-
state_dict returns the state_dict from the last time LocalSGD
112-
synchronized and not the current weights.
113-
"""
114-
state_dict = self._model.state_dict()
115-
for name, p in self._backup_parameters.items():
116-
assert name in state_dict
117-
state_dict[name] = p
118-
return state_dict
119-
120-
def load_state_dict(
121-
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
136+
def _step_post_hook(
137+
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
122138
) -> None:
123139
"""
124-
Loads the state dict to the model and the backup parameters.
125-
126-
This must be called while the model weights aren't being modified to
127-
avoid corrupting the backup weights.
140+
This hook is registered on the optimizer and is called after the optimizer step.
128141
"""
129-
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
130-
self._save_parameters()
142+
self._local_step += 1
143+
if self._local_step >= self._sync_every:
144+
self.sync()
131145

132-
def forward(self, *args: object, **kwargs: object) -> object:
146+
def _forward_step_pre_hook(self, _module, _args):
133147
"""
134-
Run the model parameters.
135-
136-
This should be called before the optimizer step.
137-
138-
This will start the quorum and save the parameters if this is the first step.
148+
Start the quorum before each module forward.
139149
"""
140150
if self._local_step == 0:
141151
self._manager.start_quorum()
142152

143-
self._started_step = True
153+
# def should_sync(self) -> bool:
154+
# """
155+
# Checks if the model should be synchronized.
156+
# """
157+
# if self._local_step >= self._sync_every:
158+
# return True
159+
# else:
160+
# return False
144161

145-
return self._model.forward(*args, **kwargs)
146-
147-
def _step_post_hook(
148-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
149-
) -> None:
162+
def sync(self) -> None:
150163
"""
151-
This hook is registered on the optimizer and is called after the optimizer step.
152-
153-
This will call the allreduce on the model weights every sync_every steps.
154-
If any errors occur it will restore to the weights from the previous sync.
155-
156-
``forward`` must be called before this function.
164+
Synchronizes and averages the model weights across the manager.
157165
"""
158-
assert self._started_step, "forward must be called before step"
159-
self._started_step = False
166+
self._local_step = 0
167+
self._perform_sync()
160168

161-
self._local_step += 1
169+
if self._manager.should_commit():
170+
# save the parameters so we can restore from them later if necessary.
171+
self._save_parameters()
172+
else:
173+
# commit failed, restore from the backup parameters
174+
self._restore_parameters()
162175

163-
if self._local_step >= self._sync_every:
164-
self._local_step = 0
165-
self._average()
166-
167-
if self._manager.should_commit():
168-
# save the parameters so we can restore from them later if necessary.
169-
self._save_parameters()
170-
else:
171-
# commit failed, restore from the backup parameters
172-
self._restore_parameters()
176+
def _perform_sync(self) -> None:
177+
"""
178+
Performs the synchronization of the model weights across the manager.
179+
This method is intended to be overridden by subclasses to implement custom
180+
synchronization logic.
181+
"""
182+
self._average()
173183

174184
def _average(self) -> None:
175185
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +192,63 @@ def _average(self) -> None:
182192

183193
for work in works:
184194
work.wait()
195+
196+
197+
class DiLoCo(LocalSGD):
198+
"""
199+
DiLoCo is a subclass of LocalSGD that overrides the synchronization
200+
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
201+
202+
diloco: https://arxiv.org/pdf/2311.08105
203+
"""
204+
205+
def __init__(
206+
self,
207+
manager: Manager,
208+
model: nn.Module,
209+
inner_optimizer: optim.Optimizer,
210+
outer_optimizer: optim.Optimizer,
211+
sync_every: int,
212+
backup_device: Optional[torch.device] = None,
213+
pin_memory: bool = True,
214+
) -> None:
215+
super().__init__(
216+
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
217+
)
218+
self._outer_optimizer = outer_optimizer
219+
220+
def _model_sync(self) -> None:
221+
"""
222+
ensure model has the same weights
223+
"""
224+
pass
225+
226+
def _perform_sync(self) -> None:
227+
"""
228+
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
229+
step using the outer optimizer.
230+
"""
231+
232+
# Set the .grad field of each parameter to its pseudogradient
233+
for name, p in self._model.named_parameters():
234+
assert name in self._backup_parameters
235+
pseudogradient = p.data - self._backup_parameters[name]
236+
p.grad = pseudogradient
237+
238+
self._average_grads()
239+
240+
# Use the outer optimizer to update the model parameters
241+
self._outer_optimizer.step()
242+
243+
def _average_grads(self) -> None:
244+
"""
245+
Average the gradients across the diloco group.
246+
"""
247+
works = []
248+
for p in self._model.parameters():
249+
# Perform allreduce on the pseudogradients
250+
work = self._manager.allreduce(p.grad)
251+
works.append(work)
252+
# Wait for all allreduce operations to complete
253+
for work in works:
254+
work.wait()

torchft/local_sgd_test.py

+90-44
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from torch import nn, optim
1313

14-
from torchft.local_sgd import LocalSGD
14+
from torchft.local_sgd import DiLoCo, LocalSGD
1515
from torchft.manager import Manager
1616

1717

@@ -40,57 +40,103 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
4040

4141
class LocalSGDTest(TestCase):
4242
def test_local_sgd_healthy(self) -> None:
43-
base_m = SimpleModel()
44-
optimizer = optim.SGD(base_m.parameters())
43+
model = SimpleModel()
44+
optimizer = optim.SGD(model.parameters())
4545
manager = create_autospec(Manager)
46-
47-
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
48-
self.assertEqual(m._local_step, 0)
49-
50-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
51-
52-
inp = torch.rand(2, 3)
53-
54-
loss = m(inp).mean()
55-
loss.backward()
56-
optimizer.step()
57-
58-
self.assertEqual(m._local_step, 1)
59-
self.assertEqual(manager.start_quorum.call_count, 1)
60-
61-
loss = m(inp).mean()
62-
loss.backward()
63-
optimizer.step()
64-
65-
manager.should_commit.return_value = True
66-
self.assertEqual(m._local_step, 0)
67-
68-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
69-
self.assertEqual(manager.should_commit.call_count, 1)
70-
self.assertEqual(manager.allreduce.call_count, 4)
46+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
47+
self.assertEqual(local_sgd._local_step, 0)
48+
torch.testing.assert_close(
49+
local_sgd._backup_parameters, _params_dict(model)
50+
)
51+
inp = torch.rand(2, 3)
52+
loss = model(inp).mean()
53+
loss.backward()
54+
optimizer.step()
55+
56+
self.assertEqual(local_sgd._local_step, 1)
57+
self.assertEqual(manager.start_quorum.call_count, 1)
58+
loss = model(inp).mean()
59+
loss.backward()
60+
optimizer.step()
61+
62+
manager.should_commit.return_value = True
63+
self.assertEqual(local_sgd._local_step, 0)
64+
torch.testing.assert_close(
65+
local_sgd._backup_parameters, _params_dict(model)
66+
)
67+
self.assertEqual(manager.should_commit.call_count, 1)
68+
self.assertEqual(manager.allreduce.call_count, 4)
7169

7270
def test_local_sgd_recovery(self) -> None:
73-
base_m = SimpleModel()
74-
optimizer = optim.SGD(base_m.parameters())
71+
model = SimpleModel()
72+
optimizer = optim.SGD(model.parameters())
7573
manager = create_autospec(Manager)
7674

77-
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
75+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
76+
torch.testing.assert_close(
77+
local_sgd._backup_parameters, _params_dict(model)
78+
)
79+
og_state_dict = _copy_state_dict(model.state_dict())
80+
81+
inp = torch.rand(2, 3)
82+
83+
loss = model(inp).mean()
84+
loss.backward()
85+
optimizer.step()
7886

79-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
80-
og_state_dict = _copy_state_dict(base_m.state_dict())
87+
# Check that the model's state dict has been updated
88+
for name, param in model.state_dict().items():
89+
# Ensure the parameter has changed
90+
self.assertFalse(
91+
torch.equal(og_state_dict[name], param),
92+
f"Parameter {name} did not change.",
93+
)
94+
self.assertEqual(local_sgd._local_step, 1)
8195

82-
inp = torch.rand(2, 3)
96+
local_sgd._restore_parameters()
97+
torch.testing.assert_close(
98+
local_sgd._backup_parameters, _params_dict(model)
99+
)
83100

84-
loss = m(inp).mean()
85-
loss.backward()
86-
optimizer.step()
87101

88-
self.assertEqual(m._local_step, 1)
102+
class DiLoCoTest(TestCase):
103+
def test_diloco_healt(self) -> None:
104+
model = SimpleModel()
89105

90-
state_dict = m.state_dict()
91-
torch.testing.assert_close(state_dict, m._backup_parameters)
92-
torch.testing.assert_close(state_dict, og_state_dict)
106+
# Setup optimizers
107+
inner_optimizer = torch.optim.AdamW(
108+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
109+
)
110+
outer_optimizer = torch.optim.SGD(
111+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
112+
)
93113

94-
m.load_state_dict(state_dict)
95-
torch.testing.assert_close(_params_dict(base_m), state_dict)
96-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
114+
manager = create_autospec(Manager)
115+
with DiLoCo(
116+
manager, model, inner_optimizer, outer_optimizer, sync_every=2
117+
) as diloco:
118+
parameter_count = len(list(model.parameters()))
119+
initial_outer_opt_state = outer_optimizer.state_dict()
120+
self.assertEqual(initial_outer_opt_state["state"], {})
121+
122+
self.assertEqual(diloco._local_step, 0)
123+
torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
124+
inp = torch.rand(2, 3)
125+
loss = model(inp).mean()
126+
loss.backward()
127+
inner_optimizer.step()
128+
129+
self.assertEqual(diloco._local_step, 1)
130+
self.assertEqual(manager.start_quorum.call_count, 1)
131+
loss = model(inp).mean()
132+
loss.backward()
133+
inner_optimizer.step()
134+
135+
manager.should_commit.return_value = True
136+
self.assertEqual(diloco._local_step, 0)
137+
torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
138+
self.assertEqual(manager.should_commit.call_count, 1)
139+
self.assertEqual(manager.allreduce.call_count, parameter_count)
140+
141+
outer_opt_state = outer_optimizer.state_dict()
142+
self.assertEqual(len(outer_opt_state["state"]), parameter_count)

0 commit comments

Comments
 (0)