From cfbf7e2f477aaafaa7ed7dbcb6ee5c2e3b7702fc Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 21 Jan 2025 14:25:53 -0800 Subject: [PATCH] Update on "[WIP] Add DiLoCo" ## API Usage ```python # LocalSGD example model = SimpleModel() optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) with LocalSGD(manager, model, optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() optimizer.step() # DiLoCo example model = SimpleModel() inner_optimizer = torch.optim.AdamW( m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) ) outer_optimizer = torch.optim.SGD( m.parameters(), lr=0.7, momentum=0.9, nesterov=True ) manager = create_autospec(Manager) with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() optimizer.step() ``` ## Changes - Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum - Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared - TODO: should be working, but still validating some tests discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk [ghstack-poisoned] --- torchft/local_sgd.py | 1 - torchft/local_sgd_test.py | 2 +- torchft/manager_integ_test.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 6956907..b325b85 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -16,7 +16,6 @@ import torch from torch import nn, optim - from torch.nn.parameter import Parameter from torch.optim.optimizer import Optimizer diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 10db999..7872fc2 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -11,7 +11,7 @@ import torch from torch import nn, optim -from torchft.local_sgd import DiLoCo, DiLoCoOptimizer, LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index bb0d1ff..3f5aa6d 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,8 +1,8 @@ import logging import threading import time -from concurrent.futures import as_completed, ThreadPoolExecutor -from contextlib import contextmanager, ExitStack +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from datetime import timedelta from typing import Dict, Generator, List, Optional, Protocol, Set, Tuple