From 58fe0813032690882f12fc9a1576b7f03b894e10 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 13 Jan 2025 13:47:06 -0800 Subject: [PATCH] lighthouse, manager: remove room support --- proto/torchft.proto | 16 +--- src/lib.rs | 4 +- src/lighthouse.rs | 193 ++++++++++++++++-------------------------- src/manager.rs | 53 ++++-------- templates/status.html | 13 +-- torchft/manager.py | 7 +- torchft/torchft.pyi | 1 - 7 files changed, 101 insertions(+), 186 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 3bf03b3..e84855c 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -50,11 +50,7 @@ message Quorum { } message LighthouseQuorumRequest { - // room_id is the specific quorum channel to use. All workers/replicas - // participating in the quorum must specify the same channel. - // Multiple channels can be active simultaneously. - string room_id = 1; - QuorumMember requester = 2; + QuorumMember requester = 1; } message LighthouseQuorumResponse { @@ -73,13 +69,9 @@ service LighthouseService { } message ManagerQuorumRequest { - // room_id is the specific quorum channel to use. All workers/replicas - // participating in the quorum must specify the same channel. - // Multiple channels can be active simultaneously. - string room_id = 1; - int64 rank = 2; - int64 step = 3; - string checkpoint_server_addr = 4; + int64 rank = 1; + int64 step = 2; + string checkpoint_server_addr = 3; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 7fe848a..199d4cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,11 +105,10 @@ impl ManagerClient { }) } - #[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))] + #[pyo3(signature = (rank, step, checkpoint_server_addr, timeout=None))] fn quorum( &mut self, py: Python<'_>, - room_id: String, rank: i64, step: i64, checkpoint_server_addr: String, @@ -117,7 +116,6 @@ impl ManagerClient { ) -> Result<(i64, i64, i64, String, String, i64, Option, i64, bool), StatusError> { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { - room_id: room_id, rank: rank, step: step, checkpoint_server_addr: checkpoint_server_addr, diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 7c9cfea..c70bb74 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -44,16 +44,12 @@ struct QuorumMemberDetails { member: QuorumMember, } -struct RoomState { - room_id: String, +struct State { channel: broadcast::Sender, participants: HashMap, prev_quorum: Option, quorum_id: i64, -} -struct State { - rooms: HashMap, // heartbeat information // replica_id -> last heartbeat heartbeats: HashMap, @@ -115,10 +111,10 @@ fn quorum_changed(a: &Vec, b: &Vec) -> bool { // Checks whether the quorum is valid, the new quorum and an explanation for the state. fn quorum_compute( now: Instant, - heartbeats: &HashMap, - state: &RoomState, + state: &State, opt: &LighthouseOpt, ) -> (Option>, String) { + let heartbeats = &state.heartbeats; let healthy_participants: HashMap = state .participants .clone() @@ -205,9 +201,15 @@ fn quorum_compute( impl Lighthouse { pub async fn new(opt: LighthouseOpt) -> Result> { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; + + let (tx, _) = broadcast::channel(16); + Ok(Arc::new(Self { state: Mutex::new(State { - rooms: HashMap::new(), + participants: HashMap::new(), + channel: tx, + prev_quorum: None, + quorum_id: 0, heartbeats: HashMap::new(), }), opt: opt, @@ -216,13 +218,9 @@ impl Lighthouse { })) } - fn _quorum_tick( - self: Arc, - heartbeats: &HashMap, - state: &mut RoomState, - ) -> Result<()> { - let (quorum_met, reason) = quorum_compute(Instant::now(), heartbeats, state, &self.opt); - info!("{}: {}", state.room_id, reason); + fn _quorum_tick(self: Arc, state: &mut State) -> Result<()> { + let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt); + info!("{}", reason); if quorum_met.is_some() { let participants = quorum_met.unwrap(); @@ -237,8 +235,8 @@ impl Lighthouse { { state.quorum_id += 1; info!( - "{}: Detected quorum change, bumping quorum_id to {}", - state.room_id, state.quorum_id + "Detected quorum change, bumping quorum_id to {}", + state.quorum_id ); } @@ -248,7 +246,7 @@ impl Lighthouse { created: Some(SystemTime::now().into()), }; - info!("{}: Quorum! {:?}", state.room_id, quorum); + info!("Quorum! {:?}", quorum); state.prev_quorum = Some(quorum.clone()); state.participants.clear(); @@ -264,10 +262,7 @@ impl Lighthouse { loop { { let mut state = self.state.lock().await; - let heartbeats = state.heartbeats.clone(); - for (_room_id, room) in &mut state.rooms { - self.clone()._quorum_tick(&heartbeats, room)?; - } + self.clone()._quorum_tick(&mut state)?; } sleep(Duration::from_millis(self.opt.quorum_tick_ms)).await; @@ -339,39 +334,27 @@ impl Lighthouse { let template = { let state = self.state.lock().await; - let rooms = state - .rooms - .iter() - .map(|(room_id, room)| { - let (_, quorum_status) = - quorum_compute(Instant::now(), &state.heartbeats, &room, &self.opt); - - let max_step = { - if let Some(quorum) = room.prev_quorum.clone() { - quorum - .participants - .iter() - .map(|p| p.step) - .max() - .unwrap_or(-1) - } else { - -1 - } - }; - - RoomStatus { - room_id: room_id.clone(), - quorum_id: room.quorum_id, - prev_quorum: room.prev_quorum.clone(), - quorum_status: quorum_status, - - max_step: max_step, - } - }) - .collect(); + let (_, quorum_status) = quorum_compute(Instant::now(), &state, &self.opt); + + let max_step = { + if let Some(quorum) = state.prev_quorum.clone() { + quorum + .participants + .iter() + .map(|p| p.step) + .max() + .unwrap_or(-1) + } else { + -1 + } + }; StatusTemplate { - rooms: rooms, + quorum_id: state.quorum_id, + prev_quorum: state.prev_quorum.clone(), + quorum_status: quorum_status, + max_step: max_step, + heartbeats: state.heartbeats.clone(), old_age_threshold: Instant::now() .checked_sub(Duration::from_millis(self.opt.heartbeat_timeout_ms)) @@ -385,17 +368,16 @@ impl Lighthouse { let addr = 'addr: { let state = self.state.lock().await; - for (_room_id, room) in &state.rooms { - if room.prev_quorum.is_none() { - return Err(AppError(anyhow!("failed to find replica"))); - } + if state.prev_quorum.is_none() { + return Err(AppError(anyhow!("failed to find replica"))); + } - for member in room.prev_quorum.clone().unwrap().participants { - if member.replica_id == replica_id { - break 'addr member.address; - } + for member in state.prev_quorum.clone().unwrap().participants { + if member.replica_id == replica_id { + break 'addr member.address; } } + return Err(AppError(anyhow!("failed to find replica"))); }; @@ -417,7 +399,6 @@ impl LighthouseService for Arc { request: Request, ) -> Result, Status> { let req = request.into_inner(); - let room_id = req.room_id; let requester = req .requester .ok_or_else(|| return Status::invalid_argument("missing requester"))?; @@ -432,37 +413,18 @@ impl LighthouseService for Arc { .heartbeats .insert(requester.replica_id.clone(), Instant::now()); - let heartbeats = state.heartbeats.clone(); - - if !state.rooms.contains_key(&room_id) { - let (tx, _) = broadcast::channel(16); - - state.rooms.insert( - room_id.clone(), - RoomState { - room_id: room_id.clone(), - participants: HashMap::new(), - channel: tx, - prev_quorum: None, - quorum_id: 0, - }, - ); - } - - let room = state.rooms.get_mut(&room_id).unwrap(); - - room.participants.insert( + state.participants.insert( requester.replica_id.clone(), QuorumMemberDetails { joined: Instant::now(), member: requester, }, ); - let rx = room.channel.subscribe(); + let rx = state.channel.subscribe(); // proactively run quorum tick self.clone() - ._quorum_tick(&heartbeats, room) + ._quorum_tick(&mut state) .map_err(|e| Status::from_error(e.into()))?; rx @@ -500,19 +462,14 @@ struct IndexTemplate {} #[derive(Template)] #[template(path = "status.html")] struct StatusTemplate { - rooms: Vec, - heartbeats: HashMap, - - // visualization thresholds - old_age_threshold: Instant, -} - -struct RoomStatus { - room_id: String, prev_quorum: Option, quorum_id: i64, quorum_status: String, max_step: i64, + heartbeats: HashMap, + + // visualization thresholds + old_age_threshold: Instant, } // Make our own error that wraps `anyhow::Error`. @@ -567,18 +524,17 @@ mod tests { heartbeat_timeout_ms: 5000, }; - let mut state = RoomState { - room_id: "test".to_string(), + let mut state = State { channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, + heartbeats: HashMap::new(), }; - let mut heartbeats = HashMap::new(); let now = Instant::now(); - assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(!quorum_compute(now, &state, &opt).0.is_some()); state.participants.insert( "a".to_string(), @@ -593,14 +549,14 @@ mod tests { }, }, ); - heartbeats.insert("a".to_string(), now); + state.heartbeats.insert("a".to_string(), now); - assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(!quorum_compute(now, &state, &opt).0.is_some()); state.participants.get_mut("a").unwrap().joined = now.sub(Duration::from_secs(10 * 60 * 60)); - assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(quorum_compute(now, &state, &opt).0.is_some()); Ok(()) } @@ -615,14 +571,13 @@ mod tests { heartbeat_timeout_ms: 5000, }; - let mut state = RoomState { - room_id: "test".to_string(), + let mut state = State { channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, + heartbeats: HashMap::new(), }; - let mut heartbeats = HashMap::new(); let now = Instant::now(); @@ -639,14 +594,16 @@ mod tests { }, }, ); - heartbeats.insert("a".to_string(), now); + state.heartbeats.insert("a".to_string(), now); - assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(quorum_compute(now, &state, &opt).0.is_some()); // expired heartbeat - heartbeats.insert("a".to_string(), now.sub(Duration::from_secs(10))); + state + .heartbeats + .insert("a".to_string(), now.sub(Duration::from_secs(10))); - let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_none(), "{}", reason); // 1 healthy, 1 expired @@ -663,9 +620,9 @@ mod tests { }, }, ); - heartbeats.insert("b".to_string(), now); + state.heartbeats.insert("b".to_string(), now); - let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_some(), "{}", reason); let participants = quorum_met.unwrap(); assert!(participants.len() == 1); @@ -683,18 +640,17 @@ mod tests { heartbeat_timeout_ms: 5000, }; - let mut state = RoomState { - room_id: "test".to_string(), + let mut state = State { channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, + heartbeats: HashMap::new(), }; - let mut heartbeats = HashMap::new(); let now = Instant::now(); - assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(!quorum_compute(now, &state, &opt).0.is_some()); state.participants.insert( "a".to_string(), @@ -709,9 +665,9 @@ mod tests { }, }, ); - heartbeats.insert("a".to_string(), now); + state.heartbeats.insert("a".to_string(), now); - assert!(!quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(!quorum_compute(now, &state, &opt).0.is_some()); state.prev_quorum = Some(Quorum { quorum_id: 1, @@ -725,7 +681,7 @@ mod tests { created: Some(SystemTime::now().into()), }); - assert!(quorum_compute(now, &heartbeats, &state, &opt).0.is_some()); + assert!(quorum_compute(now, &state, &opt).0.is_some()); // test expanding quorum w/ fast quorum state.participants.insert( @@ -741,9 +697,9 @@ mod tests { }, }, ); - heartbeats.insert("b".to_string(), now); + state.heartbeats.insert("b".to_string(), now); - let (quorum_met, reason) = quorum_compute(now, &heartbeats, &state, &opt); + let (quorum_met, reason) = quorum_compute(now, &state, &opt); assert!(quorum_met.is_some(), "{}", reason); let participants = quorum_met.unwrap(); assert!(participants.len() == 2); @@ -776,7 +732,6 @@ mod tests { { let request = tonic::Request::new(LighthouseQuorumRequest { - room_id: "test".to_string(), requester: Some(QuorumMember { replica_id: "foo".to_string(), address: "".to_string(), diff --git a/src/manager.rs b/src/manager.rs index eaf521c..a51bbe9 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -35,14 +35,10 @@ use log::{info, warn}; #[cfg(test)] use std::{println as info, println as warn}; -struct RoomState { - channel: broadcast::Sender, - participants: HashSet, -} - struct ManagerState { checkpoint_servers: HashMap, - rooms: HashMap, + channel: broadcast::Sender, + participants: HashSet, should_commit_channel: broadcast::Sender, should_commit_failures: HashSet, @@ -90,6 +86,7 @@ impl Manager { let local_addr = listener.local_addr()?; let (should_commit_tx, _) = broadcast::channel(16); + let (tx, _) = broadcast::channel(16); Ok(Arc::new(Self { replica_id: replica_id, @@ -100,7 +97,8 @@ impl Manager { heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_servers: HashMap::new(), - rooms: HashMap::new(), + channel: tx, + participants: HashSet::new(), should_commit_channel: should_commit_tx, should_commit_count: HashSet::new(), @@ -181,9 +179,8 @@ impl ManagerService for Arc { ) -> Result, Status> { let req = request.get_ref(); let rank = req.rank; - let room_id = &req.room_id; - info!("{}: got quorum request for rank {}", room_id, rank); + info!("got quorum request for rank {}", rank); let mut rx = { let mut state = self.state.lock().await; @@ -194,27 +191,13 @@ impl ManagerService for Arc { .checkpoint_servers .insert(req.rank, req.checkpoint_server_addr.clone()); - if !state.rooms.contains_key(room_id) { - let (tx, _) = broadcast::channel(16); - - state.rooms.insert( - room_id.clone(), - RoomState { - channel: tx, - participants: HashSet::new(), - }, - ); - } - - let room = state.rooms.get_mut(room_id).unwrap(); - // TODO check step - room.participants.insert(rank); - let rx = room.channel.subscribe(); + state.participants.insert(rank); + let rx = state.channel.subscribe(); - if room.participants.len() as u64 >= self.world_size { - room.participants.clear(); - info!("{}: all workers joined -- starting quorum", room_id); + if state.participants.len() as u64 >= self.world_size { + state.participants.clear(); + info!("all workers joined -- starting quorum"); // TODO: don't hold the lock during quorum @@ -224,7 +207,6 @@ impl ManagerService for Arc { .map_err(|e| Status::from_error(e.into()))?; let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest { - room_id: room_id.clone(), requester: Some(QuorumMember { replica_id: self.replica_id.clone(), address: self.address(), @@ -246,9 +228,10 @@ impl ManagerService for Arc { let response = client.quorum(lighthouse_request).await.unwrap(); let resp = response.into_inner(); - info!("{}: got lighthouse quorum {:?}", room_id, resp); + info!("got lighthouse quorum {:?}", resp); - room.channel + state + .channel .send( resp.quorum .ok_or_else(|| Status::internal("missing quorum"))?, @@ -295,8 +278,8 @@ impl ManagerService for Arc { let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id; if heal { info!( - "{}: healing is required step={}, max_step={}", - room_id, req.step, max_step + "healing is required step={}, max_step={}", + req.step, max_step ); } @@ -313,7 +296,7 @@ impl ManagerService for Arc { heal: heal, }; - info!("{}: returning quorum for rank {}", room_id, rank); + info!("returning quorum for rank {}", rank); Ok(Response::new(reply)) } @@ -483,7 +466,6 @@ mod tests { let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?; let mut request = tonic::Request::new(ManagerQuorumRequest { - room_id: "room".to_string(), rank: 0, step: 123, checkpoint_server_addr: "addr".to_string(), @@ -541,7 +523,6 @@ mod tests { manager_client_new(manager.address(), Duration::from_secs(10)).await?; let mut request = tonic::Request::new(ManagerQuorumRequest { - room_id: "room".to_string(), rank: 0, step: 0, checkpoint_server_addr: "addr".to_string(), diff --git a/templates/status.html b/templates/status.html index bacd340..83ca845 100644 --- a/templates/status.html +++ b/templates/status.html @@ -1,12 +1,9 @@ -{% for room in rooms %} -

Room Status: {{room.room_id}}

- -Current quorum_id: {{room.quorum_id}}
-Next quorum status: {{room.quorum_status}} +Current quorum_id: {{quorum_id}}
+Next quorum status: {{quorum_status}}

Previous Quorum

-{% if let Some(prev_quorum) = room.prev_quorum %} +{% if let Some(prev_quorum) = prev_quorum %} Previous quorum id: {{prev_quorum.quorum_id}}
Quorum age: @@ -16,7 +13,7 @@

Previous Quorum

{% for member in prev_quorum.participants %}
{{ member.replica_id }}
Step: {{ member.step }}
@@ -35,8 +32,6 @@

Previous Quorum

{% endif %} -{% endfor %} -

Heartbeats

    diff --git a/torchft/manager.py b/torchft/manager.py index c75bc4c..c75bb48 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -337,7 +337,6 @@ def callback( def start_quorum( self, - room_id: str = "default", allow_heal: bool = True, timeout: Optional[timedelta] = None, ) -> None: @@ -356,8 +355,6 @@ def start_quorum( If allow_heal is set, the manager will attempt to heal either synchronously before returning or asynchronously prior to any network calls. All replicas must pass the same value to allow_heal. - room_id: (experimental) the room id to use for quorum, this allows - for multiple quorums to be used within the same job. timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used """ @@ -374,7 +371,6 @@ def start_quorum( self._quorum_future = self._executor.submit( self._async_quorum, - room_id=room_id, allow_heal=allow_heal, timeout=timeout or self._timeout, ) @@ -400,7 +396,7 @@ def wait_quorum(self) -> None: ), "must call start_quorum before wait_quorum" self._quorum_future.result() - def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> None: + def _async_quorum(self, allow_heal: bool, timeout: timedelta) -> None: ( quorum_id, replica_rank, @@ -412,7 +408,6 @@ def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> N max_world_size, heal, ) = self._client.quorum( - room_id=room_id, rank=self._rank, step=self._step, checkpoint_server_addr=self._ckpt_server.address(), diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index 07365aa..a5196ea 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -5,7 +5,6 @@ class ManagerClient: def __init__(self, addr: str, timeout: timedelta) -> None: ... def quorum( self, - room_id: str, rank: int, step: int, checkpoint_server_addr: str,