diff --git a/docs/source/coordination.rst b/docs/source/coordination.rst new file mode 100644 index 0000000..e111170 --- /dev/null +++ b/docs/source/coordination.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.coordination + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 4d2a5af..361aec4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ the entire training job. data checkpointing parameter_server + coordination License diff --git a/pyproject.toml b/pyproject.toml index b76e204..a2a1751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dev = [ [tool.maturin] features = ["pyo3/extension-module"] +module-name = "torchft._torchft" [project.scripts] torchft_lighthouse = "torchft.torchft:lighthouse_main" diff --git a/src/lib.rs b/src/lib.rs index 529532d..59fad6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,15 +30,30 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest}; use pyo3::prelude::*; +/// ManagerServer is a GRPC server for the manager service. +/// There should be one manager server per replica group (typically running on +/// the rank 0 host). The individual ranks within a replica group should use +/// ManagerClient to communicate with the manager server and participate in +/// quorum operations. +/// +/// Args: +/// replica_id (str): The ID of the replica group. +/// lighthouse_addr (str): The HTTP address of the lighthouse server. +/// hostname (str): The hostname of the manager server. +/// bind (str): The HTTP address to bind the server to. +/// store_addr (str): The HTTP address of the store server. +/// world_size (int): The world size of the replica group. +/// heartbeat_interval (timedelta): The interval at which heartbeats are sent. +/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server. #[pyclass] -struct Manager { +struct ManagerServer { handle: JoinHandle>, manager: Arc, _runtime: Runtime, } #[pymethods] -impl Manager { +impl ManagerServer { #[new] fn new( py: Python<'_>, @@ -74,10 +89,15 @@ impl Manager { }) } + /// address returns the address of the manager server. + /// + /// Returns: + /// str: The address of the manager server. fn address(&self) -> PyResult { Ok(self.manager.address().to_string()) } + /// shutdown shuts down the manager server. fn shutdown(&self, py: Python<'_>) { py.allow_threads(move || { self.handle.abort(); @@ -85,6 +105,13 @@ impl Manager { } } +/// ManagerClient is a GRPC client to the manager service. +/// +/// It is used by the trainer to communicate with the ManagerServer. +/// +/// Args: +/// addr (str): The HTTP address of the manager server. +/// connect_timeout (timedelta): The timeout for connecting to the manager server. #[pyclass] struct ManagerClient { runtime: Runtime, @@ -108,7 +135,7 @@ impl ManagerClient { }) } - fn quorum( + fn _quorum( &self, py: Python<'_>, rank: i64, @@ -147,7 +174,7 @@ impl ManagerClient { }) } - fn checkpoint_metadata( + fn _checkpoint_metadata( &self, py: Python<'_>, rank: i64, @@ -168,6 +195,20 @@ impl ManagerClient { }) } + /// should_commit makes a request to the manager to determine if the trainer + /// should commit the current step. This waits until all ranks check in at + /// the specified step and will return false if any worker passes + /// ``should_commit=False``. + /// + /// Args: + /// rank (int): The rank of the trainer. + /// step (int): The step of the trainer. + /// should_commit (bool): Whether the trainer should commit the current step. + /// timeout (timedelta): The timeout for the request. If the request + /// times out a TimeoutError is raised. + /// + /// Returns: + /// bool: Whether the trainer should commit the current step. fn should_commit( &self, py: Python<'_>, @@ -263,15 +304,28 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { Ok(()) } +/// LighthouseServer is a GRPC server for the lighthouse service. +/// +/// It is used to coordinate the ManagerServer for each replica group. +/// +/// This entrypoint is primarily for testing and debugging purposes. The +/// ``torchft_lighthouse`` command is recommended for most use cases. +/// +/// Args: +/// bind (str): The HTTP address to bind the server to. +/// min_replicas (int): The minimum number of replicas required to form a quorum. +/// join_timeout_ms (int): The timeout for joining the quorum. +/// quorum_tick_ms (int): The interval at which the quorum is checked. +/// heartbeat_timeout_ms (int): The timeout for heartbeats. #[pyclass] -struct Lighthouse { +struct LighthouseServer { lighthouse: Arc, handle: JoinHandle>, _runtime: Runtime, } #[pymethods] -impl Lighthouse { +impl LighthouseServer { #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] #[new] fn new( @@ -307,10 +361,15 @@ impl Lighthouse { }) } + /// address returns the address of the lighthouse server. + /// + /// Returns: + /// str: The address of the lighthouse server. fn address(&self) -> PyResult { Ok(self.lighthouse.address().to_string()) } + /// shutdown shuts down the lighthouse server. fn shutdown(&self, py: Python<'_>) { py.allow_threads(move || { self.handle.abort(); @@ -339,7 +398,7 @@ impl From for StatusError { } #[pymodule] -fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { // setup logging on import let mut log = stderrlog::new(); log.verbosity(2) @@ -353,9 +412,9 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { log.init() .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; diff --git a/torchft/torchft.pyi b/torchft/_torchft.pyi similarity index 91% rename from torchft/torchft.pyi rename to torchft/_torchft.pyi index 2c8c6cd..b4afde6 100644 --- a/torchft/torchft.pyi +++ b/torchft/_torchft.pyi @@ -3,7 +3,7 @@ from typing import List, Optional class ManagerClient: def __init__(self, addr: str, connect_timeout: timedelta) -> None: ... - def quorum( + def _quorum( self, rank: int, step: int, @@ -11,7 +11,7 @@ class ManagerClient: shrink_only: bool, timeout: timedelta, ) -> QuorumResult: ... - def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... + def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( self, rank: int, @@ -33,7 +33,7 @@ class QuorumResult: max_world_size: int heal: bool -class Manager: +class ManagerServer: def __init__( self, replica_id: str, @@ -48,7 +48,7 @@ class Manager: def address(self) -> str: ... def shutdown(self) -> None: ... -class Lighthouse: +class LighthouseServer: def __init__( self, bind: str, diff --git a/torchft/coordination.py b/torchft/coordination.py new file mode 100644 index 0000000..48c6075 --- /dev/null +++ b/torchft/coordination.py @@ -0,0 +1,23 @@ +""" +Coordination (Low Level API) +============================ + +.. warning:: + As torchft is still in development, the APIs in this module are subject to change. + +This module exposes low level coordination APIs to allow you to build your own +custom fault tolerance algorithms on top of torchft. + +If you're looking for a more complete solution, please use the other modules in +torchft. + +This provides direct access to the Lighthouse and Manager servers and clients. +""" + +from torchft._torchft import LighthouseServer, ManagerClient, ManagerServer + +__all__ = [ + "LighthouseServer", + "ManagerServer", + "ManagerClient", +] diff --git a/torchft/coordination_test.py b/torchft/coordination_test.py new file mode 100644 index 0000000..fc13f7b --- /dev/null +++ b/torchft/coordination_test.py @@ -0,0 +1,19 @@ +import inspect +from unittest import TestCase + +from torchft.coordination import LighthouseServer, ManagerClient, ManagerServer + + +class TestCoordination(TestCase): + def test_coordination_docs(self) -> None: + classes = [ + ManagerClient, + ManagerServer, + LighthouseServer, + ] + for cls in classes: + self.assertIn("Args:", str(cls.__doc__), cls) + for name, method in inspect.getmembers(cls, predicate=inspect.ismethod): + if name.startswith("_"): + continue + self.assertIn("Args:", str(cls.__doc__), cls) diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 38700b6..f755a7a 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -4,7 +4,7 @@ import torch.distributed as dist from torchft import Manager, ProcessGroupGloo -from torchft.torchft import Lighthouse +from torchft._torchft import LighthouseServer class TestLighthouse(TestCase): @@ -12,7 +12,7 @@ def test_join_timeout_behavior(self) -> None: """Test that join_timeout_ms affects joining behavior""" # To test, we create a lighthouse with 100ms and 400ms join timeouts # and measure the time taken to validate the quorum. - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=1, join_timeout_ms=100, @@ -52,14 +52,14 @@ def test_join_timeout_behavior(self) -> None: if "manager" in locals(): manager.shutdown() - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=1, join_timeout_ms=400, ) def test_heartbeat_timeout_ms_sanity(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=1, heartbeat_timeout_ms=100, diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index bf54c8f..55ca5c3 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -8,11 +8,11 @@ import torch from torch import nn, optim +from torchft._torchft import LighthouseServer from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.manager_integ_test import FailureInjector, MyModel, Runner from torchft.process_group import ProcessGroupGloo -from torchft.torchft import Lighthouse logger: logging.Logger = logging.getLogger(__name__) @@ -166,7 +166,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] class ManagerIntegTest(TestCase): def test_local_sgd_recovery(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) @@ -214,7 +214,7 @@ def test_local_sgd_recovery(self) -> None: self.assertEqual(failure_injectors[1].count, 1) def test_diloco_healthy(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) @@ -258,7 +258,7 @@ def test_diloco_healthy(self) -> None: ) def test_diloco_recovery(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) diff --git a/torchft/manager.py b/torchft/manager.py index a355689..668189c 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -38,9 +38,9 @@ import torch from torch.distributed import ReduceOp, TCPStore +from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout -from torchft.torchft import Manager as _Manager, ManagerClient if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -180,7 +180,7 @@ def __init__( wait_for_workers=False, ) self._pg = pg - self._manager: Optional[_Manager] = None + self._manager: Optional[ManagerServer] = None if rank == 0: if port is None: @@ -192,7 +192,7 @@ def __init__( if replica_id is None: replica_id = "" replica_id = replica_id + str(uuid.uuid4()) - self._manager = _Manager( + self._manager = ManagerServer( replica_id=replica_id, lighthouse_addr=lighthouse_addr, hostname=hostname, @@ -429,7 +429,7 @@ def wait_quorum(self) -> None: def _async_quorum( self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta ) -> None: - quorum = self._client.quorum( + quorum = self._client._quorum( rank=self._rank, step=self._step, checkpoint_metadata=self._checkpoint_transport.metadata(), @@ -498,7 +498,7 @@ def _async_quorum( primary_client = ManagerClient( recover_src_manager_address, connect_timeout=self._connect_timeout ) - checkpoint_metadata = primary_client.checkpoint_metadata( + checkpoint_metadata = primary_client._checkpoint_metadata( self._rank, timeout=self._timeout ) recover_src_rank = quorum.recover_src_rank diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 636f07f..0f69a7b 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -14,12 +14,12 @@ from parameterized import parameterized from torch import nn, optim +from torchft._torchft import LighthouseServer from torchft.ddp import DistributedDataParallel from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper from torchft.process_group import ProcessGroupGloo -from torchft.torchft import Lighthouse logger: logging.Logger = logging.getLogger(__name__) @@ -201,7 +201,7 @@ def assertElapsedLessThan( self.assertLess(elapsed, timeout, msg) def test_ddp_healthy(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) @@ -242,7 +242,7 @@ def test_ddp_healthy(self) -> None: ] ) def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) @@ -282,7 +282,7 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: self.assertEqual(failure_injectors[1].count, 1) def test_ddp_recovery_multi_rank(self) -> None: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) @@ -324,7 +324,7 @@ def test_ddp_recovery_multi_rank(self) -> None: def test_quorum_timeout(self) -> None: with ExitStack() as stack: - lighthouse = Lighthouse( + lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, ) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index d262787..05793e1 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -13,9 +13,9 @@ import torch from torch.distributed import TCPStore +from torchft._torchft import QuorumResult from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup, _DummyWork -from torchft.torchft import QuorumResult def mock_should_commit( @@ -143,7 +143,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None: quorum.max_world_size = 2 quorum.heal = False - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -180,7 +180,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: quorum.max_world_size = 2 quorum.heal = True - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum # forcible increment checkpoint server to compute correct address manager._checkpoint_transport.send_checkpoint( @@ -189,7 +189,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: state_dict=manager._manager_state_dict(), timeout=timedelta(seconds=10), ) - client_mock().checkpoint_metadata.return_value = ( + client_mock()._checkpoint_metadata.return_value = ( manager._checkpoint_transport.metadata() ) @@ -234,7 +234,7 @@ def test_quorum_heal_async_not_enough_participants( quorum.max_world_size = 1 quorum.heal = True - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum # forcible increment checkpoint server to compute correct address manager._checkpoint_transport.send_checkpoint( @@ -243,7 +243,7 @@ def test_quorum_heal_async_not_enough_participants( state_dict=manager._manager_state_dict(), timeout=timedelta(seconds=10), ) - client_mock().checkpoint_metadata.return_value = ( + client_mock()._checkpoint_metadata.return_value = ( manager._checkpoint_transport.metadata() ) @@ -296,7 +296,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: quorum.max_world_size = 1 quorum.heal = True - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum # forceable increment checkpoint server to compute correct address manager._checkpoint_transport.send_checkpoint( @@ -305,7 +305,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: state_dict=manager._manager_state_dict(), timeout=timedelta(seconds=10), ) - client_mock().checkpoint_metadata.return_value = ( + client_mock()._checkpoint_metadata.return_value = ( manager._checkpoint_transport.metadata() ) @@ -355,7 +355,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: quorum.max_world_size = 2 quorum.heal = False - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -429,7 +429,7 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: quorum.max_world_size = 3 quorum.heal = False - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -463,7 +463,7 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None: quorum.max_rank = None quorum.max_world_size = 2 quorum.heal = True - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -567,11 +567,11 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: quorum.max_world_size = 2 quorum.heal = False - client_mock().quorum.return_value = quorum + client_mock()._quorum.return_value = quorum manager.start_quorum(timeout=timedelta(seconds=12)) self.assertEqual( - client_mock().quorum.call_args.kwargs["timeout"], timedelta(seconds=12) + client_mock()._quorum.call_args.kwargs["timeout"], timedelta(seconds=12) ) self.assertTrue(manager.should_commit(timeout=timedelta(seconds=23)))