Skip to content

Commit 663a292

Browse files
committed
overhaul timeouts for Lighthouse, Manager, checkpoint server
1 parent 79572e6 commit 663a292

12 files changed

+375
-123
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ log = "0.4.22"
1212
prost = "0.13.3"
1313
prost-types = "0.13.3"
1414
pyo3 = {version="0.22.3", features = ["extension-module"]}
15+
rand = "0.8.5"
1516
slog = "2.7.0"
1617
slog-stdlog = "4.1.1"
1718
stderrlog = "0.6.0"

src/lib.rs

+18-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
pub mod lighthouse;
88
pub mod manager;
9+
mod net;
10+
mod retry;
11+
mod timeout;
912

1013
use core::time::Duration;
1114
use std::env;
@@ -84,36 +87,33 @@ impl Manager {
8487
struct ManagerClient {
8588
runtime: Runtime,
8689
client: ManagerServiceClient<Channel>,
87-
timeout: Duration,
8890
}
8991

9092
#[pymethods]
9193
impl ManagerClient {
9294
#[new]
93-
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
95+
fn new(py: Python<'_>, addr: String) -> PyResult<Self> {
9496
py.allow_threads(move || {
9597
let runtime = Runtime::new()?;
9698
let client = runtime
97-
.block_on(manager::manager_client_new(addr, timeout))
99+
.block_on(manager::manager_client_new(addr, Duration::from_secs(60)))
98100
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
99101

100102
Ok(Self {
101103
runtime: runtime,
102104
client: client,
103-
timeout: timeout,
104105
})
105106
})
106107
}
107108

108-
#[pyo3(signature = (rank, step, checkpoint_server_addr, shrink_only, timeout=None))]
109109
fn quorum(
110110
&mut self,
111111
py: Python<'_>,
112112
rank: i64,
113113
step: i64,
114114
checkpoint_server_addr: String,
115115
shrink_only: bool,
116-
timeout: Option<Duration>,
116+
timeout: Duration,
117117
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
118118
py.allow_threads(move || {
119119
let mut request = tonic::Request::new(ManagerQuorumRequest {
@@ -122,9 +122,10 @@ impl ManagerClient {
122122
checkpoint_server_addr: checkpoint_server_addr,
123123
shrink_only: shrink_only,
124124
});
125-
// This notifies the server about the timeout but doesn't affect the
126-
// endpoint timeout which we set on client creation.
127-
request.set_timeout(timeout.unwrap_or(self.timeout));
125+
126+
// This timeout is processed on the server side so we also enable
127+
// keep alives to detect server health.
128+
request.set_timeout(timeout);
128129

129130
let response = self.runtime.block_on(self.client.quorum(request))?;
130131
let resp = response.into_inner();
@@ -142,18 +143,18 @@ impl ManagerClient {
142143
})
143144
}
144145

145-
#[pyo3(signature = (rank, timeout=None))]
146146
fn checkpoint_address(
147147
&mut self,
148148
py: Python<'_>,
149149
rank: i64,
150-
timeout: Option<Duration>,
150+
timeout: Duration,
151151
) -> Result<String, StatusError> {
152152
py.allow_threads(move || {
153153
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
154-
// This notifies the server about the timeout but doesn't affect the
155-
// endpoint timeout which we set on client creation.
156-
request.set_timeout(timeout.unwrap_or(self.timeout));
154+
155+
// This timeout is processed on the server side so we also enable
156+
// keep alives to detect server health.
157+
request.set_timeout(timeout);
157158

158159
let response = self
159160
.runtime
@@ -163,24 +164,24 @@ impl ManagerClient {
163164
})
164165
}
165166

166-
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
167167
fn should_commit(
168168
&mut self,
169169
py: Python<'_>,
170170
rank: i64,
171171
step: i64,
172172
should_commit: bool,
173-
timeout: Option<Duration>,
173+
timeout: Duration,
174174
) -> Result<bool, StatusError> {
175175
py.allow_threads(move || {
176176
let mut request = tonic::Request::new(ShouldCommitRequest {
177177
rank: rank,
178178
step: step,
179179
should_commit: should_commit,
180180
});
181+
181182
// This notifies the server about the timeout but doesn't affect the
182183
// endpoint timeout which we set on client creation.
183-
request.set_timeout(timeout.unwrap_or(self.timeout));
184+
request.set_timeout(timeout);
184185

185186
let response = self.runtime.block_on(self.client.should_commit(request))?;
186187
let resp = response.into_inner();

src/lighthouse.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -543,15 +543,13 @@ mod tests {
543543
use super::*;
544544
use std::ops::Sub;
545545

546-
use tonic::transport::{Channel, Endpoint};
546+
use tonic::transport::Channel;
547547

548+
use crate::net::connect;
548549
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
549550

550551
async fn lighthouse_client_new(addr: String) -> Result<LighthouseServiceClient<Channel>> {
551-
let conn = Endpoint::new(addr)?
552-
.connect_timeout(Duration::from_secs(10))
553-
.connect()
554-
.await?;
552+
let conn = connect(addr, Duration::from_secs(10)).await?;
555553
Ok(LighthouseServiceClient::new(conn))
556554
}
557555

src/manager.rs

+76-57
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ use tokio::sync::Mutex;
1616
use tokio::task::JoinSet;
1717
use tokio::time::sleep;
1818
use tonic::transport::server::TcpIncoming;
19+
use tonic::transport::Channel;
1920
use tonic::transport::Server;
20-
use tonic::transport::{Channel, Endpoint};
2121
use tonic::{Request, Response, Status};
2222

23+
use crate::net::connect;
24+
use crate::timeout::try_parse_grpc_timeout;
2325
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
2426
use crate::torchftpb::manager_service_client::ManagerServiceClient;
2527
use crate::torchftpb::{
@@ -59,16 +61,10 @@ pub struct Manager {
5961

6062
pub async fn manager_client_new(
6163
addr: String,
62-
timeout: Duration,
64+
connect_timeout: Duration,
6365
) -> Result<ManagerServiceClient<Channel>> {
64-
// TODO add retries + backoff so other nodes can start before the rank0 comes up
65-
6666
info!("ManagerClient: establishing connection to {}", &addr);
67-
let conn = Endpoint::new(addr.clone())?
68-
.timeout(timeout)
69-
.connect_timeout(Duration::from_secs(60))
70-
.connect()
71-
.await?;
67+
let conn = connect(addr, connect_timeout).await?;
7268
Ok(ManagerServiceClient::new(conn))
7369
}
7470

@@ -163,12 +159,57 @@ impl Manager {
163159
&self.lighthouse_addr
164160
);
165161

166-
let conn = Endpoint::new(self.lighthouse_addr.clone())?
167-
.connect_timeout(Duration::from_secs(60))
168-
.connect()
169-
.await?;
162+
let conn = connect(self.lighthouse_addr.clone(), Duration::from_secs(60)).await?;
170163
Ok(LighthouseServiceClient::new(conn))
171164
}
165+
166+
async fn _run_quorum(
167+
&self,
168+
state: &mut ManagerState,
169+
requester: QuorumMember,
170+
timeout: Duration,
171+
) -> Result<(), Status> {
172+
if (state.participants.len() as u64) < self.world_size {
173+
return Ok(());
174+
}
175+
176+
state.participants.clear();
177+
info!("all workers joined -- starting quorum");
178+
179+
// TODO: don't hold the lock during quorum
180+
181+
let mut client = self
182+
.lighthouse_client_new()
183+
.await
184+
.map_err(|e| Status::from_error(e.into()))?;
185+
186+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
187+
requester: Some(requester),
188+
});
189+
lighthouse_request.set_timeout(timeout);
190+
191+
let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request))
192+
.await
193+
.unwrap_or_else(|e| {
194+
Err(Status::cancelled(format!(
195+
"lighthouse quorum timed out: {}",
196+
e.to_string()
197+
)))
198+
})?;
199+
let resp = response.into_inner();
200+
201+
info!("got lighthouse quorum {:?}", resp);
202+
203+
state
204+
.channel
205+
.send(
206+
resp.quorum
207+
.ok_or_else(|| Status::internal("missing quorum"))?,
208+
)
209+
.map_err(|e| Status::from_error(e.into()))?;
210+
211+
Ok(())
212+
}
172213
}
173214

174215
#[tonic::async_trait]
@@ -182,6 +223,15 @@ impl ManagerService for Arc<Manager> {
182223

183224
info!("got quorum request for rank {}", rank);
184225

226+
let timeout = try_parse_grpc_timeout(&request.metadata())
227+
.map_err(|e| {
228+
Status::invalid_argument(format!(
229+
"invalid timeout {}",
230+
e.to_str().unwrap_or("invalid")
231+
))
232+
})?
233+
.ok_or_else(|| Status::invalid_argument("missing timeout"))?;
234+
185235
let mut rx = {
186236
let mut state = self.state.lock().await;
187237

@@ -195,50 +245,19 @@ impl ManagerService for Arc<Manager> {
195245
state.participants.insert(rank);
196246
let rx = state.channel.subscribe();
197247

198-
if state.participants.len() as u64 >= self.world_size {
199-
state.participants.clear();
200-
info!("all workers joined -- starting quorum");
201-
202-
// TODO: don't hold the lock during quorum
203-
204-
let mut client = self
205-
.lighthouse_client_new()
206-
.await
207-
.map_err(|e| Status::from_error(e.into()))?;
208-
209-
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
210-
requester: Some(QuorumMember {
211-
replica_id: self.replica_id.clone(),
212-
address: self.address(),
213-
store_address: self.store_address.clone(),
214-
step: req.step,
215-
world_size: self.world_size,
216-
shrink_only: req.shrink_only,
217-
}),
218-
});
219-
220-
// propagate timeout from request to lighthouse
221-
let timeout = request
222-
.metadata()
223-
.get("grpc-timeout")
224-
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
225-
lighthouse_request
226-
.metadata_mut()
227-
.insert("grpc-timeout", timeout.clone());
228-
229-
let response = client.quorum(lighthouse_request).await.unwrap();
230-
let resp = response.into_inner();
231-
232-
info!("got lighthouse quorum {:?}", resp);
233-
234-
state
235-
.channel
236-
.send(
237-
resp.quorum
238-
.ok_or_else(|| Status::internal("missing quorum"))?,
239-
)
240-
.map_err(|e| Status::from_error(e.into()))?;
241-
}
248+
self._run_quorum(
249+
&mut state,
250+
QuorumMember {
251+
replica_id: self.replica_id.clone(),
252+
address: self.address(),
253+
store_address: self.store_address.clone(),
254+
step: req.step,
255+
world_size: self.world_size,
256+
shrink_only: req.shrink_only,
257+
},
258+
timeout,
259+
)
260+
.await?;
242261

243262
rx
244263
};

src/net.rs

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use std::time::Duration;
2+
3+
use anyhow::Result;
4+
use tonic::transport::{Channel, Endpoint};
5+
6+
use crate::retry::{retry_backoff, ExponentialBackoff};
7+
8+
pub async fn connect_once(addr: String, connect_timeout: Duration) -> Result<Channel> {
9+
let conn = Endpoint::new(addr)?
10+
.connect_timeout(connect_timeout)
11+
// Enable HTTP2 keep alives
12+
.http2_keep_alive_interval(Duration::from_secs(60))
13+
// Time taken for server to respond. 20s is default for GRPC.
14+
.keep_alive_timeout(Duration::from_secs(20))
15+
// Enable alive for idle connections.
16+
.keep_alive_while_idle(true)
17+
.connect()
18+
.await?;
19+
Ok(conn)
20+
}
21+
22+
pub async fn connect(addr: String, connect_timeout: Duration) -> Result<Channel> {
23+
retry_backoff(
24+
ExponentialBackoff {
25+
initial_backoff: Duration::from_millis(100),
26+
max_backoff: Duration::from_secs(10),
27+
timeout: connect_timeout,
28+
factor: 1.5,
29+
jitter: Duration::from_millis(100),
30+
},
31+
|| Box::pin(connect_once(addr.clone(), connect_timeout)),
32+
)
33+
.await
34+
}

0 commit comments

Comments
 (0)