diff --git a/src/lib.rs b/src/lib.rs index 5d03bf3..7fe848a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,6 +45,7 @@ impl Manager { bind: String, store_addr: String, world_size: u64, + heartbeat_interval: Duration, ) -> PyResult { py.allow_threads(move || { let runtime = Runtime::new()?; @@ -56,6 +57,7 @@ impl Manager { bind, store_addr, world_size, + heartbeat_interval, )) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); @@ -228,6 +230,7 @@ struct Lighthouse { #[pymethods] impl Lighthouse { + #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] #[new] fn new( py: Python<'_>, @@ -235,9 +238,11 @@ impl Lighthouse { min_replicas: u64, join_timeout_ms: Option, quorum_tick_ms: Option, + heartbeat_timeout_ms: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); + let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); py.allow_threads(move || { let rt = Runtime::new()?; @@ -248,6 +253,7 @@ impl Lighthouse { min_replicas: min_replicas, join_timeout_ms: join_timeout_ms, quorum_tick_ms: quorum_tick_ms, + heartbeat_timeout_ms: heartbeat_timeout_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index c22fe23..7c9cfea 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -38,6 +38,7 @@ use crate::torchftpb::{ LighthouseQuorumResponse, Quorum, QuorumMember, }; +#[derive(Clone)] struct QuorumMemberDetails { joined: Instant, member: QuorumMember, @@ -69,17 +70,39 @@ pub struct Lighthouse { #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. - #[structopt(long = "bind", default_value = "[::]:29510")] + #[structopt( + long = "bind", + default_value = "[::]:29510", + help = "Address to bind the server to" + )] pub bind: String, - #[structopt(long = "join_timeout_ms", default_value = "60000")] + #[structopt( + long = "join_timeout_ms", + default_value = "60000", + help = "How long to wait for new replicas to join before considering a quorum" + )] pub join_timeout_ms: u64, - #[structopt(long = "min_replicas")] + #[structopt( + long = "min_replicas", + help = "Minimum number of replicas to consider a quorum" + )] pub min_replicas: u64, - #[structopt(long = "quorum_tick_ms", default_value = "100")] + #[structopt( + long = "quorum_tick_ms", + default_value = "100", + help = "How frequently to check for quorum when waiting for workers." + )] pub quorum_tick_ms: u64, + + #[structopt( + long = "heartbeat_timeout_ms", + default_value = "5000", + help = "how long to wait for a heartbeat before considering a replica dead." + )] + pub heartbeat_timeout_ms: u64, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -89,56 +112,94 @@ fn quorum_changed(a: &Vec, b: &Vec) -> bool { return a_ids != b_ids; } -// Checks whether the quorum is valid and an explanation for the state. -fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) { - let mut first_joined = Instant::now(); +// Checks whether the quorum is valid, the new quorum and an explanation for the state. +fn quorum_compute( + now: Instant, + heartbeats: &HashMap, + state: &RoomState, + opt: &LighthouseOpt, +) -> (Option>, String) { + let healthy_participants: HashMap = state + .participants + .clone() + .into_iter() + .filter(|(replica_id, _details)| { + let last_heartbeat = heartbeats.get(replica_id); + if last_heartbeat.is_none() { + return false; + } - for details in state.participants.values() { - if details.joined < first_joined { - first_joined = details.joined; - } - } + now.duration_since(*last_heartbeat.unwrap()) + < Duration::from_millis(opt.heartbeat_timeout_ms) + }) + .collect(); + + let mut candidate_participants: Vec = healthy_participants + .values() + .map(|details| details.member.clone()) + .collect(); + // Sort by replica ID to get a consistent ordering across runs. + candidate_participants.sort_by_key(|p| p.replica_id.clone()); + + let metadata = format!( + "[{}/{} participants healthy]", + healthy_participants.len(), + state.participants.len() + ); + + // Check if we can use the previous quorum. if state.prev_quorum.is_some() { - let mut is_fast_quorum = true; let prev_quorum = state.prev_quorum.as_ref().unwrap(); - for prev_member in prev_quorum.participants.iter() { - if !state.participants.contains_key(&prev_member.replica_id) { - is_fast_quorum = false; - break; - } - } + // Fast quorum is when all previous participants are still in the quorum + // and we have enough participants to form a quorum. + let is_fast_quorum = prev_quorum + .participants + .iter() + .all(|prev_member| healthy_participants.contains_key(&prev_member.replica_id)); if is_fast_quorum { - return (is_fast_quorum, format!("Fast quorum found!")); + return ( + Some(candidate_participants), + format!("Fast quorum found! {}", metadata), + ); } } - if state.participants.len() < opt.min_replicas as usize { + if healthy_participants.len() < opt.min_replicas as usize { return ( - false, + None, format!( - "No quorum, only have {} participants, need {}", - state.participants.len(), - opt.min_replicas + "No quorum, only have {} participants, need {} {}", + healthy_participants.len(), + opt.min_replicas, + metadata ), ); } // Quorum is valid at this point but lets wait for stragglers. - - if Instant::now().duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { + let first_joined = healthy_participants + .values() + .map(|details| details.joined) + .min() + .unwrap_or(now); + if now.duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { return ( - false, + None, format!( - "Valid quorum with {} participants, waiting for stragglers due to join timeout", - state.participants.len() + "Valid quorum with {} participants, waiting for stragglers due to join timeout {}", + healthy_participants.len(), + metadata ), ); } - (true, format!("Valid quorum found")) + ( + Some(candidate_participants), + format!("Valid quorum found {}", metadata), + ) } impl Lighthouse { @@ -155,19 +216,16 @@ impl Lighthouse { })) } - fn _quorum_tick(self: Arc, state: &mut RoomState) -> Result<()> { - let (quorum_met, reason) = quorum_valid(state, &self.opt); + fn _quorum_tick( + self: Arc, + heartbeats: &HashMap, + state: &mut RoomState, + ) -> Result<()> { + let (quorum_met, reason) = quorum_compute(Instant::now(), heartbeats, state, &self.opt); info!("{}: {}", state.room_id, reason); - if quorum_met { - let mut participants: Vec = state - .participants - .values() - .map(|details| details.member.clone()) - .collect(); - - // Sort by replica ID to get a consistent ordering across runs. - participants.sort_by_key(|p| p.replica_id.clone()); + if quorum_met.is_some() { + let participants = quorum_met.unwrap(); // only increment quorum ID if something about the quorum // changed (members/addresses/etc) @@ -206,8 +264,9 @@ impl Lighthouse { loop { { let mut state = self.state.lock().await; + let heartbeats = state.heartbeats.clone(); for (_room_id, room) in &mut state.rooms { - self.clone()._quorum_tick(room)?; + self.clone()._quorum_tick(&heartbeats, room)?; } } @@ -284,7 +343,8 @@ impl Lighthouse { .rooms .iter() .map(|(room_id, room)| { - let (_, quorum_status) = quorum_valid(&room, &self.opt); + let (_, quorum_status) = + quorum_compute(Instant::now(), &state.heartbeats, &room, &self.opt); let max_step = { if let Some(quorum) = room.prev_quorum.clone() { @@ -314,7 +374,7 @@ impl Lighthouse { rooms: rooms, heartbeats: state.heartbeats.clone(), old_age_threshold: Instant::now() - .checked_sub(Duration::from_secs(1)) + .checked_sub(Duration::from_millis(self.opt.heartbeat_timeout_ms)) .unwrap_or(Instant::now()), } }; @@ -367,6 +427,13 @@ impl LighthouseService for Arc { let mut rx = { let mut state = self.state.lock().await; + // implicit heartbeat + state + .heartbeats + .insert(requester.replica_id.clone(), Instant::now()); + + let heartbeats = state.heartbeats.clone(); + if !state.rooms.contains_key(&room_id) { let (tx, _) = broadcast::channel(16); @@ -395,7 +462,7 @@ impl LighthouseService for Arc { // proactively run quorum tick self.clone() - ._quorum_tick(room) + ._quorum_tick(&heartbeats, room) .map_err(|e| Status::from_error(e.into()))?; rx @@ -497,6 +564,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let mut state = RoomState { @@ -506,13 +574,16 @@ mod tests { prev_quorum: None, quorum_id: 0, }; + let mut heartbeats = HashMap::new(); - assert!(!quorum_valid(&state, &opt).0); + let now = Instant::now(); + + assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); state.participants.insert( "a".to_string(), QuorumMemberDetails { - joined: Instant::now(), + joined: now, member: QuorumMember { replica_id: "a".to_string(), address: "".to_string(), @@ -522,13 +593,82 @@ mod tests { }, }, ); + heartbeats.insert("a".to_string(), now); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); state.participants.get_mut("a").unwrap().joined = - Instant::now().sub(Duration::from_secs(10 * 60 * 60)); + now.sub(Duration::from_secs(10 * 60 * 60)); + + assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + + Ok(()) + } + + #[tokio::test] + async fn test_quorum_heartbeats() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + }; + + let mut state = RoomState { + room_id: "test".to_string(), + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + }; + let mut heartbeats = HashMap::new(); + + let now = Instant::now(); + + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + }, + }, + ); + heartbeats.insert("a".to_string(), now); + + assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + + // expired heartbeat + heartbeats.insert("a".to_string(), now.sub(Duration::from_secs(10))); + + let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + assert!(quorum_met.is_none(), "{}", reason); + + // 1 healthy, 1 expired + state.participants.insert( + "b".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "b".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + }, + }, + ); + heartbeats.insert("b".to_string(), now); - assert!(quorum_valid(&state, &opt).0); + let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + let participants = quorum_met.unwrap(); + assert!(participants.len() == 1); Ok(()) } @@ -540,6 +680,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let mut state = RoomState { @@ -549,13 +690,16 @@ mod tests { prev_quorum: None, quorum_id: 0, }; + let mut heartbeats = HashMap::new(); + + let now = Instant::now(); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); state.participants.insert( "a".to_string(), QuorumMemberDetails { - joined: Instant::now(), + joined: now, member: QuorumMember { replica_id: "a".to_string(), address: "".to_string(), @@ -565,8 +709,9 @@ mod tests { }, }, ); + heartbeats.insert("a".to_string(), now); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); state.prev_quorum = Some(Quorum { quorum_id: 1, @@ -580,7 +725,28 @@ mod tests { created: Some(SystemTime::now().into()), }); - assert!(quorum_valid(&state, &opt).0); + assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + + // test expanding quorum w/ fast quorum + state.participants.insert( + "b".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "b".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + }, + }, + ); + heartbeats.insert("b".to_string(), now); + + let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + let participants = quorum_met.unwrap(); + assert!(participants.len() == 2); Ok(()) } @@ -592,6 +758,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 1, quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let lighthouse = Lighthouse::new(opt).await?; diff --git a/src/manager.rs b/src/manager.rs index 4f37cdf..eaf521c 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -58,6 +58,7 @@ pub struct Manager { state: Mutex, listener: Mutex>, local_addr: SocketAddr, + heartbeat_interval: Duration, } pub async fn manager_client_new( @@ -83,6 +84,7 @@ impl Manager { bind: String, store_addr: String, world_size: u64, + heartbeat_interval: Duration, ) -> Result> { let listener = tokio::net::TcpListener::bind(&bind).await?; let local_addr = listener.local_addr()?; @@ -95,6 +97,7 @@ impl Manager { hostname: hostname, store_address: store_addr, world_size: world_size, + heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_servers: HashMap::new(), rooms: HashMap::new(), @@ -152,7 +155,7 @@ impl Manager { let _response = client.heartbeat(request).await; - sleep(Duration::from_millis(100)).await; + sleep(self.heartbeat_interval).await; } } @@ -426,7 +429,8 @@ mod tests { "addr".to_string(), "[::]:29531".to_string(), "store_addr".to_string(), - 2, + 2, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager._run_grpc()); @@ -459,6 +463,7 @@ mod tests { join_timeout_ms: 100, min_replicas: 1, quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -469,7 +474,8 @@ mod tests { "localhost".to_string(), "[::]:0".to_string(), "store_addr".to_string(), - 1, // world size + 1, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); @@ -508,6 +514,7 @@ mod tests { join_timeout_ms: 100, min_replicas: 2, quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -524,7 +531,8 @@ mod tests { "localhost".to_string(), "[::]:0".to_string(), "store_addr".to_string(), - 1, // world size + 1, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index f6efc32..36ab62c 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -91,3 +91,11 @@ def test_join_timeout_behavior(self) -> None: lighthouse.shutdown() if "manager" in locals(): manager.shutdown() + + def test_heartbeat_timeout_ms_sanity(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=1, + heartbeat_timeout_ms=100, + ) + lighthouse.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index 58fbdf3..c75bc4c 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -101,6 +101,7 @@ def __init__( replica_id: Optional[str] = None, port: Optional[int] = None, hostname: str = socket.gethostname(), + heartbeat_interval: timedelta = timedelta(milliseconds=100), ) -> None: """ Args: @@ -177,6 +178,7 @@ def _manager_state_dict() -> Dict[str, T]: bind=bind, store_addr=f"{store_addr}:{store_port}", world_size=world_size, + heartbeat_interval=heartbeat_interval, ) self._store.set(MANAGER_ADDR_KEY, self._manager.address()) @@ -443,6 +445,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> N self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. + # TODO: handle configure errors self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index 1dcb09d..07365aa 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -31,6 +31,7 @@ class Manager: bind: str, store_addr: str, world_size: int, + heartbeat_interval: timedelta, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... @@ -42,6 +43,7 @@ class Lighthouse: min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None, + heartbeat_timeout_ms: Optional[int] = None, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ...