Skip to content

Commit 517f300

Browse files
authored
manager: add per request timeouts (#59)
1 parent 5dd6f38 commit 517f300

File tree

6 files changed

+206
-50
lines changed

6 files changed

+206
-50
lines changed

src/lib.rs

+41-17
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ use std::env;
1212
use std::sync::Arc;
1313

1414
use anyhow::Result;
15-
use pyo3::exceptions::PyRuntimeError;
15+
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
1616
use structopt::StructOpt;
1717
use tokio::runtime::Runtime;
1818
use tokio::task::JoinHandle;
1919
use tonic::transport::Channel;
20+
use tonic::Status;
2021

2122
pub mod torchftpb {
2223
tonic::include_proto!("torchft");
@@ -102,14 +103,16 @@ impl ManagerClient {
102103
})
103104
}
104105

106+
#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
105107
fn quorum(
106108
&mut self,
107109
py: Python<'_>,
108110
room_id: String,
109111
rank: i64,
110112
step: i64,
111113
checkpoint_server_addr: String,
112-
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
114+
timeout: Option<Duration>,
115+
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
113116
py.allow_threads(move || {
114117
let mut request = tonic::Request::new(ManagerQuorumRequest {
115118
room_id: room_id,
@@ -119,12 +122,9 @@ impl ManagerClient {
119122
});
120123
// This notifies the server about the timeout but doesn't affect the
121124
// endpoint timeout which we set on client creation.
122-
request.set_timeout(self.timeout);
125+
request.set_timeout(timeout.unwrap_or(self.timeout));
123126

124-
let response = self
125-
.runtime
126-
.block_on(self.client.quorum(request))
127-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
127+
let response = self.runtime.block_on(self.client.quorum(request))?;
128128
let resp = response.into_inner();
129129
Ok((
130130
resp.quorum_id,
@@ -140,29 +140,36 @@ impl ManagerClient {
140140
})
141141
}
142142

143-
fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
143+
#[pyo3(signature = (rank, timeout=None))]
144+
fn checkpoint_address(
145+
&mut self,
146+
py: Python<'_>,
147+
rank: i64,
148+
timeout: Option<Duration>,
149+
) -> Result<String, StatusError> {
144150
py.allow_threads(move || {
145151
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
146152
// This notifies the server about the timeout but doesn't affect the
147153
// endpoint timeout which we set on client creation.
148-
request.set_timeout(self.timeout);
154+
request.set_timeout(timeout.unwrap_or(self.timeout));
149155

150156
let response = self
151157
.runtime
152-
.block_on(self.client.checkpoint_address(request))
153-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
158+
.block_on(self.client.checkpoint_address(request))?;
154159
let resp = response.into_inner();
155160
Ok(resp.checkpoint_server_address)
156161
})
157162
}
158163

164+
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
159165
fn should_commit(
160166
&mut self,
161167
py: Python<'_>,
162168
rank: i64,
163169
step: i64,
164170
should_commit: bool,
165-
) -> PyResult<bool> {
171+
timeout: Option<Duration>,
172+
) -> Result<bool, StatusError> {
166173
py.allow_threads(move || {
167174
let mut request = tonic::Request::new(ShouldCommitRequest {
168175
rank: rank,
@@ -171,12 +178,9 @@ impl ManagerClient {
171178
});
172179
// This notifies the server about the timeout but doesn't affect the
173180
// endpoint timeout which we set on client creation.
174-
request.set_timeout(self.timeout);
181+
request.set_timeout(timeout.unwrap_or(self.timeout));
175182

176-
let response = self
177-
.runtime
178-
.block_on(self.client.should_commit(request))
179-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
183+
let response = self.runtime.block_on(self.client.should_commit(request))?;
180184
let resp = response.into_inner();
181185
Ok(resp.should_commit)
182186
})
@@ -266,6 +270,26 @@ impl Lighthouse {
266270
}
267271
}
268272

273+
struct StatusError(Status);
274+
275+
impl From<StatusError> for PyErr {
276+
fn from(error: StatusError) -> Self {
277+
let code = error.0.code();
278+
match code {
279+
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
280+
PyTimeoutError::new_err(error.0.to_string())
281+
}
282+
_ => PyRuntimeError::new_err(error.0.to_string()),
283+
}
284+
}
285+
}
286+
287+
impl From<Status> for StatusError {
288+
fn from(other: Status) -> Self {
289+
Self(other)
290+
}
291+
}
292+
269293
#[pymodule]
270294
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
271295
// setup logging on import

src/manager.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
180180
&self,
181181
request: Request<ManagerQuorumRequest>,
182182
) -> Result<Response<ManagerQuorumResponse>, Status> {
183-
let req = request.into_inner();
183+
let req = request.get_ref();
184184
let rank = req.rank;
185-
let room_id = req.room_id;
185+
let room_id = &req.room_id;
186186

187187
info!("{}: got quorum request for rank {}", room_id, rank);
188188

@@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
195195
.checkpoint_servers
196196
.insert(req.rank, req.checkpoint_server_addr.clone());
197197

198-
if !state.rooms.contains_key(&room_id) {
198+
if !state.rooms.contains_key(room_id) {
199199
let (tx, _) = broadcast::channel(16);
200200

201201
state.rooms.insert(
@@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
207207
);
208208
}
209209

210-
let room = state.rooms.get_mut(&room_id).unwrap();
210+
let room = state.rooms.get_mut(room_id).unwrap();
211211

212212
// TODO check step
213213
room.participants.insert(rank);
@@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
224224
.await
225225
.map_err(|e| Status::from_error(e.into()))?;
226226

227-
let request = tonic::Request::new(LighthouseQuorumRequest {
227+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
228228
room_id: room_id.clone(),
229229
requester: Some(QuorumMember {
230230
replica_id: self.replica_id.clone(),
@@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
235235
}),
236236
});
237237

238-
let response = client.quorum(request).await.unwrap();
238+
// propagate timeout from request to lighthouse
239+
let timeout = request
240+
.metadata()
241+
.get("grpc-timeout")
242+
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
243+
lighthouse_request
244+
.metadata_mut()
245+
.insert("grpc-timeout", timeout.clone());
246+
247+
let response = client.quorum(lighthouse_request).await.unwrap();
239248
let resp = response.into_inner();
240249

241250
info!("{}: got lighthouse quorum {:?}", room_id, resp);
@@ -471,12 +480,13 @@ mod tests {
471480

472481
let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
473482

474-
let request = tonic::Request::new(ManagerQuorumRequest {
483+
let mut request = tonic::Request::new(ManagerQuorumRequest {
475484
room_id: "room".to_string(),
476485
rank: 0,
477486
step: 123,
478487
checkpoint_server_addr: "addr".to_string(),
479488
});
489+
request.set_timeout(Duration::from_secs(10));
480490
let resp = client.quorum(request).await?.into_inner();
481491

482492
manager_fut.abort();
@@ -526,12 +536,13 @@ mod tests {
526536
let mut client =
527537
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
528538

529-
let request = tonic::Request::new(ManagerQuorumRequest {
539+
let mut request = tonic::Request::new(ManagerQuorumRequest {
530540
room_id: "room".to_string(),
531541
rank: 0,
532542
step: 0,
533543
checkpoint_server_addr: "addr".to_string(),
534544
});
545+
request.set_timeout(Duration::from_secs(10));
535546

536547
let result = client.quorum(request).await?.into_inner();
537548

torchft/manager.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def __init__(
102102
min_replica_size: minimum number of replicas on each step
103103
port: if rank==0, the port to run the manager server on
104104
use_async_quorum: whether to run the quorum asynchronously during the forward pass
105-
timeout: timeout for all operations
105+
timeout:
106+
the default timeout for all operation, if you're using per
107+
request timeouts this should be longer than the longest request
108+
timeout.
106109
rank: the replica group local rank
107110
world_size: the replica group local world size
108111
store_addr: TCPStore address for this replica group
@@ -279,7 +282,10 @@ def errored(self) -> Optional[Exception]:
279282
return self._errored
280283

281284
def wrap_future(
282-
self, fut: torch.futures.Future[T], default: T
285+
self,
286+
fut: torch.futures.Future[T],
287+
default: T,
288+
timeout: Optional[timedelta] = None,
283289
) -> torch.futures.Future[T]:
284290
"""
285291
Wrap a Future and swallow any errors that occur and report them to the manager.
@@ -289,10 +295,11 @@ def wrap_future(
289295
Args:
290296
fut: the Future to wrap
291297
default: the default value to complete the Future with if an error occurs
298+
timeout: the timeout for the Future, if None, the manager's timeout will be used
292299
"""
293300

294301
# add a timeout to the future
295-
fut = future_timeout(fut, self._timeout)
302+
fut = future_timeout(fut, timeout or self._timeout)
296303

297304
# schedule error handling as a continuation on the Future
298305
def callback(
@@ -313,7 +320,12 @@ def callback(
313320
self._pending_work.append(cast(torch.futures.Future[object], fut))
314321
return fut
315322

316-
def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> None:
323+
def start_quorum(
324+
self,
325+
room_id: str = "default",
326+
allow_heal: bool = True,
327+
timeout: Optional[timedelta] = None,
328+
) -> None:
317329
"""
318330
.. note::
319331
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
@@ -331,6 +343,7 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
331343
calls. All replicas must pass the same value to allow_heal.
332344
room_id: (experimental) the room id to use for quorum, this allows
333345
for multiple quorums to be used within the same job.
346+
timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used
334347
"""
335348

336349
# wait for previous quorum to complete
@@ -345,7 +358,10 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
345358
# block to allow gracefully recovering from issues in PG setup and quorum.
346359

347360
self._quorum_future = self._executor.submit(
348-
self._async_quorum, room_id=room_id, allow_heal=allow_heal
361+
self._async_quorum,
362+
room_id=room_id,
363+
allow_heal=allow_heal,
364+
timeout=timeout or self._timeout,
349365
)
350366
if not self._use_async_quorum:
351367
self.wait_quorum()
@@ -369,7 +385,7 @@ def wait_quorum(self) -> None:
369385
), "must call start_quorum before wait_quorum"
370386
self._quorum_future.result()
371387

372-
def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
388+
def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> None:
373389
(
374390
quorum_id,
375391
replica_rank,
@@ -385,6 +401,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
385401
rank=self._rank,
386402
step=self._step,
387403
checkpoint_server_addr=self._ckpt_server.address(),
404+
timeout=timeout,
388405
)
389406

390407
# When using async quorum we need to take the recovered workers.
@@ -422,8 +439,10 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
422439
self._logger.info(
423440
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
424441
)
425-
primary_client = ManagerClient(address, timeout=self._timeout)
426-
checkpoint_server_address = primary_client.checkpoint_address(self._rank)
442+
primary_client = ManagerClient(address, timeout=timeout)
443+
checkpoint_server_address = primary_client.checkpoint_address(
444+
self._rank, timeout=timeout
445+
)
427446

428447
self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
429448
self._pending_state_dict = CheckpointServer.load_from_address(
@@ -449,7 +468,7 @@ def _apply_pending_state_dict(self) -> None:
449468
self._load_state_dict(self._pending_state_dict["user"])
450469
self._pending_state_dict = None
451470

452-
def should_commit(self) -> bool:
471+
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
453472
"""
454473
.. note::
455474
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
@@ -486,7 +505,10 @@ def should_commit(self) -> bool:
486505
enough_replicas = self.num_participants() >= self._min_replica_size
487506
local_should_commit = enough_replicas and self._errored is None
488507
should_commit = self._client.should_commit(
489-
self._rank, self._step, local_should_commit
508+
self._rank,
509+
self._step,
510+
local_should_commit,
511+
timeout=timeout or self._timeout,
490512
)
491513
self._logger.info(
492514
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"

0 commit comments

Comments
 (0)