Skip to content

Commit

Permalink
Updates based on comments
Browse files Browse the repository at this point in the history
1) Fixed lint issues
2) Changed variable name for bucket size
3) Added parameterised unit test
  • Loading branch information
Krishn1412 committed Feb 24, 2025
1 parent 2de4bdf commit 6d25cc8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
10 changes: 8 additions & 2 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ class DiLoCo(LocalSGD):
diloco: https://arxiv.org/pdf/2311.08105
"""

BUCKET_SIZE_BYTES = 32 * 1024 * 1024
bucket_cap_mb = 32 * 1024 * 1024
use_bucketization = False

def __init__(
self,
Expand All @@ -194,7 +195,8 @@ def __init__(
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization=False,
use_bucketization: bool = False,
bucket_cap_mb: int = None,
) -> None:
if manager._use_async_quorum:
raise ValueError(
Expand All @@ -205,6 +207,10 @@ def __init__(
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
)
self._outer_optimizer = outer_optimizer
if bucket_cap_mb is not None:
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

self.use_bucketization = use_bucketization

def _perform_sync(self) -> None:
"""
Expand Down
53 changes: 23 additions & 30 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest.mock import create_autospec

import torch
from parameterized import parameterized
from torch import nn, optim

from torchft.local_sgd import DiLoCo, LocalSGD
Expand Down Expand Up @@ -145,42 +146,33 @@ def test_diloco_healthy(self) -> None:
outer_opt_state = outer_optimizer.state_dict()
self.assertEqual(len(outer_opt_state["state"]), parameter_count)

def test_diloco_without_bucketization(self):
@parameterized.expand(
[
(
"without_bucketization",
False,
lambda self, manager, model: self.assertEqual(
manager.allreduce.call_count, len(list(model.parameters()))
),
),
(
"with_bucketization",
True,
lambda self, manager, model: self.assertGreaterEqual(
manager.allreduce.call_count, 1
),
),
]
)
def test_diloco_all_reduce(self, name, use_bucketization, assert_func):
model = SimpleModel()
inner_optimizer = optim.AdamW(
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = optim.SGD(
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
manager._use_async_quorum = False

with DiLoCo(
manager,
model,
inner_optimizer,
outer_optimizer,
sync_every=2,
use_bucketization=False,
) as diloco:
inp = torch.rand(2, 3)
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()
self.assertEqual(diloco._local_step, 1)
self.assertEqual(
manager.allreduce.call_count, len(list(model.parameters()))
)

def test_diloco_with_bucketization(self):
model = SimpleModel()
inner_optimizer = optim.AdamW(
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = optim.SGD(
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
manager._use_async_quorum = False

Expand All @@ -190,11 +182,12 @@ def test_diloco_with_bucketization(self):
inner_optimizer,
outer_optimizer,
sync_every=2,
use_bucketization=True,
use_bucketization=use_bucketization,
) as diloco:
inp = torch.rand(2, 3)
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()

self.assertEqual(diloco._local_step, 1)
self.assertGreaterEqual(manager.allreduce.call_count, 1)
assert_func(self, manager, model)

0 comments on commit 6d25cc8

Please sign in to comment.