Skip to content

Commit 7d52799

Browse files
committed
Clean up localsgd backup params
1 parent 5e65330 commit 7d52799

File tree

5 files changed

+192
-100
lines changed

5 files changed

+192
-100
lines changed

torchft/local_sgd.py

Lines changed: 91 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111
import logging
1212
from types import TracebackType
13-
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
13+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
1414

1515
import torch
1616
from torch import nn, optim
@@ -59,8 +59,6 @@ def __init__(
5959
model: nn.Module,
6060
optimizer: optim.Optimizer,
6161
sync_every: int,
62-
backup_device: Optional[torch.device] = None,
63-
pin_memory: bool = True,
6462
) -> None:
6563
"""
6664
Args:
@@ -78,21 +76,8 @@ def __init__(
7876
self._local_step = 0
7977
self._sync_every = sync_every
8078
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
81-
device = backup_device or torch.device("cpu")
82-
self._backup_parameters: Dict[str, torch.Tensor] = {}
83-
for name, p in self._model.named_parameters():
84-
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
85-
if (
86-
pin_memory
87-
and t.device == torch.device("cpu")
88-
and torch.cuda.is_available()
89-
):
90-
t = t.pin_memory()
91-
self._backup_parameters[name] = t
9279

9380
self._hooks: List[RemovableHandle] = []
94-
# Need to copy the parameters to the host to be safe if we are on the first step.
95-
self._save_parameters()
9681

9782
def __enter__(self) -> "LocalSGD":
9883
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -108,37 +93,26 @@ def __exit__(
10893
traceback: Optional[TracebackType],
10994
) -> bool:
11095
# Handle any cleanup or error handling here
111-
if exc_type is not None:
112-
# If an exception occurred, restore parameters
113-
self._restore_parameters()
11496
# Clean up hooks
11597
for hook in self._hooks:
11698
hook.remove()
11799
self._hooks.clear()
118100

119101
return False # Propagate exceptions
120102

121-
def _save_parameters(self) -> None:
122-
with torch.no_grad():
123-
# TODO: consider running copy on a separate stream
124-
for name, p in self._model.named_parameters():
125-
self._backup_parameters[name].copy_(p.data, non_blocking=True)
126-
127-
def _restore_parameters(self) -> None:
128-
with torch.no_grad():
129-
# TODO: consider running copy on a separate stream
130-
for name, p in self._model.named_parameters():
131-
p.data.copy_(self._backup_parameters[name], non_blocking=False)
132-
133103
def _step_post_hook(
134-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
104+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
135105
) -> None:
136106
"""
137107
This hook is registered on the optimizer and is called after the optimizer step.
138108
"""
139-
self._local_step += 1
140-
if self._local_step >= self._sync_every:
141-
self.sync()
109+
try:
110+
self._local_step += 1
111+
if self._local_step >= self._sync_every:
112+
self.sync()
113+
except Exception as e:
114+
self._manager.report_error(e)
115+
raise
142116

143117
def sync(self) -> None:
144118
"""
@@ -151,15 +125,9 @@ def sync(self) -> None:
151125
def _perform_sync(self) -> None:
152126
"""
153127
Performs the synchronization of the model weights across the manager.
154-
This method is intended to be overridden by subclasses to implement custom
155-
synchronization logic.
156128
"""
157-
self._average()
158129
if self._manager.should_commit():
159-
self._save_parameters()
160-
else:
161-
# commit failed, restore from the backup parameters
162-
self._restore_parameters()
130+
self._average()
163131

164132
def _average(self) -> None:
165133
# TODO: do we need to broadcast buffers like DDP does?
@@ -174,7 +142,7 @@ def _average(self) -> None:
174142
work.wait()
175143

176144

177-
class DiLoCo(LocalSGD):
145+
class DiLoCo:
178146
"""
179147
DiLoCo is a subclass of LocalSGD that overrides the synchronization
180148
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
@@ -197,21 +165,96 @@ def __init__(
197165
"Using DiLoCo require synchronous quorum to be enabled. "
198166
"Ensure that the manager is initialized with use_async_quorum=False"
199167
)
200-
super().__init__(
201-
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
202-
)
168+
super().__init__()
169+
self._manager = manager
170+
self._model = model
171+
self._local_optimizer = inner_optimizer
172+
self._local_step = 0
173+
self._sync_every = sync_every
174+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
175+
176+
self._hooks: List[RemovableHandle] = []
203177
self._outer_optimizer = outer_optimizer
178+
self._original_parameters: Dict[str, torch.Tensor] = {}
179+
for name, p in self._model.named_parameters():
180+
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device)
181+
if (
182+
pin_memory
183+
and t.device == torch.device("cpu")
184+
and torch.cuda.is_available()
185+
):
186+
t = t.pin_memory()
187+
self._original_parameters[name] = t
188+
189+
# Need to copy the parameters to the host to be safe if we are on the first step.
190+
self._save_parameters()
191+
192+
def _save_parameters(self) -> None:
193+
with torch.no_grad():
194+
# TODO: consider running copy on a separate stream
195+
for name, p in self._model.named_parameters():
196+
self._original_parameters[name].copy_(p.data, non_blocking=True)
197+
198+
def _restore_parameters(self) -> None:
199+
with torch.no_grad():
200+
# TODO: consider running copy on a separate stream
201+
for name, p in self._model.named_parameters():
202+
p.data.copy_(self._original_parameters[name], non_blocking=False)
203+
204+
def __enter__(self) -> "DiLoCo":
205+
# Add optimizer hook which increments the local step counter and syncs if necessary
206+
self._hooks.append(
207+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
208+
)
209+
return self
210+
211+
def __exit__(
212+
self,
213+
exc_type: Optional[Type[BaseException]],
214+
exc_value: Optional[BaseException],
215+
traceback: Optional[TracebackType],
216+
) -> bool:
217+
# Handle any cleanup or error handling here
218+
# Clean up hooks
219+
for hook in self._hooks:
220+
hook.remove()
221+
self._hooks.clear()
222+
223+
return False # Propagate exceptions
224+
225+
def _step_post_hook(
226+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
227+
) -> None:
228+
"""
229+
This hook is registered on the optimizer and is called after the optimizer step.
230+
"""
231+
try:
232+
self._local_step += 1
233+
if self._local_step >= self._sync_every:
234+
self.sync()
235+
except Exception as e:
236+
self._manager.report_error(e)
237+
raise
238+
239+
def sync(self) -> None:
240+
"""
241+
Synchronizes and averages the model weights across the manager.
242+
"""
243+
self._manager.start_quorum()
244+
self._perform_sync()
245+
self._local_step = 0
204246

205247
def _perform_sync(self) -> None:
206248
"""
207249
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208250
step using the outer optimizer.
209251
"""
252+
print("Performing DiLoCo sync", flush=True)
210253

211254
# Set the .grad field of each parameter to its pseudogradient
212255
for name, p in self._model.named_parameters():
213-
assert name in self._backup_parameters
214-
pseudogradient = p.data - self._backup_parameters[name]
256+
assert name in self._original_parameters
257+
pseudogradient = p.data - self._original_parameters[name]
215258
p.grad = pseudogradient
216259

217260
self._average_grads()

0 commit comments

Comments
 (0)