Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into bucketized_avg_localSGD
Browse files Browse the repository at this point in the history
  • Loading branch information
Krishn1412 authored Feb 24, 2025
2 parents d774938 + 5e65330 commit 2de4bdf
Show file tree
Hide file tree
Showing 14 changed files with 841 additions and 186 deletions.
4 changes: 4 additions & 0 deletions docs/source/coordination.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: torchft.coordination
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ the entire training job.
data
checkpointing
parameter_server
coordination


License
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ dev = [

[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "torchft._torchft"

[project.scripts]
torchft_lighthouse = "torchft.torchft:lighthouse_main"
torchft_lighthouse = "torchft._torchft:lighthouse_main"

[tool.isort]
multi_line_output = 3
Expand Down
77 changes: 68 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,30 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
use pyo3::prelude::*;

/// ManagerServer is a GRPC server for the manager service.
/// There should be one manager server per replica group (typically running on
/// the rank 0 host). The individual ranks within a replica group should use
/// ManagerClient to communicate with the manager server and participate in
/// quorum operations.
///
/// Args:
/// replica_id (str): The ID of the replica group.
/// lighthouse_addr (str): The HTTP address of the lighthouse server.
/// hostname (str): The hostname of the manager server.
/// bind (str): The HTTP address to bind the server to.
/// store_addr (str): The HTTP address of the store server.
/// world_size (int): The world size of the replica group.
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
#[pyclass]
struct Manager {
struct ManagerServer {
handle: JoinHandle<Result<()>>,
manager: Arc<manager::Manager>,
_runtime: Runtime,
}

#[pymethods]
impl Manager {
impl ManagerServer {
#[new]
fn new(
py: Python<'_>,
Expand Down Expand Up @@ -74,17 +89,29 @@ impl Manager {
})
}

/// address returns the address of the manager server.
///
/// Returns:
/// str: The address of the manager server.
fn address(&self) -> PyResult<String> {
Ok(self.manager.address().to_string())
}

/// shutdown shuts down the manager server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
})
}
}

/// ManagerClient is a GRPC client to the manager service.
///
/// It is used by the trainer to communicate with the ManagerServer.
///
/// Args:
/// addr (str): The HTTP address of the manager server.
/// connect_timeout (timedelta): The timeout for connecting to the manager server.
#[pyclass]
struct ManagerClient {
runtime: Runtime,
Expand All @@ -108,7 +135,7 @@ impl ManagerClient {
})
}

fn quorum(
fn _quorum(
&self,
py: Python<'_>,
rank: i64,
Expand Down Expand Up @@ -147,7 +174,7 @@ impl ManagerClient {
})
}

fn checkpoint_metadata(
fn _checkpoint_metadata(
&self,
py: Python<'_>,
rank: i64,
Expand All @@ -168,6 +195,20 @@ impl ManagerClient {
})
}

/// should_commit makes a request to the manager to determine if the trainer
/// should commit the current step. This waits until all ranks check in at
/// the specified step and will return false if any worker passes
/// ``should_commit=False``.
///
/// Args:
/// rank (int): The rank of the trainer.
/// step (int): The step of the trainer.
/// should_commit (bool): Whether the trainer should commit the current step.
/// timeout (timedelta): The timeout for the request. If the request
/// times out a TimeoutError is raised.
///
/// Returns:
/// bool: Whether the trainer should commit the current step.
fn should_commit(
&self,
py: Python<'_>,
Expand Down Expand Up @@ -263,15 +304,28 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
Ok(())
}

/// LighthouseServer is a GRPC server for the lighthouse service.
///
/// It is used to coordinate the ManagerServer for each replica group.
///
/// This entrypoint is primarily for testing and debugging purposes. The
/// ``torchft_lighthouse`` command is recommended for most use cases.
///
/// Args:
/// bind (str): The HTTP address to bind the server to.
/// min_replicas (int): The minimum number of replicas required to form a quorum.
/// join_timeout_ms (int): The timeout for joining the quorum.
/// quorum_tick_ms (int): The interval at which the quorum is checked.
/// heartbeat_timeout_ms (int): The timeout for heartbeats.
#[pyclass]
struct Lighthouse {
struct LighthouseServer {
lighthouse: Arc<lighthouse::Lighthouse>,
handle: JoinHandle<Result<()>>,
_runtime: Runtime,
}

#[pymethods]
impl Lighthouse {
impl LighthouseServer {
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
#[new]
fn new(
Expand Down Expand Up @@ -307,10 +361,15 @@ impl Lighthouse {
})
}

/// address returns the address of the lighthouse server.
///
/// Returns:
/// str: The address of the lighthouse server.
fn address(&self) -> PyResult<String> {
Ok(self.lighthouse.address().to_string())
}

/// shutdown shuts down the lighthouse server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
Expand Down Expand Up @@ -339,7 +398,7 @@ impl From<Status> for StatusError {
}

#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
let mut log = stderrlog::new();
log.verbosity(2)
Expand All @@ -353,9 +412,9 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
m.add_class::<ManagerServer>()?;
m.add_class::<ManagerClient>()?;
m.add_class::<Lighthouse>()?;
m.add_class::<LighthouseServer>()?;
m.add_class::<QuorumResult>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Expand Down
8 changes: 4 additions & 4 deletions torchft/torchft.pyi → torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ from typing import List, Optional

class ManagerClient:
def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
def quorum(
def _quorum(
self,
rank: int,
step: int,
checkpoint_metadata: str,
shrink_only: bool,
timeout: timedelta,
) -> QuorumResult: ...
def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def should_commit(
self,
rank: int,
Expand All @@ -33,7 +33,7 @@ class QuorumResult:
max_world_size: int
heal: bool

class Manager:
class ManagerServer:
def __init__(
self,
replica_id: str,
Expand All @@ -48,7 +48,7 @@ class Manager:
def address(self) -> str: ...
def shutdown(self) -> None: ...

class Lighthouse:
class LighthouseServer:
def __init__(
self,
bind: str,
Expand Down
23 changes: 23 additions & 0 deletions torchft/coordination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Coordination (Low Level API)
============================
.. warning::
As torchft is still in development, the APIs in this module are subject to change.
This module exposes low level coordination APIs to allow you to build your own
custom fault tolerance algorithms on top of torchft.
If you're looking for a more complete solution, please use the other modules in
torchft.
This provides direct access to the Lighthouse and Manager servers and clients.
"""

from torchft._torchft import LighthouseServer, ManagerClient, ManagerServer

__all__ = [
"LighthouseServer",
"ManagerServer",
"ManagerClient",
]
19 changes: 19 additions & 0 deletions torchft/coordination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import inspect
from unittest import TestCase

from torchft.coordination import LighthouseServer, ManagerClient, ManagerServer


class TestCoordination(TestCase):
def test_coordination_docs(self) -> None:
classes = [
ManagerClient,
ManagerServer,
LighthouseServer,
]
for cls in classes:
self.assertIn("Args:", str(cls.__doc__), cls)
for name, method in inspect.getmembers(cls, predicate=inspect.ismethod):
if name.startswith("_"):
continue
self.assertIn("Args:", str(cls.__doc__), cls)
8 changes: 4 additions & 4 deletions torchft/lighthouse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch.distributed as dist

from torchft import Manager, ProcessGroupGloo
from torchft.torchft import Lighthouse
from torchft._torchft import LighthouseServer


class TestLighthouse(TestCase):
def test_join_timeout_behavior(self) -> None:
"""Test that join_timeout_ms affects joining behavior"""
# To test, we create a lighthouse with 100ms and 400ms join timeouts
# and measure the time taken to validate the quorum.
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
join_timeout_ms=100,
Expand Down Expand Up @@ -52,14 +52,14 @@ def test_join_timeout_behavior(self) -> None:
if "manager" in locals():
manager.shutdown()

lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
join_timeout_ms=400,
)

def test_heartbeat_timeout_ms_sanity(self) -> None:
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=1,
heartbeat_timeout_ms=100,
Expand Down
8 changes: 4 additions & 4 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
from torch import nn, optim

from torchft._torchft import LighthouseServer
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
from torchft.process_group import ProcessGroupGloo
from torchft.torchft import Lighthouse

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -166,7 +166,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]

class ManagerIntegTest(TestCase):
def test_local_sgd_recovery(self) -> None:
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_local_sgd_recovery(self) -> None:
self.assertEqual(failure_injectors[1].count, 1)

def test_diloco_healthy(self) -> None:
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_diloco_healthy(self) -> None:
)

def test_diloco_recovery(self) -> None:
lighthouse = Lighthouse(
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
Expand Down
10 changes: 5 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
import torch
from torch.distributed import ReduceOp, TCPStore

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.futures import future_timeout
from torchft.torchft import Manager as _Manager, ManagerClient

if TYPE_CHECKING:
from torchft.process_group import ProcessGroup
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
wait_for_workers=False,
)
self._pg = pg
self._manager: Optional[_Manager] = None
self._manager: Optional[ManagerServer] = None

if rank == 0:
if port is None:
Expand All @@ -192,7 +192,7 @@ def __init__(
if replica_id is None:
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
self._manager = _Manager(
self._manager = ManagerServer(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
hostname=hostname,
Expand Down Expand Up @@ -429,7 +429,7 @@ def wait_quorum(self) -> None:
def _async_quorum(
self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta
) -> None:
quorum = self._client.quorum(
quorum = self._client._quorum(
rank=self._rank,
step=self._step,
checkpoint_metadata=self._checkpoint_transport.metadata(),
Expand Down Expand Up @@ -498,7 +498,7 @@ def _async_quorum(
primary_client = ManagerClient(
recover_src_manager_address, connect_timeout=self._connect_timeout
)
checkpoint_metadata = primary_client.checkpoint_metadata(
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
Expand Down
Loading

0 comments on commit 2de4bdf

Please sign in to comment.