Skip to content

Commit

Permalink
Update on "[WIP] Add DiLoCo"
Browse files Browse the repository at this point in the history
## API Usage

```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()
        
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
    m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
    m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()

```

## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests

discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk

[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jan 21, 2025
1 parent 5c5d064 commit cfbf7e2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
1 change: 0 additions & 1 deletion torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from torch import nn, optim

from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer

Expand Down
2 changes: 1 addition & 1 deletion torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch import nn, optim

from torchft.local_sgd import DiLoCo, DiLoCoOptimizer, LocalSGD
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager


Expand Down
4 changes: 2 additions & 2 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import threading
import time
from concurrent.futures import as_completed, ThreadPoolExecutor
from contextlib import contextmanager, ExitStack
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Dict, Generator, List, Optional, Protocol, Set, Tuple
Expand Down

0 comments on commit cfbf7e2

Please sign in to comment.