diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 80e1407..2a224e3 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -99,10 +99,6 @@ def __enter__(self) -> "LocalSGD": self._hooks.append( self._local_optimizer.register_step_post_hook(self._step_post_hook) ) - # Register a forward prehook to check for quorum - self._hooks.append( - self._model.register_forward_pre_hook(self._forward_step_pre_hook) - ) return self def __exit__( @@ -132,7 +128,7 @@ def _restore_parameters(self) -> None: with torch.no_grad(): # TODO: consider running copy on a separate stream for name, p in self._model.named_parameters(): - p.copy_(self._backup_parameters[name], non_blocking=False) + p.data.copy_(self._backup_parameters[name], non_blocking=False) def _step_post_hook( self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] @@ -144,25 +140,12 @@ def _step_post_hook( if self._local_step >= self._sync_every: self.sync() - def _forward_step_pre_hook(self, _module: nn.Module, _args: List[object]) -> None: - """ - Start the quorum before each module forward. - """ - if self._local_step == 0: - self._manager.start_quorum() - def sync(self) -> None: """ Synchronizes and averages the model weights across the manager. """ + self._manager.start_quorum() self._perform_sync() - - if self._manager.should_commit(): - self._save_parameters() - else: - # commit failed, restore from the backup parameters - self._restore_parameters() - self._local_step = 0 def _perform_sync(self) -> None: @@ -172,6 +155,11 @@ def _perform_sync(self) -> None: synchronization logic. """ self._average() + if self._manager.should_commit(): + self._save_parameters() + else: + # commit failed, restore from the backup parameters + self._restore_parameters() def _average(self) -> None: # TODO: do we need to broadcast buffers like DDP does? @@ -227,12 +215,13 @@ def _perform_sync(self) -> None: p.grad = pseudogradient self._average_grads() - # Restore the parameters back to the previous state self._restore_parameters() - # Use the outer optimizer to update the model parameters - self._outer_optimizer.step() + if self._manager.should_commit(): + # Use the outer optimizer to update the model parameters + self._outer_optimizer.step() + self._save_parameters() self._outer_optimizer.zero_grad() def _average_grads(self) -> None: diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 7872fc2..7956cd1 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -100,7 +100,7 @@ def test_local_sgd_recovery(self) -> None: class DiLoCoTest(TestCase): - def test_diloco_healt(self) -> None: + def test_diloco_healthy(self) -> None: model = SimpleModel() # Setup optimizers @@ -112,6 +112,7 @@ def test_diloco_healt(self) -> None: ) manager = create_autospec(Manager) + manager._use_async_quorum = False with DiLoCo( manager, model, inner_optimizer, outer_optimizer, sync_every=2 ) as diloco: diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 0aca7a3..ca178c2 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,17 +1,18 @@ +import copy import logging import threading import time -from concurrent.futures import as_completed, ThreadPoolExecutor -from contextlib import contextmanager, ExitStack +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Dict, Generator, List, Optional, Protocol, Set, Tuple +from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, Union from unittest import TestCase import torch import torch.distributed as dist from parameterized import parameterized -from torch import nn, optim +from torch import Tensor, nn, optim from torchft.ddp import DistributedDataParallel from torchft.local_sgd import DiLoCo, LocalSGD @@ -76,6 +77,7 @@ class Runner: world_size: int = 1 attempts: int = 3 manager_args: Dict[str, object] = field(default_factory=dict) + train_loop_args: Dict[str, Any] = field(default_factory=dict) def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: store = dist.TCPStore( @@ -103,7 +105,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: try: fut.result() except Exception as e: - logger.exception(f"worker threw exception: {e}") + logger.exception(f"worker {self.replica_id=} threw exception: {e}") raise return [fut.result() for fut in futures] @@ -257,27 +259,31 @@ def diloco_train_loop( runner: Runner, ) -> Dict[str, Dict[str, object]]: with ExitStack() as stack: - torch.manual_seed(42) - # Declare the model and optimizers m: nn.Module = MyModel() + model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"] + m.load_state_dict(model_state_dict) # Setup optimizers inner_optimizer: optim.Optimizer = torch.optim.AdamW( m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) ) - outer_optimizer = torch.optim.SGD( + outer_optimizer: optim.Optimizer = torch.optim.SGD( m.parameters(), lr=0.7, momentum=0.9, nesterov=True ) + # pyre-ignore[53] def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: m.load_state_dict(state_dict["model"]) + # TODO: make this cleaner so we don't have to save this + diloco._backup_parameters = state_dict["backup_params"] inner_optimizer.load_state_dict(state_dict["inner_optim"]) outer_optimizer.load_state_dict(state_dict["outer_optim"]) - def state_dict() -> Dict[str, Dict[str, object]]: + def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] return { "model": m.state_dict(), + "backup_params": copy.deepcopy(diloco._backup_parameters), "inner_optim": inner_optimizer.state_dict(), "outer_optim": outer_optimizer.state_dict(), } @@ -303,14 +309,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: ) stack.callback(manager.shutdown) - # TODO: where in the training loop should we do this? - # Ensure all models have the same starting state - # We set manual seed so the models start with the same weights - manager.start_quorum() - for param in m.parameters(): - manager.allreduce(param.data) - criterion = nn.CrossEntropyLoss() + all_state_dicts = {} with DiLoCo( manager, m, inner_optimizer, outer_optimizer, sync_every=2 ) as diloco: @@ -324,6 +324,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: inner_optimizer.zero_grad() loss.backward() inner_optimizer.step() + manager_step_str = str(manager.current_step()) + all_state_dicts[manager_step_str] = state_dict() # after 4 model updates then break if manager.current_step() >= 4: @@ -331,9 +333,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: runner.failure_injector.check(rank, manager.current_step()) - return_state_dict = state_dict() # return state_dict so we can check consistency - return return_state_dict + return all_state_dicts class ManagerIntegTest(TestCase): @@ -524,6 +525,11 @@ def test_diloco_healthy(self) -> None: num_replicas = 2 futures = [] + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + print(m.state_dict()) + with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id in range(num_replicas): failure_injector = FailureInjector() @@ -532,6 +538,9 @@ def test_diloco_healthy(self) -> None: lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, ) futures.append(executor.submit(runner.run_replica)) @@ -542,12 +551,16 @@ def test_diloco_healthy(self) -> None: lighthouse.shutdown() - for state_dict in state_dicts: - # inner optimizer will be different, outer optimizer and model should be the same - torch.testing.assert_close(state_dict["model"], state_dicts[0]["model"]) - torch.testing.assert_close( - state_dict["outer_optim"], state_dicts[0]["outer_optim"] - ) + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str(step)]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"] + ) def test_diloco_recovery(self) -> None: lighthouse = Lighthouse( @@ -562,6 +575,10 @@ def test_diloco_recovery(self) -> None: FailureInjector().fail_at(0, 2), ] + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id, failure_injector in zip( range(num_replicas), failure_injectors @@ -571,6 +588,9 @@ def test_diloco_recovery(self) -> None: lighthouse_address=lighthouse.address(), failure_injector=failure_injector, train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, ) futures.append(executor.submit(runner.run_replica)) @@ -584,12 +604,19 @@ def test_diloco_recovery(self) -> None: raise lighthouse.shutdown() - # for state_dict in state_dicts: - # # inner optimizer will be different, outer optimizer and model should be the same - # torch.testing.assert_close(state_dict["model"], state_dicts[0]["model"]) - # torch.testing.assert_close( - # state_dict["outer_optim"], state_dicts[0]["outer_optim"] - # ) + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + str_step = str(step) + if str_step in state_dicts[0]: + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str_step]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], + state_dicts[0][str_step]["outer_optim"], + ) self.assertEqual(failure_injectors[1].count, 1)