Skip to content

Commit 3ee2360

Browse files
authored
overhaul timeouts for Lighthouse, Manager, checkpoint server (#73)
1 parent 03160ee commit 3ee2360

12 files changed

+433
-133
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ log = "0.4.22"
1212
prost = "0.13.3"
1313
prost-types = "0.13.3"
1414
pyo3 = {version="0.22.3", features = ["extension-module"]}
15+
rand = "0.8.5"
1516
slog = "2.7.0"
1617
slog-stdlog = "4.1.1"
1718
stderrlog = "0.6.0"

src/lib.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
pub mod lighthouse;
88
pub mod manager;
9+
mod net;
10+
mod retry;
11+
mod timeout;
912

1013
use core::time::Duration;
1114
use std::env;
@@ -46,6 +49,7 @@ impl Manager {
4649
store_addr: String,
4750
world_size: u64,
4851
heartbeat_interval: Duration,
52+
connect_timeout: Duration,
4953
) -> PyResult<Self> {
5054
py.allow_threads(move || {
5155
let runtime = Runtime::new()?;
@@ -58,6 +62,7 @@ impl Manager {
5862
store_addr,
5963
world_size,
6064
heartbeat_interval,
65+
connect_timeout,
6166
))
6267
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
6368
let handle = runtime.spawn(manager.clone().run());
@@ -84,36 +89,33 @@ impl Manager {
8489
struct ManagerClient {
8590
runtime: Runtime,
8691
client: ManagerServiceClient<Channel>,
87-
timeout: Duration,
8892
}
8993

9094
#[pymethods]
9195
impl ManagerClient {
9296
#[new]
93-
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
97+
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
9498
py.allow_threads(move || {
9599
let runtime = Runtime::new()?;
96100
let client = runtime
97-
.block_on(manager::manager_client_new(addr, timeout))
101+
.block_on(manager::manager_client_new(addr, connect_timeout))
98102
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
99103

100104
Ok(Self {
101105
runtime: runtime,
102106
client: client,
103-
timeout: timeout,
104107
})
105108
})
106109
}
107110

108-
#[pyo3(signature = (rank, step, checkpoint_server_addr, shrink_only, timeout=None))]
109111
fn quorum(
110112
&mut self,
111113
py: Python<'_>,
112114
rank: i64,
113115
step: i64,
114116
checkpoint_server_addr: String,
115117
shrink_only: bool,
116-
timeout: Option<Duration>,
118+
timeout: Duration,
117119
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
118120
py.allow_threads(move || {
119121
let mut request = tonic::Request::new(ManagerQuorumRequest {
@@ -122,9 +124,10 @@ impl ManagerClient {
122124
checkpoint_server_addr: checkpoint_server_addr,
123125
shrink_only: shrink_only,
124126
});
125-
// This notifies the server about the timeout but doesn't affect the
126-
// endpoint timeout which we set on client creation.
127-
request.set_timeout(timeout.unwrap_or(self.timeout));
127+
128+
// This timeout is processed on the server side so we also enable
129+
// keep alives to detect server health.
130+
request.set_timeout(timeout);
128131

129132
let response = self.runtime.block_on(self.client.quorum(request))?;
130133
let resp = response.into_inner();
@@ -142,18 +145,18 @@ impl ManagerClient {
142145
})
143146
}
144147

145-
#[pyo3(signature = (rank, timeout=None))]
146148
fn checkpoint_address(
147149
&mut self,
148150
py: Python<'_>,
149151
rank: i64,
150-
timeout: Option<Duration>,
152+
timeout: Duration,
151153
) -> Result<String, StatusError> {
152154
py.allow_threads(move || {
153155
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
154-
// This notifies the server about the timeout but doesn't affect the
155-
// endpoint timeout which we set on client creation.
156-
request.set_timeout(timeout.unwrap_or(self.timeout));
156+
157+
// This timeout is processed on the server side so we also enable
158+
// keep alives to detect server health.
159+
request.set_timeout(timeout);
157160

158161
let response = self
159162
.runtime
@@ -163,24 +166,24 @@ impl ManagerClient {
163166
})
164167
}
165168

166-
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
167169
fn should_commit(
168170
&mut self,
169171
py: Python<'_>,
170172
rank: i64,
171173
step: i64,
172174
should_commit: bool,
173-
timeout: Option<Duration>,
175+
timeout: Duration,
174176
) -> Result<bool, StatusError> {
175177
py.allow_threads(move || {
176178
let mut request = tonic::Request::new(ShouldCommitRequest {
177179
rank: rank,
178180
step: step,
179181
should_commit: should_commit,
180182
});
183+
181184
// This notifies the server about the timeout but doesn't affect the
182185
// endpoint timeout which we set on client creation.
183-
request.set_timeout(timeout.unwrap_or(self.timeout));
186+
request.set_timeout(timeout);
184187

185188
let response = self.runtime.block_on(self.client.should_commit(request))?;
186189
let resp = response.into_inner();

src/lighthouse.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -564,15 +564,13 @@ mod tests {
564564
use super::*;
565565
use std::ops::Sub;
566566

567-
use tonic::transport::{Channel, Endpoint};
567+
use tonic::transport::Channel;
568568

569+
use crate::net::connect;
569570
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
570571

571572
async fn lighthouse_client_new(addr: String) -> Result<LighthouseServiceClient<Channel>> {
572-
let conn = Endpoint::new(addr)?
573-
.connect_timeout(Duration::from_secs(10))
574-
.connect()
575-
.await?;
573+
let conn = connect(addr, Duration::from_secs(10)).await?;
576574
Ok(LighthouseServiceClient::new(conn))
577575
}
578576

0 commit comments

Comments
 (0)