-
Notifications
You must be signed in to change notification settings - Fork 25
Commit 22a474c
committed
Update base for Update on "[WIP] Add DiLoCo"
Still WIP but open to feedback on the API
## API Usage
```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
for inp, label in dataloader:
loss = model(inp).mean()
loss.backward()
optimizer.step()
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
for inp, label in dataloader:
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()
# outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user
```
## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests
discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk
[ghstack-poisoned]1 parent 68d4059 commit 22a474cCopy full SHA for 22a474c
File tree
0 file changed
+0
-0
lines changed0 file changed
+0
-0
lines changed
0 commit comments