Skip to content

Commit 2dcd290

Browse files
committed
manager: add per request timeouts
1 parent a32f807 commit 2dcd290

File tree

4 files changed

+76
-23
lines changed

4 files changed

+76
-23
lines changed

src/lib.rs

+14-4
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ impl ManagerClient {
102102
})
103103
}
104104

105+
#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
105106
fn quorum(
106107
&mut self,
107108
py: Python<'_>,
108109
room_id: String,
109110
rank: i64,
110111
step: i64,
111112
checkpoint_server_addr: String,
113+
timeout: Option<Duration>,
112114
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
113115
py.allow_threads(move || {
114116
let mut request = tonic::Request::new(ManagerQuorumRequest {
@@ -119,7 +121,7 @@ impl ManagerClient {
119121
});
120122
// This notifies the server about the timeout but doesn't affect the
121123
// endpoint timeout which we set on client creation.
122-
request.set_timeout(self.timeout);
124+
request.set_timeout(timeout.unwrap_or(self.timeout));
123125

124126
let response = self
125127
.runtime
@@ -140,12 +142,18 @@ impl ManagerClient {
140142
})
141143
}
142144

143-
fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
145+
#[pyo3(signature = (rank, timeout=None))]
146+
fn checkpoint_address(
147+
&mut self,
148+
py: Python<'_>,
149+
rank: i64,
150+
timeout: Option<Duration>,
151+
) -> PyResult<String> {
144152
py.allow_threads(move || {
145153
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
146154
// This notifies the server about the timeout but doesn't affect the
147155
// endpoint timeout which we set on client creation.
148-
request.set_timeout(self.timeout);
156+
request.set_timeout(timeout.unwrap_or(self.timeout));
149157

150158
let response = self
151159
.runtime
@@ -156,12 +164,14 @@ impl ManagerClient {
156164
})
157165
}
158166

167+
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
159168
fn should_commit(
160169
&mut self,
161170
py: Python<'_>,
162171
rank: i64,
163172
step: i64,
164173
should_commit: bool,
174+
timeout: Option<Duration>,
165175
) -> PyResult<bool> {
166176
py.allow_threads(move || {
167177
let mut request = tonic::Request::new(ShouldCommitRequest {
@@ -171,7 +181,7 @@ impl ManagerClient {
171181
});
172182
// This notifies the server about the timeout but doesn't affect the
173183
// endpoint timeout which we set on client creation.
174-
request.set_timeout(self.timeout);
184+
request.set_timeout(timeout.unwrap_or(self.timeout));
175185

176186
let response = self
177187
.runtime

src/manager.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
180180
&self,
181181
request: Request<ManagerQuorumRequest>,
182182
) -> Result<Response<ManagerQuorumResponse>, Status> {
183-
let req = request.into_inner();
183+
let req = request.get_ref();
184184
let rank = req.rank;
185-
let room_id = req.room_id;
185+
let room_id = &req.room_id;
186186

187187
info!("{}: got quorum request for rank {}", room_id, rank);
188188

@@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
195195
.checkpoint_servers
196196
.insert(req.rank, req.checkpoint_server_addr.clone());
197197

198-
if !state.rooms.contains_key(&room_id) {
198+
if !state.rooms.contains_key(room_id) {
199199
let (tx, _) = broadcast::channel(16);
200200

201201
state.rooms.insert(
@@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
207207
);
208208
}
209209

210-
let room = state.rooms.get_mut(&room_id).unwrap();
210+
let room = state.rooms.get_mut(room_id).unwrap();
211211

212212
// TODO check step
213213
room.participants.insert(rank);
@@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
224224
.await
225225
.map_err(|e| Status::from_error(e.into()))?;
226226

227-
let request = tonic::Request::new(LighthouseQuorumRequest {
227+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
228228
room_id: room_id.clone(),
229229
requester: Some(QuorumMember {
230230
replica_id: self.replica_id.clone(),
@@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
235235
}),
236236
});
237237

238-
let response = client.quorum(request).await.unwrap();
238+
// propagate timeout from request to lighthouse
239+
let timeout = request
240+
.metadata()
241+
.get("grpc-timeout")
242+
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
243+
lighthouse_request
244+
.metadata_mut()
245+
.insert("grpc-timeout", timeout.clone());
246+
247+
let response = client.quorum(lighthouse_request).await.unwrap();
239248
let resp = response.into_inner();
240249

241250
info!("{}: got lighthouse quorum {:?}", room_id, resp);
@@ -471,12 +480,13 @@ mod tests {
471480

472481
let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
473482

474-
let request = tonic::Request::new(ManagerQuorumRequest {
483+
let mut request = tonic::Request::new(ManagerQuorumRequest {
475484
room_id: "room".to_string(),
476485
rank: 0,
477486
step: 123,
478487
checkpoint_server_addr: "addr".to_string(),
479488
});
489+
request.set_timeout(Duration::from_secs(10));
480490
let resp = client.quorum(request).await?.into_inner();
481491

482492
manager_fut.abort();
@@ -526,12 +536,13 @@ mod tests {
526536
let mut client =
527537
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
528538

529-
let request = tonic::Request::new(ManagerQuorumRequest {
539+
let mut request = tonic::Request::new(ManagerQuorumRequest {
530540
room_id: "room".to_string(),
531541
rank: 0,
532542
step: 0,
533543
checkpoint_server_addr: "addr".to_string(),
534544
});
545+
request.set_timeout(Duration::from_secs(10));
535546

536547
let result = client.quorum(request).await?.into_inner();
537548

torchft/manager.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def __init__(
102102
min_replica_size: minimum number of replicas on each step
103103
port: if rank==0, the port to run the manager server on
104104
use_async_quorum: whether to run the quorum asynchronously during the forward pass
105-
timeout: timeout for all operations
105+
timeout:
106+
the default timeout for all operation, if you're using per
107+
request timeouts this should be longer than the longest request
108+
timeout.
106109
rank: the replica group local rank
107110
world_size: the replica group local world size
108111
store_addr: TCPStore address for this replica group
@@ -279,7 +282,10 @@ def errored(self) -> Optional[Exception]:
279282
return self._errored
280283

281284
def wrap_future(
282-
self, fut: torch.futures.Future[T], default: T
285+
self,
286+
fut: torch.futures.Future[T],
287+
default: T,
288+
timeout: Optional[timedelta] = None,
283289
) -> torch.futures.Future[T]:
284290
"""
285291
Wrap a Future and swallow any errors that occur and report them to the manager.
@@ -289,10 +295,11 @@ def wrap_future(
289295
Args:
290296
fut: the Future to wrap
291297
default: the default value to complete the Future with if an error occurs
298+
timeout: the timeout for the Future, if None, the manager's timeout will be used
292299
"""
293300

294301
# add a timeout to the future
295-
fut = future_timeout(fut, self._timeout)
302+
fut = future_timeout(fut, timeout or self._timeout)
296303

297304
# schedule error handling as a continuation on the Future
298305
def callback(
@@ -313,7 +320,12 @@ def callback(
313320
self._pending_work.append(cast(torch.futures.Future[object], fut))
314321
return fut
315322

316-
def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> None:
323+
def start_quorum(
324+
self,
325+
room_id: str = "default",
326+
allow_heal: bool = True,
327+
timeout: Optional[timedelta] = None,
328+
) -> None:
317329
"""
318330
.. note::
319331
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
@@ -331,6 +343,7 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
331343
calls. All replicas must pass the same value to allow_heal.
332344
room_id: (experimental) the room id to use for quorum, this allows
333345
for multiple quorums to be used within the same job.
346+
timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used
334347
"""
335348

336349
# wait for previous quorum to complete
@@ -345,7 +358,10 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
345358
# block to allow gracefully recovering from issues in PG setup and quorum.
346359

347360
self._quorum_future = self._executor.submit(
348-
self._async_quorum, room_id=room_id, allow_heal=allow_heal
361+
self._async_quorum,
362+
room_id=room_id,
363+
allow_heal=allow_heal,
364+
timeout=timeout or self._timeout,
349365
)
350366
if not self._use_async_quorum:
351367
self.wait_quorum()
@@ -369,7 +385,7 @@ def wait_quorum(self) -> None:
369385
), "must call start_quorum before wait_quorum"
370386
self._quorum_future.result()
371387

372-
def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
388+
def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> None:
373389
(
374390
quorum_id,
375391
replica_rank,
@@ -385,6 +401,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
385401
rank=self._rank,
386402
step=self._step,
387403
checkpoint_server_addr=self._ckpt_server.address(),
404+
timeout=timeout,
388405
)
389406

390407
# When using async quorum we need to take the recovered workers.
@@ -422,8 +439,10 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
422439
self._logger.info(
423440
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
424441
)
425-
primary_client = ManagerClient(address, timeout=self._timeout)
426-
checkpoint_server_address = primary_client.checkpoint_address(self._rank)
442+
primary_client = ManagerClient(address, timeout=timeout)
443+
checkpoint_server_address = primary_client.checkpoint_address(
444+
self._rank, timeout=timeout
445+
)
427446

428447
self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
429448
self._pending_state_dict = CheckpointServer.load_from_address(

torchft/torchft.pyi

+16-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,23 @@ from typing import Optional, Tuple
44
class ManagerClient:
55
def __init__(self, addr: str, timeout: timedelta) -> None: ...
66
def quorum(
7-
self, room_id: str, rank: int, step: int, checkpoint_server_addr: str
7+
self,
8+
room_id: str,
9+
rank: int,
10+
step: int,
11+
checkpoint_server_addr: str,
12+
timeout: Optional[timedelta] = None,
813
) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ...
9-
def checkpoint_address(self, rank: int) -> str: ...
10-
def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ...
14+
def checkpoint_address(
15+
self, rank: int, timeout: Optional[timedelta] = None
16+
) -> str: ...
17+
def should_commit(
18+
self,
19+
rank: int,
20+
step: int,
21+
should_commit: bool,
22+
timeout: Optional[timedelta] = None,
23+
) -> bool: ...
1124

1225
class Manager:
1326
def __init__(

0 commit comments

Comments
 (0)