Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiLoCo #76

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 128 additions & 73 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,29 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
LocalSGD
=========

This module implements a fault tolerant version of LocalSGD and related methods.
"""

from typing import Any, Dict, List, Mapping, Optional
import logging
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type

import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
from torch.utils.hooks import RemovableHandle

from torchft.manager import Manager

logger: logging.Logger = logging.getLogger(__name__)


class LocalSGD(nn.Module):
class LocalSGD:
"""
LocalSGD is a model wrapper similar to DistributedDataParallel that
LocalSGD is a context manager that
implements the algorithm described in https://arxiv.org/pdf/1805.09767

This will synchronize the model parameters periodically in a fault tolerant
Expand Down Expand Up @@ -68,18 +72,14 @@ def __init__(
pin_memory: Whether to pin the memory used for the backup of the model parameters.
"""
super().__init__()

self._manager = manager
self._model = model
self._local_optimizer = optimizer
self._local_step = 0
self._started_step = False
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"

device = backup_device or torch.device("cpu")

self._backup_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
if (
Expand All @@ -90,95 +90,150 @@ def __init__(
t = t.pin_memory()
self._backup_parameters[name] = t

self._hooks: List[RemovableHandle] = []
# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

optimizer.register_step_post_hook(self._step_post_hook)
def __enter__(self) -> "LocalSGD":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
if exc_type is not None:
# If an exception occurred, restore parameters
self._restore_parameters()
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _save_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=True)
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=False)

# pyre-fixme[14]: support state_dict args
def state_dict(self) -> Dict[str, object]:
"""
state_dict returns the state_dict from the last time LocalSGD
synchronized and not the current weights.
"""
state_dict = self._model.state_dict()
for name, p in self._backup_parameters.items():
assert name in state_dict
state_dict[name] = p
return state_dict

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
) -> None:
"""
Loads the state dict to the model and the backup parameters.
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()

This must be called while the model weights aren't being modified to
avoid corrupting the backup weights.
def sync(self) -> None:
"""
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
self._save_parameters()
Synchronizes and averages the model weights across the manager.
"""
self._manager.start_quorum()
self._perform_sync()
self._local_step = 0

def forward(self, *args: object, **kwargs: object) -> object:
def _perform_sync(self) -> None:
"""
Performs the synchronization of the model weights across the manager.
This method is intended to be overridden by subclasses to implement custom
synchronization logic.
"""
Run the model parameters.
self._average()
if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

This should be called before the optimizer step.
def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?

This will start the quorum and save the parameters if this is the first step.
"""
if self._local_step == 0:
self._manager.start_quorum()
works = []

for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p.data.detach()))

self._started_step = True
for work in works:
work.wait()

return self._model.forward(*args, **kwargs)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
class DiLoCo(LocalSGD):
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).

This will call the allreduce on the model weights every sync_every steps.
If any errors occur it will restore to the weights from the previous sync.
diloco: https://arxiv.org/pdf/2311.08105
"""

``forward`` must be called before this function.
def __init__(
self,
manager: Manager,
model: nn.Module,
inner_optimizer: optim.Optimizer,
outer_optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
) -> None:
if manager._use_async_quorum:
raise ValueError(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)
super().__init__(
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
)
self._outer_optimizer = outer_optimizer

def _perform_sync(self) -> None:
"""
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
step using the outer optimizer.
"""
assert self._started_step, "forward must be called before step"
self._started_step = False

self._local_step += 1
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
assert name in self._backup_parameters
pseudogradient = p.data - self._backup_parameters[name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check which device the backup_parameters are on? We need a .to call here if they're on CPU right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah good point, I will change up the backup_parameters and the load_state_dict in a follow up PR

p.grad = pseudogradient
H-Huang marked this conversation as resolved.
Show resolved Hide resolved

if self._local_step >= self._sync_every:
self._local_step = 0
self._average()
self._average_grads()
# Restore the parameters back to the previous state
self._restore_parameters()

if self._manager.should_commit():
# save the parameters so we can restore from them later if necessary.
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?
if self._manager.should_commit():
# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
self._save_parameters()
self._outer_optimizer.zero_grad()

def _average_grads(self) -> None:
"""
Average the gradients across the diloco group.
"""
works = []

for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p.data.detach()))

# Perform allreduce on the pseudogradients
assert p.grad is not None
work = self._manager.allreduce(p.grad)
works.append(work)
# Wait for all allreduce operations to complete
for work in works:
work.wait()
H-Huang marked this conversation as resolved.
Show resolved Hide resolved
Loading