Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Jan 10, 2025
2 parents 556e286 + 3122fda commit a03406f
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 136 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Easy Per Step Fault Tolerance for PyTorch
</h3>

<p align="center">
| <a href="https://pytorch-labs.github.io/torchft/"><b>Documentation</b></a>
| <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
| <a href="https://github.com/pytorch-labs/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
| <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
|
Expand Down
73 changes: 53 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use tonic::Status;

pub mod torchftpb {
tonic::include_proto!("torchft");
Expand Down Expand Up @@ -102,14 +103,16 @@ impl ManagerClient {
})
}

#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
fn quorum(
&mut self,
py: Python<'_>,
room_id: String,
rank: i64,
step: i64,
checkpoint_server_addr: String,
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
timeout: Option<Duration>,
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: room_id,
Expand All @@ -119,12 +122,9 @@ impl ManagerClient {
});
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.quorum(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let response = self.runtime.block_on(self.client.quorum(request))?;
let resp = response.into_inner();
Ok((
resp.quorum_id,
Expand All @@ -140,29 +140,36 @@ impl ManagerClient {
})
}

fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
#[pyo3(signature = (rank, timeout=None))]
fn checkpoint_address(
&mut self,
py: Python<'_>,
rank: i64,
timeout: Option<Duration>,
) -> Result<String, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.checkpoint_address(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
.block_on(self.client.checkpoint_address(request))?;
let resp = response.into_inner();
Ok(resp.checkpoint_server_address)
})
}

#[pyo3(signature = (rank, step, should_commit, timeout=None))]
fn should_commit(
&mut self,
py: Python<'_>,
rank: i64,
step: i64,
should_commit: bool,
) -> PyResult<bool> {
timeout: Option<Duration>,
) -> Result<bool, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ShouldCommitRequest {
rank: rank,
Expand All @@ -171,12 +178,9 @@ impl ManagerClient {
});
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.should_commit(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let response = self.runtime.block_on(self.client.should_commit(request))?;
let resp = response.into_inner();
Ok(resp.should_commit)
})
Expand Down Expand Up @@ -225,16 +229,25 @@ struct Lighthouse {
#[pymethods]
impl Lighthouse {
#[new]
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
fn new(
py: Python<'_>,
bind: String,
min_replicas: u64,
join_timeout_ms: Option<u64>,
quorum_tick_ms: Option<u64>,
) -> PyResult<Self> {
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);

py.allow_threads(move || {
let rt = Runtime::new()?;

let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
bind: bind,
min_replicas: min_replicas,
join_timeout_ms: 100,
quorum_tick_ms: 100,
join_timeout_ms: join_timeout_ms,
quorum_tick_ms: quorum_tick_ms,
}))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

Expand All @@ -257,6 +270,26 @@ impl Lighthouse {
}
}

struct StatusError(Status);

impl From<StatusError> for PyErr {
fn from(error: StatusError) -> Self {
let code = error.0.code();
match code {
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
PyTimeoutError::new_err(error.0.to_string())
}
_ => PyRuntimeError::new_err(error.0.to_string()),
}
}
}

impl From<Status> for StatusError {
fn from(other: Status) -> Self {
Self(other)
}
}

#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
Expand Down
1 change: 1 addition & 0 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) {
for prev_member in prev_quorum.participants.iter() {
if !state.participants.contains_key(&prev_member.replica_id) {
is_fast_quorum = false;
break;
}
}

Expand Down
27 changes: 19 additions & 8 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
&self,
request: Request<ManagerQuorumRequest>,
) -> Result<Response<ManagerQuorumResponse>, Status> {
let req = request.into_inner();
let req = request.get_ref();
let rank = req.rank;
let room_id = req.room_id;
let room_id = &req.room_id;

info!("{}: got quorum request for rank {}", room_id, rank);

Expand All @@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
.checkpoint_servers
.insert(req.rank, req.checkpoint_server_addr.clone());

if !state.rooms.contains_key(&room_id) {
if !state.rooms.contains_key(room_id) {
let (tx, _) = broadcast::channel(16);

state.rooms.insert(
Expand All @@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
);
}

let room = state.rooms.get_mut(&room_id).unwrap();
let room = state.rooms.get_mut(room_id).unwrap();

// TODO check step
room.participants.insert(rank);
Expand All @@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
.await
.map_err(|e| Status::from_error(e.into()))?;

let request = tonic::Request::new(LighthouseQuorumRequest {
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
room_id: room_id.clone(),
requester: Some(QuorumMember {
replica_id: self.replica_id.clone(),
Expand All @@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
}),
});

let response = client.quorum(request).await.unwrap();
// propagate timeout from request to lighthouse
let timeout = request
.metadata()
.get("grpc-timeout")
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
lighthouse_request
.metadata_mut()
.insert("grpc-timeout", timeout.clone());

let response = client.quorum(lighthouse_request).await.unwrap();
let resp = response.into_inner();

info!("{}: got lighthouse quorum {:?}", room_id, resp);
Expand Down Expand Up @@ -471,12 +480,13 @@ mod tests {

let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;

let request = tonic::Request::new(ManagerQuorumRequest {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: "room".to_string(),
rank: 0,
step: 123,
checkpoint_server_addr: "addr".to_string(),
});
request.set_timeout(Duration::from_secs(10));
let resp = client.quorum(request).await?.into_inner();

manager_fut.abort();
Expand Down Expand Up @@ -526,12 +536,13 @@ mod tests {
let mut client =
manager_client_new(manager.address(), Duration::from_secs(10)).await?;

let request = tonic::Request::new(ManagerQuorumRequest {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: "room".to_string(),
rank: 0,
step: 0,
checkpoint_server_addr: "addr".to_string(),
});
request.set_timeout(Duration::from_secs(10));

let result = client.quorum(request).await?.into_inner();

Expand Down
32 changes: 19 additions & 13 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import multiprocessing
import os
import unittest
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, Tuple
from unittest.mock import Mock

Expand Down Expand Up @@ -33,20 +36,14 @@
from torchft.process_group import ManagedProcessGroup, ft_init_device_mesh


class FSDPTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4
class FSDPTest(unittest.TestCase):
@staticmethod
def _test_fsdp(world_size: int, rank: int) -> None:
torch.cuda.set_device(rank)

def setUp(self) -> None:
super().setUp()
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
self._spawn_processes()

def test_fsdp(self) -> None:
group_size = self.world_size // 2
group = self.rank // group_size
group_rank = self.rank % group_size
group_size = world_size // 2
group = rank // group_size
group_rank = rank % group_size

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(12346 + group)
Expand All @@ -66,3 +63,12 @@ def test_fsdp(self) -> None:
batch = torch.randn(4, 128).cuda()
shard_model = fully_shard(model, mesh=device_mesh)
shard_model(batch).mean().backward()

@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
def test_fsdp(self) -> None:
multiprocessing.set_start_method("spawn")
with ProcessPoolExecutor(max_workers=4) as executor:
futures = []
for i in range(4):
future = executor.submit(self._test_fsdp, 4, i)
futures.append(future)
Loading

0 comments on commit a03406f

Please sign in to comment.