From 58a436d6212f34c2c3407918edbb4a68908b1092 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 13 Dec 2024 17:39:16 -0800 Subject: [PATCH] manager_integ_tests: added multi rank recovery and sync tests (#40) --- pyproject.toml | 3 +- torchft/ddp.py | 6 ++ torchft/manager.py | 70 ++++++++++--- torchft/manager_integ_test.py | 192 ++++++++++++++++++++++++++++------ torchft/manager_test.py | 7 +- 5 files changed, 225 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d61fb6..313dc5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ dev = [ "pytest", "black", - "pyre-check" + "pyre-check", + "parameterized" ] [tool.maturin] diff --git a/torchft/ddp.py b/torchft/ddp.py index fc6913b..e1d00a1 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -52,6 +52,12 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N super().__init__( module, process_group=pg, + # HACK: This forces the reducer to never rebuild buckets. + # The reducer normally rebuilds the buckets after the first training + # step which can improve performance but is incompatible with + # torchft as it will cause the buckets to diverge for recovering + # replicas. + find_unused_parameters=True, # pyre-fixme[6]: got object **kwargs, ) diff --git a/torchft/manager.py b/torchft/manager.py index 3e05751..c47de22 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -45,10 +45,9 @@ if TYPE_CHECKING: from torchft.process_group import ProcessGroup -logger: logging.Logger = logging.getLogger(__name__) - MANAGER_ADDR_KEY: str = "manager_addr" MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511)) +REPLICA_ID_KEY: str = "replica_id" T = TypeVar("T") @@ -132,7 +131,9 @@ def _manager_state_dict() -> Dict[str, T]: } self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict) - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="async_quorum" + ) self._quorum_future: Optional[concurrent.futures.Future] = None self._store = TCPStore( @@ -163,10 +164,16 @@ def _manager_state_dict() -> Dict[str, T]: ) self._store.set(MANAGER_ADDR_KEY, addr) + self._store.set(REPLICA_ID_KEY, replica_id) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") self._client = ManagerClient(addr, timeout=timeout) + replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") + self._logger = _ManagerLogger( + manager=self, replica_id=replica_id or "", rank=rank + ) + self._step = 0 self._quorum_id = -1 self._errored: Optional[Exception] = None @@ -230,6 +237,7 @@ def callback( ) -> torch.Tensor: nonlocal grad + # check for exceptions fut.value() grad /= self.num_participants() @@ -241,7 +249,9 @@ def callback( return fut except Exception as e: - logger.exception(f"got exception in all reduce -- skipping remaining: {e}") + self._logger.exception( + f"got exception in all reduce -- skipping remaining: {e}" + ) self.report_error(e) fut = torch.futures.Future() # pyre-fixme[29]: not a function @@ -294,7 +304,9 @@ def callback( try: return fut.value() except Exception as e: - logger.exception(f"got exception in future -- skipping remaining: {e}") + self._logger.exception( + f"got exception in future -- skipping remaining: {e}" + ) self.report_error(e) return default @@ -328,12 +340,13 @@ def step(self) -> None: if not self._use_async_quorum: self._quorum_future.result() - # eagerly apply pending state_dict so we can run the forwards pass - self._apply_pending_state_dict() + if self._healing: + # eagerly apply pending state_dict so we can run the forwards pass + self._apply_pending_state_dict() - # we are forcing healing at the beginning so we're in a good state - # and don't need to zero_grad - self._healing = False + # we are forcing healing at the beginning so we're in a good state + # and don't need to zero_grad + self._healing = False def _async_quorum(self) -> None: ( @@ -374,8 +387,9 @@ def _async_quorum(self) -> None: self._participating_rank = None if quorum_id != self._quorum_id: - logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}") store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}" + + self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id @@ -383,12 +397,13 @@ def _async_quorum(self) -> None: # See manager.rs for healing conditions if heal: self._healing = True - logger.info(f"{replica_rank}= healing required") - - logger.info(f"fetching checkpoint server address from {address}") + self._logger.info( + f"healing required, fetching checkpoint server address from {address=} {max_step=}" + ) primary_client = ManagerClient(address, timeout=self._timeout) checkpoint_server_address = primary_client.checkpoint_address(self._rank) + self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}") self._pending_state_dict = CheckpointServer.load_from_address( checkpoint_server_address ) @@ -406,8 +421,9 @@ def _apply_pending_state_dict(self) -> None: assert self._quorum_future is not None, "must call step before should_commit" self._quorum_future.result() - assert self._pending_state_dict is not None, "checkpoint was not staged" + self._logger.info("applying pending state dict") + assert self._pending_state_dict is not None, "checkpoint was not staged" self._load_state_dict(self._pending_state_dict["user"]) self._pending_state_dict = None @@ -450,7 +466,7 @@ def should_commit(self) -> bool: should_commit = self._client.should_commit( self._rank, self._step, local_should_commit ) - logger.info( + self._logger.info( f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}" ) @@ -534,3 +550,25 @@ def is_participating(self) -> bool: assert self._use_async_quorum return False return True + + +class _ManagerLogger: + def __init__(self, manager: Manager, replica_id: str, rank: int) -> None: + self._logger: logging.Logger = logging.getLogger(__name__) + self._replica_id = replica_id + self._rank = rank + self._manager = manager + + def prefix(self) -> str: + return ( + f"[{self._replica_id}/{self._rank} - step {self._manager.current_step()}]" + ) + + def info(self, msg: str) -> None: + self._logger.info(f"{self.prefix()} {msg}") + + def warn(self, msg: str) -> None: + self._logger.warn(f"{self.prefix()} {msg}") + + def exception(self, msg: str) -> None: + self._logger.exception(f"{self.prefix()} {msg}") diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 8251cab..4401770 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,10 +1,15 @@ +import logging +import threading from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import ExitStack -from typing import Dict, Set, Tuple +from contextlib import ExitStack, nullcontext +from typing import Dict, List, Set, Tuple from unittest import TestCase import torch import torch.distributed as dist + +# pyre-fixme[21]: missing module +from parameterized import parameterized from torch import nn, optim from torchft.ddp import DistributedDataParallel @@ -13,6 +18,8 @@ from torchft.process_group import ProcessGroupGloo from torchft.torchft import Lighthouse +logger: logging.Logger = logging.getLogger(__name__) + class MyModel(nn.Module): def __init__(self) -> None: @@ -32,32 +39,85 @@ class InjectedFailure(Exception): class FailureInjector: def __init__(self) -> None: - self._failures: Set[int] = set() + self._lock = threading.Lock() + self._failures: Set[Tuple[int, int]] = set() self.count = 0 - def fail_at(self, step: int) -> "FailureInjector": - self._failures.add(step) - return self + def fail_at(self, rank: int, step: int) -> "FailureInjector": + with self._lock: + self._failures.add((rank, step)) + return self - def check(self, step: int) -> None: - if step in self._failures: - self.count += 1 - self._failures.remove(step) - print(f"injecting failure {step=}") - raise InjectedFailure(f"injected failure {step=}") + def check(self, rank: int, step: int) -> None: + with self._lock: + key = (rank, step) + if key in self._failures: + self.count += 1 + self._failures.remove(key) + print(f"injecting failure {rank=} {step=}") + raise InjectedFailure(f"injected failure {rank=} {step=}") + + +def replica_main( + replica_id: int, + lighthouse_address: str, + failure_injector: FailureInjector, + world_size: int, + manager_args: Dict[str, object], +) -> List[Dict[str, Dict[str, object]]]: + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + + with ThreadPoolExecutor( + max_workers=world_size, thread_name_prefix=f"replica{replica_id}" + ) as executor: + futures = [] + for rank in range(world_size): + futures.append( + executor.submit( + train_loop, + replica_id, + lighthouse_address, + failure_injector=failure_injector, + rank=rank, + world_size=world_size, + store_port=store.port, + manager_args=manager_args, + ) + ) + + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.exception(f"worker threw exception: {e}") + raise + + return [fut.result() for fut in futures] def worker_manager( replica_id: int, lighthouse_address: str, failure_injector: FailureInjector, + manager_args: Dict[str, object], attempts: int = 3, -) -> Dict[str, Dict[str, object]]: + world_size: int = 1, +) -> List[Dict[str, Dict[str, object]]]: + for i in range(attempts): try: - print(f"starting worker {replica_id} attempt {i}") - return train_loop( - replica_id, lighthouse_address, failure_injector=failure_injector + print(f"starting replica group {replica_id=} {world_size=} attempt {i}") + return replica_main( + replica_id, + lighthouse_address, + failure_injector=failure_injector, + world_size=world_size, + manager_args=manager_args, ) except InjectedFailure as e: print("got injected failure", i, e) @@ -69,15 +129,15 @@ def worker_manager( def train_loop( - replica_id: int, lighthouse_address: str, failure_injector: FailureInjector + replica_id: int, + lighthouse_address: str, + failure_injector: FailureInjector, + rank: int, + world_size: int, + store_port: int, + manager_args: Dict[str, object], ) -> Dict[str, Dict[str, object]]: with ExitStack() as stack: - store = dist.TCPStore( - host_name="localhost", - port=0, - is_master=True, - wait_for_workers=False, - ) def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: m.load_state_dict(state_dict["model"]) @@ -89,6 +149,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: "optim": optimizer.state_dict(), } + print(f"worker {replica_id=} {rank=} {world_size=} starting") + pg = ProcessGroupGloo() manager = Manager( pg=pg, @@ -97,11 +159,13 @@ def state_dict() -> Dict[str, Dict[str, object]]: state_dict=state_dict, replica_id=str(replica_id), store_addr="localhost", - store_port=store.port, - rank=0, - world_size=1, + store_port=store_port, + rank=rank, + world_size=world_size, lighthouse_addr=lighthouse_address, port=19530 + replica_id, + # pyre-fixme[6]: Incompatible parameter type + **manager_args, ) stack.callback(manager.shutdown) @@ -112,7 +176,6 @@ def state_dict() -> Dict[str, Dict[str, object]]: criterion = nn.CrossEntropyLoss() while True: - print(f"worker {replica_id} starting step {manager.current_step()}") inputs = torch.rand(2, 3) labels = torch.randint(4, (2,)) @@ -121,12 +184,13 @@ def state_dict() -> Dict[str, Dict[str, object]]: loss = criterion(out, labels) loss.backward() + optimizer.step() - if manager.current_step() >= 5: + if manager.current_step() >= 4: break - failure_injector.check(manager.current_step()) + failure_injector.check(rank, manager.current_step()) # return state_dict so we can check consistency return state_dict() @@ -150,6 +214,7 @@ def test_ddp_healthy(self) -> None: replica_id, lighthouse.address(), failure_injector=failure_injector, + manager_args={}, ) ) @@ -163,7 +228,20 @@ def test_ddp_healthy(self) -> None: for state_dict in state_dicts: torch.testing.assert_close(state_dict, state_dicts[0]) - def test_ddp_recovery(self) -> None: + # pyre-fixme[56]: couldn't infer type of decorator + @parameterized.expand( + [ + ( + "async_quorum", + True, + ), + ( + "sync_quorum", + False, + ), + ] + ) + def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: lighthouse = Lighthouse( bind="[::]:0", min_replicas=2, @@ -173,7 +251,7 @@ def test_ddp_recovery(self) -> None: failure_injectors = [ FailureInjector(), - FailureInjector().fail_at(2), + FailureInjector().fail_at(0, 2), ] with ThreadPoolExecutor(max_workers=num_replicas) as executor: @@ -186,13 +264,16 @@ def test_ddp_recovery(self) -> None: replica_id, lighthouse.address(), failure_injector=failure_injector, + manager_args={ + "use_async_quorum": use_async_quorum, + }, ) ) - state_dicts = [] + state_dicts = [] - for fut in as_completed(futures): - state_dicts.append(fut.result()) + for fut in as_completed(futures): + state_dicts.append(fut.result()) lighthouse.shutdown() @@ -200,3 +281,46 @@ def test_ddp_recovery(self) -> None: torch.testing.assert_close(state_dict, state_dicts[0]) self.assertEqual(failure_injectors[1].count, 1) + + def test_ddp_recovery_multi_rank(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + world_size = 2 + futures = [] + + failure_injectors = [ + FailureInjector(), + FailureInjector().fail_at(0, 2).fail_at(1, 2), + ] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + futures.append( + executor.submit( + worker_manager, + replica_id, + lighthouse.address(), + failure_injector=failure_injector, + world_size=world_size, + manager_args={}, + ) + ) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 38d0d52..054685f 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -11,7 +11,7 @@ import torch from torch.distributed import TCPStore -from torchft.manager import MANAGER_ADDR_KEY, Manager, WorldSizeMode +from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup, _DummyWork from torchft.torchft import ManagerClient @@ -32,6 +32,7 @@ def _create_manager( host_name="localhost", port=0, is_master=True, wait_for_workers=False ) self.store.set(MANAGER_ADDR_KEY, "dummy") + self.store.set(REPLICA_ID_KEY, "dummy_id") with patch( "os.environ", { @@ -315,14 +316,16 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: ) manager.step() + self.assertFalse(manager._errored) + bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut manager.allreduce_grad(torch.tensor([1.0])).wait() + self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2) self.assertTrue(manager._errored) self.assertFalse(manager.should_commit()) self.assertTrue(manager._errored) - self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2) # cleanup manager._pg.allreduce.reset_mock(return_value=True)