Skip to content

Commit

Permalink
Add DiLoCo (#92)
Browse files Browse the repository at this point in the history
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
  • Loading branch information
H-Huang authored Jan 31, 2025
1 parent 2b23017 commit 2e2a3cb
Show file tree
Hide file tree
Showing 3 changed files with 430 additions and 137 deletions.
201 changes: 128 additions & 73 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,29 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
LocalSGD
=========
This module implements a fault tolerant version of LocalSGD and related methods.
"""

from typing import Any, Dict, List, Mapping, Optional
import logging
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type

import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
from torch.utils.hooks import RemovableHandle

from torchft.manager import Manager

logger: logging.Logger = logging.getLogger(__name__)


class LocalSGD(nn.Module):
class LocalSGD:
"""
LocalSGD is a model wrapper similar to DistributedDataParallel that
LocalSGD is a context manager that
implements the algorithm described in https://arxiv.org/pdf/1805.09767
This will synchronize the model parameters periodically in a fault tolerant
Expand Down Expand Up @@ -68,18 +72,14 @@ def __init__(
pin_memory: Whether to pin the memory used for the backup of the model parameters.
"""
super().__init__()

self._manager = manager
self._model = model
self._local_optimizer = optimizer
self._local_step = 0
self._started_step = False
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"

device = backup_device or torch.device("cpu")

self._backup_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
if (
Expand All @@ -90,95 +90,150 @@ def __init__(
t = t.pin_memory()
self._backup_parameters[name] = t

self._hooks: List[RemovableHandle] = []
# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

optimizer.register_step_post_hook(self._step_post_hook)
def __enter__(self) -> "LocalSGD":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
if exc_type is not None:
# If an exception occurred, restore parameters
self._restore_parameters()
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _save_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=True)
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=False)

# pyre-fixme[14]: support state_dict args
def state_dict(self) -> Dict[str, object]:
"""
state_dict returns the state_dict from the last time LocalSGD
synchronized and not the current weights.
"""
state_dict = self._model.state_dict()
for name, p in self._backup_parameters.items():
assert name in state_dict
state_dict[name] = p
return state_dict

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
) -> None:
"""
Loads the state dict to the model and the backup parameters.
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()

This must be called while the model weights aren't being modified to
avoid corrupting the backup weights.
def sync(self) -> None:
"""
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
self._save_parameters()
Synchronizes and averages the model weights across the manager.
"""
self._manager.start_quorum()
self._perform_sync()
self._local_step = 0

def forward(self, *args: object, **kwargs: object) -> object:
def _perform_sync(self) -> None:
"""
Performs the synchronization of the model weights across the manager.
This method is intended to be overridden by subclasses to implement custom
synchronization logic.
"""
Run the model parameters.
self._average()
if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

This should be called before the optimizer step.
def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?

This will start the quorum and save the parameters if this is the first step.
"""
if self._local_step == 0:
self._manager.start_quorum()
works = []

for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p.data.detach()))

self._started_step = True
for work in works:
work.wait()

return self._model.forward(*args, **kwargs)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
class DiLoCo(LocalSGD):
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
This will call the allreduce on the model weights every sync_every steps.
If any errors occur it will restore to the weights from the previous sync.
diloco: https://arxiv.org/pdf/2311.08105
"""

``forward`` must be called before this function.
def __init__(
self,
manager: Manager,
model: nn.Module,
inner_optimizer: optim.Optimizer,
outer_optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
) -> None:
if manager._use_async_quorum:
raise ValueError(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)
super().__init__(
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
)
self._outer_optimizer = outer_optimizer

def _perform_sync(self) -> None:
"""
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
step using the outer optimizer.
"""
assert self._started_step, "forward must be called before step"
self._started_step = False

self._local_step += 1
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
assert name in self._backup_parameters
pseudogradient = p.data - self._backup_parameters[name]
p.grad = pseudogradient

if self._local_step >= self._sync_every:
self._local_step = 0
self._average()
self._average_grads()
# Restore the parameters back to the previous state
self._restore_parameters()

if self._manager.should_commit():
# save the parameters so we can restore from them later if necessary.
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?
if self._manager.should_commit():
# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
self._save_parameters()
self._outer_optimizer.zero_grad()

def _average_grads(self) -> None:
"""
Average the gradients across the diloco group.
"""
works = []

for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p.data.detach()))

# Perform allreduce on the pseudogradients
assert p.grad is not None
work = self._manager.allreduce(p.grad)
works.append(work)
# Wait for all allreduce operations to complete
for work in works:
work.wait()
Loading

0 comments on commit 2e2a3cb

Please sign in to comment.