diff --git a/patchbay/src/balancer.rs b/patchbay/src/balancer.rs new file mode 100644 index 0000000..4b07000 --- /dev/null +++ b/patchbay/src/balancer.rs @@ -0,0 +1,534 @@ +//! L4 load balancer backed by nftables DNAT rules on a router. +//! +//! The balancer uses the router's existing IX (public) IP as the VIP. +//! Different balancers on the same router use different ports. Backends +//! are private devices behind the router. +//! +//! Traffic flow: client sends to `:` on the IX bridge. +//! The router's DNAT rules rewrite to a backend's private IP. Masquerade +//! ensures return traffic goes through the router. + +use std::{ + net::{Ipv4Addr, Ipv6Addr}, + sync::Arc, + time::Duration, +}; + +use anyhow::{anyhow, Result}; +use tracing::debug; + +use crate::{ + core::{NodeId, RouterData}, + lab::LabInner, + nft::run_nft_in, +}; + +// ───────────────────────────────────────────── +// Types +// ───────────────────────────────────────────── + +/// Load-balancing algorithm. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum LbAlgorithm { + /// Distribute connections evenly across backends in order. + #[default] + RoundRobin, +} + +/// Transport protocol for the load balancer. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum LbProtocol { + /// TCP (default). + #[default] + Tcp, + /// UDP. + Udp, +} + +impl LbProtocol { + fn nft_name(self) -> &'static str { + match self { + Self::Tcp => "tcp", + Self::Udp => "udp", + } + } +} + +/// A single backend target: device ID and port. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct BackendEntry { + /// Device node identifier. + pub device_id: NodeId, + /// Port on the backend device. + pub port: u16, +} + +/// Resolved configuration for a single load balancer, stored on [`RouterData`]. +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub(crate) struct BalancerConfig { + /// Human-readable name (e.g. `"web"`). + pub name: String, + /// Frontend port on the router's WAN IP. + pub port: u16, + /// Backend targets. + pub backends: Vec, + /// Balancing algorithm. + pub algorithm: LbAlgorithm, + /// Transport protocol. + pub protocol: LbProtocol, + /// Optional session affinity timeout (not yet wired into nft rules). + pub affinity: Option, +} + +// ───────────────────────────────────────────── +// BalancerBuilder +// ───────────────────────────────────────────── + +/// Builder for an L4 load balancer on a router. +/// +/// Created by [`Router::add_balancer`]. Call `.backend()` to register +/// targets, then `.build().await` to install the nftables rules. +pub struct BalancerBuilder { + router_id: NodeId, + lab: Arc, + name: String, + port: u16, + backends: Vec, + algorithm: LbAlgorithm, + protocol: LbProtocol, + affinity: Option, +} + +impl BalancerBuilder { + /// Adds a backend device at the given port. + pub fn backend(mut self, device_id: NodeId, port: u16) -> Self { + self.backends.push(BackendEntry { device_id, port }); + self + } + + /// Selects round-robin distribution (the default). + pub fn round_robin(mut self) -> Self { + self.algorithm = LbAlgorithm::RoundRobin; + self + } + + /// Sets the transport protocol. + pub fn protocol(mut self, proto: LbProtocol) -> Self { + self.protocol = proto; + self + } + + /// Enables session affinity with the given timeout. + pub fn session_affinity(mut self, duration: Duration) -> Self { + self.affinity = Some(duration); + self + } + + /// Builds the balancer, installs nftables rules, and returns a handle. + pub async fn build(self) -> Result { + if self.backends.is_empty() { + return Err(anyhow!("balancer '{}' has no backends", self.name)); + } + + let config = BalancerConfig { + name: self.name.clone(), + port: self.port, + backends: self.backends, + algorithm: self.algorithm, + protocol: self.protocol, + affinity: self.affinity, + }; + + // Store config on the router. + { + let mut core = self.lab.core.lock().expect("poisoned"); + let router = core + .router_mut(self.router_id) + .ok_or_else(|| anyhow!("router removed"))?; + // Check for duplicate name. + if router.balancers.iter().any(|b| b.name == config.name) { + return Err(anyhow!( + "balancer '{}' already exists on this router", + config.name + )); + } + router.balancers.push(config); + } + + // Apply nftables rules. + apply_balancer_rules(&self.lab, self.router_id).await?; + + Ok(Balancer { + router_id: self.router_id, + name: self.name, + lab: self.lab, + }) + } +} + +// ───────────────────────────────────────────── +// Balancer handle +// ───────────────────────────────────────────── + +/// Handle to an active L4 load balancer on a router. +/// +/// Provides read access to the VIP and runtime mutation (add/remove backends). +pub struct Balancer { + router_id: NodeId, + name: String, + lab: Arc, +} + +impl Clone for Balancer { + fn clone(&self) -> Self { + Self { + router_id: self.router_id, + name: self.name.clone(), + lab: Arc::clone(&self.lab), + } + } +} + +impl std::fmt::Debug for Balancer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Balancer") + .field("router_id", &self.router_id) + .field("name", &self.name) + .finish() + } +} + +impl Balancer { + /// Returns the balancer name. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the frontend port. + pub fn port(&self) -> u16 { + let core = self.lab.core.lock().expect("poisoned"); + core.router(self.router_id) + .and_then(|r| r.balancers.iter().find(|b| b.name == self.name)) + .map(|b| b.port) + .unwrap_or(0) + } + + /// Returns the VIP (router's IX IPv4 address). + pub fn ip(&self) -> Option { + let core = self.lab.core.lock().expect("poisoned"); + core.router(self.router_id).and_then(|r| r.upstream_ip) + } + + /// Returns the VIP (router's IX IPv6 address). + pub fn ip6(&self) -> Option { + let core = self.lab.core.lock().expect("poisoned"); + core.router(self.router_id).and_then(|r| r.upstream_ip_v6) + } + + /// Adds a backend device at runtime and regenerates rules. + pub async fn add_backend(&self, device_id: NodeId, port: u16) -> Result<()> { + { + let mut core = self.lab.core.lock().expect("poisoned"); + let router = core + .router_mut(self.router_id) + .ok_or_else(|| anyhow!("router removed"))?; + let cfg = router + .balancers + .iter_mut() + .find(|b| b.name == self.name) + .ok_or_else(|| anyhow!("balancer '{}' not found", self.name))?; + if cfg + .backends + .iter() + .any(|b| b.device_id == device_id && b.port == port) + { + return Ok(()); + } + cfg.backends.push(BackendEntry { device_id, port }); + } + apply_balancer_rules(&self.lab, self.router_id).await + } + + /// Removes a backend device at runtime and regenerates rules. + pub async fn remove_backend(&self, device_id: NodeId) -> Result<()> { + { + let mut core = self.lab.core.lock().expect("poisoned"); + let router = core + .router_mut(self.router_id) + .ok_or_else(|| anyhow!("router removed"))?; + let cfg = router + .balancers + .iter_mut() + .find(|b| b.name == self.name) + .ok_or_else(|| anyhow!("balancer '{}' not found", self.name))?; + cfg.backends.retain(|b| b.device_id != device_id); + } + apply_balancer_rules(&self.lab, self.router_id).await + } +} + +// ───────────────────────────────────────────── +// Router / Lab glue (impl blocks on foreign types) +// ───────────────────────────────────────────── + +impl crate::Router { + /// Begins building an L4 load balancer on this router. + /// + /// The balancer uses this router's public (IX) IP as the VIP and the + /// given `port` as the frontend port. + pub fn add_balancer(&self, name: &str, port: u16) -> BalancerBuilder { + BalancerBuilder { + router_id: self.id(), + lab: Arc::clone(&self.lab), + name: name.to_string(), + port, + backends: Vec::new(), + algorithm: LbAlgorithm::default(), + protocol: LbProtocol::default(), + affinity: None, + } + } + + /// Returns a handle to an existing balancer by name. + pub fn balancer(&self, name: &str) -> Option { + let core = self.lab.core.lock().expect("poisoned"); + let router = core.router(self.id())?; + if router.balancers.iter().any(|b| b.name == name) { + Some(Balancer { + router_id: self.id(), + name: name.to_string(), + lab: Arc::clone(&self.lab), + }) + } else { + None + } + } +} + +// ───────────────────────────────────────────── +// nftables rule generation +// ───────────────────────────────────────────── + +/// Resolved backend address for rule generation. +struct ResolvedBackend { + ip: Ipv4Addr, + port: u16, +} + +/// Resolved v6 backend address for rule generation. +struct ResolvedBackendV6 { + ip: Ipv6Addr, + port: u16, +} + +/// Regenerates and applies all balancer nftables rules for a router. +/// +/// Deletes the old `table ip lb` (and `table ip6 lb`) then recreates +/// them from the current balancer configs stored on the router. +async fn apply_balancer_rules(lab: &Arc, router_id: NodeId) -> Result<()> { + // Phase 1: lock, snapshot, unlock. + let (ns, rules) = { + let core = lab.core.lock().expect("poisoned"); + let router = core + .router(router_id) + .ok_or_else(|| anyhow!("router removed"))?; + let ns = router.ns.clone(); + + if router.balancers.is_empty() { + (ns, None) + } else { + let r = generate_all_balancer_rules(router, &core)?; + (ns, Some(r)) + } + }; + + // Phase 2: apply rules (no lock held). + // Always delete existing lb tables first (ignoring errors if they + // do not exist yet). + run_nft_in(&lab.netns, &ns, "delete table ip lb\n") + .await + .ok(); + run_nft_in(&lab.netns, &ns, "delete table ip6 lb\n") + .await + .ok(); + + match rules { + None => Ok(()), + Some(rules) => { + debug!(ns = %ns, rules = %rules, "balancer: apply rules"); + run_nft_in(&lab.netns, &ns, &rules).await + } + } +} + +/// Generates the complete nftables ruleset for all balancers on a router. +fn generate_all_balancer_rules( + router: &RouterData, + core: &crate::core::NetworkCore, +) -> Result { + let wan_ip = router + .upstream_ip + .ok_or_else(|| anyhow!("router has no WAN IP for balancer"))?; + + let mut rules = String::new(); + + // IPv4 table. + rules.push_str("table ip lb {\n"); + + // Per-service chains. + for cfg in &router.balancers { + let resolved = resolve_backends_v4(cfg, core)?; + if resolved.is_empty() { + continue; + } + rules.push_str(&format!(" chain svc_{} {{\n", cfg.name)); + rules.push_str(&generate_dnat_map_v4(&resolved, cfg.protocol)); + rules.push_str(" }\n"); + } + + // Prerouting chain. + rules.push_str(" chain prerouting {\n"); + rules.push_str(" type nat hook prerouting priority dstnat - 5; policy accept;\n"); + for cfg in &router.balancers { + let resolved = resolve_backends_v4(cfg, core)?; + if resolved.is_empty() { + continue; + } + rules.push_str(&format!( + " ip daddr {} {} dport {} goto svc_{}\n", + wan_ip, + cfg.protocol.nft_name(), + cfg.port, + cfg.name + )); + } + rules.push_str(" }\n"); + + // Postrouting chain. + rules.push_str(" chain postrouting {\n"); + rules.push_str(" type nat hook postrouting priority srcnat; policy accept;\n"); + rules.push_str(" ct status dnat masquerade\n"); + rules.push_str(" }\n"); + + rules.push_str("}\n"); + + // IPv6 table (if the router has an upstream v6 address). + if let Some(wan_ip6) = router.upstream_ip_v6 { + rules.push_str("table ip6 lb {\n"); + + for cfg in &router.balancers { + let resolved = resolve_backends_v6(cfg, core); + if resolved.is_empty() { + continue; + } + rules.push_str(&format!(" chain svc_{} {{\n", cfg.name)); + rules.push_str(&generate_dnat_map_v6(&resolved, cfg.protocol)); + rules.push_str(" }\n"); + } + + rules.push_str(" chain prerouting {\n"); + rules.push_str(" type nat hook prerouting priority dstnat - 5; policy accept;\n"); + for cfg in &router.balancers { + let resolved = resolve_backends_v6(cfg, core); + if resolved.is_empty() { + continue; + } + rules.push_str(&format!( + " ip6 daddr {} {} dport {} goto svc_{}\n", + wan_ip6, + cfg.protocol.nft_name(), + cfg.port, + cfg.name + )); + } + rules.push_str(" }\n"); + + rules.push_str(" chain postrouting {\n"); + rules.push_str(" type nat hook postrouting priority srcnat; policy accept;\n"); + rules.push_str(" ct status dnat masquerade\n"); + rules.push_str(" }\n"); + + rules.push_str("}\n"); + } + + Ok(rules) +} + +/// Resolves backend device IDs to IPv4 addresses. +fn resolve_backends_v4( + cfg: &BalancerConfig, + core: &crate::core::NetworkCore, +) -> Result> { + let mut out = Vec::new(); + for be in &cfg.backends { + let dev = core + .device(be.device_id) + .ok_or_else(|| anyhow!("backend device {} not found", be.device_id))?; + if let Some(ip) = dev.default_iface().ip { + out.push(ResolvedBackend { ip, port: be.port }); + } + } + Ok(out) +} + +/// Resolves backend device IDs to IPv6 addresses. +fn resolve_backends_v6( + cfg: &BalancerConfig, + core: &crate::core::NetworkCore, +) -> Vec { + let mut out = Vec::new(); + for be in &cfg.backends { + if let Some(dev) = core.device(be.device_id) { + if let Some(ip) = dev.default_iface().ip_v6 { + out.push(ResolvedBackendV6 { ip, port: be.port }); + } + } + } + out +} + +/// Generates the DNAT map rule for IPv4 backends. +/// +/// The `meta l4proto` match must appear on the same rule as the dnat +/// statement so nft can resolve the port part of the concatenated target. +fn generate_dnat_map_v4(backends: &[ResolvedBackend], proto: LbProtocol) -> String { + let proto_kw = proto.nft_name(); + if backends.len() == 1 { + return format!( + " meta l4proto {} dnat to {} : {}\n", + proto_kw, backends[0].ip, backends[0].port + ); + } + let mut s = format!( + " meta l4proto {} dnat to numgen inc mod {} map {{\n", + proto_kw, + backends.len() + ); + for (i, be) in backends.iter().enumerate() { + s.push_str(&format!(" {} : {} . {},\n", i, be.ip, be.port)); + } + s.push_str(" }\n"); + s +} + +/// Generates the DNAT map rule for IPv6 backends. +fn generate_dnat_map_v6(backends: &[ResolvedBackendV6], proto: LbProtocol) -> String { + let proto_kw = proto.nft_name(); + if backends.len() == 1 { + return format!( + " meta l4proto {} dnat to {} . {}\n", + proto_kw, backends[0].ip, backends[0].port + ); + } + let mut s = format!( + " meta l4proto {} dnat to numgen inc mod {} map {{\n", + proto_kw, + backends.len() + ); + for (i, be) in backends.iter().enumerate() { + s.push_str(&format!(" {} : {} . {},\n", i, be.ip, be.port)); + } + s.push_str(" }\n"); + s +} diff --git a/patchbay/src/core.rs b/patchbay/src/core.rs index 479e504..4a7d192 100644 --- a/patchbay/src/core.rs +++ b/patchbay/src/core.rs @@ -280,6 +280,8 @@ pub(crate) struct RouterData { pub ra_runtime: Arc, /// Per-router operation lock — serializes multi-step mutations. pub op: Arc>, + /// Active load balancer configurations. + pub balancers: Vec, } impl RouterData { @@ -802,6 +804,7 @@ impl NetworkCore { RA_DEFAULT_LIFETIME_SECS, )), op: Arc::new(tokio::sync::Mutex::new(())), + balancers: Vec::new(), }, ); id diff --git a/patchbay/src/lib.rs b/patchbay/src/lib.rs index 41c40c4..6fcf6f1 100644 --- a/patchbay/src/lib.rs +++ b/patchbay/src/lib.rs @@ -197,6 +197,8 @@ use anyhow::{anyhow, bail, Context, Result}; +/// L4 load balancer backed by nftables DNAT. +pub(crate) mod balancer; /// TOML configuration structures used by [`Lab::load`]. pub mod config; /// Shared filename constants for the run output directory. @@ -235,6 +237,7 @@ pub mod util; pub(crate) mod wiring; pub(crate) mod writer; +pub use balancer::{Balancer, BalancerBuilder, LbAlgorithm, LbProtocol}; pub use device::{Device, DeviceBuilder}; pub use firewall::PortPolicy; pub use iface::{Iface, IfaceConfig}; diff --git a/patchbay/src/router.rs b/patchbay/src/router.rs index 6878d13..e1bacff 100644 --- a/patchbay/src/router.rs +++ b/patchbay/src/router.rs @@ -100,7 +100,7 @@ pub struct Router { id: NodeId, name: Arc, ns: Arc, - lab: Arc, + pub(crate) lab: Arc, dispatch: tracing::Dispatch, } diff --git a/patchbay/src/tests/balancer.rs b/patchbay/src/tests/balancer.rs new file mode 100644 index 0000000..d3817d1 --- /dev/null +++ b/patchbay/src/tests/balancer.rs @@ -0,0 +1,454 @@ +//! L4 load balancer tests. + +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + time::Duration, +}; + +use anyhow::{Context, Result}; +use n0_tracing_test::traced_test; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, +}; +use tracing::debug; + +use super::*; +use crate::LbProtocol; + +/// Spawns a TCP server that replies with `name` on each accepted connection. +async fn spawn_named_tcp_server(bind: SocketAddr, name: &str) -> Result<()> { + let name = name.to_string(); + let (ready_tx, ready_rx) = oneshot::channel::>(); + tokio::spawn(async move { + match TcpListener::bind(bind).await { + Ok(listener) => { + let _ = ready_tx.send(Ok(())); + loop { + let Ok((mut stream, _)) = listener.accept().await else { + break; + }; + let msg = name.clone(); + let _ = stream.write_all(msg.as_bytes()).await; + } + } + Err(e) => { + let _ = ready_tx.send(Err(anyhow::anyhow!("tcp bind {bind}: {e}"))); + } + } + }); + ready_rx + .await + .map_err(|_| anyhow::anyhow!("server task dropped before ready"))? +} + +/// Spawns a UDP server that replies with `name` on each received datagram. +async fn spawn_named_udp_server(bind: SocketAddr, name: &str) -> Result<()> { + let name = name.to_string(); + let (ready_tx, ready_rx) = oneshot::channel::>(); + tokio::spawn(async move { + match UdpSocket::bind(bind).await { + Ok(sock) => { + let _ = ready_tx.send(Ok(())); + let mut buf = [0u8; 64]; + loop { + let Ok((n, peer)) = sock.recv_from(&mut buf).await else { + break; + }; + let _ = sock.send_to(name.as_bytes(), peer).await; + debug!(recv = n, %peer, "udp echo"); + } + } + Err(e) => { + let _ = ready_tx.send(Err(anyhow::anyhow!("udp bind {bind}: {e}"))); + } + } + }); + ready_rx + .await + .map_err(|_| anyhow::anyhow!("udp server task dropped before ready"))? +} + +/// TCP connect, read reply, return the reply string. +async fn tcp_query(target: SocketAddr) -> Result { + let timeout = Duration::from_millis(1000); + let mut stream = tokio::time::timeout(timeout, TcpStream::connect(target)) + .await + .context("tcp connect timeout")? + .context("tcp connect")?; + let mut buf = vec![0u8; 256]; + let n = tokio::time::timeout(timeout, stream.read(&mut buf)) + .await + .context("tcp read timeout")? + .context("tcp read")?; + Ok(String::from_utf8_lossy(&buf[..n]).to_string()) +} + +/// UDP send/recv, return the reply string. +async fn udp_query(target: SocketAddr) -> Result { + let timeout = Duration::from_millis(1000); + let sock = UdpSocket::bind("0.0.0.0:0").await?; + sock.send_to(b"hello", target).await?; + let mut buf = [0u8; 256]; + let (n, _) = tokio::time::timeout(timeout, sock.recv_from(&mut buf)) + .await + .context("udp recv timeout")? + .context("udp recv")?; + Ok(String::from_utf8_lossy(&buf[..n]).to_string()) +} + +/// 2 backends behind a public router. Client on a different router connects +/// to the VIP:port. Verify both backends get connections. +#[tokio::test(flavor = "current_thread")] +#[traced_test] +async fn round_robin_distribution() -> Result<()> { + check_caps()?; + let lab = Lab::new().await?; + + // DC router (public, no NAT) hosts the load balancer. + let dc = lab + .add_router("dc") + .preset(RouterPreset::Public) + .build() + .await?; + + // Two backend servers. + let web1 = lab + .add_device("web1") + .iface("eth0", dc.id()) + .build() + .await?; + let web2 = lab + .add_device("web2") + .iface("eth0", dc.id()) + .build() + .await?; + + let web1_ip = web1.ip().context("web1 has no ip")?; + let web2_ip = web2.ip().context("web2 has no ip")?; + + // Start named TCP servers on each backend. + web1.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(web1_ip), 8080), "web1").await + })? + .await + .context("web1 server task panicked")??; + + web2.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(web2_ip), 8080), "web2").await + })? + .await + .context("web2 server task panicked")??; + + // Build the balancer. + let lb = dc + .add_balancer("web", 80) + .backend(web1.id(), 8080) + .backend(web2.id(), 8080) + .round_robin() + .build() + .await?; + + let vip = lb.ip().context("lb has no VIP")?; + assert_eq!(vip, dc.uplink_ip().unwrap()); + assert_eq!(lb.port(), 80); + assert_eq!(lb.name(), "web"); + + // Client on a separate router. + let client_router = lab.add_router("client").build().await?; + let client = lab + .add_device("client") + .iface("eth0", client_router.id()) + .build() + .await?; + + // Make several connections and tally responses. + let target = SocketAddr::new(IpAddr::V4(vip), 80); + let mut counts: HashMap = HashMap::new(); + for _ in 0..6 { + let reply = client + .spawn(move |_| async move { tcp_query(target).await })? + .await + .context("client query panicked")??; + *counts.entry(reply).or_default() += 1; + } + + debug!(?counts, "round robin distribution"); + assert!(counts.contains_key("web1"), "web1 never received traffic"); + assert!(counts.contains_key("web2"), "web2 never received traffic"); + Ok(()) +} + +/// Client behind NAT, backends behind the LB router (private IPs). +/// Verify the client can reach the LB and gets balanced to backends. +#[tokio::test(flavor = "current_thread")] +#[traced_test] +async fn nat_client_to_lb() -> Result<()> { + check_caps()?; + let lab = Lab::new().await?; + + // DC router hosts the LB. + let dc = lab + .add_router("dc") + .preset(RouterPreset::Public) + .build() + .await?; + + // Home router with NAT for the client. + let home = lab.add_router("home").nat(Nat::Home).build().await?; + + // Backend servers behind dc. + let srv1 = lab + .add_device("srv1") + .iface("eth0", dc.id()) + .build() + .await?; + let srv2 = lab + .add_device("srv2") + .iface("eth0", dc.id()) + .build() + .await?; + + let srv1_ip = srv1.ip().context("srv1 has no ip")?; + let srv2_ip = srv2.ip().context("srv2 has no ip")?; + + srv1.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(srv1_ip), 9090), "srv1").await + })? + .await + .context("srv1 server task panicked")??; + + srv2.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(srv2_ip), 9090), "srv2").await + })? + .await + .context("srv2 server task panicked")??; + + // Build LB on dc. + let lb = dc + .add_balancer("api", 443) + .backend(srv1.id(), 9090) + .backend(srv2.id(), 9090) + .build() + .await?; + + let vip = lb.ip().context("lb has no VIP")?; + + // Client behind NAT. + let client = lab + .add_device("client") + .iface("eth0", home.id()) + .build() + .await?; + + let target = SocketAddr::new(IpAddr::V4(vip), 443); + let mut counts: HashMap = HashMap::new(); + for _ in 0..6 { + let reply = client + .spawn(move |_| async move { tcp_query(target).await })? + .await + .context("client query panicked")??; + *counts.entry(reply).or_default() += 1; + } + + debug!(?counts, "nat client to lb distribution"); + assert!(counts.contains_key("srv1"), "srv1 never received traffic"); + assert!(counts.contains_key("srv2"), "srv2 never received traffic"); + Ok(()) +} + +/// Add/remove backends at runtime, verify redistribution. +#[tokio::test(flavor = "current_thread")] +#[traced_test] +async fn backend_add_remove() -> Result<()> { + check_caps()?; + let lab = Lab::new().await?; + + let dc = lab + .add_router("dc") + .preset(RouterPreset::Public) + .build() + .await?; + + let web1 = lab + .add_device("web1") + .iface("eth0", dc.id()) + .build() + .await?; + let web2 = lab + .add_device("web2") + .iface("eth0", dc.id()) + .build() + .await?; + let web3 = lab + .add_device("web3") + .iface("eth0", dc.id()) + .build() + .await?; + + let web1_ip = web1.ip().context("no ip")?; + let web2_ip = web2.ip().context("no ip")?; + let web3_ip = web3.ip().context("no ip")?; + + web1.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(web1_ip), 8080), "web1").await + })? + .await??; + web2.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(web2_ip), 8080), "web2").await + })? + .await??; + web3.spawn(move |_| async move { + spawn_named_tcp_server(SocketAddr::new(IpAddr::V4(web3_ip), 8080), "web3").await + })? + .await??; + + // Start with 2 backends. + let lb = dc + .add_balancer("web", 80) + .backend(web1.id(), 8080) + .backend(web2.id(), 8080) + .build() + .await?; + + let vip = lb.ip().context("no VIP")?; + + let client_router = lab.add_router("client").build().await?; + let client = lab + .add_device("client") + .iface("eth0", client_router.id()) + .build() + .await?; + + let target = SocketAddr::new(IpAddr::V4(vip), 80); + + // Verify initial 2-backend distribution. + let mut counts: HashMap = HashMap::new(); + for _ in 0..4 { + let reply = client + .spawn(move |_| async move { tcp_query(target).await })? + .await??; + *counts.entry(reply).or_default() += 1; + } + assert!(counts.contains_key("web1")); + assert!(counts.contains_key("web2")); + + // Add web3. + lb.add_backend(web3.id(), 8080).await?; + + // Flush conntrack to ensure new rules take effect immediately. + dc.run_sync(|| { + let _ = std::process::Command::new("conntrack") + .args(["-F"]) + .status(); + Ok(()) + })?; + + let mut counts: HashMap = HashMap::new(); + for _ in 0..6 { + let reply = client + .spawn(move |_| async move { tcp_query(target).await })? + .await??; + *counts.entry(reply).or_default() += 1; + } + debug!(?counts, "after adding web3"); + assert!( + counts.contains_key("web3"), + "web3 should receive traffic after add" + ); + + // Remove web1. + lb.remove_backend(web1.id()).await?; + + dc.run_sync(|| { + let _ = std::process::Command::new("conntrack") + .args(["-F"]) + .status(); + Ok(()) + })?; + + let mut counts: HashMap = HashMap::new(); + for _ in 0..4 { + let reply = client + .spawn(move |_| async move { tcp_query(target).await })? + .await??; + *counts.entry(reply).or_default() += 1; + } + debug!(?counts, "after removing web1"); + assert!( + !counts.contains_key("web1"), + "web1 should not receive traffic after removal" + ); + + Ok(()) +} + +/// UDP round-robin balancing. +#[tokio::test(flavor = "current_thread")] +#[traced_test] +async fn udp_balancing() -> Result<()> { + check_caps()?; + let lab = Lab::new().await?; + + let dc = lab + .add_router("dc") + .preset(RouterPreset::Public) + .build() + .await?; + + let udp1 = lab + .add_device("udp1") + .iface("eth0", dc.id()) + .build() + .await?; + let udp2 = lab + .add_device("udp2") + .iface("eth0", dc.id()) + .build() + .await?; + + let udp1_ip = udp1.ip().context("no ip")?; + let udp2_ip = udp2.ip().context("no ip")?; + + udp1.spawn(move |_| async move { + spawn_named_udp_server(SocketAddr::new(IpAddr::V4(udp1_ip), 5000), "udp1").await + })? + .await??; + udp2.spawn(move |_| async move { + spawn_named_udp_server(SocketAddr::new(IpAddr::V4(udp2_ip), 5000), "udp2").await + })? + .await??; + + let _lb = dc + .add_balancer("dns", 53) + .backend(udp1.id(), 5000) + .backend(udp2.id(), 5000) + .protocol(LbProtocol::Udp) + .build() + .await?; + + let vip = dc.uplink_ip().context("no VIP")?; + + let client_router = lab.add_router("client").build().await?; + let client = lab + .add_device("client") + .iface("eth0", client_router.id()) + .build() + .await?; + + let target = SocketAddr::new(IpAddr::V4(vip), 53); + let mut counts: HashMap = HashMap::new(); + // Use different source ports to get different numgen slots. + for _ in 0..6 { + let reply = client + .spawn(move |_| async move { udp_query(target).await })? + .await??; + *counts.entry(reply).or_default() += 1; + } + + debug!(?counts, "udp distribution"); + assert!(counts.contains_key("udp1"), "udp1 never received traffic"); + assert!(counts.contains_key("udp2"), "udp2 never received traffic"); + Ok(()) +} diff --git a/patchbay/src/tests/mod.rs b/patchbay/src/tests/mod.rs index e7431b7..dfd89dc 100644 --- a/patchbay/src/tests/mod.rs +++ b/patchbay/src/tests/mod.rs @@ -34,6 +34,7 @@ use super::*; use crate::{check_caps, config}; mod alloc; +mod balancer; mod devtools; mod dns; mod firewall;