You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Still WIP but open to feedback on the API
## API Usage
```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
for inp, label in dataloader:
loss = model(inp).mean()
loss.backward()
optimizer.step()
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
for inp, label in dataloader:
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()
# outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user
```
## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests
discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk
[ghstack-poisoned]
0 commit comments