Skip to content

Commit e82057c

Browse files
committed
overhaul timeouts for Lighthouse, Manager, checkpoint server
1 parent 79572e6 commit e82057c

10 files changed

+271
-143
lines changed

src/lib.rs

+17-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
pub mod lighthouse;
88
pub mod manager;
9+
mod net;
10+
mod timeout;
911

1012
use core::time::Duration;
1113
use std::env;
@@ -84,36 +86,33 @@ impl Manager {
8486
struct ManagerClient {
8587
runtime: Runtime,
8688
client: ManagerServiceClient<Channel>,
87-
timeout: Duration,
8889
}
8990

9091
#[pymethods]
9192
impl ManagerClient {
9293
#[new]
93-
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
94+
fn new(py: Python<'_>, addr: String) -> PyResult<Self> {
9495
py.allow_threads(move || {
9596
let runtime = Runtime::new()?;
9697
let client = runtime
97-
.block_on(manager::manager_client_new(addr, timeout))
98+
.block_on(manager::manager_client_new(addr))
9899
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
99100

100101
Ok(Self {
101102
runtime: runtime,
102103
client: client,
103-
timeout: timeout,
104104
})
105105
})
106106
}
107107

108-
#[pyo3(signature = (rank, step, checkpoint_server_addr, shrink_only, timeout=None))]
109108
fn quorum(
110109
&mut self,
111110
py: Python<'_>,
112111
rank: i64,
113112
step: i64,
114113
checkpoint_server_addr: String,
115114
shrink_only: bool,
116-
timeout: Option<Duration>,
115+
timeout: Duration,
117116
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
118117
py.allow_threads(move || {
119118
let mut request = tonic::Request::new(ManagerQuorumRequest {
@@ -122,9 +121,10 @@ impl ManagerClient {
122121
checkpoint_server_addr: checkpoint_server_addr,
123122
shrink_only: shrink_only,
124123
});
125-
// This notifies the server about the timeout but doesn't affect the
126-
// endpoint timeout which we set on client creation.
127-
request.set_timeout(timeout.unwrap_or(self.timeout));
124+
125+
// This timeout is processed on the server side so we also enable
126+
// keep alives to detect server health.
127+
request.set_timeout(timeout);
128128

129129
let response = self.runtime.block_on(self.client.quorum(request))?;
130130
let resp = response.into_inner();
@@ -142,18 +142,18 @@ impl ManagerClient {
142142
})
143143
}
144144

145-
#[pyo3(signature = (rank, timeout=None))]
146145
fn checkpoint_address(
147146
&mut self,
148147
py: Python<'_>,
149148
rank: i64,
150-
timeout: Option<Duration>,
149+
timeout: Duration,
151150
) -> Result<String, StatusError> {
152151
py.allow_threads(move || {
153152
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
154-
// This notifies the server about the timeout but doesn't affect the
155-
// endpoint timeout which we set on client creation.
156-
request.set_timeout(timeout.unwrap_or(self.timeout));
153+
154+
// This timeout is processed on the server side so we also enable
155+
// keep alives to detect server health.
156+
request.set_timeout(timeout);
157157

158158
let response = self
159159
.runtime
@@ -163,24 +163,24 @@ impl ManagerClient {
163163
})
164164
}
165165

166-
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
167166
fn should_commit(
168167
&mut self,
169168
py: Python<'_>,
170169
rank: i64,
171170
step: i64,
172171
should_commit: bool,
173-
timeout: Option<Duration>,
172+
timeout: Duration,
174173
) -> Result<bool, StatusError> {
175174
py.allow_threads(move || {
176175
let mut request = tonic::Request::new(ShouldCommitRequest {
177176
rank: rank,
178177
step: step,
179178
should_commit: should_commit,
180179
});
180+
181181
// This notifies the server about the timeout but doesn't affect the
182182
// endpoint timeout which we set on client creation.
183-
request.set_timeout(timeout.unwrap_or(self.timeout));
183+
request.set_timeout(timeout);
184184

185185
let response = self.runtime.block_on(self.client.should_commit(request))?;
186186
let resp = response.into_inner();

src/lighthouse.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ impl Lighthouse {
422422
return Err(AppError(anyhow!("failed to find replica")));
423423
};
424424

425-
let mut client = manager_client_new(addr, Duration::from_secs(10)).await?;
425+
let mut client = manager_client_new(addr).await?;
426426

427427
let request = tonic::Request::new(KillRequest {
428428
msg: "killed from dashboard".to_string(),
@@ -543,15 +543,13 @@ mod tests {
543543
use super::*;
544544
use std::ops::Sub;
545545

546-
use tonic::transport::{Channel, Endpoint};
546+
use tonic::transport::Channel;
547547

548+
use crate::net::connect;
548549
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
549550

550551
async fn lighthouse_client_new(addr: String) -> Result<LighthouseServiceClient<Channel>> {
551-
let conn = Endpoint::new(addr)?
552-
.connect_timeout(Duration::from_secs(10))
553-
.connect()
554-
.await?;
552+
let conn = connect(addr).await?;
555553
Ok(LighthouseServiceClient::new(conn))
556554
}
557555

src/manager.rs

+79-66
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ use tokio::sync::Mutex;
1616
use tokio::task::JoinSet;
1717
use tokio::time::sleep;
1818
use tonic::transport::server::TcpIncoming;
19+
use tonic::transport::Channel;
1920
use tonic::transport::Server;
20-
use tonic::transport::{Channel, Endpoint};
2121
use tonic::{Request, Response, Status};
2222

23+
use crate::net::connect;
24+
use crate::timeout::try_parse_grpc_timeout;
2325
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
2426
use crate::torchftpb::manager_service_client::ManagerServiceClient;
2527
use crate::torchftpb::{
@@ -57,18 +59,11 @@ pub struct Manager {
5759
heartbeat_interval: Duration,
5860
}
5961

60-
pub async fn manager_client_new(
61-
addr: String,
62-
timeout: Duration,
63-
) -> Result<ManagerServiceClient<Channel>> {
62+
pub async fn manager_client_new(addr: String) -> Result<ManagerServiceClient<Channel>> {
6463
// TODO add retries + backoff so other nodes can start before the rank0 comes up
6564

6665
info!("ManagerClient: establishing connection to {}", &addr);
67-
let conn = Endpoint::new(addr.clone())?
68-
.timeout(timeout)
69-
.connect_timeout(Duration::from_secs(60))
70-
.connect()
71-
.await?;
66+
let conn = connect(addr.clone()).await?;
7267
Ok(ManagerServiceClient::new(conn))
7368
}
7469

@@ -163,12 +158,57 @@ impl Manager {
163158
&self.lighthouse_addr
164159
);
165160

166-
let conn = Endpoint::new(self.lighthouse_addr.clone())?
167-
.connect_timeout(Duration::from_secs(60))
168-
.connect()
169-
.await?;
161+
let conn = connect(self.lighthouse_addr.clone()).await?;
170162
Ok(LighthouseServiceClient::new(conn))
171163
}
164+
165+
async fn _run_quorum(
166+
&self,
167+
state: &mut ManagerState,
168+
requester: QuorumMember,
169+
timeout: Duration,
170+
) -> Result<(), Status> {
171+
if (state.participants.len() as u64) < self.world_size {
172+
return Ok(());
173+
}
174+
175+
state.participants.clear();
176+
info!("all workers joined -- starting quorum");
177+
178+
// TODO: don't hold the lock during quorum
179+
180+
let mut client = self
181+
.lighthouse_client_new()
182+
.await
183+
.map_err(|e| Status::from_error(e.into()))?;
184+
185+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
186+
requester: Some(requester),
187+
});
188+
lighthouse_request.set_timeout(timeout);
189+
190+
let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
191+
.await
192+
.unwrap_or_else(|e| {
193+
Err(Status::cancelled(format!(
194+
"lighthouse quorum timed out: {}",
195+
e.to_string()
196+
)))
197+
})?;
198+
let resp = response.into_inner();
199+
200+
info!("got lighthouse quorum {:?}", resp);
201+
202+
state
203+
.channel
204+
.send(
205+
resp.quorum
206+
.ok_or_else(|| Status::internal("missing quorum"))?,
207+
)
208+
.map_err(|e| Status::from_error(e.into()))?;
209+
210+
Ok(())
211+
}
172212
}
173213

174214
#[tonic::async_trait]
@@ -182,6 +222,15 @@ impl ManagerService for Arc<Manager> {
182222

183223
info!("got quorum request for rank {}", rank);
184224

225+
let timeout = try_parse_grpc_timeout(&request.metadata())
226+
.map_err(|e| {
227+
Status::invalid_argument(format!(
228+
"invalid timeout {}",
229+
e.to_str().unwrap_or("invalid")
230+
))
231+
})?
232+
.ok_or_else(|| Status::invalid_argument("missing timeout"))?;
233+
185234
let mut rx = {
186235
let mut state = self.state.lock().await;
187236

@@ -195,50 +244,19 @@ impl ManagerService for Arc<Manager> {
195244
state.participants.insert(rank);
196245
let rx = state.channel.subscribe();
197246

198-
if state.participants.len() as u64 >= self.world_size {
199-
state.participants.clear();
200-
info!("all workers joined -- starting quorum");
201-
202-
// TODO: don't hold the lock during quorum
203-
204-
let mut client = self
205-
.lighthouse_client_new()
206-
.await
207-
.map_err(|e| Status::from_error(e.into()))?;
208-
209-
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
210-
requester: Some(QuorumMember {
211-
replica_id: self.replica_id.clone(),
212-
address: self.address(),
213-
store_address: self.store_address.clone(),
214-
step: req.step,
215-
world_size: self.world_size,
216-
shrink_only: req.shrink_only,
217-
}),
218-
});
219-
220-
// propagate timeout from request to lighthouse
221-
let timeout = request
222-
.metadata()
223-
.get("grpc-timeout")
224-
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
225-
lighthouse_request
226-
.metadata_mut()
227-
.insert("grpc-timeout", timeout.clone());
228-
229-
let response = client.quorum(lighthouse_request).await.unwrap();
230-
let resp = response.into_inner();
231-
232-
info!("got lighthouse quorum {:?}", resp);
233-
234-
state
235-
.channel
236-
.send(
237-
resp.quorum
238-
.ok_or_else(|| Status::internal("missing quorum"))?,
239-
)
240-
.map_err(|e| Status::from_error(e.into()))?;
241-
}
247+
self._run_quorum(
248+
&mut state,
249+
QuorumMember {
250+
replica_id: self.replica_id.clone(),
251+
address: self.address(),
252+
store_address: self.store_address.clone(),
253+
step: req.step,
254+
world_size: self.world_size,
255+
shrink_only: req.shrink_only,
256+
},
257+
timeout,
258+
)
259+
.await?;
242260

243261
rx
244262
};
@@ -395,11 +413,7 @@ mod tests {
395413
use crate::lighthouse::{Lighthouse, LighthouseOpt};
396414

397415
async fn should_commit(rank: i64, should_commit: bool) -> Result<ShouldCommitResponse> {
398-
let mut client = manager_client_new(
399-
"http://localhost:29531".to_string(),
400-
Duration::from_secs(10),
401-
)
402-
.await?;
416+
let mut client = manager_client_new("http://localhost:29531".to_string()).await?;
403417

404418
let request = tonic::Request::new(ShouldCommitRequest {
405419
rank: rank,
@@ -470,7 +484,7 @@ mod tests {
470484
.await?;
471485
let manager_fut = tokio::spawn(manager.clone().run());
472486

473-
let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
487+
let mut client = manager_client_new(manager.address()).await?;
474488

475489
let mut request = tonic::Request::new(ManagerQuorumRequest {
476490
rank: 0,
@@ -527,8 +541,7 @@ mod tests {
527541
.await?;
528542
let manager_fut = tokio::spawn(manager.clone().run());
529543

530-
let mut client =
531-
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
544+
let mut client = manager_client_new(manager.address()).await?;
532545

533546
let mut request = tonic::Request::new(ManagerQuorumRequest {
534547
rank: 0,

src/net.rs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use std::time::Duration;
2+
3+
use anyhow::Result;
4+
use tonic::transport::{Channel, Endpoint};
5+
6+
pub async fn connect(addr: String) -> Result<Channel> {
7+
let conn = Endpoint::new(addr)?
8+
.connect_timeout(Duration::from_secs(60))
9+
// Enable HTTP2 keep alives
10+
.http2_keep_alive_interval(Duration::from_secs(60))
11+
// Time taken for server to respond. 20s is default for GRPC.
12+
.keep_alive_timeout(Duration::from_secs(20))
13+
// Enable alive for idle connections.
14+
.keep_alive_while_idle(true)
15+
.connect()
16+
.await?;
17+
Ok(conn)
18+
}

0 commit comments

Comments
 (0)