diff --git a/lib/virtual-net/src/client.rs b/lib/virtual-net/src/client.rs index 52d66ac9fc7..7b382615353 100644 --- a/lib/virtual-net/src/client.rs +++ b/lib/virtual-net/src/client.rs @@ -54,6 +54,7 @@ use crate::VirtualIoSource; use crate::VirtualNetworking; use crate::VirtualRawSocket; use crate::VirtualSocket; +use crate::VirtualTcpBoundSocket; use crate::VirtualTcpListener; use crate::VirtualTcpSocket; use crate::VirtualUdpSocket; @@ -276,6 +277,7 @@ impl RemoteNetworkingClient { buffer_accept: Default::default(), buffer_recv_with_addr: Default::default(), send_available: 0, + owns_socket_bindings: true, } } } @@ -760,6 +762,39 @@ impl VirtualNetworking for RemoteNetworkingClient { } } + async fn bind_tcp( + &self, + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + ) -> Result> { + let socket_id: SocketId = self + .common + .socket_seed + .fetch_add(1, Ordering::SeqCst) + .into(); + match self + .common + .io_iface(RequestType::BindTcp { + socket_id, + addr, + only_v6, + reuse_port, + reuse_addr, + }) + .await + { + ResponseType::Err(err) => Err(err), + ResponseType::None => Ok(Box::new(self.new_socket(socket_id))), + ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))), + res => { + tracing::debug!("invalid response to bind TCP request - {res:?}"); + Err(NetworkError::IOError) + } + } + } + async fn bind_udp( &self, addr: SocketAddr, @@ -880,19 +915,40 @@ struct RemoteSocket { buffer_recv_with_addr: VecDeque, buffer_accept: VecDeque, send_available: u64, + owns_socket_bindings: bool, } impl Drop for RemoteSocket { fn drop(&mut self) { + if !self.owns_socket_bindings { + return; + } + let _ = self.io_socket_fire_and_forget(RequestType::Close); + self.release_socket_bindings(); + } +} + +impl RemoteSocket { + fn release_socket_bindings(&mut self) { + self.owns_socket_bindings = false; self.common.recv_tx.lock().unwrap().remove(&self.socket_id); self.common .recv_with_addr_tx .lock() .unwrap() .remove(&self.socket_id); + self.common + .accept_tx + .lock() + .unwrap() + .remove(&self.socket_id); + self.common.sent_tx.lock().unwrap().remove(&self.socket_id); + self.common.handlers.lock().unwrap().remove(&self.socket_id); + + if let Some((child_id, _)) = self.pending_accept.take() { + self.common.recv_tx.lock().unwrap().remove(&child_id); + } } -} -impl RemoteSocket { async fn io_socket(&self, req: RequestType) -> ResponseType { let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst); let mut req_rx = { @@ -941,6 +997,31 @@ impl RemoteSocket { self.pending_accept.replace((child_id, rx_recv)); Ok(()) } + + fn transition_socket(&mut self) -> RemoteSocket { + let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1); + let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1); + let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1); + let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1); + + self.owns_socket_bindings = false; + + RemoteSocket { + socket_id: self.socket_id, + common: self.common.clone(), + rx_buffer: std::mem::take(&mut self.rx_buffer), + rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv), + rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr), + tx_waker: self.tx_waker.clone(), + rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept), + rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent), + pending_accept: self.pending_accept.take(), + buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr), + buffer_accept: std::mem::take(&mut self.buffer_accept), + send_available: self.send_available, + owns_socket_bindings: true, + } + } } impl VirtualIoSource for RemoteSocket { @@ -1121,6 +1202,7 @@ impl VirtualTcpListener for RemoteSocket { buffer_accept: Default::default(), buffer_recv_with_addr: Default::default(), send_available: 0, + owns_socket_bindings: true, }; Ok((Box::new(socket), accepted.addr)) } @@ -1159,6 +1241,46 @@ impl VirtualTcpListener for RemoteSocket { } } +impl VirtualTcpBoundSocket for RemoteSocket { + fn addr_local(&self) -> Result { + VirtualSocket::addr_local(self) + } + + fn listen(&mut self) -> Result> { + match block_on(self.io_socket(RequestType::ListenBound)) { + ResponseType::Err(err) => Err(err), + ResponseType::None => { + let mut socket = self.transition_socket(); + socket.touch_begin_accept().ok(); + Ok(Box::new(socket)) + } + res => { + tracing::debug!("invalid response to listen bound request - {res:?}"); + Err(NetworkError::IOError) + } + } + } + + fn connect(&mut self, peer: SocketAddr) -> Result> { + match block_on(self.io_socket(RequestType::ConnectBound { peer })) { + ResponseType::Err(err) => Err(err), + ResponseType::None => Ok(Box::new(self.transition_socket())), + res => { + tracing::debug!("invalid response to connect bound request - {res:?}"); + Err(NetworkError::IOError) + } + } + } + + fn set_ttl(&mut self, ttl: u32) -> Result<()> { + VirtualSocket::set_ttl(self, ttl) + } + + fn ttl(&self) -> Result { + VirtualSocket::ttl(self) + } +} + impl VirtualRawSocket for RemoteSocket { fn try_send(&mut self, data: &[u8]) -> Result { let mut cx = Context::from_waker(&self.tx_waker); @@ -1431,7 +1553,11 @@ impl VirtualConnectedSocket for RemoteSocket { } fn close(&mut self) -> Result<()> { - self.io_socket_fire_and_forget(RequestType::Close) + let ret = self.io_socket_fire_and_forget(RequestType::Close); + if ret.is_ok() { + self.release_socket_bindings(); + } + ret } fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit], peek: bool) -> Result { diff --git a/lib/virtual-net/src/host.rs b/lib/virtual-net/src/host.rs index 3f1af0ff481..4549a1a8d95 100644 --- a/lib/virtual-net/src/host.rs +++ b/lib/virtual-net/src/host.rs @@ -4,7 +4,7 @@ use crate::ruleset::{Direction, Ruleset}; use crate::{ IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket, VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, - VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, + VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, }; use crate::{VirtualIoSource, io_err_into_net_error}; use bytes::{Buf, BytesMut}; @@ -29,6 +29,14 @@ use virtual_mio::{ HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map, }; +/// Use the platform's maximum listen backlog where available so that +/// `LocalTcpBoundSocket::listen` preserves the same accept capacity as +/// the previous `std::net::TcpListener`-based implementation. +#[cfg(all(target_family = "unix", feature = "libc"))] +const LISTEN_BACKLOG: i32 = libc::SOMAXCONN; +#[cfg(not(all(target_family = "unix", feature = "libc")))] +const LISTEN_BACKLOG: i32 = 128; + #[derive(Debug)] pub struct LocalNetworking { selector: Arc, @@ -66,6 +74,41 @@ impl Default for LocalNetworking { } } +fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result { + addr.as_socket().ok_or(NetworkError::UnknownError) +} + +fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain { + if addr.is_ipv4() { + socket2::Domain::IPV4 + } else { + socket2::Domain::IPV6 + } +} + +#[allow(clippy::needless_bool)] +fn tcp_connect_in_progress(err: &io::Error) -> bool { + if matches!( + err.kind(), + io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted + ) { + true + } else { + #[cfg(all(target_family = "unix", feature = "libc"))] + { + matches!( + err.raw_os_error(), + Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY + ) + } + + #[cfg(not(all(target_family = "unix", feature = "libc")))] + { + false + } + } +} + #[async_trait::async_trait] #[allow(unused_variables)] impl VirtualNetworking for LocalNetworking { @@ -83,21 +126,47 @@ impl VirtualNetworking for LocalNetworking { return Err(NetworkError::PermissionDenied); } - let listener = std::net::TcpListener::bind(addr) - .map(|sock| { - sock.set_nonblocking(true).ok(); - Box::new(LocalTcpListener { - stream: mio::net::TcpListener::from_std(sock), - selector: self.selector.clone(), - handler_guard: HandlerGuardState::None, - no_delay: None, - keep_alive: None, - backlog: Default::default(), - ruleset: self.ruleset.clone(), - }) - }) + self.bind_tcp(addr, only_v6, reuse_port, reuse_addr) + .await? + .listen() + } + + async fn bind_tcp( + &self, + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + ) -> Result> { + if let Some(ruleset) = self.ruleset.as_ref() + && !ruleset.allows_socket(addr, Direction::Inbound) + { + tracing::warn!(%addr, "bind_tcp blocked by firewall rule"); + return Err(NetworkError::PermissionDenied); + } + + let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None) .map_err(io_err_into_net_error)?; - Ok(listener) + socket + .set_nonblocking(true) + .map_err(io_err_into_net_error)?; + if addr.is_ipv6() { + socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?; + } + socket + .set_reuse_address(reuse_addr) + .map_err(io_err_into_net_error)?; + #[cfg(not(windows))] + socket + .set_reuse_port(reuse_port) + .map_err(io_err_into_net_error)?; + socket.bind(&addr.into()).map_err(io_err_into_net_error)?; + + Ok(Box::new(LocalTcpBoundSocket { + socket: Some(socket), + selector: self.selector.clone(), + ruleset: self.ruleset.clone(), + })) } async fn bind_udp( @@ -377,6 +446,82 @@ impl VirtualIoSource for LocalTcpListener { } } +#[derive(Debug)] +pub struct LocalTcpBoundSocket { + socket: Option, + selector: Arc, + ruleset: Option, +} + +impl VirtualTcpBoundSocket for LocalTcpBoundSocket { + fn addr_local(&self) -> Result { + let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?; + let addr = socket.local_addr().map_err(io_err_into_net_error)?; + sock_addr_into_socket_addr(addr) + } + + fn listen(&mut self) -> Result> { + let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?; + socket + .listen(LISTEN_BACKLOG) + .map_err(io_err_into_net_error)?; + let listener = mio::net::TcpListener::from_std(socket.into()); + Ok(Box::new(LocalTcpListener { + stream: listener, + selector: self.selector.clone(), + handler_guard: HandlerGuardState::None, + no_delay: None, + keep_alive: None, + backlog: Default::default(), + ruleset: self.ruleset.clone(), + })) + } + + fn connect(&mut self, mut peer: SocketAddr) -> Result> { + if let Some(ruleset) = self.ruleset.as_ref() + && !ruleset.allows_socket(peer, Direction::Outbound) + { + tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule"); + return Err(NetworkError::PermissionDenied); + } + + let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?; + if let Err(err) = socket.connect(&peer.into()) + && !tcp_connect_in_progress(&err) + { + return Err(io_err_into_net_error(err)); + } + + let stream = mio::net::TcpStream::from_std(socket.into()); + if let Ok(p) = stream.peer_addr() { + peer = p; + } + Ok(Box::new(LocalTcpStream::new( + self.selector.clone(), + stream, + peer, + ))) + } + + fn set_ttl(&mut self, ttl: u32) -> Result<()> { + let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?; + match self.addr_local()?.ip() { + IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error), + IpAddr::V6(_) => socket + .set_unicast_hops_v6(ttl) + .map_err(io_err_into_net_error), + } + } + + fn ttl(&self) -> Result { + let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?; + match self.addr_local()?.ip() { + IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error), + IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error), + } + } +} + #[derive(Debug)] enum ConnectState { Unknown, diff --git a/lib/virtual-net/src/lib.rs b/lib/virtual-net/src/lib.rs index 22ae000a499..6a2db323649 100644 --- a/lib/virtual-net/src/lib.rs +++ b/lib/virtual-net/src/lib.rs @@ -183,6 +183,18 @@ pub trait VirtualNetworking: fmt::Debug + Send + Sync + 'static { Err(NetworkError::Unsupported) } + /// Binds a TCP socket to a specific IP and port without immediately + /// listening for connections or connecting to a peer. + async fn bind_tcp( + &self, + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + ) -> Result> { + Err(NetworkError::Unsupported) + } + /// Opens a UDP socket that listens on a specific IP and Port combination /// Multiple servers (processes or threads) can bind to the same port if they each set /// the reuse-port and-or reuse-addr flags @@ -241,6 +253,23 @@ pub trait VirtualTcpListener: VirtualIoSource + fmt::Debug + Send + Sync + 'stat fn ttl(&self) -> Result; } +pub trait VirtualTcpBoundSocket: fmt::Debug + Send + Sync + 'static { + /// Returns the local address of this bound TCP socket. + fn addr_local(&self) -> Result; + + /// Places the socket into listening mode. + fn listen(&mut self) -> Result>; + + /// Initiates a TCP connection using the already-bound local address. + fn connect(&mut self, peer: SocketAddr) -> Result>; + + /// Sets how many network hops the packets are permitted for this socket. + fn set_ttl(&mut self, ttl: u32) -> Result<()>; + + /// Returns the maximum number of network hops before packets are dropped. + fn ttl(&self) -> Result; +} + #[async_trait::async_trait] pub trait VirtualTcpListenerExt: VirtualTcpListener { /// Accepts a new connection from the TCP listener diff --git a/lib/virtual-net/src/loopback.rs b/lib/virtual-net/src/loopback.rs index b59519790bc..23f43ad25cb 100644 --- a/lib/virtual-net/src/loopback.rs +++ b/lib/virtual-net/src/loopback.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{HashSet, VecDeque}; use std::net::SocketAddr; use std::sync::Mutex; use std::task::{Context, Poll, Waker}; @@ -6,17 +6,32 @@ use std::{collections::HashMap, sync::Arc}; use crate::tcp_pair::TcpSocketHalf; use crate::{ - InterestHandler, IpAddr, IpCidr, Ipv4Addr, Ipv6Addr, NetworkError, VirtualIoSource, - VirtualNetworking, VirtualTcpListener, VirtualTcpSocket, + InterestHandler, IpAddr, IpCidr, Ipv4Addr, Ipv6Addr, NetworkError, VirtualConnectedSocket, + VirtualIoSource, VirtualNetworking, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, + VirtualTcpSocket, }; use virtual_mio::InterestType; const DEFAULT_MAX_BUFFER_SIZE: usize = 1_048_576; +const LOOPBACK_EPHEMERAL_PORT_START: u16 = 49152; -#[derive(Debug, Default)] +#[derive(Debug)] struct LoopbackNetworkingState { tcp_listeners: HashMap, + tcp_bound: HashSet, ip_addresses: Vec, + next_ephemeral_port: u16, +} + +impl Default for LoopbackNetworkingState { + fn default() -> Self { + Self { + tcp_listeners: HashMap::new(), + tcp_bound: HashSet::new(), + ip_addresses: Vec::new(), + next_ephemeral_port: LOOPBACK_EPHEMERAL_PORT_START, + } + } } #[derive(Debug, Clone)] @@ -62,6 +77,61 @@ impl LoopbackNetworking { .map(|listener| listener.1.connect_to(local_addr)) } } + + fn allocate_tcp_bind_addr( + state: &mut LoopbackNetworkingState, + mut addr: SocketAddr, + ) -> crate::Result { + let is_available = |candidate: SocketAddr, state: &LoopbackNetworkingState| { + let key = Self::normalize_listener_addr(candidate); + !state.tcp_listeners.contains_key(&key) && !state.tcp_bound.contains(&key) + }; + + if addr.port() == 0 { + let start = state.next_ephemeral_port; + let mut candidate = start; + loop { + let candidate_addr = SocketAddr::new(addr.ip(), candidate); + if is_available(candidate_addr, state) { + addr.set_port(candidate); + state.tcp_bound.insert(Self::normalize_listener_addr(addr)); + state.next_ephemeral_port = if candidate == u16::MAX { + LOOPBACK_EPHEMERAL_PORT_START + } else { + candidate + 1 + }; + return Ok(addr); + } + + candidate = if candidate == u16::MAX { + LOOPBACK_EPHEMERAL_PORT_START + } else { + candidate + 1 + }; + if candidate == start { + return Err(NetworkError::AddressInUse); + } + } + } + + let reservation_key = Self::normalize_listener_addr(addr); + if state.tcp_listeners.contains_key(&reservation_key) + || state.tcp_bound.contains(&reservation_key) + { + return Err(NetworkError::AddressInUse); + } + state.tcp_bound.insert(reservation_key); + Ok(addr) + } + + fn normalize_listener_addr(mut addr: SocketAddr) -> SocketAddr { + if addr.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) { + addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), addr.port()); + } else if addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) { + addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), addr.port()); + } + addr + } } impl Default for LoopbackNetworking { @@ -115,23 +185,195 @@ impl VirtualNetworking for LoopbackNetworking { async fn listen_tcp( &self, - mut addr: SocketAddr, + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + ) -> crate::Result> { + self.bind_tcp(addr, only_v6, reuse_port, reuse_addr) + .await? + .listen() + } + + async fn bind_tcp( + &self, + addr: SocketAddr, _only_v6: bool, _reuse_port: bool, _reuse_addr: bool, - ) -> crate::Result> { - let listener = LoopbackTcpListener::new(addr); + ) -> crate::Result> { + let mut state = self.state.lock().unwrap(); + let addr = Self::allocate_tcp_bind_addr(&mut state, addr)?; + Ok(Box::new(LoopbackTcpBoundSocket { + networking: self.clone(), + local_addr: addr, + reservation_key: Some(Self::normalize_listener_addr(addr)), + ttl: 64, + })) + } +} - if addr.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED) { - addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), addr.port()); - } else if addr.ip() == IpAddr::V6(Ipv6Addr::UNSPECIFIED) { - addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), addr.port()); +#[cfg(test)] +impl LoopbackNetworking { + pub(crate) fn exhaust_tcp_ephemeral_ports_for_test(&self, ip: IpAddr) { + let mut state = self.state.lock().unwrap(); + for port in LOOPBACK_EPHEMERAL_PORT_START..=u16::MAX { + let addr = SocketAddr::new(ip, port); + state + .tcp_listeners + .insert(addr, LoopbackTcpListener::new(addr, 64)); } + state.next_ephemeral_port = LOOPBACK_EPHEMERAL_PORT_START; + } +} - let mut state = self.state.lock().unwrap(); - state.tcp_listeners.insert(addr, listener.clone()); +/// A connected TCP socket that keeps its local-port reservation in +/// `LoopbackNetworkingState::tcp_bound` until it is explicitly closed or +/// dropped, matching POSIX/Linux semantics where a connected socket holds +/// its local port for its entire lifetime. +#[derive(Debug)] +struct LoopbackConnectedSocket { + inner: TcpSocketHalf, + networking: LoopbackNetworking, + /// `None` once the reservation has been released (after `close()` or `drop`). + reservation_key: Option, +} - Ok(Box::new(listener)) +impl LoopbackConnectedSocket { + fn release_reservation(&mut self) { + if let Some(key) = self.reservation_key.take() { + self.networking.state.lock().unwrap().tcp_bound.remove(&key); + } + } +} + +impl Drop for LoopbackConnectedSocket { + fn drop(&mut self) { + self.release_reservation(); + } +} + +impl VirtualIoSource for LoopbackConnectedSocket { + fn remove_handler(&mut self) { + self.inner.remove_handler(); + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_read_ready(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_write_ready(cx) + } +} + +impl VirtualSocket for LoopbackConnectedSocket { + fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> { + self.inner.set_ttl(ttl) + } + + fn ttl(&self) -> crate::Result { + self.inner.ttl() + } + + fn addr_local(&self) -> crate::Result { + self.inner.addr_local() + } + + fn status(&self) -> crate::Result { + self.inner.status() + } + + fn set_handler( + &mut self, + handler: Box, + ) -> crate::Result<()> { + self.inner.set_handler(handler) + } +} + +impl VirtualConnectedSocket for LoopbackConnectedSocket { + fn set_linger(&mut self, linger: Option) -> crate::Result<()> { + self.inner.set_linger(linger) + } + + fn linger(&self) -> crate::Result> { + self.inner.linger() + } + + fn try_send(&mut self, data: &[u8]) -> crate::Result { + self.inner.try_send(data) + } + + fn try_flush(&mut self) -> crate::Result<()> { + self.inner.try_flush() + } + + fn close(&mut self) -> crate::Result<()> { + self.release_reservation(); + self.inner.close() + } + + fn try_recv( + &mut self, + buf: &mut [std::mem::MaybeUninit], + peek: bool, + ) -> crate::Result { + self.inner.try_recv(buf, peek) + } +} + +impl VirtualTcpSocket for LoopbackConnectedSocket { + fn set_recv_buf_size(&mut self, size: usize) -> crate::Result<()> { + self.inner.set_recv_buf_size(size) + } + + fn recv_buf_size(&self) -> crate::Result { + self.inner.recv_buf_size() + } + + fn set_send_buf_size(&mut self, size: usize) -> crate::Result<()> { + self.inner.set_send_buf_size(size) + } + + fn send_buf_size(&self) -> crate::Result { + self.inner.send_buf_size() + } + + fn set_nodelay(&mut self, reuse: bool) -> crate::Result<()> { + self.inner.set_nodelay(reuse) + } + + fn nodelay(&self) -> crate::Result { + self.inner.nodelay() + } + + fn set_keepalive(&mut self, keepalive: bool) -> crate::Result<()> { + self.inner.set_keepalive(keepalive) + } + + fn keepalive(&self) -> crate::Result { + self.inner.keepalive() + } + + fn set_dontroute(&mut self, dontroute: bool) -> crate::Result<()> { + self.inner.set_dontroute(dontroute) + } + + fn dontroute(&self) -> crate::Result { + self.inner.dontroute() + } + + fn addr_peer(&self) -> crate::Result { + self.inner.addr_peer() + } + + fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> { + self.inner.shutdown(how) + } + + fn is_closed(&self) -> bool { + self.inner.is_closed() } } @@ -139,6 +381,7 @@ impl VirtualNetworking for LoopbackNetworking { struct LoopbackTcpListenerState { handler: Option>, addr_local: SocketAddr, + ttl: u8, backlog: VecDeque, wakers: Vec, } @@ -149,11 +392,12 @@ pub struct LoopbackTcpListener { } impl LoopbackTcpListener { - pub fn new(addr_local: SocketAddr) -> Self { + pub fn new(addr_local: SocketAddr, ttl: u8) -> Self { Self { state: Arc::new(Mutex::new(LoopbackTcpListenerState { handler: None, addr_local, + ttl, backlog: Default::default(), wakers: Default::default(), })), @@ -162,8 +406,9 @@ impl LoopbackTcpListener { pub fn connect_to(&self, addr_local: SocketAddr) -> TcpSocketHalf { let mut state = self.state.lock().unwrap(); - let (half1, half2) = + let (mut half1, half2) = TcpSocketHalf::channel(DEFAULT_MAX_BUFFER_SIZE, state.addr_local, addr_local); + half1.set_ttl(u32::from(state.ttl)).ok(); state.backlog.push_back(half1); if let Some(handler) = state.handler.as_mut() { @@ -227,11 +472,83 @@ impl VirtualTcpListener for LoopbackTcpListener { Ok(state.addr_local) } - fn set_ttl(&mut self, _ttl: u8) -> crate::Result<()> { + fn set_ttl(&mut self, ttl: u8) -> crate::Result<()> { + let mut state = self.state.lock().unwrap(); + state.ttl = ttl; Ok(()) } fn ttl(&self) -> crate::Result { - Ok(64) + let state = self.state.lock().unwrap(); + Ok(state.ttl) + } +} + +#[derive(Debug)] +pub struct LoopbackTcpBoundSocket { + networking: LoopbackNetworking, + local_addr: SocketAddr, + reservation_key: Option, + ttl: u32, +} + +impl Drop for LoopbackTcpBoundSocket { + fn drop(&mut self) { + if let Some(reservation_key) = self.reservation_key.take() { + let mut state = self.networking.state.lock().unwrap(); + state.tcp_bound.remove(&reservation_key); + } + } +} + +impl VirtualTcpBoundSocket for LoopbackTcpBoundSocket { + fn addr_local(&self) -> crate::Result { + Ok(self.local_addr) + } + + fn listen(&mut self) -> crate::Result> { + let listener = + LoopbackTcpListener::new(self.local_addr, u8::try_from(self.ttl).unwrap_or(u8::MAX)); + let mut state = self.networking.state.lock().unwrap(); + let reservation_key = self.reservation_key.ok_or(NetworkError::InvalidFd)?; + if !state.tcp_bound.remove(&reservation_key) { + return Err(NetworkError::InvalidFd); + } + if state.tcp_listeners.contains_key(&reservation_key) { + state.tcp_bound.insert(reservation_key); + return Err(NetworkError::AddressInUse); + } + state + .tcp_listeners + .insert(reservation_key, listener.clone()); + self.reservation_key = None; + Ok(Box::new(listener)) + } + + fn connect(&mut self, peer: SocketAddr) -> crate::Result> { + let mut socket = self + .networking + .loopback_connect_to(self.local_addr, peer) + .ok_or(NetworkError::ConnectionRefused)?; + // Transfer the port reservation to the connected socket so that the + // local port stays in `tcp_bound` for the socket's entire lifetime, + // matching POSIX/Linux semantics (a connected socket holds its local + // port; rebinding it returns EADDRINUSE). + let reservation_key = self.reservation_key.take().ok_or(NetworkError::InvalidFd)?; + socket.set_ttl(self.ttl)?; + Ok(Box::new(LoopbackConnectedSocket { + inner: socket, + networking: self.networking.clone(), + reservation_key: Some(reservation_key), + })) + } + + fn set_ttl(&mut self, ttl: u32) -> crate::Result<()> { + self.ttl = ttl; + Ok(()) + } + + fn ttl(&self) -> crate::Result { + Ok(self.ttl) } } diff --git a/lib/virtual-net/src/meta.rs b/lib/virtual-net/src/meta.rs index 7a8f516a8c5..1c62af53353 100644 --- a/lib/virtual-net/src/meta.rs +++ b/lib/virtual-net/src/meta.rs @@ -97,6 +97,14 @@ pub enum RequestType { reuse_port: bool, reuse_addr: bool, }, + /// Binds a TCP socket without immediately listening or connecting. + BindTcp { + socket_id: SocketId, + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + }, /// Opens a UDP socket that listens on a specific IP and Port combination /// Multiple servers (processes or threads) can bind to the same port if they each set /// the reuse-port and-or reuse-addr flags @@ -123,6 +131,10 @@ pub enum RequestType { }, /// Closes the socket Close, + /// Converts a bound TCP socket into a listening socket. + ListenBound, + /// Converts a bound TCP socket into a connected TCP stream. + ConnectBound { peer: SocketAddr }, /// Begins the process of accepting a socket and returns it later BeginAccept(SocketId), /// Returns the local address of this TCP listener diff --git a/lib/virtual-net/src/server.rs b/lib/virtual-net/src/server.rs index 7a0b8a6c6fe..e036e55512f 100644 --- a/lib/virtual-net/src/server.rs +++ b/lib/virtual-net/src/server.rs @@ -1,8 +1,9 @@ use crate::meta::{FrameSerializationFormat, ResponseType}; use crate::rx_tx::{RemoteRx, RemoteTx, RemoteTxWakers}; -use crate::{IpCidr, IpRoute, NetworkError, StreamSecurity, VirtualIcmpSocket}; +use crate::{IpCidr, IpRoute, NetworkError, SocketStatus, StreamSecurity, VirtualIcmpSocket}; use crate::{ - VirtualNetworking, VirtualRawSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, + VirtualNetworking, VirtualRawSocket, VirtualTcpBoundSocket, VirtualTcpListener, + VirtualTcpSocket, VirtualUdpSocket, meta::{MessageRequest, MessageResponse, RequestType, SocketId}, }; use futures_util::stream::FuturesOrdered; @@ -187,6 +188,11 @@ impl RemoteNetworkingServer { let rx = RemoteRx::HyperWebSocket { rx, format }; Self::new(tx, rx, rx_work, inner) } + + #[cfg(test)] + pub(crate) fn socket_count_for_test(&self) -> usize { + self.common.sockets.lock().unwrap().len() + } } #[async_trait::async_trait] @@ -534,6 +540,7 @@ impl RemoteNetworkingServerDriver { // a child ID we can actually use Ok(()) } + RemoteAdapterSocket::BoundTcp(_) => Ok(()), RemoteAdapterSocket::TcpSocket(s) => s.set_handler(handler), RemoteAdapterSocket::UdpSocket(s) => s.set_handler(handler), RemoteAdapterSocket::IcmpSocket(s) => s.set_handler(handler), @@ -756,6 +763,23 @@ impl RemoteNetworkingServerDriver { socket_id, req_id, ), + RequestType::BindTcp { + socket_id, + addr, + only_v6, + reuse_port, + reuse_addr, + } => self.process_async_new_socket( + move |inner: Arc| async move { + Ok(RemoteAdapterSocket::BoundTcp( + inner + .bind_tcp(addr, only_v6, reuse_port, reuse_addr) + .await?, + )) + }, + socket_id, + req_id, + ), RequestType::ListenTcp { socket_id, addr, @@ -849,19 +873,89 @@ impl RemoteNetworkingServerDriver { socket_id, req_id, ), - RequestType::Close => self.process_inner_noop( - move |socket| match socket { - RemoteAdapterSocket::TcpSocket(s) => s.close(), - _ => Err(NetworkError::Unsupported), - }, - socket_id, - req_id, - ), + RequestType::Close => { + let res = { + let mut guard = self.common.sockets.lock().unwrap(); + self.common.socket_accept.lock().unwrap().remove(&socket_id); + match guard.remove(&socket_id) { + Some(RemoteAdapterSocket::TcpSocket(mut socket)) => socket.close(), + Some(_) => Ok(()), + None => Err(NetworkError::InvalidFd), + } + }; + req_id.and_then(|req_id| { + self.common.send(MessageResponse::ResponseToRequest { + req_id, + res: match res { + Ok(()) => ResponseType::None, + Err(err) => ResponseType::Err(err), + }, + }) + }) + } + RequestType::ListenBound => { + let res = { + let mut guard = self.common.sockets.lock().unwrap(); + match guard.get_mut(&socket_id) { + Some(socket) => match socket { + RemoteAdapterSocket::BoundTcp(bound) => match bound.listen() { + Ok(listener) => { + *socket = RemoteAdapterSocket::TcpListener { + socket: listener, + next_accept: None, + }; + Ok(()) + } + Err(err) => Err(err), + }, + _ => Err(NetworkError::Unsupported), + }, + _ => Err(NetworkError::Unsupported), + } + }; + req_id.and_then(|req_id| { + self.common.send(MessageResponse::ResponseToRequest { + req_id, + res: match res { + Ok(()) => ResponseType::None, + Err(err) => ResponseType::Err(err), + }, + }) + }) + } + RequestType::ConnectBound { peer } => { + let res = { + let mut guard = self.common.sockets.lock().unwrap(); + match guard.get_mut(&socket_id) { + Some(socket) => match socket { + RemoteAdapterSocket::BoundTcp(bound) => match bound.connect(peer) { + Ok(connected) => { + *socket = RemoteAdapterSocket::TcpSocket(connected); + Ok(()) + } + Err(err) => Err(err), + }, + _ => Err(NetworkError::Unsupported), + }, + _ => Err(NetworkError::Unsupported), + } + }; + req_id.and_then(|req_id| { + self.common.send(MessageResponse::ResponseToRequest { + req_id, + res: match res { + Ok(()) => ResponseType::None, + Err(err) => ResponseType::Err(err), + }, + }) + }) + } RequestType::BeginAccept(child_id) => { self.process_inner_begin_accept(socket_id, child_id, req_id) } RequestType::GetAddrLocal => self.process_inner( move |socket| match socket { + RemoteAdapterSocket::BoundTcp(s) => s.addr_local(), RemoteAdapterSocket::TcpSocket(s) => s.addr_local(), RemoteAdapterSocket::TcpListener { socket: s, .. } => s.addr_local(), RemoteAdapterSocket::UdpSocket(s) => s.addr_local(), @@ -877,6 +971,7 @@ impl RemoteNetworkingServerDriver { ), RequestType::GetAddrPeer => self.process_inner( move |socket| match socket { + RemoteAdapterSocket::BoundTcp(_) => Err(NetworkError::Unsupported), RemoteAdapterSocket::TcpSocket(s) => s.addr_peer().map(Some), RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported), RemoteAdapterSocket::UdpSocket(s) => s.addr_peer(), @@ -893,6 +988,7 @@ impl RemoteNetworkingServerDriver { ), RequestType::SetTtl(ttl) => self.process_inner_noop( move |socket| match socket { + RemoteAdapterSocket::BoundTcp(s) => s.set_ttl(ttl), RemoteAdapterSocket::TcpSocket(s) => s.set_ttl(ttl), RemoteAdapterSocket::TcpListener { socket: s, .. } => { s.set_ttl(ttl.try_into().unwrap_or_default()) @@ -906,6 +1002,7 @@ impl RemoteNetworkingServerDriver { ), RequestType::GetTtl => self.process_inner( move |socket| match socket { + RemoteAdapterSocket::BoundTcp(s) => s.ttl(), RemoteAdapterSocket::TcpSocket(s) => s.ttl(), RemoteAdapterSocket::TcpListener { socket: s, .. } => s.ttl().map(|t| t as u32), RemoteAdapterSocket::UdpSocket(s) => s.ttl(), @@ -921,6 +1018,7 @@ impl RemoteNetworkingServerDriver { ), RequestType::GetStatus => self.process_inner( move |socket| match socket { + RemoteAdapterSocket::BoundTcp(_) => Ok(SocketStatus::Opened), RemoteAdapterSocket::TcpSocket(s) => s.status(), RemoteAdapterSocket::TcpListener { .. } => Err(NetworkError::Unsupported), RemoteAdapterSocket::UdpSocket(s) => s.status(), @@ -1227,6 +1325,7 @@ impl RemoteNetworkingServerDriver { #[derive(Debug)] enum RemoteAdapterSocket { + BoundTcp(Box), TcpListener { socket: Box, next_accept: Option, @@ -1414,6 +1513,7 @@ impl RemoteAdapterSocket { let mut ret: FuturesOrdered> = Default::default(); loop { break match self { + Self::BoundTcp(_) => {} Self::TcpListener { socket, next_accept, diff --git a/lib/virtual-net/src/tests.rs b/lib/virtual-net/src/tests.rs index 9ed1a46d8e1..d75ae630348 100644 --- a/lib/virtual-net/src/tests.rs +++ b/lib/virtual-net/src/tests.rs @@ -122,6 +122,98 @@ async fn test_tcp(client: RemoteNetworkingClient, _server: RemoteNetworkingServe tracing::info!("all good"); } +#[cfg(feature = "remote")] +async fn test_bound_tcp(client: RemoteNetworkingClient, _server: RemoteNetworkingServer) { + let mut bound = client + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + let addr_after_bind = bound.addr_local().unwrap(); + assert_ne!( + addr_after_bind.port(), + 0, + "remote bind_tcp should allocate a real ephemeral port before listen" + ); + + let listener = bound.listen().unwrap(); + let addr_after_listen = listener.addr_local().unwrap(); + assert_eq!( + addr_after_listen, addr_after_bind, + "remote listen should preserve the already-bound local address" + ); +} + +#[cfg(feature = "remote")] +async fn test_bound_tcp_ttl(client: RemoteNetworkingClient, _server: RemoteNetworkingServer) { + let mut bound = client + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + bound.set_ttl(42).unwrap(); + assert_eq!( + bound.ttl().unwrap(), + 42, + "remote bound_tcp should round-trip TTL before listen" + ); + + let listener = bound.listen().unwrap(); + assert_eq!( + listener.ttl().unwrap(), + 42, + "remote listener should preserve TTL set while the socket was only bound" + ); +} + +#[cfg(feature = "remote")] +async fn test_bound_tcp_drop_releases_server_socket( + client: RemoteNetworkingClient, + server: RemoteNetworkingServer, +) { + use tokio::time::{Duration, Instant, sleep}; + + let bound = client + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + assert_eq!( + server.socket_count_for_test(), + 1, + "server should retain the bound socket until the client drops it" + ); + + drop(bound); + + let deadline = Instant::now() + Duration::from_secs(1); + loop { + if server.socket_count_for_test() == 0 { + break; + } + assert!( + Instant::now() < deadline, + "server retained a dropped bound tcp socket" + ); + sleep(Duration::from_millis(10)).await; + } +} + #[cfg(feature = "remote")] #[cfg_attr(windows, ignore)] #[traced_test] @@ -132,6 +224,36 @@ async fn test_tcp_with_mpsc() { test_tcp(client, server).await } +#[cfg(feature = "remote")] +#[cfg_attr(windows, ignore)] +#[traced_test] +#[tokio::test(flavor = "multi_thread")] +#[serial_test::serial] +async fn test_bound_tcp_with_mpsc() { + let (client, server) = setup_mpsc().await; + test_bound_tcp(client, server).await +} + +#[cfg(feature = "remote")] +#[cfg_attr(windows, ignore)] +#[traced_test] +#[tokio::test(flavor = "multi_thread")] +#[serial_test::serial] +async fn test_bound_tcp_ttl_with_mpsc() { + let (client, server) = setup_mpsc().await; + test_bound_tcp_ttl(client, server).await +} + +#[cfg(feature = "remote")] +#[cfg_attr(windows, ignore)] +#[traced_test] +#[tokio::test(flavor = "multi_thread")] +#[serial_test::serial] +async fn test_bound_tcp_drop_releases_server_socket_with_mpsc() { + let (client, server) = setup_mpsc().await; + test_bound_tcp_drop_releases_server_socket(client, server).await +} + // Disabled on musl due to flakiness. // See https://github.com/wasmerio/wasmer/issues/4425 #[cfg(not(target_env = "musl"))] @@ -548,3 +670,283 @@ async fn test_failed_connect_status_stays_failed() { assert!(matches!(socket.status().unwrap(), SocketStatus::Failed)); } + +#[cfg(not(target_os = "windows"))] +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_bind_tcp_assigns_ephemeral_port_before_listen() { + let networking = LocalNetworking::new(); + let mut bound = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + let addr_after_bind = bound.addr_local().unwrap(); + assert_ne!( + addr_after_bind.port(), + 0, + "bind_tcp should allocate a real ephemeral port before listen" + ); + + let listener = bound.listen().unwrap(); + let addr_after_listen = listener.addr_local().unwrap(); + assert_eq!( + addr_after_listen, addr_after_bind, + "listen should preserve the already-bound local address" + ); +} + +#[cfg(not(target_os = "windows"))] +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_bind_tcp_keeps_same_port_across_connect() { + let probe = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + let peer = probe.local_addr().unwrap(); + + let networking = LocalNetworking::new(); + let mut bound = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + let addr_after_bind = bound.addr_local().unwrap(); + assert_ne!(addr_after_bind.port(), 0); + + let socket = bound.connect(peer).unwrap(); + let addr_after_connect = socket.addr_local().unwrap(); + assert_eq!( + addr_after_connect, addr_after_bind, + "connect should preserve the already-bound local address" + ); +} + +#[cfg(not(target_os = "windows"))] +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_bind_tcp_preserves_ttl_across_connect() { + let probe = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + let peer = probe.local_addr().unwrap(); + + let networking = LocalNetworking::new(); + let mut bound = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + bound.set_ttl(42).unwrap(); + assert_eq!(bound.ttl().unwrap(), 42); + + let socket = bound.connect(peer).unwrap(); + assert_eq!( + socket.ttl().unwrap(), + 42, + "connect should preserve TTL set while the socket was only bound" + ); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_bind_tcp_assigns_ephemeral_port_before_listen() { + let networking = LoopbackNetworking::new(); + let mut bound = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + let addr_after_bind = bound.addr_local().unwrap(); + assert_ne!( + addr_after_bind.port(), + 0, + "loopback bind_tcp should allocate a real ephemeral port before listen" + ); + + let listener = bound.listen().unwrap(); + let addr_after_listen = listener.addr_local().unwrap(); + assert_eq!( + addr_after_listen, addr_after_bind, + "loopback listen should preserve the already-bound local address" + ); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_bind_tcp_preserves_ttl_across_listen() { + let networking = LoopbackNetworking::new(); + let mut bound = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + bound.set_ttl(42).unwrap(); + assert_eq!(bound.ttl().unwrap(), 42); + + let listener = bound.listen().unwrap(); + assert_eq!( + listener.ttl().unwrap(), + 42, + "loopback listen should preserve TTL set while the socket was only bound" + ); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_bind_tcp_preserves_ttl_across_connect() { + let server_networking = LoopbackNetworking::new(); + let listener = server_networking + .listen_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + let peer = listener.addr_local().unwrap(); + + let client_networking = server_networking.clone(); + let mut bound = client_networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + + bound.set_ttl(42).unwrap(); + assert_eq!(bound.ttl().unwrap(), 42); + + let socket = bound.connect(peer).unwrap(); + assert_eq!( + socket.ttl().unwrap(), + 42, + "loopback connect should preserve TTL set while the socket was only bound" + ); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_bind_tcp_returns_error_when_ephemeral_ports_are_exhausted() { + let networking = LoopbackNetworking::new(); + networking.exhaust_tcp_ephemeral_ports_for_test(Ipv4Addr::LOCALHOST.into()); + + let err = networking + .bind_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap_err(); + + assert!( + matches!(err, NetworkError::AddressInUse), + "expected AddressInUse when all loopback ephemeral ports are exhausted, got {err:?}" + ); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_bind_tcp_reserves_port_before_listen() { + let networking = LoopbackNetworking::new(); + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 40123)); + + let bound = networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap(); + + let err = networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap_err(); + assert!( + matches!(err, NetworkError::AddressInUse), + "expected AddressInUse while a bound socket is reserving the port, got {err:?}" + ); + + drop(bound); + + networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap(); +} + +#[traced_test] +#[tokio::test] +#[serial_test::serial] +async fn test_loopback_connected_socket_holds_local_port_reservation() { + let server_networking = LoopbackNetworking::new(); + let listener = server_networking + .listen_tcp( + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + false, + false, + false, + ) + .await + .unwrap(); + let peer = listener.addr_local().unwrap(); + + let client_networking = server_networking.clone(); + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 40124)); + let mut bound = client_networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap(); + + let socket = bound.connect(peer).unwrap(); + + // While the connected socket is alive the local port must stay reserved. + let err = client_networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap_err(); + assert!( + matches!(err, NetworkError::AddressInUse), + "expected AddressInUse while connected socket holds the port, got {err:?}" + ); + + // After the connected socket is dropped the port must be released. + drop(socket); + client_networking + .bind_tcp(bind_addr, false, false, false) + .await + .unwrap(); +} diff --git a/lib/wasix/src/net/socket.rs b/lib/wasix/src/net/socket.rs index 5c6245c2ded..e97792c9b97 100644 --- a/lib/wasix/src/net/socket.rs +++ b/lib/wasix/src/net/socket.rs @@ -4,7 +4,7 @@ use std::{ mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, - sync::{Arc, RwLock, RwLockWriteGuard}, + sync::{Arc, RwLock}, task::{Context, Poll}, time::Duration, }; @@ -13,8 +13,8 @@ use std::{ use serde_derive::{Deserialize, Serialize}; use virtual_mio::InterestHandler; use virtual_net::{ - NetworkError, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualTcpListener, - VirtualTcpSocket, VirtualUdpSocket, net_error_into_io_err, + NetworkError, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualTcpBoundSocket, + VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, net_error_into_io_err, }; use wasmer_types::MemorySize; use wasmer_wasix_types::wasi::{Addressfamily, Errno, Rights, SockProto, Sockoption, Socktype}; @@ -32,6 +32,9 @@ pub enum InodeHttpSocketType { Headers, } +type TcpConnectFuture<'a> = + Pin, Errno>> + 'a>>; + #[derive(Debug)] pub struct SocketProperties { pub family: Addressfamily, @@ -52,6 +55,29 @@ pub struct SocketProperties { pub handler: Option>, } +impl Default for SocketProperties { + fn default() -> Self { + Self { + family: Addressfamily::Unspec, + ty: Socktype::Unknown, + pt: SockProto::Ip, + only_v6: false, + reuse_port: false, + reuse_addr: false, + no_delay: None, + keep_alive: None, + dont_route: None, + send_buf_size: None, + recv_buf_size: None, + write_timeout: None, + read_timeout: None, + accept_timeout: None, + connect_timeout: None, + handler: None, + } + } +} + #[derive(Debug)] //#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] pub enum InodeSocketKind { @@ -70,6 +96,10 @@ pub enum InodeSocketKind { write_timeout: Option, read_timeout: Option, }, + BoundTcp { + socket: Box, + props: SocketProperties, + }, UdpSocket { socket: Box, peer: Option, @@ -238,9 +268,6 @@ impl InodeSocket { // When a sendto or connect call comes in for a UDP "pre-socket", it must be bound to // an ephemeral port automatically. - // Apparently, clippy fails to recognize the write-locked guard being passed into - // the other function, hence the `allow` attribute. - #[allow(clippy::await_holding_lock, clippy::readonly_write_lock)] pub async fn auto_bind_udp( &self, tasks: &dyn VirtualTaskManager, @@ -251,18 +278,26 @@ impl InodeSocket { .ok() .flatten() .unwrap_or(Duration::from_secs(30)); - let inner = self.inner.protected.write().unwrap(); - match &inner.kind { - InodeSocketKind::PreSocket { props, .. } if props.ty == Socktype::Dgram => { - let addr = match props.family { - Addressfamily::Inet4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - Addressfamily::Inet6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), - _ => return Err(Errno::Notsup), - }; - Self::bind_internal(tasks, net, addr, timeout, inner).await + let family = { + let inner = self.inner.protected.read().unwrap(); + match &inner.kind { + InodeSocketKind::PreSocket { props, .. } if props.ty == Socktype::Dgram => { + Some(props.family) + } + _ => None, } - _ => Ok(None), - } + }; + let Some(family) = family else { + return Ok(None); + }; + + let addr = match family { + Addressfamily::Inet4 => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + Addressfamily::Inet6 => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + _ => return Err(Errno::Notsup), + }; + + self.bind_internal(tasks, net, addr, timeout).await } pub async fn bind( @@ -276,20 +311,32 @@ impl InodeSocket { .ok() .flatten() .unwrap_or(Duration::from_secs(30)); - let inner = self.inner.protected.write().unwrap(); - Self::bind_internal(tasks, net, set_addr, timeout, inner).await + self.bind_internal(tasks, net, set_addr, timeout).await } - // The lock is dropped before awaiting, but clippy doesn't realize it - #[allow(clippy::await_holding_lock)] async fn bind_internal( + &self, tasks: &dyn VirtualTaskManager, net: &dyn VirtualNetworking, set_addr: SocketAddr, timeout: Duration, - mut inner: RwLockWriteGuard<'_, InodeSocketProtected>, ) -> Result, Errno> { - let socket = { + enum PendingBind { + Tcp { + addr: SocketAddr, + only_v6: bool, + reuse_port: bool, + reuse_addr: bool, + }, + Udp { + addr: SocketAddr, + reuse_port: bool, + reuse_addr: bool, + }, + } + + let bind = { + let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { InodeSocketKind::PreSocket { props, addr, .. } => { match props.family { @@ -318,17 +365,17 @@ impl InodeSocket { let addr = (*addr).unwrap(); match props.ty { - Socktype::Stream => { - // we already set the socket address - next we need a listen or connect so nothing - // more to do at this time - return Ok(None); - } - Socktype::Dgram => { - let reuse_port = props.reuse_port; - let reuse_addr = props.reuse_addr; - - net.bind_udp(addr, reuse_port, reuse_addr) - } + Socktype::Stream => PendingBind::Tcp { + addr, + only_v6: props.only_v6, + reuse_port: props.reuse_port, + reuse_addr: props.reuse_addr, + }, + Socktype::Dgram => PendingBind::Udp { + addr, + reuse_port: props.reuse_port, + reuse_addr: props.reuse_addr, + }, _ => return Err(Errno::Inval), } } @@ -368,27 +415,81 @@ impl InodeSocket { // more to do at this time return Ok(None); } - Socktype::Dgram => { - let reuse_port = props.reuse_port; - let reuse_addr = props.reuse_addr; - - net.bind_udp(addr, reuse_port, reuse_addr) - } + Socktype::Dgram => PendingBind::Udp { + addr, + reuse_port: props.reuse_port, + reuse_addr: props.reuse_addr, + }, _ => return Err(Errno::Inval), } } + InodeSocketKind::BoundTcp { .. } => return Err(Errno::Inval), _ => return Err(Errno::Notsup), } }; - drop(inner); - - tokio::select! { - socket = socket => { - let socket = socket.map_err(net_error_into_wasi_err)?; - Ok(Some(InodeSocket::new(InodeSocketKind::UdpSocket { socket, peer: None }))) - }, - _ = tasks.sleep_now(timeout) => Err(Errno::Timedout) + match bind { + PendingBind::Tcp { + addr, + only_v6, + reuse_port, + reuse_addr, + } => { + tokio::select! { + socket = net.bind_tcp(addr, only_v6, reuse_port, reuse_addr) => { + match socket { + Ok(socket) => { + let props = { + let mut inner = self.inner.protected.write().unwrap(); + match &mut inner.kind { + InodeSocketKind::PreSocket { props, .. } => { + std::mem::take(props) + } + _ => return Err(Errno::Inval), + } + }; + Ok(Some(InodeSocket::new(InodeSocketKind::BoundTcp { socket, props }))) + } + Err(NetworkError::Unsupported) => { + // Fallback for backends that still only materialize TCP state at + // listen/connect time. + Ok(None) + } + Err(err) => { + // Roll back the pre-set address so the socket stays unbound, + // matching Linux semantics where a failed bind(2) leaves the + // socket unbound. + let mut inner = self.inner.protected.write().unwrap(); + if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind { + addr.take(); + } + Err(net_error_into_wasi_err(err)) + } + } + }, + _ = tasks.sleep_now(timeout) => { + // Bind timed out; roll back the pre-set address for the same reason. + let mut inner = self.inner.protected.write().unwrap(); + if let InodeSocketKind::PreSocket { addr, .. } = &mut inner.kind { + addr.take(); + } + Err(Errno::Timedout) + } + } + } + PendingBind::Udp { + addr, + reuse_port, + reuse_addr, + } => { + tokio::select! { + socket = net.bind_udp(addr, reuse_port, reuse_addr) => { + let socket = socket.map_err(net_error_into_wasi_err)?; + Ok(Some(InodeSocket::new(InodeSocketKind::UdpSocket { socket, peer: None }))) + }, + _ = tasks.sleep_now(timeout) => Err(Errno::Timedout) + } + } } } @@ -405,8 +506,8 @@ impl InodeSocket { .unwrap_or(Duration::from_secs(30)); let socket = { - let inner = self.inner.protected.read().unwrap(); - match &inner.kind { + let mut inner = self.inner.protected.write().unwrap(); + match &mut inner.kind { InodeSocketKind::PreSocket { props, addr, .. } => match props.ty { Socktype::Stream => { if addr.is_none() { @@ -451,6 +552,12 @@ impl InodeSocket { return Err(Errno::Notsup); } }, + InodeSocketKind::BoundTcp { socket, .. } => { + return Ok(Some(InodeSocket::new(InodeSocketKind::TcpListener { + socket: socket.listen().map_err(net_error_into_wasi_err)?, + accept_timeout: Some(timeout), + }))); + } InodeSocketKind::Icmp(_) => { tracing::warn!("wasi[?]::sock_listen - failed - not supported(icmp)"); return Err(Errno::Notsup); @@ -562,6 +669,7 @@ impl InodeSocket { InodeSocketKind::TcpStream { socket, .. } => { socket.close().map_err(net_error_into_wasi_err)?; } + InodeSocketKind::BoundTcp { .. } => {} InodeSocketKind::Icmp(_) => {} InodeSocketKind::UdpSocket { .. } => {} InodeSocketKind::Raw(_) => {} @@ -587,7 +695,7 @@ impl InodeSocket { .unwrap_or(Duration::from_secs(30)); let handler; - let connect = { + let connect: TcpConnectFuture<'_> = { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { InodeSocketKind::PreSocket { props, addr, .. } => { @@ -610,7 +718,10 @@ impl InodeSocket { } }; Box::pin(async move { - let mut ret = net.connect_tcp(addr, peer).await?; + let mut ret = net + .connect_tcp(addr, peer) + .await + .map_err(net_error_into_wasi_err)?; if let Some(no_delay) = no_delay { ret.set_nodelay(no_delay).ok(); } @@ -621,7 +732,41 @@ impl InodeSocket { ret.set_dontroute(dont_route).ok(); } if !nonblocking { - futures::future::poll_fn(|cx| ret.poll_write_ready(cx)).await?; + futures::future::poll_fn(|cx| ret.poll_write_ready(cx)) + .await + .map_err(net_error_into_wasi_err)?; + } + Ok(ret) + }) + } + Socktype::Dgram => return Err(Errno::Inval), + _ => return Err(Errno::Notsup), + } + } + InodeSocketKind::BoundTcp { socket, props } => { + handler = props.handler.take(); + new_write_timeout = props.write_timeout; + new_read_timeout = props.read_timeout; + match props.ty { + Socktype::Stream => { + let no_delay = props.no_delay; + let keep_alive = props.keep_alive; + let dont_route = props.dont_route; + let mut ret = socket.connect(peer).map_err(net_error_into_wasi_err)?; + if let Some(no_delay) = no_delay { + ret.set_nodelay(no_delay).ok(); + } + if let Some(keep_alive) = keep_alive { + ret.set_keepalive(keep_alive).ok(); + } + if let Some(dont_route) = dont_route { + ret.set_dontroute(dont_route).ok(); + } + Box::pin(async move { + if !nonblocking { + futures::future::poll_fn(|cx| ret.poll_write_ready(cx)) + .await + .map_err(net_error_into_wasi_err)?; } Ok(ret) }) @@ -645,7 +790,7 @@ impl InodeSocket { }; let mut socket = tokio::select! { - res = connect => res.map_err(net_error_into_wasi_err)?, + res = connect => res?, _ = tasks.sleep_now(timeout) => return Err(Errno::Timedout) }; @@ -668,6 +813,7 @@ impl InodeSocket { let inner = self.inner.protected.read().unwrap(); Ok(match &inner.kind { InodeSocketKind::PreSocket { .. } => WasiSocketStatus::Opening, + InodeSocketKind::BoundTcp { .. } => WasiSocketStatus::Opened, InodeSocketKind::TcpListener { .. } => WasiSocketStatus::Opened, InodeSocketKind::TcpStream { socket, .. } => match socket.status() { Ok(virtual_net::SocketStatus::Opening) => WasiSocketStatus::Opening, @@ -709,6 +855,9 @@ impl InodeSocket { InodeSocketKind::TcpStream { socket, .. } => { socket.addr_local().map_err(net_error_into_wasi_err)? } + InodeSocketKind::BoundTcp { socket, .. } => { + socket.addr_local().map_err(net_error_into_wasi_err)? + } InodeSocketKind::UdpSocket { socket, .. } => { socket.addr_local().map_err(net_error_into_wasi_err)? } @@ -730,6 +879,14 @@ impl InodeSocket { }, 0, ), + InodeSocketKind::BoundTcp { props, .. } => SocketAddr::new( + match props.family { + Addressfamily::Inet4 => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + Addressfamily::Inet6 => IpAddr::V6(Ipv6Addr::UNSPECIFIED), + _ => return Err(Errno::Inval), + }, + 0, + ), InodeSocketKind::TcpStream { socket, .. } => { socket.addr_peer().map_err(net_error_into_wasi_err)? } @@ -760,6 +917,7 @@ impl InodeSocket { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { match option { WasiSocketOption::OnlyV6 => props.only_v6 = val, @@ -811,6 +969,7 @@ impl InodeSocket { let mut inner = self.inner.protected.write().unwrap(); Ok(match &mut inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => match option { WasiSocketOption::OnlyV6 => props.only_v6, WasiSocketOption::ReusePort => props.reuse_port, @@ -855,6 +1014,7 @@ impl InodeSocket { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { props.send_buf_size = Some(size); } @@ -872,6 +1032,7 @@ impl InodeSocket { let inner = self.inner.protected.read().unwrap(); match &inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { Ok(props.send_buf_size.unwrap_or_default()) } @@ -886,6 +1047,7 @@ impl InodeSocket { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { props.recv_buf_size = Some(size); } @@ -903,6 +1065,7 @@ impl InodeSocket { let inner = self.inner.protected.read().unwrap(); match &inner.kind { InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { Ok(props.recv_buf_size.unwrap_or_default()) } @@ -920,6 +1083,7 @@ impl InodeSocket { socket.set_linger(linger).map_err(net_error_into_wasi_err) } InodeSocketKind::RemoteSocket { .. } => Ok(()), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -931,6 +1095,7 @@ impl InodeSocket { InodeSocketKind::TcpStream { socket, .. } => { socket.linger().map_err(net_error_into_wasi_err) } + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -963,6 +1128,7 @@ impl InodeSocket { Ok(()) } InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { match ty { TimeType::ConnectTimeout => props.connect_timeout = timeout, @@ -994,6 +1160,7 @@ impl InodeSocket { _ => return Err(Errno::Inval), }), InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => match ty { TimeType::ConnectTimeout => Ok(props.connect_timeout), TimeType::AcceptTimeout => Ok(props.accept_timeout), @@ -1008,6 +1175,9 @@ impl InodeSocket { pub fn set_ttl(&self, ttl: u32) -> Result<(), Errno> { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { + InodeSocketKind::BoundTcp { socket, .. } => { + socket.set_ttl(ttl).map_err(net_error_into_wasi_err) + } InodeSocketKind::TcpStream { socket, .. } => { socket.set_ttl(ttl).map_err(net_error_into_wasi_err) } @@ -1026,6 +1196,9 @@ impl InodeSocket { pub fn ttl(&self) -> Result { let inner = self.inner.protected.read().unwrap(); match &inner.kind { + InodeSocketKind::BoundTcp { socket, .. } => { + socket.ttl().map_err(net_error_into_wasi_err) + } InodeSocketKind::TcpStream { socket, .. } => { socket.ttl().map_err(net_error_into_wasi_err) } @@ -1051,6 +1224,7 @@ impl InodeSocket { *set_ttl = ttl; Ok(()) } + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1063,6 +1237,7 @@ impl InodeSocket { socket.multicast_ttl_v4().map_err(net_error_into_wasi_err) } InodeSocketKind::RemoteSocket { multicast_ttl, .. } => Ok(*multicast_ttl), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1075,6 +1250,7 @@ impl InodeSocket { .join_multicast_v4(multiaddr, iface) .map_err(net_error_into_wasi_err), InodeSocketKind::RemoteSocket { .. } => Ok(()), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1087,6 +1263,7 @@ impl InodeSocket { .leave_multicast_v4(multiaddr, iface) .map_err(net_error_into_wasi_err), InodeSocketKind::RemoteSocket { .. } => Ok(()), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1099,6 +1276,7 @@ impl InodeSocket { .join_multicast_v6(multiaddr, iface) .map_err(net_error_into_wasi_err), InodeSocketKind::RemoteSocket { .. } => Ok(()), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1111,6 +1289,7 @@ impl InodeSocket { .leave_multicast_v6(multiaddr, iface) .map_err(net_error_into_wasi_err), InodeSocketKind::RemoteSocket { .. } => Ok(()), + InodeSocketKind::BoundTcp { .. } => Err(Errno::Io), InodeSocketKind::PreSocket { .. } => Err(Errno::Io), _ => Err(Errno::Notsup), } @@ -1480,6 +1659,7 @@ impl InodeSocket { socket.shutdown(how).map_err(net_error_into_wasi_err)?; } InodeSocketKind::RemoteSocket { .. } => return Ok(()), + InodeSocketKind::BoundTcp { .. } => return Err(Errno::Notconn), InodeSocketKind::PreSocket { .. } => return Err(Errno::Notconn), _ => return Err(Errno::Notsup), } @@ -1491,6 +1671,7 @@ impl InodeSocket { #[allow(clippy::match_like_matches_macro)] match &mut guard.kind { InodeSocketKind::TcpStream { .. } + | InodeSocketKind::BoundTcp { .. } | InodeSocketKind::UdpSocket { .. } | InodeSocketKind::Raw(..) => true, InodeSocketKind::RemoteSocket { is_dead, .. } => !(*is_dead), @@ -1513,6 +1694,9 @@ impl InodeSocketProtected { InodeSocketKind::PreSocket { props, .. } => { props.handler.take(); } + InodeSocketKind::BoundTcp { props, .. } => { + props.handler.take(); + } InodeSocketKind::RemoteSocket { props, .. } => { props.handler.take(); } @@ -1526,6 +1710,7 @@ impl InodeSocketProtected { InodeSocketKind::UdpSocket { socket, .. } => socket.poll_read_ready(cx), InodeSocketKind::Raw(socket) => socket.poll_read_ready(cx), InodeSocketKind::Icmp(socket) => socket.poll_read_ready(cx), + InodeSocketKind::BoundTcp { .. } => Poll::Pending, InodeSocketKind::PreSocket { .. } => Poll::Pending, InodeSocketKind::RemoteSocket { is_dead, .. } => match is_dead { true => Poll::Ready(Ok(0)), @@ -1542,6 +1727,9 @@ impl InodeSocketProtected { InodeSocketKind::UdpSocket { socket, .. } => socket.poll_write_ready(cx), InodeSocketKind::Raw(socket) => socket.poll_write_ready(cx), InodeSocketKind::Icmp(socket) => socket.poll_write_ready(cx), + // A bound-but-not-yet-listening TCP socket is writable immediately, + // matching Linux select()/poll() semantics. + InodeSocketKind::BoundTcp { .. } => Poll::Ready(Ok(0)), InodeSocketKind::PreSocket { .. } => Poll::Pending, InodeSocketKind::RemoteSocket { is_dead, .. } => match is_dead { true => Poll::Ready(Ok(0)), @@ -1562,6 +1750,7 @@ impl InodeSocketProtected { InodeSocketKind::Raw(socket) => socket.set_handler(handler), InodeSocketKind::Icmp(socket) => socket.set_handler(handler), InodeSocketKind::PreSocket { props, .. } + | InodeSocketKind::BoundTcp { props, .. } | InodeSocketKind::RemoteSocket { props, .. } => { props.handler.replace(handler); Ok(()) @@ -1593,8 +1782,9 @@ pub(crate) fn all_socket_rights() -> Rights { #[cfg(test)] mod tests { - use super::{InodeSocket, InodeSocketKind, WasiSocketStatus}; + use super::{InodeSocket, InodeSocketKind, SocketProperties, WasiSocketStatus}; use std::{ + future::pending, mem::MaybeUninit, net::{Ipv4Addr, Shutdown, SocketAddr}, pin::Pin, @@ -1608,8 +1798,9 @@ mod tests { use virtual_mio::InterestHandler; use virtual_net::{ NetworkError, Result as NetResult, SocketStatus, VirtualConnectedSocket, VirtualIoSource, - VirtualSocket, VirtualTcpSocket, + VirtualNetworking, VirtualSocket, VirtualTcpBoundSocket, VirtualTcpSocket, }; + use wasmer_wasix_types::wasi::{Addressfamily, Errno, SockProto, Socktype}; #[derive(Debug)] struct MockTcpSocket { @@ -1748,6 +1939,54 @@ mod tests { } } + #[derive(Debug)] + struct MockTcpBoundSocket { + ttl: Arc, + } + + impl VirtualTcpBoundSocket for MockTcpBoundSocket { + fn addr_local(&self) -> NetResult { + Ok(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))) + } + + fn listen(&mut self) -> NetResult> { + Err(NetworkError::Unsupported) + } + + fn connect( + &mut self, + _peer: SocketAddr, + ) -> NetResult> { + Err(NetworkError::Unsupported) + } + + fn set_ttl(&mut self, ttl: u32) -> NetResult<()> { + self.ttl.store(ttl as usize, Ordering::Relaxed); + Ok(()) + } + + fn ttl(&self) -> NetResult { + Ok(self.ttl.load(Ordering::Relaxed) as u32) + } + } + + #[derive(Debug)] + struct PendingBindNetworking; + + #[async_trait::async_trait] + impl VirtualNetworking for PendingBindNetworking { + async fn bind_tcp( + &self, + _addr: SocketAddr, + _only_v6: bool, + _reuse_port: bool, + _reuse_addr: bool, + ) -> NetResult> { + pending::<()>().await; + unreachable!("pending bind_tcp future should never complete") + } + } + #[test] fn inode_socket_poll_write_ready_uses_write_path() { let read_calls = Arc::new(AtomicUsize::new(0)); @@ -1789,4 +2028,73 @@ mod tests { status.store(MOCK_STATUS_OPENED, Ordering::Relaxed); assert!(matches!(inode.status().unwrap(), WasiSocketStatus::Opened)); } + + #[test] + fn inode_socket_bound_tcp_forwards_ttl() { + let ttl = Arc::new(AtomicUsize::new(64)); + let inode = InodeSocket::new(InodeSocketKind::BoundTcp { + socket: Box::new(MockTcpBoundSocket { ttl: ttl.clone() }), + props: SocketProperties { + family: Addressfamily::Inet4, + ty: Socktype::Stream, + pt: SockProto::Tcp, + only_v6: false, + reuse_port: false, + reuse_addr: false, + no_delay: None, + keep_alive: None, + dont_route: None, + send_buf_size: None, + recv_buf_size: None, + write_timeout: None, + read_timeout: None, + accept_timeout: None, + connect_timeout: None, + handler: None, + }, + }); + + inode.set_ttl(42).unwrap(); + assert_eq!(inode.ttl().unwrap(), 42); + assert_eq!(ttl.load(Ordering::Relaxed), 42); + } + + #[cfg(feature = "sys")] + #[tokio::test(flavor = "current_thread")] + async fn inode_socket_tcp_bind_respects_bind_timeout() { + let inode = InodeSocket::new(InodeSocketKind::PreSocket { + props: SocketProperties { + family: Addressfamily::Inet4, + ty: Socktype::Stream, + pt: SockProto::Tcp, + only_v6: false, + reuse_port: false, + reuse_addr: false, + no_delay: None, + keep_alive: None, + dont_route: None, + send_buf_size: None, + recv_buf_size: None, + write_timeout: None, + read_timeout: None, + accept_timeout: None, + connect_timeout: None, + handler: None, + }, + addr: None, + }); + let tasks = crate::runtime::task_manager::tokio::TokioTaskManager::default(); + let net = PendingBindNetworking; + + let err = inode + .bind_internal( + &tasks, + &net, + SocketAddr::from((Ipv4Addr::LOCALHOST, 0)), + Duration::from_millis(10), + ) + .await + .unwrap_err(); + assert_eq!(err, Errno::Timedout); + } } diff --git a/lib/wasix/src/syscalls/wasix/sock_bind.rs b/lib/wasix/src/syscalls/wasix/sock_bind.rs index 37bec2ac69d..8a5af867d04 100644 --- a/lib/wasix/src/syscalls/wasix/sock_bind.rs +++ b/lib/wasix/src/syscalls/wasix/sock_bind.rs @@ -28,7 +28,13 @@ pub fn sock_bind( #[cfg(feature = "journal")] if ctx.data().enable_journal { - JournalEffector::save_sock_bind(&mut ctx, sock, addr).map_err(|err| { + let effective_addr = wasi_try_ok!(__sock_actor( + &mut ctx, + sock, + Rights::empty(), + |socket, _| socket.addr_local() + )); + JournalEffector::save_sock_bind(&mut ctx, sock, effective_addr).map_err(|err| { tracing::error!("failed to save sock_bind event - {}", err); WasiError::Exit(ExitCode::from(Errno::Fault)) })?; diff --git a/lib/wasix/tests/wasm_tests/socket_tests.rs b/lib/wasix/tests/wasm_tests/socket_tests.rs index 7c624923631..d7abf2b225d 100644 --- a/lib/wasix/tests/wasm_tests/socket_tests.rs +++ b/lib/wasix/tests/wasm_tests/socket_tests.rs @@ -30,6 +30,84 @@ fn test_nonblocking_connect() { ); } +#[test] +// https://github.com/wasmerio/wasmer/issues/6403 +fn test_bind_port_zero_allocates_ephemeral_port() { + let wasm = run_build_script(file!(), "bind-port-zero").unwrap(); + let result = run_wasm_with_result(&wasm, wasm.parent().unwrap()).unwrap(); + let stdout = String::from_utf8_lossy(&result.stdout); + assert_eq!( + stdout.trim(), + "bind port 0 allocates an ephemeral port", + "exit_code={:?}\nstdout:\n{}\nstderr:\n{}", + result.exit_code, + stdout, + String::from_utf8_lossy(&result.stderr) + ); +} + +#[test] +// https://github.com/wasmerio/wasmer/issues/6403 +fn test_bind_port_zero_keeps_same_port_across_connect() { + let wasm = run_build_script(file!(), "bind-port-zero-connect").unwrap(); + let result = run_wasm_with_result(&wasm, wasm.parent().unwrap()).unwrap(); + let stdout = String::from_utf8_lossy(&result.stdout); + assert_eq!( + stdout.trim(), + "bind port 0 keeps the same ephemeral port across connect", + "exit_code={:?}\nstdout:\n{}\nstderr:\n{}", + result.exit_code, + stdout, + String::from_utf8_lossy(&result.stderr) + ); +} + +#[test] +// https://github.com/wasmerio/wasmer/issues/6403 +fn test_connect_holds_local_port() { + let wasm = run_build_script(file!(), "connect-holds-local-port").unwrap(); + let result = run_wasm_with_result(&wasm, wasm.parent().unwrap()).unwrap(); + let stdout = String::from_utf8_lossy(&result.stdout); + assert_eq!( + stdout.trim(), + "connected socket holds its local port", + "exit_code={:?}\nstdout:\n{}\nstderr:\n{}", + result.exit_code, + stdout, + String::from_utf8_lossy(&result.stderr) + ); +} + +#[test] +fn test_bound_tcp_socket_is_writable() { + let wasm = run_build_script(file!(), "bound-tcp-writable").unwrap(); + let result = run_wasm_with_result(&wasm, wasm.parent().unwrap()).unwrap(); + let stdout = String::from_utf8_lossy(&result.stdout); + assert_eq!( + stdout.trim(), + "bound TCP socket is writable", + "exit_code={:?}\nstdout:\n{}\nstderr:\n{}", + result.exit_code, + stdout, + String::from_utf8_lossy(&result.stderr) + ); +} + +#[test] +fn test_bind_fail_leaves_socket_unbound() { + let wasm = run_build_script(file!(), "bind-fail-stays-unbound").unwrap(); + let result = run_wasm_with_result(&wasm, wasm.parent().unwrap()).unwrap(); + let stdout = String::from_utf8_lossy(&result.stdout); + assert_eq!( + stdout.trim(), + "bind failure leaves socket unbound", + "exit_code={:?}\nstdout:\n{}\nstderr:\n{}", + result.exit_code, + stdout, + String::from_utf8_lossy(&result.stderr) + ); +} + #[test] // https://github.com/wasmerio/wasmer/issues/6366 #[ignore = "flaky test (#6366)"] diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/build.sh b/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/build.sh new file mode 100755 index 00000000000..b3d2a483d5f --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +$CC main.c -o main diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/main.c b/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/main.c new file mode 100644 index 00000000000..4f61c679762 --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-fail-stays-unbound/main.c @@ -0,0 +1,109 @@ +/* + * Verify that a failed bind(2) leaves the socket logically unbound. + * + * Steps: + * 1. Bind socket A to 127.0.0.1:0 so the OS assigns an ephemeral port. + * 2. Read the assigned port back with getsockname. + * 3. Try to bind socket B to the same address/port — this must fail with + * EADDRINUSE because socket A still holds the port. + * 4. Call getsockname on socket B. The returned port must be 0, proving + * the failed bind did not leave a stale address on the socket. + */ +#include +#include +#include +#include +#include +#include + +int main(void) { + /* --- socket A: claim an ephemeral port --- */ + int fd_a = socket(AF_INET, SOCK_STREAM, 0); + if (fd_a < 0) { + perror("socket A"); + return 1; + } + + struct sockaddr_in addr_zero; + memset(&addr_zero, 0, sizeof(addr_zero)); + addr_zero.sin_family = AF_INET; + addr_zero.sin_port = htons(0); + if (inet_pton(AF_INET, "127.0.0.1", &addr_zero.sin_addr) != 1) { + fprintf(stderr, "inet_pton failed\n"); + close(fd_a); + return 1; + } + + if (bind(fd_a, (struct sockaddr*)&addr_zero, sizeof(addr_zero)) < 0) { + perror("bind A"); + close(fd_a); + return 1; + } + + /* Find out which port was assigned to socket A. */ + struct sockaddr_in addr_a; + socklen_t len = sizeof(addr_a); + memset(&addr_a, 0, sizeof(addr_a)); + if (getsockname(fd_a, (struct sockaddr*)&addr_a, &len) < 0) { + perror("getsockname A"); + close(fd_a); + return 1; + } + + int port_a = (int)ntohs(addr_a.sin_port); + if (port_a == 0) { + fprintf(stderr, "getsockname returned port 0 for socket A\n"); + close(fd_a); + return 1; + } + + /* --- socket B: attempt a conflicting bind --- */ + int fd_b = socket(AF_INET, SOCK_STREAM, 0); + if (fd_b < 0) { + perror("socket B"); + close(fd_a); + return 1; + } + + /* Bind B to the exact same address that A already owns. */ + if (bind(fd_b, (struct sockaddr*)&addr_a, sizeof(addr_a)) == 0) { + fprintf(stderr, "bind B unexpectedly succeeded on port %d\n", port_a); + close(fd_a); + close(fd_b); + return 1; + } + if (errno != EADDRINUSE) { + fprintf(stderr, "bind B failed with errno %d (%s), expected EADDRINUSE\n", + errno, strerror(errno)); + close(fd_a); + close(fd_b); + return 1; + } + + /* --- check that socket B is still unbound --- */ + struct sockaddr_in local_b; + socklen_t len_b = sizeof(local_b); + memset(&local_b, 0, sizeof(local_b)); + if (getsockname(fd_b, (struct sockaddr*)&local_b, &len_b) < 0) { + perror("getsockname B"); + close(fd_a); + close(fd_b); + return 1; + } + + int port_b = (int)ntohs(local_b.sin_port); + if (port_b != 0) { + fprintf(stderr, + "after failed bind, getsockname returned port %d for socket B " + "(expected 0 — socket should still be unbound)\n", + port_b); + close(fd_a); + close(fd_b); + return 1; + } + + close(fd_a); + close(fd_b); + printf("bind failure leaves socket unbound\n"); + return 0; +} diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/build.sh b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/build.sh new file mode 100644 index 00000000000..b3d2a483d5f --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +$CC main.c -o main diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/main.c b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/main.c new file mode 100644 index 00000000000..4bf441db441 --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero-connect/main.c @@ -0,0 +1,123 @@ +#include +#include +#include +#include +#include + +static int get_local_addr(int fd, struct sockaddr_in* addr) { + socklen_t len = sizeof(*addr); + memset(addr, 0, sizeof(*addr)); + return getsockname(fd, (struct sockaddr*)addr, &len); +} + +int main(void) { + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd < 0) { + perror("socket(server)"); + return 1; + } + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(0); + if (inet_pton(AF_INET, "127.0.0.1", &server_addr.sin_addr) != 1) { + fprintf(stderr, "inet_pton(server) failed\n"); + close(server_fd); + return 1; + } + + if (bind(server_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) < + 0) { + perror("bind(server)"); + close(server_fd); + return 1; + } + + if (listen(server_fd, 1) < 0) { + perror("listen(server)"); + close(server_fd); + return 1; + } + + struct sockaddr_in server_bound_addr; + if (get_local_addr(server_fd, &server_bound_addr) < 0) { + perror("getsockname(server)"); + close(server_fd); + return 1; + } + + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + if (client_fd < 0) { + perror("socket(client)"); + close(server_fd); + return 1; + } + + struct sockaddr_in client_bind_addr; + memset(&client_bind_addr, 0, sizeof(client_bind_addr)); + client_bind_addr.sin_family = AF_INET; + client_bind_addr.sin_port = htons(0); + if (inet_pton(AF_INET, "127.0.0.1", &client_bind_addr.sin_addr) != 1) { + fprintf(stderr, "inet_pton(client) failed\n"); + close(client_fd); + close(server_fd); + return 1; + } + + if (bind(client_fd, (struct sockaddr*)&client_bind_addr, + sizeof(client_bind_addr)) < 0) { + perror("bind(client)"); + close(client_fd); + close(server_fd); + return 1; + } + + struct sockaddr_in client_after_bind; + if (get_local_addr(client_fd, &client_after_bind) < 0) { + perror("getsockname(client after bind)"); + close(client_fd); + close(server_fd); + return 1; + } + + int bind_port = ntohs(client_after_bind.sin_port); + if (bind_port == 0) { + fprintf(stderr, "expected nonzero client port after bind, got 0\n"); + close(client_fd); + close(server_fd); + return 1; + } + + if (connect(client_fd, (struct sockaddr*)&server_bound_addr, + sizeof(server_bound_addr)) < 0) { + perror("connect(client)"); + close(client_fd); + close(server_fd); + return 1; + } + + struct sockaddr_in client_after_connect; + if (get_local_addr(client_fd, &client_after_connect) < 0) { + perror("getsockname(client after connect)"); + close(client_fd); + close(server_fd); + return 1; + } + + int connect_port = ntohs(client_after_connect.sin_port); + if (connect_port != bind_port) { + fprintf( + stderr, + "expected client port to stay stable across connect, got %d then %d\n", + bind_port, connect_port); + close(client_fd); + close(server_fd); + return 1; + } + + puts("bind port 0 keeps the same ephemeral port across connect"); + close(client_fd); + close(server_fd); + return 0; +} diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/build.sh b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/build.sh new file mode 100644 index 00000000000..b3d2a483d5f --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +$CC main.c -o main diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/main.c b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/main.c new file mode 100644 index 00000000000..3aa0dfe7a69 --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bind-port-zero/main.c @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include + +static int get_local_addr(int fd, struct sockaddr_in* addr) { + socklen_t len = sizeof(*addr); + memset(addr, 0, sizeof(*addr)); + return getsockname(fd, (struct sockaddr*)addr, &len); +} + +int main(void) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + perror("socket"); + return 1; + } + + struct sockaddr_in bind_addr; + memset(&bind_addr, 0, sizeof(bind_addr)); + bind_addr.sin_family = AF_INET; + bind_addr.sin_port = htons(0); + if (inet_pton(AF_INET, "127.0.0.1", &bind_addr.sin_addr) != 1) { + fprintf(stderr, "inet_pton failed\n"); + close(fd); + return 1; + } + + if (bind(fd, (struct sockaddr*)&bind_addr, sizeof(bind_addr)) < 0) { + perror("bind"); + close(fd); + return 1; + } + + struct sockaddr_in after_bind; + if (get_local_addr(fd, &after_bind) < 0) { + perror("getsockname(after bind)"); + close(fd); + return 1; + } + + int bind_port = ntohs(after_bind.sin_port); + if (bind_port == 0) { + fprintf(stderr, "expected nonzero ephemeral port after bind, got 0\n"); + close(fd); + return 1; + } + + if (listen(fd, 1) < 0) { + perror("listen"); + close(fd); + return 1; + } + + struct sockaddr_in after_listen; + if (get_local_addr(fd, &after_listen) < 0) { + perror("getsockname(after listen)"); + close(fd); + return 1; + } + + int listen_port = ntohs(after_listen.sin_port); + if (listen_port != bind_port) { + fprintf(stderr, + "expected port to stay stable after listen, got %d then %d\n", + bind_port, listen_port); + close(fd); + return 1; + } + + puts("bind port 0 allocates an ephemeral port"); + close(fd); + return 0; +} diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/build.sh b/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/build.sh new file mode 100755 index 00000000000..b3d2a483d5f --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +$CC main.c -o main diff --git a/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/main.c b/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/main.c new file mode 100644 index 00000000000..dad49669bfb --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/bound-tcp-writable/main.c @@ -0,0 +1,68 @@ +/* + * Verify that a successfully bound TCP socket is immediately reported writable + * by select(2), matching Linux semantics. + * + * On Linux: + * int fd = socket(...); bind(fd, ...); + * select(fd+1, NULL, &wfds, NULL, &zero_timeout); + * returns 1 — the socket is writable right away. + */ +#include +#include +#include +#include +#include +#include +#include + +int main(void) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + perror("socket"); + return 1; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(0); + if (inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr) != 1) { + fprintf(stderr, "inet_pton failed\n"); + close(fd); + return 1; + } + + if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + perror("bind"); + close(fd); + return 1; + } + + /* Zero-timeout select: only report what is ready RIGHT NOW. */ + fd_set wfds; + FD_ZERO(&wfds); + FD_SET(fd, &wfds); + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 0; + + int n = select(fd + 1, NULL, &wfds, NULL, &tv); + if (n < 0) { + perror("select"); + close(fd); + return 1; + } + + if (n == 0 || !FD_ISSET(fd, &wfds)) { + fprintf(stderr, + "bound TCP socket not reported writable by select " + "(n=%d) — expected writable immediately after bind\n", + n); + close(fd); + return 1; + } + + close(fd); + printf("bound TCP socket is writable\n"); + return 0; +} diff --git a/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/build.sh b/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/build.sh new file mode 100755 index 00000000000..b3d2a483d5f --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/build.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +$CC main.c -o main diff --git a/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/main.c b/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/main.c new file mode 100644 index 00000000000..172b03a00f4 --- /dev/null +++ b/lib/wasix/tests/wasm_tests/socket_tests/connect-holds-local-port/main.c @@ -0,0 +1,185 @@ +/* + * Verify that a connected TCP socket keeps its local port reserved. + * + * POSIX / Linux behaviour: + * - A socket that has completed connect() holds its local (ephemeral) port + * for its entire lifetime. + * - Attempting to bind a *different* socket to the same local address while + * the first socket is still connected must fail with EADDRINUSE. + * - After the connected socket is closed the explicit reservation is gone. + * A new bind() with SO_REUSEADDR should then succeed; without that Linux + * may still reject the bind because the connection can remain in TIME_WAIT. + * + * Steps + * 1. Create a server socket and listen on 127.0.0.1:0. + * 2. Bind a client socket to 127.0.0.1:0 (ephemeral) and connect to server. + * 3. Record the client's local address via getsockname. + * 4. Try to bind a third socket to that exact local address → EADDRINUSE. + * 5. Close the connected client socket. + * 6. Bind a third socket to the same address again with SO_REUSEADDR → must + * succeed. + */ +#include +#include +#include +#include +#include +#include + +int main(void) { + /* ---- step 1: server ---- */ + int server = socket(AF_INET, SOCK_STREAM, 0); + if (server < 0) { + perror("socket(server)"); + return 1; + } + + int one = 1; + setsockopt(server, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + struct sockaddr_in srv_addr; + memset(&srv_addr, 0, sizeof(srv_addr)); + srv_addr.sin_family = AF_INET; + srv_addr.sin_port = htons(0); + int inet_result = inet_pton(AF_INET, "127.0.0.1", &srv_addr.sin_addr); + if (inet_result != 1) { + if (inet_result == 0) { + fprintf(stderr, "inet_pton(server) could not parse 127.0.0.1\n"); + } else { + perror("inet_pton(server)"); + } + close(server); + return 1; + } + + if (bind(server, (struct sockaddr*)&srv_addr, sizeof(srv_addr)) < 0) { + perror("bind(server)"); + close(server); + return 1; + } + if (listen(server, 1) < 0) { + perror("listen"); + close(server); + return 1; + } + + socklen_t srv_len = sizeof(srv_addr); + if (getsockname(server, (struct sockaddr*)&srv_addr, &srv_len) < 0) { + perror("getsockname(server)"); + close(server); + return 1; + } + + /* ---- step 2: client — bind to ephemeral port then connect ---- */ + int client = socket(AF_INET, SOCK_STREAM, 0); + if (client < 0) { + perror("socket(client)"); + close(server); + return 1; + } + + struct sockaddr_in cli_bind; + memset(&cli_bind, 0, sizeof(cli_bind)); + cli_bind.sin_family = AF_INET; + cli_bind.sin_port = htons(0); + inet_result = inet_pton(AF_INET, "127.0.0.1", &cli_bind.sin_addr); + if (inet_result != 1) { + if (inet_result == 0) { + fprintf(stderr, "inet_pton(client) could not parse 127.0.0.1\n"); + } else { + perror("inet_pton(client)"); + } + close(server); + close(client); + return 1; + } + setsockopt(client, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + if (bind(client, (struct sockaddr*)&cli_bind, sizeof(cli_bind)) < 0) { + perror("bind(client)"); + close(server); + close(client); + return 1; + } + if (connect(client, (struct sockaddr*)&srv_addr, sizeof(srv_addr)) < 0) { + perror("connect"); + close(server); + close(client); + return 1; + } + + /* ---- step 3: record the client's local port ---- */ + struct sockaddr_in cli_local; + socklen_t cli_len = sizeof(cli_local); + memset(&cli_local, 0, sizeof(cli_local)); + if (getsockname(client, (struct sockaddr*)&cli_local, &cli_len) < 0) { + perror("getsockname(client)"); + close(server); + close(client); + return 1; + } + int cli_port = (int)ntohs(cli_local.sin_port); + if (cli_port == 0) { + fprintf(stderr, "client local port is 0 after connect\n"); + close(server); + close(client); + return 1; + } + + /* ---- step 4: rebind to same port while client is still connected ---- */ + int probe = socket(AF_INET, SOCK_STREAM, 0); + if (probe < 0) { + perror("socket(probe)"); + close(server); + close(client); + return 1; + } + + if (bind(probe, (struct sockaddr*)&cli_local, sizeof(cli_local)) == 0) { + fprintf(stderr, + "bind to port %d succeeded while client socket is still connected " + "(expected EADDRINUSE)\n", + cli_port); + close(probe); + close(server); + close(client); + return 1; + } + if (errno != EADDRINUSE) { + fprintf(stderr, + "bind to port %d failed with errno %d (%s), expected EADDRINUSE\n", + cli_port, errno, strerror(errno)); + close(probe); + close(server); + close(client); + return 1; + } + close(probe); + + /* ---- step 5: close the connected client ---- */ + close(client); + + /* ---- step 6: now the port must be available again ---- */ + int probe2 = socket(AF_INET, SOCK_STREAM, 0); + if (probe2 < 0) { + perror("socket(probe2)"); + close(server); + return 1; + } + setsockopt(probe2, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + if (bind(probe2, (struct sockaddr*)&cli_local, sizeof(cli_local)) < 0) { + fprintf(stderr, + "bind to port %d failed after client socket was closed even with " + "SO_REUSEADDR: %s\n", + cli_port, strerror(errno)); + close(probe2); + close(server); + return 1; + } + close(probe2); + close(server); + + printf("connected socket holds its local port\n"); + return 0; +}