Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manager: add per request timeouts #59

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use tonic::Status;

pub mod torchftpb {
tonic::include_proto!("torchft");
Expand Down Expand Up @@ -102,14 +103,16 @@ impl ManagerClient {
})
}

#[pyo3(signature = (room_id, rank, step, checkpoint_server_addr, timeout=None))]
fn quorum(
&mut self,
py: Python<'_>,
room_id: String,
rank: i64,
step: i64,
checkpoint_server_addr: String,
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
timeout: Option<Duration>,
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: room_id,
Expand All @@ -119,12 +122,9 @@ impl ManagerClient {
});
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.quorum(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let response = self.runtime.block_on(self.client.quorum(request))?;
let resp = response.into_inner();
Ok((
resp.quorum_id,
Expand All @@ -140,29 +140,36 @@ impl ManagerClient {
})
}

fn checkpoint_address(&mut self, py: Python<'_>, rank: i64) -> PyResult<String> {
#[pyo3(signature = (rank, timeout=None))]
fn checkpoint_address(
&mut self,
py: Python<'_>,
rank: i64,
timeout: Option<Duration>,
) -> Result<String, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.checkpoint_address(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
.block_on(self.client.checkpoint_address(request))?;
let resp = response.into_inner();
Ok(resp.checkpoint_server_address)
})
}

#[pyo3(signature = (rank, step, should_commit, timeout=None))]
fn should_commit(
&mut self,
py: Python<'_>,
rank: i64,
step: i64,
should_commit: bool,
) -> PyResult<bool> {
timeout: Option<Duration>,
) -> Result<bool, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ShouldCommitRequest {
rank: rank,
Expand All @@ -171,12 +178,9 @@ impl ManagerClient {
});
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(self.timeout);
request.set_timeout(timeout.unwrap_or(self.timeout));

let response = self
.runtime
.block_on(self.client.should_commit(request))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let response = self.runtime.block_on(self.client.should_commit(request))?;
let resp = response.into_inner();
Ok(resp.should_commit)
})
Expand Down Expand Up @@ -257,6 +261,26 @@ impl Lighthouse {
}
}

struct StatusError(Status);

impl From<StatusError> for PyErr {
fn from(error: StatusError) -> Self {
let code = error.0.code();
match code {
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
PyTimeoutError::new_err(error.0.to_string())
}
_ => PyRuntimeError::new_err(error.0.to_string()),
}
}
}

impl From<Status> for StatusError {
fn from(other: Status) -> Self {
Self(other)
}
}

#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
Expand Down
27 changes: 19 additions & 8 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ impl ManagerService for Arc<Manager> {
&self,
request: Request<ManagerQuorumRequest>,
) -> Result<Response<ManagerQuorumResponse>, Status> {
let req = request.into_inner();
let req = request.get_ref();
let rank = req.rank;
let room_id = req.room_id;
let room_id = &req.room_id;

info!("{}: got quorum request for rank {}", room_id, rank);

Expand All @@ -195,7 +195,7 @@ impl ManagerService for Arc<Manager> {
.checkpoint_servers
.insert(req.rank, req.checkpoint_server_addr.clone());

if !state.rooms.contains_key(&room_id) {
if !state.rooms.contains_key(room_id) {
let (tx, _) = broadcast::channel(16);

state.rooms.insert(
Expand All @@ -207,7 +207,7 @@ impl ManagerService for Arc<Manager> {
);
}

let room = state.rooms.get_mut(&room_id).unwrap();
let room = state.rooms.get_mut(room_id).unwrap();

// TODO check step
room.participants.insert(rank);
Expand All @@ -224,7 +224,7 @@ impl ManagerService for Arc<Manager> {
.await
.map_err(|e| Status::from_error(e.into()))?;

let request = tonic::Request::new(LighthouseQuorumRequest {
let mut lighthouse_request = tonic::Request::new(LighthouseQuorumRequest {
room_id: room_id.clone(),
requester: Some(QuorumMember {
replica_id: self.replica_id.clone(),
Expand All @@ -235,7 +235,16 @@ impl ManagerService for Arc<Manager> {
}),
});

let response = client.quorum(request).await.unwrap();
// propagate timeout from request to lighthouse
let timeout = request
.metadata()
.get("grpc-timeout")
.ok_or_else(|| Status::internal("grpc-timeout not set"))?;
lighthouse_request
.metadata_mut()
.insert("grpc-timeout", timeout.clone());

let response = client.quorum(lighthouse_request).await.unwrap();
let resp = response.into_inner();

info!("{}: got lighthouse quorum {:?}", room_id, resp);
Expand Down Expand Up @@ -471,12 +480,13 @@ mod tests {

let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?;

let request = tonic::Request::new(ManagerQuorumRequest {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: "room".to_string(),
rank: 0,
step: 123,
checkpoint_server_addr: "addr".to_string(),
});
request.set_timeout(Duration::from_secs(10));
let resp = client.quorum(request).await?.into_inner();

manager_fut.abort();
Expand Down Expand Up @@ -526,12 +536,13 @@ mod tests {
let mut client =
manager_client_new(manager.address(), Duration::from_secs(10)).await?;

let request = tonic::Request::new(ManagerQuorumRequest {
let mut request = tonic::Request::new(ManagerQuorumRequest {
room_id: "room".to_string(),
rank: 0,
step: 0,
checkpoint_server_addr: "addr".to_string(),
});
request.set_timeout(Duration::from_secs(10));

let result = client.quorum(request).await?.into_inner();

Expand Down
42 changes: 32 additions & 10 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def __init__(
min_replica_size: minimum number of replicas on each step
port: if rank==0, the port to run the manager server on
use_async_quorum: whether to run the quorum asynchronously during the forward pass
timeout: timeout for all operations
timeout:
the default timeout for all operation, if you're using per
request timeouts this should be longer than the longest request
timeout.
rank: the replica group local rank
world_size: the replica group local world size
store_addr: TCPStore address for this replica group
Expand Down Expand Up @@ -279,7 +282,10 @@ def errored(self) -> Optional[Exception]:
return self._errored

def wrap_future(
self, fut: torch.futures.Future[T], default: T
self,
fut: torch.futures.Future[T],
default: T,
timeout: Optional[timedelta] = None,
) -> torch.futures.Future[T]:
"""
Wrap a Future and swallow any errors that occur and report them to the manager.
Expand All @@ -289,10 +295,11 @@ def wrap_future(
Args:
fut: the Future to wrap
default: the default value to complete the Future with if an error occurs
timeout: the timeout for the Future, if None, the manager's timeout will be used
"""

# add a timeout to the future
fut = future_timeout(fut, self._timeout)
fut = future_timeout(fut, timeout or self._timeout)

# schedule error handling as a continuation on the Future
def callback(
Expand All @@ -313,7 +320,12 @@ def callback(
self._pending_work.append(cast(torch.futures.Future[object], fut))
return fut

def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> None:
def start_quorum(
self,
room_id: str = "default",
allow_heal: bool = True,
timeout: Optional[timedelta] = None,
) -> None:
"""
.. note::
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
Expand All @@ -331,6 +343,7 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
calls. All replicas must pass the same value to allow_heal.
room_id: (experimental) the room id to use for quorum, this allows
for multiple quorums to be used within the same job.
timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used
"""

# wait for previous quorum to complete
Expand All @@ -345,7 +358,10 @@ def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> Non
# block to allow gracefully recovering from issues in PG setup and quorum.

self._quorum_future = self._executor.submit(
self._async_quorum, room_id=room_id, allow_heal=allow_heal
self._async_quorum,
room_id=room_id,
allow_heal=allow_heal,
timeout=timeout or self._timeout,
)
if not self._use_async_quorum:
self.wait_quorum()
Expand All @@ -369,7 +385,7 @@ def wait_quorum(self) -> None:
), "must call start_quorum before wait_quorum"
self._quorum_future.result()

def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> None:
(
quorum_id,
replica_rank,
Expand All @@ -385,6 +401,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
rank=self._rank,
step=self._step,
checkpoint_server_addr=self._ckpt_server.address(),
timeout=timeout,
)

# When using async quorum we need to take the recovered workers.
Expand Down Expand Up @@ -422,8 +439,10 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None:
self._logger.info(
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
)
primary_client = ManagerClient(address, timeout=self._timeout)
checkpoint_server_address = primary_client.checkpoint_address(self._rank)
primary_client = ManagerClient(address, timeout=timeout)
checkpoint_server_address = primary_client.checkpoint_address(
self._rank, timeout=timeout
)

self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
self._pending_state_dict = CheckpointServer.load_from_address(
Expand All @@ -449,7 +468,7 @@ def _apply_pending_state_dict(self) -> None:
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None

def should_commit(self) -> bool:
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
.. note::
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
Expand Down Expand Up @@ -486,7 +505,10 @@ def should_commit(self) -> bool:
enough_replicas = self.num_participants() >= self._min_replica_size
local_should_commit = enough_replicas and self._errored is None
should_commit = self._client.should_commit(
self._rank, self._step, local_should_commit
self._rank,
self._step,
local_should_commit,
timeout=timeout or self._timeout,
)
self._logger.info(
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
Expand Down
Loading
Loading