diff --git a/Cargo.toml b/Cargo.toml index 0d2c44f..e8d6b8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ log = "0.4.22" prost = "0.13.3" prost-types = "0.13.3" pyo3 = {version="0.22.3", features = ["extension-module"]} +rand = "0.8.5" slog = "2.7.0" slog-stdlog = "4.1.1" stderrlog = "0.6.0" diff --git a/src/lib.rs b/src/lib.rs index 8923682..bfbae26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,9 @@ pub mod lighthouse; pub mod manager; +mod net; +mod retry; +mod timeout; use core::time::Duration; use std::env; @@ -46,6 +49,7 @@ impl Manager { store_addr: String, world_size: u64, heartbeat_interval: Duration, + connect_timeout: Duration, ) -> PyResult { py.allow_threads(move || { let runtime = Runtime::new()?; @@ -58,6 +62,7 @@ impl Manager { store_addr, world_size, heartbeat_interval, + connect_timeout, )) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); @@ -84,28 +89,25 @@ impl Manager { struct ManagerClient { runtime: Runtime, client: ManagerServiceClient, - timeout: Duration, } #[pymethods] impl ManagerClient { #[new] - fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult { + fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { py.allow_threads(move || { let runtime = Runtime::new()?; let client = runtime - .block_on(manager::manager_client_new(addr, timeout)) + .block_on(manager::manager_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; Ok(Self { runtime: runtime, client: client, - timeout: timeout, }) }) } - #[pyo3(signature = (rank, step, checkpoint_server_addr, shrink_only, timeout=None))] fn quorum( &mut self, py: Python<'_>, @@ -113,7 +115,7 @@ impl ManagerClient { step: i64, checkpoint_server_addr: String, shrink_only: bool, - timeout: Option, + timeout: Duration, ) -> Result<(i64, i64, i64, String, String, i64, Option, i64, bool), StatusError> { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { @@ -122,9 +124,10 @@ impl ManagerClient { checkpoint_server_addr: checkpoint_server_addr, shrink_only: shrink_only, }); - // This notifies the server about the timeout but doesn't affect the - // endpoint timeout which we set on client creation. - request.set_timeout(timeout.unwrap_or(self.timeout)); + + // This timeout is processed on the server side so we also enable + // keep alives to detect server health. + request.set_timeout(timeout); let response = self.runtime.block_on(self.client.quorum(request))?; let resp = response.into_inner(); @@ -142,18 +145,18 @@ impl ManagerClient { }) } - #[pyo3(signature = (rank, timeout=None))] fn checkpoint_address( &mut self, py: Python<'_>, rank: i64, - timeout: Option, + timeout: Duration, ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank }); - // This notifies the server about the timeout but doesn't affect the - // endpoint timeout which we set on client creation. - request.set_timeout(timeout.unwrap_or(self.timeout)); + + // This timeout is processed on the server side so we also enable + // keep alives to detect server health. + request.set_timeout(timeout); let response = self .runtime @@ -163,14 +166,13 @@ impl ManagerClient { }) } - #[pyo3(signature = (rank, step, should_commit, timeout=None))] fn should_commit( &mut self, py: Python<'_>, rank: i64, step: i64, should_commit: bool, - timeout: Option, + timeout: Duration, ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ShouldCommitRequest { @@ -178,9 +180,10 @@ impl ManagerClient { step: step, should_commit: should_commit, }); + // This notifies the server about the timeout but doesn't affect the // endpoint timeout which we set on client creation. - request.set_timeout(timeout.unwrap_or(self.timeout)); + request.set_timeout(timeout); let response = self.runtime.block_on(self.client.should_commit(request))?; let resp = response.into_inner(); diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 420bb3c..bdfdccb 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -543,15 +543,13 @@ mod tests { use super::*; use std::ops::Sub; - use tonic::transport::{Channel, Endpoint}; + use tonic::transport::Channel; + use crate::net::connect; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; async fn lighthouse_client_new(addr: String) -> Result> { - let conn = Endpoint::new(addr)? - .connect_timeout(Duration::from_secs(10)) - .connect() - .await?; + let conn = connect(addr, Duration::from_secs(10)).await?; Ok(LighthouseServiceClient::new(conn)) } diff --git a/src/manager.rs b/src/manager.rs index 76efb58..982500a 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -16,10 +16,12 @@ use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::sleep; use tonic::transport::server::TcpIncoming; +use tonic::transport::Channel; use tonic::transport::Server; -use tonic::transport::{Channel, Endpoint}; use tonic::{Request, Response, Status}; +use crate::net::connect; +use crate::timeout::try_parse_grpc_timeout; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ @@ -47,7 +49,6 @@ struct ManagerState { pub struct Manager { replica_id: String, - lighthouse_addr: String, hostname: String, store_address: String, world_size: u64, @@ -55,23 +56,27 @@ pub struct Manager { listener: Mutex>, local_addr: SocketAddr, heartbeat_interval: Duration, + lighthouse_client: LighthouseServiceClient, } pub async fn manager_client_new( addr: String, - timeout: Duration, + connect_timeout: Duration, ) -> Result> { - // TODO add retries + backoff so other nodes can start before the rank0 comes up - info!("ManagerClient: establishing connection to {}", &addr); - let conn = Endpoint::new(addr.clone())? - .timeout(timeout) - .connect_timeout(Duration::from_secs(60)) - .connect() - .await?; + let conn = connect(addr, connect_timeout).await?; Ok(ManagerServiceClient::new(conn)) } +pub async fn lighthouse_client_new( + addr: String, + connect_timeout: Duration, +) -> Result> { + info!("LighthouseClient: establishing connection to {}", &addr); + let conn = connect(addr, connect_timeout).await?; + Ok(LighthouseServiceClient::new(conn)) +} + impl Manager { pub async fn new( replica_id: String, @@ -81,6 +86,7 @@ impl Manager { store_addr: String, world_size: u64, heartbeat_interval: Duration, + connect_timeout: Duration, ) -> Result> { let listener = tokio::net::TcpListener::bind(&bind).await?; let local_addr = listener.local_addr()?; @@ -88,9 +94,11 @@ impl Manager { let (should_commit_tx, _) = broadcast::channel(16); let (tx, _) = broadcast::channel(16); + let client = lighthouse_client_new(lighthouse_addr.clone(), connect_timeout).await?; + Ok(Arc::new(Self { replica_id: replica_id, - lighthouse_addr: lighthouse_addr, + lighthouse_client: client, hostname: hostname, store_address: store_addr, world_size: world_size, @@ -145,7 +153,7 @@ impl Manager { } async fn _run_heartbeat(self: Arc) -> Result<()> { - let mut client = self.lighthouse_client_new().await?; + let mut client = self.lighthouse_client.clone(); loop { let request = tonic::Request::new(LighthouseHeartbeatRequest { replica_id: self.replica_id.clone(), @@ -157,17 +165,49 @@ impl Manager { } } - async fn lighthouse_client_new(&self) -> Result> { - info!( - "Manager: connecting to lighthouse at {}", - &self.lighthouse_addr - ); + async fn _run_quorum( + &self, + state: &mut ManagerState, + requester: QuorumMember, + timeout: Duration, + ) -> Result<(), Status> { + if (state.participants.len() as u64) < self.world_size { + return Ok(()); + } - let conn = Endpoint::new(self.lighthouse_addr.clone())? - .connect_timeout(Duration::from_secs(60)) - .connect() - .await?; - Ok(LighthouseServiceClient::new(conn)) + state.participants.clear(); + info!("all workers joined -- starting quorum"); + + // TODO: don't hold the lock during quorum + + let mut client = self.lighthouse_client.clone(); + + let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest { + requester: Some(requester), + }); + lighthouse_request.set_timeout(timeout); + + let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request)) + .await + .unwrap_or_else(|e| { + Err(Status::cancelled(format!( + "lighthouse quorum timed out: {}", + e.to_string() + ))) + })?; + let resp = response.into_inner(); + + info!("got lighthouse quorum {:?}", resp); + + state + .channel + .send( + resp.quorum + .ok_or_else(|| Status::internal("missing quorum"))?, + ) + .map_err(|e| Status::from_error(e.into()))?; + + Ok(()) } } @@ -182,6 +222,15 @@ impl ManagerService for Arc { info!("got quorum request for rank {}", rank); + let timeout = try_parse_grpc_timeout(&request.metadata()) + .map_err(|e| { + Status::invalid_argument(format!( + "invalid timeout {}", + e.to_str().unwrap_or("invalid") + )) + })? + .ok_or_else(|| Status::invalid_argument("missing timeout"))?; + let mut rx = { let mut state = self.state.lock().await; @@ -195,50 +244,19 @@ impl ManagerService for Arc { state.participants.insert(rank); let rx = state.channel.subscribe(); - if state.participants.len() as u64 >= self.world_size { - state.participants.clear(); - info!("all workers joined -- starting quorum"); - - // TODO: don't hold the lock during quorum - - let mut client = self - .lighthouse_client_new() - .await - .map_err(|e| Status::from_error(e.into()))?; - - let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest { - requester: Some(QuorumMember { - replica_id: self.replica_id.clone(), - address: self.address(), - store_address: self.store_address.clone(), - step: req.step, - world_size: self.world_size, - shrink_only: req.shrink_only, - }), - }); - - // propagate timeout from request to lighthouse - let timeout = request - .metadata() - .get("grpc-timeout") - .ok_or_else(|| Status::internal("grpc-timeout not set"))?; - lighthouse_request - .metadata_mut() - .insert("grpc-timeout", timeout.clone()); - - let response = client.quorum(lighthouse_request).await.unwrap(); - let resp = response.into_inner(); - - info!("got lighthouse quorum {:?}", resp); - - state - .channel - .send( - resp.quorum - .ok_or_else(|| Status::internal("missing quorum"))?, - ) - .map_err(|e| Status::from_error(e.into()))?; - } + self._run_quorum( + &mut state, + QuorumMember { + replica_id: self.replica_id.clone(), + address: self.address(), + store_address: self.store_address.clone(), + step: req.step, + world_size: self.world_size, + shrink_only: req.shrink_only, + }, + timeout, + ) + .await?; rx }; @@ -413,14 +431,25 @@ mod tests { #[tokio::test] async fn test_should_commit() -> Result<()> { + let lighthouse = Lighthouse::new(LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }) + .await?; + let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); + let manager = Manager::new( "rep_id".to_string(), - "lighthouse".to_string(), + lighthouse.address(), "addr".to_string(), "[::]:29531".to_string(), "store_addr".to_string(), 2, // world size Duration::from_millis(100), // heartbeat interval + Duration::from_secs(10), // connect timeout ) .await?; let manager_fut = tokio::spawn(manager._run_grpc()); @@ -442,6 +471,7 @@ mod tests { assert!(!resp_b.should_commit); manager_fut.abort(); + lighthouse_fut.abort(); Ok(()) } @@ -466,6 +496,7 @@ mod tests { "store_addr".to_string(), 1, // world size Duration::from_millis(100), // heartbeat interval + Duration::from_secs(10), // connect timeout ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); @@ -523,6 +554,7 @@ mod tests { "store_addr".to_string(), 1, // world size Duration::from_millis(100), // heartbeat interval + Duration::from_secs(10), // connect timeout ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..e6d9b69 --- /dev/null +++ b/src/net.rs @@ -0,0 +1,34 @@ +use std::time::Duration; + +use anyhow::Result; +use tonic::transport::{Channel, Endpoint}; + +use crate::retry::{retry_backoff, ExponentialBackoff}; + +pub async fn connect_once(addr: String, connect_timeout: Duration) -> Result { + let conn = Endpoint::new(addr)? + .connect_timeout(connect_timeout) + // Enable HTTP2 keep alives + .http2_keep_alive_interval(Duration::from_secs(60)) + // Time taken for server to respond. 20s is default for GRPC. + .keep_alive_timeout(Duration::from_secs(20)) + // Enable alive for idle connections. + .keep_alive_while_idle(true) + .connect() + .await?; + Ok(conn) +} + +pub async fn connect(addr: String, connect_timeout: Duration) -> Result { + retry_backoff( + ExponentialBackoff { + initial_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(10), + timeout: connect_timeout, + factor: 1.5, + max_jitter: Duration::from_millis(100), + }, + || Box::pin(connect_once(addr.clone(), connect_timeout)), + ) + .await +} diff --git a/src/retry.rs b/src/retry.rs new file mode 100644 index 0000000..00f2f20 --- /dev/null +++ b/src/retry.rs @@ -0,0 +1,100 @@ +use anyhow::Result; +use std::future::Future; +use std::pin::Pin; +use std::time::{Duration, Instant}; + +pub struct ExponentialBackoff { + pub initial_backoff: Duration, + pub max_backoff: Duration, + pub timeout: Duration, + pub factor: f64, + pub max_jitter: Duration, +} + +pub async fn retry_backoff(policy: ExponentialBackoff, f: F) -> Result +where + F: Fn() -> Pin> + Send>>, + R: Send, +{ + assert!(policy.initial_backoff > Duration::from_millis(0)); + assert!(policy.factor > 1.0); + let mut backoff = policy.initial_backoff; + + let deadline = Instant::now() + policy.timeout; + + loop { + match f().await { + Ok(v) => return Ok(v), + Err(e) => { + if Instant::now() > deadline { + return Err(e); + } + let jitter = policy.max_jitter.mul_f64(rand::random::()); + tokio::time::sleep(backoff + jitter).await; + backoff = backoff.mul_f64(policy.factor); + if backoff > policy.max_backoff { + backoff = policy.max_backoff; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use std::sync::Mutex; + + #[tokio::test] + async fn test_retry_backoff() -> Result<()> { + let count = Arc::new(Mutex::new(0)); + let result = retry_backoff( + ExponentialBackoff { + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(100), + timeout: Duration::from_secs(1000), + factor: 2.0, + max_jitter: Duration::from_millis(1), + }, + || { + let current_count = { + let mut count = count.lock().unwrap(); + *count += 1; + *count + }; + + Box::pin(async move { + if current_count < 3 { + Err(anyhow::anyhow!("test")) + } else { + Ok(1234) + } + }) + }, + ) + .await?; + assert!(result == 1234); + let count = *count.lock().unwrap(); + assert!(count == 3, "count: {}", count); + Ok(()) + } + + #[tokio::test] + async fn test_retry_backoff_timeout() -> Result<()> { + let result: Result<()> = retry_backoff( + ExponentialBackoff { + initial_backoff: Duration::from_millis(1), + max_backoff: Duration::from_millis(100), + timeout: Duration::from_millis(1), + factor: 2.0, + max_jitter: Duration::from_millis(1), + }, + || Box::pin(async { Err(anyhow::anyhow!("test")) }), + ) + .await; + + assert!(result.is_err()); + Ok(()) + } +} diff --git a/src/timeout.rs b/src/timeout.rs new file mode 100644 index 0000000..d966e88 --- /dev/null +++ b/src/timeout.rs @@ -0,0 +1,87 @@ +use std::time::Duration; + +use anyhow::Result; +use tonic::metadata::{Ascii, MetadataMap, MetadataValue}; + +const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; +const SECONDS_IN_HOUR: u64 = 60 * 60; +const SECONDS_IN_MINUTE: u64 = 60; + +/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns +/// the value we attempted to parse. +/// +/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). +/// +/// From https://github.com/hyperium/tonic/blob/79a06cc8067818ec53bae76ab717063683bb0acb/tonic/src/transport/service/grpc_timeout.rs#L106 +/// Copyright (c) 2020 Lucio Franco MIT License +/// https://github.com/hyperium/tonic/blob/master/LICENSE +pub fn try_parse_grpc_timeout( + headers: &MetadataMap, +) -> Result, &MetadataValue> { + let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else { + return Ok(None); + }; + + let (timeout_value, timeout_unit) = val + .to_str() + .map_err(|_| val) + .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? + // `MetadataValue::to_str` only returns `Ok` if the header contains ASCII so this + // `split_at` will never panic from trying to split in the middle of a character. + // See https://docs.rs/http/0.2.4/http/header/struct.MetadataValue.html#method.to_str + // + // `len - 1` also wont panic since we just checked `s.is_empty`. + .split_at(val.len() - 1); + + // gRPC spec specifies `TimeoutValue` will be at most 8 digits + // Caping this at 8 digits also prevents integer overflow from ever occurring + if timeout_value.len() > 8 { + return Err(val); + } + + let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; + + let duration = match timeout_unit { + // Hours + "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), + // Minutes + "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), + // Seconds + "S" => Duration::from_secs(timeout_value), + // Milliseconds + "m" => Duration::from_millis(timeout_value), + // Microseconds + "u" => Duration::from_micros(timeout_value), + // Nanoseconds + "n" => Duration::from_nanos(timeout_value), + _ => return Err(val), + }; + + Ok(Some(duration)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parsing() { + fn map(val: &str) -> MetadataMap { + let mut map = MetadataMap::new(); + map.insert(GRPC_TIMEOUT_HEADER, val.parse().unwrap()); + map + } + + assert!(try_parse_grpc_timeout(&map("3H")).unwrap() == Some(Duration::from_secs(3 * 3600))); + assert!(try_parse_grpc_timeout(&map("3M")).unwrap() == Some(Duration::from_secs(3 * 60))); + assert!(try_parse_grpc_timeout(&map("3S")).unwrap() == Some(Duration::from_secs(3))); + assert!(try_parse_grpc_timeout(&map("3m")).unwrap() == Some(Duration::from_millis(3))); + assert!(try_parse_grpc_timeout(&map("3u")).unwrap() == Some(Duration::from_micros(3))); + assert!(try_parse_grpc_timeout(&map("3n")).unwrap() == Some(Duration::from_nanos(3))); + + assert!(try_parse_grpc_timeout(&MetadataMap::new()) + .unwrap() + .is_none()); + assert!(try_parse_grpc_timeout(&map("")).is_err()); + } +} diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index c358d51..aaad843 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -16,6 +16,7 @@ import socket import threading import urllib.request +from datetime import timedelta from http.server import BaseHTTPRequestHandler from typing import Callable, Generic, TypeVar @@ -40,36 +41,48 @@ class CheckpointServer(Generic[T]): state_dict: a callable that returns the state dict to be transferred """ - def __init__(self, state_dict: Callable[[], T]) -> None: + def __init__(self, state_dict: Callable[[], T], timeout: timedelta) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 + self._timeout = timeout ckpt_server = self class RequestHandler(BaseHTTPRequestHandler): - def do_GET(self): - with ckpt_server._checkpoint_lock: - step = ckpt_server._step + # set request socket timeout to avoid hanging forever + timeout = self._timeout.total_seconds() - if self.path != f"/checkpoint/{step}": - self.send_response(400) - self.send_header("Content-type", "text/plain") + def do_GET(self): + try: + # validate socket timeout is actually set + assert self.connection.gettimeout() == self.timeout + + with ckpt_server._checkpoint_lock: + step = ckpt_server._step + + if self.path != f"/checkpoint/{step}": + self.send_response(400) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.err( + f"invalid checkpoint requested, serving {step} but got {self.path}" + ) + return + + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") self.end_headers() - self.err( - f"invalid checkpoint requested, serving {step} but got {self.path}" - ) - return - - self.send_response(200) - self.send_header( - "Content-type", "tensor" - ) # TODO: correct mime type - self.end_headers() - sd = state_dict() + sd = state_dict() - torch.save(sd, self.wfile) + torch.save(sd, self.wfile) + except Exception as e: + logger.exception( + f"Exception in checkpoint server when handling {self.path=}: {e}", + ) + self.send_response(500, str(e)) + self.end_headers() def err(self, msg: str) -> None: logger.error(msg) @@ -87,7 +100,7 @@ def err(self, msg: str) -> None: self._thread.start() @classmethod - def load_from_address(cls, address: str) -> T: + def load_from_address(cls, address: str, timeout: timedelta) -> T: """ Loads a checkpoint from the given address. @@ -96,7 +109,7 @@ def load_from_address(cls, address: str) -> T: """ logger.info(f"fetching checkpoint from {address}") - with urllib.request.urlopen(address) as f: + with urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f: data = f.read() reader = io.BytesIO(data) diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index f27392b..983c429 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import urllib.error +from datetime import timedelta from unittest import TestCase from unittest.mock import MagicMock @@ -16,20 +17,27 @@ def test_checkpoint_server(self) -> None: expected = {"state": "dict"} state_dict_fn = MagicMock() state_dict_fn.return_value = expected - server = CheckpointServer(state_dict=state_dict_fn) + server = CheckpointServer( + state_dict=state_dict_fn, + timeout=timedelta(seconds=10), + ) server.disallow_checkpoint() server.allow_checkpoint(1234) addr = server.address() - out = CheckpointServer.load_from_address(addr) + out = CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) self.assertEqual(out, expected) + # test timeout + with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"): + CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=0.0)) + # test mismatch case server.allow_checkpoint(2345) with self.assertRaisesRegex(urllib.error.HTTPError, r"Error 400"): - CheckpointServer.load_from_address(addr) + CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) server.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index 9bc4ba1..d9ff366 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -92,6 +92,8 @@ def __init__( min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = timedelta(seconds=60), + quorum_timeout: timedelta = timedelta(seconds=60), + connect_timeout: timedelta = timedelta(seconds=60), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, @@ -114,10 +116,22 @@ def __init__( 2. TORCHFT_MANAGER_PORT env var 3. arbitrary port assigned via 0 use_async_quorum: whether to run the quorum asynchronously during the forward pass - timeout: - the default timeout for all operation, if you're using per - request timeouts this should be longer than the longest request - timeout. + timeout: the default timeout for all operations + Included: + * collectives such as allreduce + * should_commit rpc + * checkpoint_address rpc + * checkpoint HTTP operations + * wrap_future + quorum_timeout: the default timeout to wait for the quorum to complete. + This generally should be longer than the training step time / + the interval between quorum checks to avoid any split brain + issues. + + For LocalSGD/DiLoCo this may need to be set to ~1h or longer + depending on how frequently the syncs occur. + connect_timeout: the timeout used for establishing rpc connections + to ManagerServer and Lighthouse rank: the replica group local rank world_size: the replica group local world size store_addr: TCPStore address for this replica group @@ -131,6 +145,8 @@ def __init__( self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout + self._quorum_timeout = quorum_timeout + self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode store_addr = store_addr or os.environ["MASTER_ADDR"] @@ -146,7 +162,10 @@ def _manager_state_dict() -> Dict[str, T]: "torchft": cast(T, self.state_dict()), } - self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict) + self._ckpt_server = CheckpointServer[Dict[str, T]]( + _manager_state_dict, + timeout=timeout, + ) self._executor = ThreadPoolExecutor( max_workers=1, thread_name_prefix="async_quorum" ) @@ -179,13 +198,14 @@ def _manager_state_dict() -> Dict[str, T]: store_addr=f"{store_addr}:{store_port}", world_size=world_size, heartbeat_interval=heartbeat_interval, + connect_timeout=connect_timeout, ) self._store.set(MANAGER_ADDR_KEY, self._manager.address()) self._store.set(REPLICA_ID_KEY, replica_id) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") - self._client = ManagerClient(addr, timeout=timeout) + self._client = ManagerClient(addr, connect_timeout=connect_timeout) replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") self._logger = _ManagerLogger( @@ -356,7 +376,8 @@ def start_quorum( If allow_heal is set, the manager will attempt to heal either synchronously before returning or asynchronously prior to any network calls. All replicas must pass the same value to allow_heal. - timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used + timeout: the timeout for quorum to be ready, if None, the manager's timeout will be used + recovery operations will use the manager timeout """ # wait for previous quorum to complete @@ -374,7 +395,7 @@ def start_quorum( self._async_quorum, allow_heal=allow_heal, shrink_only=shrink_only, - timeout=timeout or self._timeout, + quorum_timeout=timeout or self._quorum_timeout, ) if not self._use_async_quorum: self.wait_quorum() @@ -399,7 +420,7 @@ def wait_quorum(self) -> None: self._quorum_future.result() def _async_quorum( - self, allow_heal: bool, shrink_only: bool, timeout: timedelta + self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta ) -> None: ( quorum_id, @@ -416,7 +437,7 @@ def _async_quorum( step=self._step, checkpoint_server_addr=self._ckpt_server.address(), shrink_only=shrink_only, - timeout=timeout, + timeout=quorum_timeout, ) # When using async quorum we need to take the recovered workers. @@ -455,14 +476,16 @@ def _async_quorum( self._logger.info( f"healing required, fetching checkpoint server address from {address=} {max_step=}" ) - primary_client = ManagerClient(address, timeout=timeout) + primary_client = ManagerClient( + address, connect_timeout=self._connect_timeout + ) checkpoint_server_address = primary_client.checkpoint_address( - self._rank, timeout=timeout + self._rank, timeout=self._timeout ) self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}") self._pending_state_dict = CheckpointServer.load_from_address( - checkpoint_server_address + checkpoint_server_address, timeout=self._timeout ) self.load_state_dict(self._pending_state_dict["torchft"]) # we apply the user state dict only when safe from the main thread diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 6a849e7..e86c7e0 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -133,7 +133,9 @@ def test_gloo_timeout(self) -> None: store_addr = f"localhost:{store.port}/prefix" pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01)) - with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"): + with self.assertRaisesRegex( + RuntimeError, "(timeout after 10ms|Socket Timeout)" + ): pg.configure(store_addr, 0, 2) # pyre-fixme[56]: Pyre was not able to infer the type of argument diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index a694920..fbd0293 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -2,24 +2,22 @@ from datetime import timedelta from typing import Optional, Tuple class ManagerClient: - def __init__(self, addr: str, timeout: timedelta) -> None: ... + def __init__(self, addr: str, connect_timeout: timedelta) -> None: ... def quorum( self, rank: int, step: int, checkpoint_server_addr: str, shrink_only: bool, - timeout: Optional[timedelta] = None, + timeout: timedelta, ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... - def checkpoint_address( - self, rank: int, timeout: Optional[timedelta] = None - ) -> str: ... + def checkpoint_address(self, rank: int, timeout: timedelta) -> str: ... def should_commit( self, rank: int, step: int, should_commit: bool, - timeout: Optional[timedelta] = None, + timeout: timedelta, ) -> bool: ... class Manager: @@ -32,6 +30,7 @@ class Manager: store_addr: str, world_size: int, heartbeat_interval: timedelta, + connect_timeout: timedelta, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ...