Skip to content

Commit a03406f

Browse files
committed
Update
[ghstack-poisoned]
2 parents 556e286 + 3122fda commit a03406f

12 files changed

+482
-136
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Easy Per Step Fault Tolerance for PyTorch
1010
</h3>
1111

1212
<p align="center">
13-
| <a href="https://pytorch-labs.github.io/torchft/"><b>Documentation</b></a>
13+
| <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
1414
| <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
1515
| <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
1616
|

src/lib.rs

+53-20
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ use std::env;
1212
use std::sync::Arc;
1313

1414
use anyhow::Result;
15-
use pyo3::exceptions::PyRuntimeError;
15+
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
1616
use structopt::StructOpt;
1717
use tokio::runtime::Runtime;
1818
use tokio::task::JoinHandle;
1919
use tonic::transport::Channel;
20+
use tonic::Status;
2021

2122
pub mod torchftpb {
2223
tonic::include_proto!("torchft");
@@ -102,14 +103,16 @@ impl ManagerClient {
102103
})
103104
}
104105

106+
#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
105107
fn quorum(
106108
&mut self,
107109
py: Python<'_>,
108110
room_id: String,
109111
rank: i64,
110112
step: i64,
111113
checkpoint_server_addr: String,
112-
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
114+
timeout: Option<Duration>,
115+
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
113116
py.allow_threads(move || {
114117
let mut request = tonic::Request::new(ManagerQuorumRequest {
115118
room_id: room_id,
@@ -119,12 +122,9 @@ impl ManagerClient {
119122
});
120123
// This notifies the server about the timeout but doesn't affect the
121124
// endpoint timeout which we set on client creation.
122-
request.set_timeout(self.timeout);
125+
request.set_timeout(timeout.unwrap_or(self.timeout));
123126

124-
let response = self
125-
.runtime
126-
.block_on(self.client.quorum(request))
127-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
127+
let response = self.runtime.block_on(self.client.quorum(request))?;
128128
let resp = response.into_inner();
129129
Ok((
130130
resp.quorum_id,
@@ -140,29 +140,36 @@ impl ManagerClient {
140140
})
141141
}
142142

143-
fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
143+
#[pyo3(signature = (rank, timeout=None))]
144+
fn checkpoint_address(
145+
&mut self,
146+
py: Python<'_>,
147+
rank: i64,
148+
timeout: Option<Duration>,
149+
) -> Result<String, StatusError> {
144150
py.allow_threads(move || {
145151
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
146152
// This notifies the server about the timeout but doesn't affect the
147153
// endpoint timeout which we set on client creation.
148-
request.set_timeout(self.timeout);
154+
request.set_timeout(timeout.unwrap_or(self.timeout));
149155

150156
let response = self
151157
.runtime
152-
.block_on(self.client.checkpoint_address(request))
153-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
158+
.block_on(self.client.checkpoint_address(request))?;
154159
let resp = response.into_inner();
155160
Ok(resp.checkpoint_server_address)
156161
})
157162
}
158163

164+
#[pyo3(signature = (rank, step, should_commit, timeout=None))]
159165
fn should_commit(
160166
&mut self,
161167
py: Python<'_>,
162168
rank: i64,
163169
step: i64,
164170
should_commit: bool,
165-
) -> PyResult<bool> {
171+
timeout: Option<Duration>,
172+
) -> Result<bool, StatusError> {
166173
py.allow_threads(move || {
167174
let mut request = tonic::Request::new(ShouldCommitRequest {
168175
rank: rank,
@@ -171,12 +178,9 @@ impl ManagerClient {
171178
});
172179
// This notifies the server about the timeout but doesn't affect the
173180
// endpoint timeout which we set on client creation.
174-
request.set_timeout(self.timeout);
181+
request.set_timeout(timeout.unwrap_or(self.timeout));
175182

176-
let response = self
177-
.runtime
178-
.block_on(self.client.should_commit(request))
179-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
183+
let response = self.runtime.block_on(self.client.should_commit(request))?;
180184
let resp = response.into_inner();
181185
Ok(resp.should_commit)
182186
})
@@ -225,16 +229,25 @@ struct Lighthouse {
225229
#[pymethods]
226230
impl Lighthouse {
227231
#[new]
228-
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
232+
fn new(
233+
py: Python<'_>,
234+
bind: String,
235+
min_replicas: u64,
236+
join_timeout_ms: Option<u64>,
237+
quorum_tick_ms: Option<u64>,
238+
) -> PyResult<Self> {
239+
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
240+
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
241+
229242
py.allow_threads(move || {
230243
let rt = Runtime::new()?;
231244

232245
let lighthouse = rt
233246
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
234247
bind: bind,
235248
min_replicas: min_replicas,
236-
join_timeout_ms: 100,
237-
quorum_tick_ms: 100,
249+
join_timeout_ms: join_timeout_ms,
250+
quorum_tick_ms: quorum_tick_ms,
238251
}))
239252
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
240253

@@ -257,6 +270,26 @@ impl Lighthouse {
257270
}
258271
}
259272

273+
struct StatusError(Status);
274+
275+
impl From<StatusError> for PyErr {
276+
fn from(error: StatusError) -> Self {
277+
let code = error.0.code();
278+
match code {
279+
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
280+
PyTimeoutError::new_err(error.0.to_string())
281+
}
282+
_ => PyRuntimeError::new_err(error.0.to_string()),
283+
}
284+
}
285+
}
286+
287+
impl From<Status> for StatusError {
288+
fn from(other: Status) -> Self {
289+
Self(other)
290+
}
291+
}
292+
260293
#[pymodule]
261294
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
262295
// setup logging on import

src/lighthouse.rs

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) {
106106
for prev_member in prev_quorum.participants.iter() {
107107
if !state.participants.contains_key(&prev_member.replica_id) {
108108
is_fast_quorum = false;
109+
break;
109110
}
110111
}
111112

src/manager.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
180180
&self,
181181
request: Request<ManagerQuorumRequest>,
182182
) -> Result<Response<ManagerQuorumResponse>, Status> {
183-
let req = request.into_inner();
183+
let req = request.get_ref();
184184
let rank = req.rank;
185-
let room_id = req.room_id;
185+
let room_id = &req.room_id;
186186

187187
info!("{}: got quorum request for rank {}", room_id, rank);
188188

@@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
195195
.checkpoint_servers
196196
.insert(req.rank, req.checkpoint_server_addr.clone());
197197

198-
if !state.rooms.contains_key(&room_id) {
198+
if !state.rooms.contains_key(room_id) {
199199
let (tx, _) = broadcast::channel(16);
200200

201201
state.rooms.insert(
@@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
207207
);
208208
}
209209

210-
let room = state.rooms.get_mut(&room_id).unwrap();
210+
let room = state.rooms.get_mut(room_id).unwrap();
211211

212212
// TODO check step
213213
room.participants.insert(rank);
@@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
224224
.await
225225
.map_err(|e| Status::from_error(e.into()))?;
226226

227-
let request = tonic::Request::new(LighthouseQuorumRequest {
227+
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
228228
room_id: room_id.clone(),
229229
requester: Some(QuorumMember {
230230
replica_id: self.replica_id.clone(),
@@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
235235
}),
236236
});
237237

238-
let response = client.quorum(request).await.unwrap();
238+
// propagate timeout from request to lighthouse
239+
let timeout = request
240+
.metadata()
241+
.get("grpc-timeout")
242+
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
243+
lighthouse_request
244+
.metadata_mut()
245+
.insert("grpc-timeout", timeout.clone());
246+
247+
let response = client.quorum(lighthouse_request).await.unwrap();
239248
let resp = response.into_inner();
240249

241250
info!("{}: got lighthouse quorum {:?}", room_id, resp);
@@ -471,12 +480,13 @@ mod tests {
471480

472481
let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;
473482

474-
let request = tonic::Request::new(ManagerQuorumRequest {
483+
let mut request = tonic::Request::new(ManagerQuorumRequest {
475484
room_id: "room".to_string(),
476485
rank: 0,
477486
step: 123,
478487
checkpoint_server_addr: "addr".to_string(),
479488
});
489+
request.set_timeout(Duration::from_secs(10));
480490
let resp = client.quorum(request).await?.into_inner();
481491

482492
manager_fut.abort();
@@ -526,12 +536,13 @@ mod tests {
526536
let mut client =
527537
manager_client_new(manager.address(), Duration::from_secs(10)).await?;
528538

529-
let request = tonic::Request::new(ManagerQuorumRequest {
539+
let mut request = tonic::Request::new(ManagerQuorumRequest {
530540
room_id: "room".to_string(),
531541
rank: 0,
532542
step: 0,
533543
checkpoint_server_addr: "addr".to_string(),
534544
});
545+
request.set_timeout(Duration::from_secs(10));
535546

536547
let result = client.quorum(request).await?.into_inner();
537548

torchft/fsdp_test.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import multiprocessing
78
import os
9+
import unittest
10+
from concurrent.futures import ProcessPoolExecutor
811
from typing import Any, Dict, Tuple
912
from unittest.mock import Mock
1013

@@ -33,20 +36,14 @@
3336
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh
3437

3538

36-
class FSDPTest(MultiProcessTestCase):
37-
@property
38-
def world_size(self) -> int:
39-
return 4
39+
class FSDPTest(unittest.TestCase):
40+
@staticmethod
41+
def _test_fsdp(world_size: int, rank: int) -> None:
42+
torch.cuda.set_device(rank)
4043

41-
def setUp(self) -> None:
42-
super().setUp()
43-
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
44-
self._spawn_processes()
45-
46-
def test_fsdp(self) -> None:
47-
group_size = self.world_size // 2
48-
group = self.rank // group_size
49-
group_rank = self.rank % group_size
44+
group_size = world_size // 2
45+
group = rank // group_size
46+
group_rank = rank % group_size
5047

5148
os.environ["MASTER_ADDR"] = "127.0.0.1"
5249
os.environ["MASTER_PORT"] = str(12346 + group)
@@ -66,3 +63,12 @@ def test_fsdp(self) -> None:
6663
batch = torch.randn(4, 128).cuda()
6764
shard_model = fully_shard(model, mesh=device_mesh)
6865
shard_model(batch).mean().backward()
66+
67+
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
68+
def test_fsdp(self) -> None:
69+
multiprocessing.set_start_method("spawn")
70+
with ProcessPoolExecutor(max_workers=4) as executor:
71+
futures = []
72+
for i in range(4):
73+
future = executor.submit(self._test_fsdp, 4, i)
74+
futures.append(future)

0 commit comments

Comments
 (0)