Skip to content

Commit

Permalink
Clean up localsgd backup params
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Huang committed Feb 24, 2025
1 parent 5e65330 commit 7d52799
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 100 deletions.
139 changes: 91 additions & 48 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import logging
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type

import torch
from torch import nn, optim
Expand Down Expand Up @@ -59,8 +59,6 @@ def __init__(
model: nn.Module,
optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
) -> None:
"""
Args:
Expand All @@ -78,21 +76,8 @@ def __init__(
self._local_step = 0
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
device = backup_device or torch.device("cpu")
self._backup_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
if (
pin_memory
and t.device == torch.device("cpu")
and torch.cuda.is_available()
):
t = t.pin_memory()
self._backup_parameters[name] = t

self._hooks: List[RemovableHandle] = []
# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def __enter__(self) -> "LocalSGD":
# Add optimizer hook which increments the local step counter and syncs if necessary
Expand All @@ -108,37 +93,26 @@ def __exit__(
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
if exc_type is not None:
# If an exception occurred, restore parameters
self._restore_parameters()
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=False)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()
try:
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()
except Exception as e:
self._manager.report_error(e)
raise

def sync(self) -> None:
"""
Expand All @@ -151,15 +125,9 @@ def sync(self) -> None:
def _perform_sync(self) -> None:
"""
Performs the synchronization of the model weights across the manager.
This method is intended to be overridden by subclasses to implement custom
synchronization logic.
"""
self._average()
if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()
self._average()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?
Expand All @@ -174,7 +142,7 @@ def _average(self) -> None:
work.wait()


class DiLoCo(LocalSGD):
class DiLoCo:
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
Expand All @@ -197,21 +165,96 @@ def __init__(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)
super().__init__(
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
)
super().__init__()
self._manager = manager
self._model = model
self._local_optimizer = inner_optimizer
self._local_step = 0
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"

self._hooks: List[RemovableHandle] = []
self._outer_optimizer = outer_optimizer
self._original_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device)
if (
pin_memory
and t.device == torch.device("cpu")
and torch.cuda.is_available()
):
t = t.pin_memory()
self._original_parameters[name] = t

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def _save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._original_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._original_parameters[name], non_blocking=False)

def __enter__(self) -> "DiLoCo":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
try:
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()
except Exception as e:
self._manager.report_error(e)
raise

def sync(self) -> None:
"""
Synchronizes and averages the model weights across the manager.
"""
self._manager.start_quorum()
self._perform_sync()
self._local_step = 0

def _perform_sync(self) -> None:
"""
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
step using the outer optimizer.
"""
print("Performing DiLoCo sync", flush=True)

# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
assert name in self._backup_parameters
pseudogradient = p.data - self._backup_parameters[name]
assert name in self._original_parameters
pseudogradient = p.data - self._original_parameters[name]
p.grad = pseudogradient

self._average_grads()
Expand Down
Loading

0 comments on commit 7d52799

Please sign in to comment.