Skip to content

Commit b671401

Browse files
committed
Add DiLoCo
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
1 parent beb94f0 commit b671401

File tree

3 files changed

+430
-139
lines changed

3 files changed

+430
-139
lines changed

torchft/local_sgd.py

+134-79
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,29 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
"""
87
LocalSGD
98
=========
10-
119
This module implements a fault tolerant version of LocalSGD and related methods.
1210
"""
13-
14-
from typing import Any, Dict, List, Mapping, Optional
11+
import logging
12+
from types import TracebackType
13+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
1514

1615
import torch
1716
from torch import nn, optim
17+
from torch.nn.parameter import Parameter
18+
from torch.optim.optimizer import Optimizer
19+
from torch.utils.hooks import RemovableHandle
1820

1921
from torchft.manager import Manager
2022

23+
logger: logging.Logger = logging.getLogger(__name__)
24+
2125

22-
class LocalSGD(nn.Module):
26+
class LocalSGD:
2327
"""
24-
LocalSGD is a model wrapper similar to DistributedDataParallel that
28+
LocalSGD is a context manager that
2529
implements the algorithm described in https://arxiv.org/pdf/1805.09767
2630
2731
This will synchronize the model parameters periodically in a fault tolerant
@@ -60,26 +64,22 @@ def __init__(
6064
) -> None:
6165
"""
6266
Args:
63-
manager: The manager to use.
64-
model: The model to wrap.
65-
optimizer: The optimizer used by the model.
66-
sync_every: How often to sync the model weights.
67-
backup_device: The device to store the backup of the model parameters on. (default cpu)
68-
pin_memory: Whether to pin the memory used for the backup of the model parameters.
67+
manager (Manager): The manager to use.
68+
model (nn.Module): The model to wrap.
69+
optimizer (optim.Optimizer): The optimizer used by the model.
70+
sync_every (int): How often to sync the model weights.
71+
backup_device (Optional[torch.device]): The device to store the backup of the model parameters on. (default cpu)
72+
pin_memory (bool): Whether to pin the memory used for the backup of the model parameters.
6973
"""
7074
super().__init__()
71-
7275
self._manager = manager
7376
self._model = model
77+
self._local_optimizer = optimizer
7478
self._local_step = 0
75-
self._started_step = False
7679
self._sync_every = sync_every
7780
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
78-
7981
device = backup_device or torch.device("cpu")
80-
8182
self._backup_parameters: Dict[str, torch.Tensor] = {}
82-
8383
for name, p in self._model.named_parameters():
8484
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
8585
if (
@@ -90,95 +90,150 @@ def __init__(
9090
t = t.pin_memory()
9191
self._backup_parameters[name] = t
9292

93+
self._hooks: List[RemovableHandle] = []
9394
# Need to copy the parameters to the host to be safe if we are on the first step.
9495
self._save_parameters()
9596

96-
optimizer.register_step_post_hook(self._step_post_hook)
97+
def __enter__(self) -> "LocalSGD":
98+
# Add optimizer hook which increments the local step counter and syncs if necessary
99+
self._hooks.append(
100+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
101+
)
102+
return self
103+
104+
def __exit__(
105+
self,
106+
exc_type: Optional[Type[BaseException]],
107+
exc_value: Optional[BaseException],
108+
traceback: Optional[TracebackType],
109+
) -> bool:
110+
# Handle any cleanup or error handling here
111+
if exc_type is not None:
112+
# If an exception occurred, restore parameters
113+
self._restore_parameters()
114+
# Clean up hooks
115+
for hook in self._hooks:
116+
hook.remove()
117+
self._hooks.clear()
118+
119+
return False # Propagate exceptions
97120

98121
def _save_parameters(self) -> None:
99-
# TODO: consider running copy on a separate stream
100-
for name, p in self._model.named_parameters():
101-
self._backup_parameters[name].copy_(p.data, non_blocking=True)
122+
with torch.no_grad():
123+
# TODO: consider running copy on a separate stream
124+
for name, p in self._model.named_parameters():
125+
self._backup_parameters[name].copy_(p.data, non_blocking=True)
102126

103127
def _restore_parameters(self) -> None:
104-
# TODO: consider running copy on a separate stream
105-
for name, p in self._model.named_parameters():
106-
p.data.copy_(self._backup_parameters[name], non_blocking=True)
128+
with torch.no_grad():
129+
# TODO: consider running copy on a separate stream
130+
for name, p in self._model.named_parameters():
131+
p.data.copy_(self._backup_parameters[name], non_blocking=False)
107132

108-
# pyre-fixme[14]: support state_dict args
109-
def state_dict(self) -> Dict[str, object]:
110-
"""
111-
state_dict returns the state_dict from the last time LocalSGD
112-
synchronized and not the current weights.
113-
"""
114-
state_dict = self._model.state_dict()
115-
for name, p in self._backup_parameters.items():
116-
assert name in state_dict
117-
state_dict[name] = p
118-
return state_dict
119-
120-
def load_state_dict(
121-
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
133+
def _step_post_hook(
134+
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
122135
) -> None:
123136
"""
124-
Loads the state dict to the model and the backup parameters.
137+
This hook is registered on the optimizer and is called after the optimizer step.
138+
"""
139+
self._local_step += 1
140+
if self._local_step >= self._sync_every:
141+
self.sync()
125142

126-
This must be called while the model weights aren't being modified to
127-
avoid corrupting the backup weights.
143+
def sync(self) -> None:
128144
"""
129-
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
130-
self._save_parameters()
145+
Synchronizes and averages the model weights across the manager.
146+
"""
147+
self._manager.start_quorum()
148+
self._perform_sync()
149+
self._local_step = 0
131150

132-
def forward(self, *args: object, **kwargs: object) -> object:
151+
def _perform_sync(self) -> None:
152+
"""
153+
Performs the synchronization of the model weights across the manager.
154+
This method is intended to be overridden by subclasses to implement custom
155+
synchronization logic.
133156
"""
134-
Run the model parameters.
157+
self._average()
158+
if self._manager.should_commit():
159+
self._save_parameters()
160+
else:
161+
# commit failed, restore from the backup parameters
162+
self._restore_parameters()
135163

136-
This should be called before the optimizer step.
164+
def _average(self) -> None:
165+
# TODO: do we need to broadcast buffers like DDP does?
137166

138-
This will start the quorum and save the parameters if this is the first step.
139-
"""
140-
if self._local_step == 0:
141-
self._manager.start_quorum()
167+
works = []
168+
169+
for p in self._model.parameters():
170+
# TODO: bucketize parameters
171+
works.append(self._manager.allreduce(p.data.detach()))
142172

143-
self._started_step = True
173+
for work in works:
174+
work.wait()
144175

145-
return self._model.forward(*args, **kwargs)
146176

147-
def _step_post_hook(
148-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
149-
) -> None:
150-
"""
151-
This hook is registered on the optimizer and is called after the optimizer step.
177+
class DiLoCo(LocalSGD):
178+
"""
179+
DiLoCo is a subclass of LocalSGD that overrides the synchronization
180+
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
152181
153-
This will call the allreduce on the model weights every sync_every steps.
154-
If any errors occur it will restore to the weights from the previous sync.
182+
diloco: https://arxiv.org/pdf/2311.08105
183+
"""
155184

156-
``forward`` must be called before this function.
185+
def __init__(
186+
self,
187+
manager: Manager,
188+
model: nn.Module,
189+
inner_optimizer: optim.Optimizer,
190+
outer_optimizer: optim.Optimizer,
191+
sync_every: int,
192+
backup_device: Optional[torch.device] = None,
193+
pin_memory: bool = True,
194+
) -> None:
195+
if manager._use_async_quorum:
196+
raise ValueError(
197+
"Using DiLoCo require synchronous quorum to be enabled. "
198+
"Ensure that the manager is initialized with use_async_quorum=False"
199+
)
200+
super().__init__(
201+
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
202+
)
203+
self._outer_optimizer = outer_optimizer
204+
205+
def _perform_sync(self) -> None:
206+
"""
207+
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208+
step using the outer optimizer.
157209
"""
158-
assert self._started_step, "forward must be called before step"
159-
self._started_step = False
160210

161-
self._local_step += 1
211+
# Set the .grad field of each parameter to its pseudogradient
212+
for name, p in self._model.named_parameters():
213+
assert name in self._backup_parameters
214+
pseudogradient = p.data - self._backup_parameters[name]
215+
p.grad = pseudogradient
162216

163-
if self._local_step >= self._sync_every:
164-
self._local_step = 0
165-
self._average()
217+
self._average_grads()
218+
# Restore the parameters back to the previous state
219+
self._restore_parameters()
166220

167-
if self._manager.should_commit():
168-
# save the parameters so we can restore from them later if necessary.
169-
self._save_parameters()
170-
else:
171-
# commit failed, restore from the backup parameters
172-
self._restore_parameters()
173-
174-
def _average(self) -> None:
175-
# TODO: do we need to broadcast buffers like DDP does?
221+
if self._manager.should_commit():
222+
# Use the outer optimizer to update the model parameters
223+
self._outer_optimizer.step()
224+
self._save_parameters()
225+
self._outer_optimizer.zero_grad()
176226

227+
def _average_grads(self) -> None:
228+
"""
229+
Average the gradients across the diloco group.
230+
"""
177231
works = []
178-
179232
for p in self._model.parameters():
180-
# TODO: bucketize parameters
181-
works.append(self._manager.allreduce(p.data.detach()))
182-
233+
# Perform allreduce on the pseudogradients
234+
assert p.grad is not None
235+
work = self._manager.allreduce(p.grad)
236+
works.append(work)
237+
# Wait for all allreduce operations to complete
183238
for work in works:
184239
work.wait()

0 commit comments

Comments
 (0)