Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manager: added E2E tests and support getting lighthouse and manager addresses #25

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bin/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() {
.unwrap();

let opt = LighthouseOpt::from_args();
let lighthouse = Lighthouse::new(opt);
let lighthouse = Lighthouse::new(opt).await.unwrap();

lighthouse.run().await.unwrap();
}
38 changes: 19 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ pub mod manager;

use core::time::Duration;
use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::PyRuntimeError;
Expand All @@ -28,8 +27,6 @@ use pyo3::prelude::*;

#[pyclass]
struct Manager {
runtime: Runtime,
manager: Arc<manager::Manager>,
handle: JoinHandle<Result<()>>,
}

Expand All @@ -47,20 +44,18 @@ impl Manager {
) -> Self {
py.allow_threads(move || {
let runtime = Runtime::new().unwrap();
let manager = manager::Manager::new(
replica_id,
lighthouse_addr,
address,
bind,
store_addr,
world_size,
);
let manager = runtime
.block_on(manager::Manager::new(
replica_id,
lighthouse_addr,
address,
bind,
store_addr,
world_size,
))
.unwrap();
let handle = runtime.spawn(manager.clone().run());
Self {
runtime: runtime,
manager: manager,
handle: handle,
}
Self { handle: handle }
})
}

Expand Down Expand Up @@ -193,11 +188,16 @@ fn lighthouse_main(py: Python<'_>) {
let mut args = env::args();
args.next(); // discard binary arg
let opt = lighthouse::LighthouseOpt::from_iter(args);
let lighthouse = lighthouse::Lighthouse::new(opt);

let rt = Runtime::new().unwrap();
rt.block_on(lighthouse_main_async(opt)).unwrap();
}

rt.block_on(lighthouse.run()).unwrap();
async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
let lighthouse = lighthouse::Lighthouse::new(opt).await?;

lighthouse.run().await?;

Ok(())
}

#[pymodule]
Expand Down
70 changes: 43 additions & 27 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio::time::sleep;
use tonic::service::Routes;
use tonic::transport::server::TcpIncoming;
use tonic::transport::Server;
use tonic::{Request, Response, Status};

Expand Down Expand Up @@ -56,23 +57,25 @@ struct State {
pub struct Lighthouse {
state: Mutex<State>,
opt: LighthouseOpt,
listener: Mutex<Option<tokio::net::TcpListener>>,
local_addr: SocketAddr,
}

#[derive(StructOpt, Debug)]
#[structopt()]
pub struct LighthouseOpt {
// bind is the address to bind the server to.
#[structopt(long = "bind", default_value = "[::]:29510")]
bind: String,
pub bind: String,

#[structopt(long = "join_timeout_ms", default_value = "60000")]
join_timeout_ms: u64,
pub join_timeout_ms: u64,

#[structopt(long = "min_replicas")]
min_replicas: u64,
pub min_replicas: u64,

#[structopt(long = "quorum_tick_ms", default_value = "100")]
quorum_tick_ms: u64,
pub quorum_tick_ms: u64,
}

fn quorum_changed(a: &Vec<QuorumMember>, b: &Vec<QuorumMember>) -> bool {
Expand All @@ -83,9 +86,10 @@ fn quorum_changed(a: &Vec<QuorumMember>, b: &Vec<QuorumMember>) -> bool {
}

impl Lighthouse {
pub fn new(opt: LighthouseOpt) -> Arc<Self> {
pub async fn new(opt: LighthouseOpt) -> Result<Arc<Self>> {
let (tx, _) = broadcast::channel(16);
Arc::new(Self {
let listener = tokio::net::TcpListener::bind(&opt.bind).await?;
Ok(Arc::new(Self {
state: Mutex::new(State {
participants: HashMap::new(),
channel: tx,
Expand All @@ -94,7 +98,9 @@ impl Lighthouse {
heartbeats: HashMap::new(),
}),
opt: opt,
})
local_addr: listener.local_addr()?,
listener: Mutex::new(Some(listener)),
}))
}

// Checks whether the quorum is valid and an explanation for the state.
Expand Down Expand Up @@ -209,13 +215,20 @@ impl Lighthouse {
}
}

async fn _run_grpc(self: Arc<Self>) -> Result<()> {
let bind: SocketAddr = self.opt.bind.parse()?;
info!(
"Lighthouse listening on: http://{}:{}",
pub fn address(&self) -> String {
format!(
"http://{}:{}",
gethostname().into_string().unwrap(),
bind.port()
);
self.local_addr.port()
)
}

async fn _run_grpc(self: Arc<Self>) -> Result<()> {
info!("Lighthouse listening on: {}", self.address());

let listener = self.listener.lock().await.take().unwrap();
let incoming =
TcpIncoming::from_listener(listener, true, None).map_err(|e| anyhow::anyhow!(e))?;

// Setup HTTP endpoints
let app = Router::new()
Expand Down Expand Up @@ -245,7 +258,7 @@ impl Lighthouse {
// allow non-GRPC connections
.accept_http1(true)
.add_routes(routes)
.serve(bind)
.serve_with_incoming(incoming)
.await
.map_err(|e| e.into())
}
Expand Down Expand Up @@ -429,14 +442,14 @@ mod tests {

use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;

fn lighthouse_test_new() -> Arc<Lighthouse> {
async fn lighthouse_test_new() -> Result<Arc<Lighthouse>> {
let opt = LighthouseOpt {
min_replicas: 1,
bind: "0.0.0.0:29510".to_string(),
bind: "[::]:0".to_string(),
join_timeout_ms: 60 * 60 * 1000, // 1hr
quorum_tick_ms: 10,
};
Lighthouse::new(opt)
Lighthouse::new(opt).await
}

async fn lighthouse_client_new(addr: String) -> Result<LighthouseServiceClient<Channel>> {
Expand All @@ -448,8 +461,8 @@ mod tests {
}

#[tokio::test]
async fn test_quorum_join_timeout() {
let lighthouse = lighthouse_test_new();
async fn test_quorum_join_timeout() -> Result<()> {
let lighthouse = lighthouse_test_new().await?;
assert!(!lighthouse.quorum_valid().await.0);

{
Expand Down Expand Up @@ -478,11 +491,13 @@ mod tests {
}

assert!(lighthouse.quorum_valid().await.0);

Ok(())
}

#[tokio::test]
async fn test_quorum_fast_prev_quorum() {
let lighthouse = lighthouse_test_new();
async fn test_quorum_fast_prev_quorum() -> Result<()> {
let lighthouse = lighthouse_test_new().await?;
assert!(!lighthouse.quorum_valid().await.0);

{
Expand Down Expand Up @@ -520,23 +535,23 @@ mod tests {
}

assert!(lighthouse.quorum_valid().await.0);

Ok(())
}

#[tokio::test]
async fn test_lighthouse_e2e() {
async fn test_lighthouse_e2e() -> Result<()> {
let opt = LighthouseOpt {
min_replicas: 1,
bind: "0.0.0.0:29510".to_string(),
bind: "[::]:0".to_string(),
join_timeout_ms: 1,
quorum_tick_ms: 10,
};
let lighthouse = Lighthouse::new(opt);
let lighthouse = Lighthouse::new(opt).await?;

let lighthouse_task = tokio::spawn(lighthouse.clone().run());

let mut client = lighthouse_client_new("http://localhost:29510".to_string())
.await
.unwrap();
let mut client = lighthouse_client_new(lighthouse.address()).await.unwrap();

{
let request = tonic::Request::new(LighthouseHeartbeatRequest {
Expand All @@ -563,6 +578,7 @@ mod tests {
}

lighthouse_task.abort();
Ok(())
}

#[tokio::test]
Expand Down
Loading
Loading