Skip to content

Commit 7357b03

Browse files
committed
[WIP] Add DiLoCo
ghstack-source-id: 68e071e88b5b238d137e0ecdaa33d97b79370b22 Pull Request resolved: #76
1 parent 97ad397 commit 7357b03

File tree

3 files changed

+365
-116
lines changed

3 files changed

+365
-116
lines changed

torchft/local_sgd.py

+127-56
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,23 @@
1111
This module implements a fault tolerant version of LocalSGD and related methods.
1212
"""
1313

14-
from typing import Any, Dict, List, Mapping, Optional
14+
import logging
15+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional
1516

1617
import torch
1718
from torch import nn, optim
1819

20+
from torch.nn.parameter import Parameter
21+
from torch.optim.optimizer import Optimizer
22+
1923
from torchft.manager import Manager
2024

25+
logger: logging.Logger = logging.getLogger(__name__)
26+
2127

22-
class LocalSGD(nn.Module):
28+
class LocalSGD:
2329
"""
24-
LocalSGD is a model wrapper similar to DistributedDataParallel that
30+
LocalSGD is a context manager that
2531
implements the algorithm described in https://arxiv.org/pdf/1805.09767
2632
2733
This will synchronize the model parameters periodically in a fault tolerant
@@ -71,8 +77,8 @@ def __init__(
7177

7278
self._manager = manager
7379
self._model = model
80+
self._local_optimizer = optimizer
7481
self._local_step = 0
75-
self._started_step = False
7682
self._sync_every = sync_every
7783
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
7884

@@ -93,7 +99,30 @@ def __init__(
9399
# Need to copy the parameters to the host to be safe if we are on the first step.
94100
self._save_parameters()
95101

96-
optimizer.register_step_post_hook(self._step_post_hook)
102+
def __enter__(self):
103+
# Add optimizer hook which increments the local step counter and syncs if necessary
104+
self._opt_hook = self._local_optimizer.register_step_post_hook(
105+
self._step_post_hook
106+
)
107+
108+
# Register a forward prehook to check for quorum
109+
self._forward_pre_hook = self._model.register_forward_pre_hook(
110+
self._forward_step_pre_hook
111+
)
112+
113+
return self
114+
115+
def __exit__(self, exc_type, exc_value, traceback):
116+
# Handle any cleanup or error handling here
117+
if exc_type is not None:
118+
# If an exception occurred, restore parameters
119+
self._restore_parameters()
120+
121+
# Clean up hooks
122+
self._opt_hook.remove()
123+
self._forward_pre_hook.remove()
124+
125+
return False # Propagate exceptions
97126

98127
def _save_parameters(self) -> None:
99128
# TODO: consider running copy on a separate stream
@@ -105,71 +134,53 @@ def _restore_parameters(self) -> None:
105134
for name, p in self._model.named_parameters():
106135
p.data.copy_(self._backup_parameters[name], non_blocking=True)
107136

108-
# pyre-fixme[14]: support state_dict args
109-
def state_dict(self) -> Dict[str, object]:
110-
"""
111-
state_dict returns the state_dict from the last time LocalSGD
112-
synchronized and not the current weights.
113-
"""
114-
state_dict = self._model.state_dict()
115-
for name, p in self._backup_parameters.items():
116-
assert name in state_dict
117-
state_dict[name] = p
118-
return state_dict
119-
120-
def load_state_dict(
121-
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
137+
def _step_post_hook(
138+
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
122139
) -> None:
123140
"""
124-
Loads the state dict to the model and the backup parameters.
125-
126-
This must be called while the model weights aren't being modified to
127-
avoid corrupting the backup weights.
141+
This hook is registered on the optimizer and is called after the optimizer step.
128142
"""
129-
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
130-
self._save_parameters()
143+
self._local_step += 1
144+
if self._local_step >= self._sync_every:
145+
self.sync()
131146

132-
def forward(self, *args: object, **kwargs: object) -> object:
147+
def _forward_step_pre_hook(self, _module, _args):
133148
"""
134-
Run the model parameters.
135-
136-
This should be called before the optimizer step.
137-
138-
This will start the quorum and save the parameters if this is the first step.
149+
Start the quorum before each module forward.
139150
"""
140151
if self._local_step == 0:
141152
self._manager.start_quorum()
142153

143-
self._started_step = True
144-
145-
return self._model.forward(*args, **kwargs)
154+
# def should_sync(self) -> bool:
155+
# """
156+
# Checks if the model should be synchronized.
157+
# """
158+
# if self._local_step >= self._sync_every:
159+
# return True
160+
# else:
161+
# return False
146162

147-
def _step_post_hook(
148-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
149-
) -> None:
163+
def sync(self) -> None:
150164
"""
151-
This hook is registered on the optimizer and is called after the optimizer step.
152-
153-
This will call the allreduce on the model weights every sync_every steps.
154-
If any errors occur it will restore to the weights from the previous sync.
155-
156-
``forward`` must be called before this function.
165+
Synchronizes and averages the model weights across the manager.
157166
"""
158-
assert self._started_step, "forward must be called before step"
159-
self._started_step = False
160-
161-
self._local_step += 1
167+
self._local_step = 0
168+
self._perform_sync()
162169

163-
if self._local_step >= self._sync_every:
164-
self._local_step = 0
165-
self._average()
170+
if self._manager.should_commit():
171+
# save the parameters so we can restore from them later if necessary.
172+
self._save_parameters()
173+
else:
174+
# commit failed, restore from the backup parameters
175+
self._restore_parameters()
166176

167-
if self._manager.should_commit():
168-
# save the parameters so we can restore from them later if necessary.
169-
self._save_parameters()
170-
else:
171-
# commit failed, restore from the backup parameters
172-
self._restore_parameters()
177+
def _perform_sync(self) -> None:
178+
"""
179+
Performs the synchronization of the model weights across the manager.
180+
This method is intended to be overridden by subclasses to implement custom
181+
synchronization logic.
182+
"""
183+
self._average()
173184

174185
def _average(self) -> None:
175186
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +193,63 @@ def _average(self) -> None:
182193

183194
for work in works:
184195
work.wait()
196+
197+
198+
class DiLoCo(LocalSGD):
199+
"""
200+
DiLoCo is a subclass of LocalSGD that overrides the synchronization
201+
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
202+
203+
diloco: https://arxiv.org/pdf/2311.08105
204+
"""
205+
206+
def __init__(
207+
self,
208+
manager: Manager,
209+
model: nn.Module,
210+
inner_optimizer: optim.Optimizer,
211+
outer_optimizer: optim.Optimizer,
212+
sync_every: int,
213+
backup_device: Optional[torch.device] = None,
214+
pin_memory: bool = True,
215+
) -> None:
216+
super().__init__(
217+
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
218+
)
219+
self._outer_optimizer = outer_optimizer
220+
221+
def _model_sync(self) -> None:
222+
"""
223+
ensure model has the same weights
224+
"""
225+
pass
226+
227+
def _perform_sync(self) -> None:
228+
"""
229+
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
230+
step using the outer optimizer.
231+
"""
232+
233+
# Set the .grad field of each parameter to its pseudogradient
234+
for name, p in self._model.named_parameters():
235+
assert name in self._backup_parameters
236+
pseudogradient = p.data - self._backup_parameters[name]
237+
p.grad = pseudogradient
238+
239+
self._average_grads()
240+
241+
# Use the outer optimizer to update the model parameters
242+
self._outer_optimizer.step()
243+
244+
def _average_grads(self) -> None:
245+
"""
246+
Average the gradients across the diloco group.
247+
"""
248+
works = []
249+
for p in self._model.parameters():
250+
# Perform allreduce on the pseudogradients
251+
work = self._manager.allreduce(p.grad)
252+
works.append(work)
253+
# Wait for all allreduce operations to complete
254+
for work in works:
255+
work.wait()

torchft/local_sgd_test.py

+90-44
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from torch import nn, optim
1313

14-
from torchft.local_sgd import LocalSGD
14+
from torchft.local_sgd import DiLoCo, DiLoCoOptimizer, LocalSGD
1515
from torchft.manager import Manager
1616

1717

@@ -40,57 +40,103 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
4040

4141
class LocalSGDTest(TestCase):
4242
def test_local_sgd_healthy(self) -> None:
43-
base_m = SimpleModel()
44-
optimizer = optim.SGD(base_m.parameters())
43+
model = SimpleModel()
44+
optimizer = optim.SGD(model.parameters())
4545
manager = create_autospec(Manager)
46-
47-
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
48-
self.assertEqual(m._local_step, 0)
49-
50-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
51-
52-
inp = torch.rand(2, 3)
53-
54-
loss = m(inp).mean()
55-
loss.backward()
56-
optimizer.step()
57-
58-
self.assertEqual(m._local_step, 1)
59-
self.assertEqual(manager.start_quorum.call_count, 1)
60-
61-
loss = m(inp).mean()
62-
loss.backward()
63-
optimizer.step()
64-
65-
manager.should_commit.return_value = True
66-
self.assertEqual(m._local_step, 0)
67-
68-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
69-
self.assertEqual(manager.should_commit.call_count, 1)
70-
self.assertEqual(manager.allreduce.call_count, 4)
46+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
47+
self.assertEqual(local_sgd._local_step, 0)
48+
torch.testing.assert_close(
49+
local_sgd._backup_parameters, _params_dict(model)
50+
)
51+
inp = torch.rand(2, 3)
52+
loss = model(inp).mean()
53+
loss.backward()
54+
optimizer.step()
55+
56+
self.assertEqual(local_sgd._local_step, 1)
57+
self.assertEqual(manager.start_quorum.call_count, 1)
58+
loss = model(inp).mean()
59+
loss.backward()
60+
optimizer.step()
61+
62+
manager.should_commit.return_value = True
63+
self.assertEqual(local_sgd._local_step, 0)
64+
torch.testing.assert_close(
65+
local_sgd._backup_parameters, _params_dict(model)
66+
)
67+
self.assertEqual(manager.should_commit.call_count, 1)
68+
self.assertEqual(manager.allreduce.call_count, 4)
7169

7270
def test_local_sgd_recovery(self) -> None:
73-
base_m = SimpleModel()
74-
optimizer = optim.SGD(base_m.parameters())
71+
model = SimpleModel()
72+
optimizer = optim.SGD(model.parameters())
7573
manager = create_autospec(Manager)
7674

77-
m = LocalSGD(manager, base_m, optimizer, sync_every=2)
75+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
76+
torch.testing.assert_close(
77+
local_sgd._backup_parameters, _params_dict(model)
78+
)
79+
og_state_dict = _copy_state_dict(model.state_dict())
80+
81+
inp = torch.rand(2, 3)
82+
83+
loss = model(inp).mean()
84+
loss.backward()
85+
optimizer.step()
7886

79-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
80-
og_state_dict = _copy_state_dict(base_m.state_dict())
87+
# Check that the model's state dict has been updated
88+
for name, param in model.state_dict().items():
89+
# Ensure the parameter has changed
90+
self.assertFalse(
91+
torch.equal(og_state_dict[name], param),
92+
f"Parameter {name} did not change.",
93+
)
94+
self.assertEqual(local_sgd._local_step, 1)
8195

82-
inp = torch.rand(2, 3)
96+
local_sgd._restore_parameters()
97+
torch.testing.assert_close(
98+
local_sgd._backup_parameters, _params_dict(model)
99+
)
83100

84-
loss = m(inp).mean()
85-
loss.backward()
86-
optimizer.step()
87101

88-
self.assertEqual(m._local_step, 1)
102+
class DiLoCoTest(TestCase):
103+
def test_diloco_healt(self) -> None:
104+
model = SimpleModel()
89105

90-
state_dict = m.state_dict()
91-
torch.testing.assert_close(state_dict, m._backup_parameters)
92-
torch.testing.assert_close(state_dict, og_state_dict)
106+
# Setup optimizers
107+
inner_optimizer = torch.optim.AdamW(
108+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
109+
)
110+
outer_optimizer = torch.optim.SGD(
111+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
112+
)
93113

94-
m.load_state_dict(state_dict)
95-
torch.testing.assert_close(_params_dict(base_m), state_dict)
96-
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
114+
manager = create_autospec(Manager)
115+
with DiLoCo(
116+
manager, model, inner_optimizer, outer_optimizer, sync_every=2
117+
) as diloco:
118+
parameter_count = len(list(model.parameters()))
119+
initial_outer_opt_state = outer_optimizer.state_dict()
120+
self.assertEqual(initial_outer_opt_state["state"], {})
121+
122+
self.assertEqual(diloco._local_step, 0)
123+
torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
124+
inp = torch.rand(2, 3)
125+
loss = model(inp).mean()
126+
loss.backward()
127+
inner_optimizer.step()
128+
129+
self.assertEqual(diloco._local_step, 1)
130+
self.assertEqual(manager.start_quorum.call_count, 1)
131+
loss = model(inp).mean()
132+
loss.backward()
133+
inner_optimizer.step()
134+
135+
manager.should_commit.return_value = True
136+
self.assertEqual(diloco._local_step, 0)
137+
torch.testing.assert_close(diloco._backup_parameters, _params_dict(model))
138+
self.assertEqual(manager.should_commit.call_count, 1)
139+
self.assertEqual(manager.allreduce.call_count, parameter_count)
140+
141+
outer_opt_state = outer_optimizer.state_dict()
142+
self.assertEqual(len(outer_opt_state["state"]), parameter_count)

0 commit comments

Comments
 (0)