Skip to content

Commit

Permalink
Adding a flag and unit tests
Browse files Browse the repository at this point in the history
Some issue while running the unit test cases, will look into it more.
  • Loading branch information
Krishn1412 committed Feb 20, 2025
1 parent 910eb3f commit d774938
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 20 deletions.
75 changes: 55 additions & 20 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type

import torch
import torch.distributed as dist
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
from torch.utils.hooks import RemovableHandle
import torch.distributed as dist

from torchft.manager import Manager

Expand Down Expand Up @@ -183,6 +183,8 @@ class DiLoCo(LocalSGD):
diloco: https://arxiv.org/pdf/2311.08105
"""

BUCKET_SIZE_BYTES = 32 * 1024 * 1024

def __init__(
self,
manager: Manager,
Expand All @@ -192,6 +194,7 @@ def __init__(
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization=False,
) -> None:
if manager._use_async_quorum:
raise ValueError(
Expand Down Expand Up @@ -224,35 +227,67 @@ def _perform_sync(self) -> None:
self._outer_optimizer.step()
self._save_parameters()
self._outer_optimizer.zero_grad()

def _average_grads(self) -> None:
"""
Efficiently averages gradients across the diloco group using buffer-based bucketization.
Efficiently averages gradients across the group using either:
- Per-parameter allreduce (old behavior)
- Bucketized allreduce (new behavior)
"""
if self.use_bucketization:
self._allreduce_bucketized()
else:
self._allreduce_per_param()

grads = [p.grad for p in self._model.parameters() if p.grad is not None]
def _allreduce_per_param(self) -> None:
"""Performs allreduce on each gradient tensor separately (original method)."""
works = []
for p in self._model.parameters():
if p.grad is None:
continue
work = self._manager.allreduce(p.grad)
works.append(work)

for work in works:
work.wait()

def _allreduce_bucketized(self) -> None:
"""
Averages gradients using bucketized allreduce with a fixed 32MB buffer.
"""

grads = [p.grad for p in self._model.parameters() if p.grad is not None]
if not grads:
return # No gradients to process
return

# Compute total size and allocate a flat buffer for all gradients
# Compute total size and allocate a flat buffer
total_size = sum(g.numel() for g in grads)
flat_buffer = torch.zeros(total_size, dtype=grads[0].dtype, device=grads[0].device)
dtype, device = grads[0].dtype, grads[0].device

# Pack gradients into the buffer
# Process in fixed 32MB chunks
offset = 0
for g in grads:
flat_buffer[offset : offset + g.numel()].copy_(g.view(-1))
offset += g.numel()
while offset < total_size:
# Compute chunk size
chunk_size = min(
self.BUCKET_SIZE_BYTES // grads[0].element_size(), total_size - offset
)

# Perform Allreduce on the entire buffer
work = self._manager.allreduce(flat_buffer)
flat_buffer = torch.zeros(chunk_size, dtype=dtype, device=device)

# Wait for Allreduce to complete
work.wait()
# Pack gradients into buffer
pack_offset, bucket_tensors = 0, []
for g in grads:
numel = g.numel()
if pack_offset + numel > chunk_size:
break
flat_buffer[pack_offset : pack_offset + numel].copy_(g.view(-1))
bucket_tensors.append((g, pack_offset, numel))
pack_offset += numel

# Unpack gradients back into their original tensors
offset = 0
for g in grads:
g.copy_(flat_buffer[offset : offset + g.numel()].view_as(g))
offset += g.numel()
work = self._manager.allreduce(flat_buffer)
work.wait()

for g, pack_offset, numel in bucket_tensors:
g.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(g))

offset += chunk_size # Move to next chunk
54 changes: 54 additions & 0 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,57 @@ def test_diloco_healthy(self) -> None:

outer_opt_state = outer_optimizer.state_dict()
self.assertEqual(len(outer_opt_state["state"]), parameter_count)

def test_diloco_without_bucketization(self):
model = SimpleModel()
inner_optimizer = optim.AdamW(
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = optim.SGD(
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
manager._use_async_quorum = False

with DiLoCo(
manager,
model,
inner_optimizer,
outer_optimizer,
sync_every=2,
use_bucketization=False,
) as diloco:
inp = torch.rand(2, 3)
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()
self.assertEqual(diloco._local_step, 1)
self.assertEqual(
manager.allreduce.call_count, len(list(model.parameters()))
)

def test_diloco_with_bucketization(self):
model = SimpleModel()
inner_optimizer = optim.AdamW(
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = optim.SGD(
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
manager._use_async_quorum = False

with DiLoCo(
manager,
model,
inner_optimizer,
outer_optimizer,
sync_every=2,
use_bucketization=True,
) as diloco:
inp = torch.rand(2, 3)
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()
self.assertEqual(diloco._local_step, 1)
self.assertGreaterEqual(manager.allreduce.call_count, 1)

0 comments on commit d774938

Please sign in to comment.