Skip to content

Commit 3122fda

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents 73d0e81 + bc99344 commit 3122fda

11 files changed

+442
-98
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Easy Per Step Fault Tolerance for PyTorch
1010
</h3>
1111

1212
<p align="center">
13-
| <a href="https://pytorch-labs.github.io/torchft/"><b>Documentation</b></a>
13+
| <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
1414
| <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
1515
| <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
1616
|

src/lib.rs

+53-20
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ use std::env;
1212
use std::sync::Arc;
1313

1414
use anyhow::Result;
15-
use pyo3::exceptions::PyRuntimeError;
15+
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
1616
use structopt::StructOpt;
1717
use tokio::runtime::Runtime;
1818
use tokio::task::JoinHandle;
1919
use tonic::transport::Channel;
20+
use tonic::Status;
2021

2122
pub mod torchftpb {
2223
tonic::include_proto!("torchft");
@@ -102,14 +103,16 @@ impl ManagerClient {
102103
})
103104
}
104105

106+
#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
105107
fn quorum(
106108
&mut self,
107109
py: Python<'_>,
108110
room_id: String,
109111
rank: i64,
110112
step: i64,
111113
checkpoint_server_addr: String,
112-
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
114+
timeout: Option<Duration>,
115+
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
113116
py.allow_threads(move || {
114117
let mut request = tonic::Request::new(ManagerQuorumRequest {
115118
room_id: room_id,
@@ -119,12 +122,9 @@ impl ManagerClient {
119122
});
120123
// This notifies the server about the timeout but doesn't affect the
121124
// endpoint timeout which we set on client creation.
122-
request.set_timeout(self.timeout);
125+
request.set_timeout(timeout.unwrap_or(self.timeout));
123126

124-
let response = self
125-
.runtime
126-
.block_on(self.client.quorum(request))
127-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
127+
let response = self.runtime.block_on(self.client.quorum(request))?;
128128
let resp = response.into_inner();
129129
Ok((
130130
resp.quorum_id,
@@ -140,29 +140,36 @@ impl ManagerClient {
140140
})
141141
}
142142

143-
fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
143+
#[pyo3(signature = (rank, timeout=None))]
144+
fn checkpoint_address(
145+
&mut self,
146+
py: Python<'_>,
147+
rank: i64,
148+
timeout: Option<Duration>,
149+
) -> Result<String, StatusError> {
144150
py.allow_threads(move || {
145151
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
146152
// This notifies the server about the timeout but doesn't affect the
147153
// endpoint timeout which we set on client creation.
148-
request.set_timeout(self.timeout);
154+
request.set_timeout(timeout.unwrap_or(self.timeout));
149155

150156
let response = self
151157
.runtime
152-
.block_on(self.client.checkpoint_address(request))
153-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
158+
.block_on(self.client.checkpoint_address(request))?;
154159
let resp = response.into_inner();
155160
Ok(resp.checkpoint_server_address)
156161
})
157162
}
158163

164+
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
159165
fn should_commit(
160166
&mut self,
161167
py: Python<'_>,
162168
rank: i64,
163169
step: i64,
164170
should_commit: bool,
165-
) -> PyResult<bool> {
171+
timeout: Option<Duration>,
172+
) -> Result<bool, StatusError> {
166173
py.allow_threads(move || {
167174
let mut request = tonic::Request::new(ShouldCommitRequest {
168175
rank: rank,
@@ -171,12 +178,9 @@ impl ManagerClient {
171178
});
172179
// This notifies the server about the timeout but doesn't affect the
173180
// endpoint timeout which we set on client creation.
174-
request.set_timeout(self.timeout);
181+
request.set_timeout(timeout.unwrap_or(self.timeout));
175182

176-
let response = self
177-
.runtime
178-
.block_on(self.client.should_commit(request))
179-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
183+
let response = self.runtime.block_on(self.client.should_commit(request))?;
180184
let resp = response.into_inner();
181185
Ok(resp.should_commit)
182186
})
@@ -225,16 +229,25 @@ struct Lighthouse {
225229
#[pymethods]
226230
impl Lighthouse {
227231
#[new]
228-
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
232+
fn new(
233+
py: Python<'_>,
234+
bind: String,
235+
min_replicas: u64,
236+
join_timeout_ms: Option<u64>,
237+
quorum_tick_ms: Option<u64>,
238+
) -> PyResult<Self> {
239+
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
240+
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
241+
229242
py.allow_threads(move || {
230243
let rt = Runtime::new()?;
231244

232245
let lighthouse = rt
233246
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
234247
bind: bind,
235248
min_replicas: min_replicas,
236-
join_timeout_ms: 100,
237-
quorum_tick_ms: 100,
249+
join_timeout_ms: join_timeout_ms,
250+
quorum_tick_ms: quorum_tick_ms,
238251
}))
239252
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
240253

@@ -257,6 +270,26 @@ impl Lighthouse {
257270
}
258271
}
259272

273+
struct StatusError(Status);
274+
275+
impl From<StatusError> for PyErr {
276+
fn from(error: StatusError) -> Self {
277+
let code = error.0.code();
278+
match code {
279+
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
280+
PyTimeoutError::new_err(error.0.to_string())
281+
}
282+
_ => PyRuntimeError::new_err(error.0.to_string()),
283+
}
284+
}
285+
}
286+
287+
impl From<Status> for StatusError {
288+
fn from(other: Status) -> Self {
289+
Self(other)
290+
}
291+
}
292+
260293
#[pymodule]
261294
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
262295
// setup logging on import

src/lighthouse.rs

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) {
106106
for prev_member in prev_quorum.participants.iter() {
107107
if !state.participants.contains_key(&prev_member.replica_id) {
108108
is_fast_quorum = false;
109+
break;
109110
}
110111
}
111112

src/manager.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
180180
&self,
181181
request: Request<ManagerQuorumRequest>,
182182
) -> Result<Response<ManagerQuorumResponse>, Status> {
183-
let req = request.into_inner();
183+
let req = request.get_ref();
184184
let rank = req.rank;
185-
let room_id = req.room_id;
185+
let room_id = &req.room_id;
186186

187187
info!("{}: got quorum request for rank {}", room_id, rank);
188188

@@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
195195
.checkpoint_servers
196196
.insert(req.rank, req.checkpoint_server_addr.clone());
197197

198-
if !state.rooms.contains_key(&room_id) {
198+
if !state.rooms.contains_key(room_id) {
199199
let (tx, _) = broadcast::channel(16);
200200

201201
state.rooms.insert(
@@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
207207
);
208208
}
209209

210-
let room = state.rooms.get_mut(&room_id).unwrap();
210+
let room = state.rooms.get_mut(room_id).unwrap();
211211

212212
// TODO check step
213213
room.participants.insert(rank);
@@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
224224
.await
225225
.map_err(|e| Status::from_error(e.into()))?;
226226

227-
let request = tonic::Request::new(LighthouseQuorumRequest {
227+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
228228
room_id: room_id.clone(),
229229
requester: Some(QuorumMember {
230230
replica_id: self.replica_id.clone(),
@@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
235235
}),
236236
});
237237

238-
let response = client.quorum(request).await.unwrap();
238+
// propagate timeout from request to lighthouse
239+
let timeout = request
240+
.metadata()
241+
.get("grpc-timeout")
242+
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
243+
lighthouse_request
244+
.metadata_mut()
245+
.insert("grpc-timeout", timeout.clone());
246+
247+
let response = client.quorum(lighthouse_request).await.unwrap();
239248
let resp = response.into_inner();
240249

241250
info!("{}: got lighthouse quorum {:?}", room_id, resp);
@@ -471,12 +480,13 @@ mod tests {
471480

472481
let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
473482

474-
let request = tonic::Request::new(ManagerQuorumRequest {
483+
let mut request = tonic::Request::new(ManagerQuorumRequest {
475484
room_id: "room".to_string(),
476485
rank: 0,
477486
step: 123,
478487
checkpoint_server_addr: "addr".to_string(),
479488
});
489+
request.set_timeout(Duration::from_secs(10));
480490
let resp = client.quorum(request).await?.into_inner();
481491

482492
manager_fut.abort();
@@ -526,12 +536,13 @@ mod tests {
526536
let mut client =
527537
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
528538

529-
let request = tonic::Request::new(ManagerQuorumRequest {
539+
let mut request = tonic::Request::new(ManagerQuorumRequest {
530540
room_id: "room".to_string(),
531541
rank: 0,
532542
step: 0,
533543
checkpoint_server_addr: "addr".to_string(),
534544
});
545+
request.set_timeout(Duration::from_secs(10));
535546

536547
let result = client.quorum(request).await?.into_inner();
537548

torchft/lighthouse_test.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import time
2+
from unittest import TestCase
3+
4+
import torch.distributed as dist
5+
6+
from torchft import Manager, ProcessGroupGloo
7+
from torchft.torchft import Lighthouse
8+
9+
10+
class TestLighthouse(TestCase):
11+
def test_join_timeout_behavior(self) -> None:
12+
"""Test that join_timeout_ms affects joining behavior"""
13+
# To test, we create a lighthouse with 100ms and 400ms join timeouts
14+
# and measure the time taken to validate the quorum.
15+
lighthouse = Lighthouse(
16+
bind="[::]:0",
17+
min_replicas=1,
18+
join_timeout_ms=100,
19+
)
20+
21+
# Create a manager that tries to join
22+
try:
23+
store = dist.TCPStore(
24+
host_name="localhost",
25+
port=0,
26+
is_master=True,
27+
wait_for_workers=False,
28+
)
29+
pg = ProcessGroupGloo()
30+
manager = Manager(
31+
pg=pg,
32+
min_replica_size=1,
33+
load_state_dict=lambda x: None,
34+
state_dict=lambda: None,
35+
replica_id=f"lighthouse_test",
36+
store_addr="localhost",
37+
store_port=store.port,
38+
rank=0,
39+
world_size=1,
40+
use_async_quorum=False,
41+
lighthouse_addr=lighthouse.address(),
42+
)
43+
44+
start_time = time.time()
45+
manager.start_quorum()
46+
time_taken = time.time() - start_time
47+
assert time_taken < 0.4, f"Time taken to join: {time_taken} > 0.4s"
48+
49+
finally:
50+
# Cleanup
51+
lighthouse.shutdown()
52+
if "manager" in locals():
53+
manager.shutdown()
54+
55+
lighthouse = Lighthouse(
56+
bind="[::]:0",
57+
min_replicas=1,
58+
join_timeout_ms=400,
59+
)
60+
61+
# Create a manager that tries to join
62+
try:
63+
store = dist.TCPStore(
64+
host_name="localhost",
65+
port=0,
66+
is_master=True,
67+
wait_for_workers=False,
68+
)
69+
pg = ProcessGroupGloo()
70+
manager = Manager(
71+
pg=pg,
72+
min_replica_size=1,
73+
load_state_dict=lambda x: None,
74+
state_dict=lambda: None,
75+
replica_id=f"lighthouse_test",
76+
store_addr="localhost",
77+
store_port=store.port,
78+
rank=0,
79+
world_size=1,
80+
use_async_quorum=False,
81+
lighthouse_addr=lighthouse.address(),
82+
)
83+
84+
start_time = time.time()
85+
manager.start_quorum()
86+
time_taken = time.time() - start_time
87+
assert time_taken > 0.4, f"Time taken to join: {time_taken} < 0.4s"
88+
89+
finally:
90+
# Cleanup
91+
lighthouse.shutdown()
92+
if "manager" in locals():
93+
manager.shutdown()

0 commit comments

Comments
 (0)