Skip to content

Commit 2a6e9b7

Browse files
committed
[WIP] Add DiLoCo
ghstack-source-id: 588c13d633fc69574f8a87e46bbe8ae4069d4a3c Pull Request resolved: #76
1 parent beb94f0 commit 2a6e9b7

File tree

3 files changed

+409
-135
lines changed

3 files changed

+409
-135
lines changed

torchft/local_sgd.py

+141-75
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,29 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
"""
87
LocalSGD
98
=========
10-
119
This module implements a fault tolerant version of LocalSGD and related methods.
1210
"""
13-
14-
from typing import Any, Dict, List, Mapping, Optional
11+
import logging
12+
from types import TracebackType
13+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
1514

1615
import torch
1716
from torch import nn, optim
17+
from torch.nn.parameter import Parameter
18+
from torch.optim.optimizer import Optimizer
19+
from torch.utils.hooks import RemovableHandle
1820

1921
from torchft.manager import Manager
2022

23+
logger: logging.Logger = logging.getLogger(__name__)
24+
2125

22-
class LocalSGD(nn.Module):
26+
class LocalSGD:
2327
"""
24-
LocalSGD is a model wrapper similar to DistributedDataParallel that
28+
LocalSGD is a context manager that
2529
implements the algorithm described in https://arxiv.org/pdf/1805.09767
2630
2731
This will synchronize the model parameters periodically in a fault tolerant
@@ -60,26 +64,22 @@ def __init__(
6064
) -> None:
6165
"""
6266
Args:
63-
manager: The manager to use.
64-
model: The model to wrap.
65-
optimizer: The optimizer used by the model.
66-
sync_every: How often to sync the model weights.
67-
backup_device: The device to store the backup of the model parameters on. (default cpu)
68-
pin_memory: Whether to pin the memory used for the backup of the model parameters.
67+
manager (Manager): The manager to use.
68+
model (nn.Module): The model to wrap.
69+
optimizer (optim.Optimizer): The optimizer used by the model.
70+
sync_every (int): How often to sync the model weights.
71+
backup_device (Optional[torch.device]): The device to store the backup of the model parameters on. (default cpu)
72+
pin_memory (bool): Whether to pin the memory used for the backup of the model parameters.
6973
"""
7074
super().__init__()
71-
7275
self._manager = manager
7376
self._model = model
77+
self._local_optimizer = optimizer
7478
self._local_step = 0
75-
self._started_step = False
7679
self._sync_every = sync_every
7780
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
78-
7981
device = backup_device or torch.device("cpu")
80-
8182
self._backup_parameters: Dict[str, torch.Tensor] = {}
82-
8383
for name, p in self._model.named_parameters():
8484
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
8585
if (
@@ -90,86 +90,88 @@ def __init__(
9090
t = t.pin_memory()
9191
self._backup_parameters[name] = t
9292

93+
self._hooks: List[RemovableHandle] = []
9394
# Need to copy the parameters to the host to be safe if we are on the first step.
9495
self._save_parameters()
9596

96-
optimizer.register_step_post_hook(self._step_post_hook)
97+
def __enter__(self) -> "LocalSGD":
98+
# Add optimizer hook which increments the local step counter and syncs if necessary
99+
self._hooks.append(
100+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
101+
)
102+
# Register a forward prehook to check for quorum
103+
self._hooks.append(
104+
self._model.register_forward_pre_hook(self._forward_step_pre_hook)
105+
)
106+
return self
107+
108+
def __exit__(
109+
self,
110+
exc_type: Optional[Type[BaseException]],
111+
exc_value: Optional[BaseException],
112+
traceback: Optional[TracebackType],
113+
) -> bool:
114+
# Handle any cleanup or error handling here
115+
if exc_type is not None:
116+
# If an exception occurred, restore parameters
117+
self._restore_parameters()
118+
# Clean up hooks
119+
for hook in self._hooks:
120+
hook.remove()
121+
self._hooks.clear()
122+
123+
return False # Propagate exceptions
97124

98125
def _save_parameters(self) -> None:
99-
# TODO: consider running copy on a separate stream
100-
for name, p in self._model.named_parameters():
101-
self._backup_parameters[name].copy_(p.data, non_blocking=True)
126+
with torch.no_grad():
127+
# TODO: consider running copy on a separate stream
128+
for name, p in self._model.named_parameters():
129+
self._backup_parameters[name].copy_(p.data, non_blocking=True)
102130

103131
def _restore_parameters(self) -> None:
104-
# TODO: consider running copy on a separate stream
105-
for name, p in self._model.named_parameters():
106-
p.data.copy_(self._backup_parameters[name], non_blocking=True)
132+
with torch.no_grad():
133+
# TODO: consider running copy on a separate stream
134+
for name, p in self._model.named_parameters():
135+
p.copy_(self._backup_parameters[name], non_blocking=False)
107136

108-
# pyre-fixme[14]: support state_dict args
109-
def state_dict(self) -> Dict[str, object]:
110-
"""
111-
state_dict returns the state_dict from the last time LocalSGD
112-
synchronized and not the current weights.
113-
"""
114-
state_dict = self._model.state_dict()
115-
for name, p in self._backup_parameters.items():
116-
assert name in state_dict
117-
state_dict[name] = p
118-
return state_dict
119-
120-
def load_state_dict(
121-
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
137+
def _step_post_hook(
138+
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
122139
) -> None:
123140
"""
124-
Loads the state dict to the model and the backup parameters.
125-
126-
This must be called while the model weights aren't being modified to
127-
avoid corrupting the backup weights.
141+
This hook is registered on the optimizer and is called after the optimizer step.
128142
"""
129-
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
130-
self._save_parameters()
143+
self._local_step += 1
144+
if self._local_step >= self._sync_every:
145+
self.sync()
131146

132-
def forward(self, *args: object, **kwargs: object) -> object:
147+
def _forward_step_pre_hook(self, _module: nn.Module, _args: List[object]) -> None:
133148
"""
134-
Run the model parameters.
135-
136-
This should be called before the optimizer step.
137-
138-
This will start the quorum and save the parameters if this is the first step.
149+
Start the quorum before each module forward.
139150
"""
140151
if self._local_step == 0:
141152
self._manager.start_quorum()
142153

143-
self._started_step = True
144-
145-
return self._model.forward(*args, **kwargs)
146-
147-
def _step_post_hook(
148-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
149-
) -> None:
154+
def sync(self) -> None:
150155
"""
151-
This hook is registered on the optimizer and is called after the optimizer step.
152-
153-
This will call the allreduce on the model weights every sync_every steps.
154-
If any errors occur it will restore to the weights from the previous sync.
155-
156-
``forward`` must be called before this function.
156+
Synchronizes and averages the model weights across the manager.
157157
"""
158-
assert self._started_step, "forward must be called before step"
159-
self._started_step = False
158+
self._perform_sync()
160159

161-
self._local_step += 1
160+
if self._manager.should_commit():
161+
self._save_parameters()
162+
else:
163+
# commit failed, restore from the backup parameters
164+
self._restore_parameters()
162165

163-
if self._local_step >= self._sync_every:
164-
self._local_step = 0
165-
self._average()
166+
self._local_step = 0
166167

167-
if self._manager.should_commit():
168-
# save the parameters so we can restore from them later if necessary.
169-
self._save_parameters()
170-
else:
171-
# commit failed, restore from the backup parameters
172-
self._restore_parameters()
168+
def _perform_sync(self) -> None:
169+
"""
170+
Performs the synchronization of the model weights across the manager.
171+
This method is intended to be overridden by subclasses to implement custom
172+
synchronization logic.
173+
"""
174+
self._average()
173175

174176
def _average(self) -> None:
175177
# TODO: do we need to broadcast buffers like DDP does?
@@ -182,3 +184,67 @@ def _average(self) -> None:
182184

183185
for work in works:
184186
work.wait()
187+
188+
189+
class DiLoCo(LocalSGD):
190+
"""
191+
DiLoCo is a subclass of LocalSGD that overrides the synchronization
192+
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
193+
194+
diloco: https://arxiv.org/pdf/2311.08105
195+
"""
196+
197+
def __init__(
198+
self,
199+
manager: Manager,
200+
model: nn.Module,
201+
inner_optimizer: optim.Optimizer,
202+
outer_optimizer: optim.Optimizer,
203+
sync_every: int,
204+
backup_device: Optional[torch.device] = None,
205+
pin_memory: bool = True,
206+
) -> None:
207+
if manager._use_async_quorum:
208+
raise ValueError(
209+
"Using DiLoCo require synchronous quorum to be enabled. "
210+
"Ensure that the manager is initialized with use_async_quorum=False"
211+
)
212+
super().__init__(
213+
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
214+
)
215+
self._outer_optimizer = outer_optimizer
216+
217+
def _perform_sync(self) -> None:
218+
"""
219+
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
220+
step using the outer optimizer.
221+
"""
222+
223+
# Set the .grad field of each parameter to its pseudogradient
224+
for name, p in self._model.named_parameters():
225+
assert name in self._backup_parameters
226+
pseudogradient = p.data - self._backup_parameters[name]
227+
p.grad = pseudogradient
228+
229+
self._average_grads()
230+
231+
# Restore the parameters back to the previous state
232+
self._restore_parameters()
233+
234+
# Use the outer optimizer to update the model parameters
235+
self._outer_optimizer.step()
236+
self._outer_optimizer.zero_grad()
237+
238+
def _average_grads(self) -> None:
239+
"""
240+
Average the gradients across the diloco group.
241+
"""
242+
works = []
243+
for p in self._model.parameters():
244+
# Perform allreduce on the pseudogradients
245+
assert p.grad is not None
246+
work = self._manager.allreduce(p.grad)
247+
works.append(work)
248+
# Wait for all allreduce operations to complete
249+
for work in works:
250+
work.wait()

0 commit comments

Comments
 (0)