Open
Description
This is a tracking issue for adding LocalSGD support into torchft. There's been interest in LocalSGD support and it's something we'd like to be able to support.
This should be fairly straightforward as we can use the Manager + quorum in an outer loop and then use an allreduce only periodically copy of the weights.
Something like:
manager = Manager(...)
model = ...
while True:
for step in range(local_steps):
inputs, labels = next(dataloader_iter)
optimizer.zero_grad()
criterion(model(inputs), labels).backwards()
optimizer.step()
# update quorum and PGs (could overlap with the optimizer steps above)
manager.step()
# free gradient memory to make room for averaged weights
optimizer.zero_grad(set_to_none=True)
# copy the model weights and start the allreduce mean
# we need a temporary copy to gracefully handle failures
params = {}
for name, param in model.named_parameters():
copy = param.detach().clone()
manager.allreduce_grad(copy)
params[name] = copy
# this will wait for all transfers to complete succesfully
if manager.should_commit():
for name, param in model.named_parameters():
param.copy_(params[name])
del params[name]
DiLoCo should be a small modification of this algorithm to use a separate optimizer instead of just averaging the weights
For efficiency we should probably use the DDP reducer on the parameters directly and copy underlying Storage to make a backup copy
References:
- LocalSGD: https://arxiv.org/abs/2311.08105
- DiLoCo: https://arxiv.org/abs/2311.08105