Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 118 additions & 3 deletions lib/virtual-net/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use crate::VirtualIoSource;
use crate::VirtualNetworking;
use crate::VirtualRawSocket;
use crate::VirtualSocket;
use crate::VirtualTcpBoundSocket;
use crate::VirtualTcpListener;
use crate::VirtualTcpSocket;
use crate::VirtualUdpSocket;
Expand Down Expand Up @@ -276,6 +277,7 @@ impl RemoteNetworkingClient {
buffer_accept: Default::default(),
buffer_recv_with_addr: Default::default(),
send_available: 0,
owns_socket_bindings: true,
}
}
}
Expand Down Expand Up @@ -760,6 +762,39 @@ impl VirtualNetworking for RemoteNetworkingClient {
}
}

async fn bind_tcp(
&self,
addr: SocketAddr,
only_v6: bool,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
let socket_id: SocketId = self
.common
.socket_seed
.fetch_add(1, Ordering::SeqCst)
.into();
match self
.common
.io_iface(RequestType::BindTcp {
socket_id,
addr,
only_v6,
reuse_port,
reuse_addr,
})
.await
{
ResponseType::Err(err) => Err(err),
ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
res => {
tracing::debug!("invalid response to bind TCP request - {res:?}");
Err(NetworkError::IOError)
}
}
}

async fn bind_udp(
&self,
addr: SocketAddr,
Expand Down Expand Up @@ -880,19 +915,29 @@ struct RemoteSocket {
buffer_recv_with_addr: VecDeque<DataWithAddr>,
buffer_accept: VecDeque<SocketWithAddr>,
send_available: u64,
owns_socket_bindings: bool,
}
impl Drop for RemoteSocket {
fn drop(&mut self) {
if !self.owns_socket_bindings {
return;
}
let _ = self.io_socket_fire_and_forget(RequestType::Close);
self.release_socket_bindings();
}
}

impl RemoteSocket {
fn release_socket_bindings(&mut self) {
self.owns_socket_bindings = false;
self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
self.common
.recv_with_addr_tx
.lock()
.unwrap()
.remove(&self.socket_id);
Comment thread
Arshia001 marked this conversation as resolved.
}
}

impl RemoteSocket {
async fn io_socket(&self, req: RequestType) -> ResponseType {
let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
let mut req_rx = {
Expand Down Expand Up @@ -941,6 +986,31 @@ impl RemoteSocket {
self.pending_accept.replace((child_id, rx_recv));
Ok(())
}

fn transition_socket(&mut self) -> RemoteSocket {
let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);

self.owns_socket_bindings = false;

RemoteSocket {
socket_id: self.socket_id,
common: self.common.clone(),
rx_buffer: std::mem::take(&mut self.rx_buffer),
rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
tx_waker: self.tx_waker.clone(),
rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
pending_accept: self.pending_accept.take(),
buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
buffer_accept: std::mem::take(&mut self.buffer_accept),
send_available: self.send_available,
owns_socket_bindings: true,
}
}
}

impl VirtualIoSource for RemoteSocket {
Expand Down Expand Up @@ -1121,6 +1191,7 @@ impl VirtualTcpListener for RemoteSocket {
buffer_accept: Default::default(),
buffer_recv_with_addr: Default::default(),
send_available: 0,
owns_socket_bindings: true,
};
Ok((Box::new(socket), accepted.addr))
}
Expand Down Expand Up @@ -1159,6 +1230,46 @@ impl VirtualTcpListener for RemoteSocket {
}
}

impl VirtualTcpBoundSocket for RemoteSocket {
fn addr_local(&self) -> Result<SocketAddr> {
VirtualSocket::addr_local(self)
}

fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
match block_on(self.io_socket(RequestType::ListenBound)) {
ResponseType::Err(err) => Err(err),
ResponseType::None => {
let mut socket = self.transition_socket();
socket.touch_begin_accept().ok();
Ok(Box::new(socket))
}
res => {
tracing::debug!("invalid response to listen bound request - {res:?}");
Err(NetworkError::IOError)
}
}
}

fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
ResponseType::Err(err) => Err(err),
ResponseType::None => Ok(Box::new(self.transition_socket())),
res => {
tracing::debug!("invalid response to connect bound request - {res:?}");
Err(NetworkError::IOError)
}
}
}

fn set_ttl(&mut self, ttl: u32) -> Result<()> {
VirtualSocket::set_ttl(self, ttl)
}

fn ttl(&self) -> Result<u32> {
VirtualSocket::ttl(self)
}
}

impl VirtualRawSocket for RemoteSocket {
fn try_send(&mut self, data: &[u8]) -> Result<usize> {
let mut cx = Context::from_waker(&self.tx_waker);
Expand Down Expand Up @@ -1431,7 +1542,11 @@ impl VirtualConnectedSocket for RemoteSocket {
}

fn close(&mut self) -> Result<()> {
self.io_socket_fire_and_forget(RequestType::Close)
let ret = self.io_socket_fire_and_forget(RequestType::Close);
if ret.is_ok() {
self.release_socket_bindings();
}
ret
}

fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
Expand Down
175 changes: 160 additions & 15 deletions lib/virtual-net/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::ruleset::{Direction, Ruleset};
use crate::{
IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
};
use crate::{VirtualIoSource, io_err_into_net_error};
use bytes::{Buf, BytesMut};
Expand All @@ -29,6 +29,14 @@ use virtual_mio::{
HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, state_as_waker_map,
};

/// Use the platform's maximum listen backlog where available so that
/// `LocalTcpBoundSocket::listen` preserves the same accept capacity as
/// the previous `std::net::TcpListener`-based implementation.
#[cfg(all(target_family = "unix", feature = "libc"))]
const LISTEN_BACKLOG: i32 = libc::SOMAXCONN;
#[cfg(not(all(target_family = "unix", feature = "libc")))]
const LISTEN_BACKLOG: i32 = 128;

#[derive(Debug)]
pub struct LocalNetworking {
selector: Arc<Selector>,
Expand Down Expand Up @@ -66,6 +74,41 @@ impl Default for LocalNetworking {
}
}

fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
addr.as_socket().ok_or(NetworkError::UnknownError)
}

fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
}
}

#[allow(clippy::needless_bool)]
fn tcp_connect_in_progress(err: &io::Error) -> bool {
if matches!(
err.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
) {
true
} else {
#[cfg(all(target_family = "unix", feature = "libc"))]
{
matches!(
err.raw_os_error(),
Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
)
}

#[cfg(not(all(target_family = "unix", feature = "libc")))]
{
false
}
}
}

#[async_trait::async_trait]
#[allow(unused_variables)]
impl VirtualNetworking for LocalNetworking {
Expand All @@ -83,21 +126,47 @@ impl VirtualNetworking for LocalNetworking {
return Err(NetworkError::PermissionDenied);
}

let listener = std::net::TcpListener::bind(addr)
.map(|sock| {
sock.set_nonblocking(true).ok();
Box::new(LocalTcpListener {
stream: mio::net::TcpListener::from_std(sock),
selector: self.selector.clone(),
handler_guard: HandlerGuardState::None,
no_delay: None,
keep_alive: None,
backlog: Default::default(),
ruleset: self.ruleset.clone(),
})
})
self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
.await?
.listen()
}

async fn bind_tcp(
&self,
addr: SocketAddr,
only_v6: bool,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
if let Some(ruleset) = self.ruleset.as_ref()
&& !ruleset.allows_socket(addr, Direction::Inbound)
{
tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
return Err(NetworkError::PermissionDenied);
}

let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
.map_err(io_err_into_net_error)?;
Ok(listener)
socket
.set_nonblocking(true)
.map_err(io_err_into_net_error)?;
if addr.is_ipv6() {
socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
}
socket
.set_reuse_address(reuse_addr)
.map_err(io_err_into_net_error)?;
#[cfg(not(windows))]
socket
.set_reuse_port(reuse_port)
.map_err(io_err_into_net_error)?;
socket.bind(&addr.into()).map_err(io_err_into_net_error)?;

Ok(Box::new(LocalTcpBoundSocket {
socket: Some(socket),
selector: self.selector.clone(),
ruleset: self.ruleset.clone(),
}))
}

async fn bind_udp(
Expand Down Expand Up @@ -377,6 +446,82 @@ impl VirtualIoSource for LocalTcpListener {
}
}

#[derive(Debug)]
pub struct LocalTcpBoundSocket {
socket: Option<socket2::Socket>,
selector: Arc<Selector>,
ruleset: Option<Ruleset>,
}

impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
fn addr_local(&self) -> Result<SocketAddr> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
let addr = socket.local_addr().map_err(io_err_into_net_error)?;
sock_addr_into_socket_addr(addr)
}

fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
socket
.listen(LISTEN_BACKLOG)
.map_err(io_err_into_net_error)?;
let listener = mio::net::TcpListener::from_std(socket.into());
Ok(Box::new(LocalTcpListener {
stream: listener,
selector: self.selector.clone(),
handler_guard: HandlerGuardState::None,
no_delay: None,
keep_alive: None,
backlog: Default::default(),
ruleset: self.ruleset.clone(),
}))
}
Comment thread
Arshia001 marked this conversation as resolved.

fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
if let Some(ruleset) = self.ruleset.as_ref()
&& !ruleset.allows_socket(peer, Direction::Outbound)
{
tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
return Err(NetworkError::PermissionDenied);
}

let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
if let Err(err) = socket.connect(&peer.into())
&& !tcp_connect_in_progress(&err)
{
return Err(io_err_into_net_error(err));
}

let stream = mio::net::TcpStream::from_std(socket.into());
if let Ok(p) = stream.peer_addr() {
peer = p;
}
Ok(Box::new(LocalTcpStream::new(
self.selector.clone(),
stream,
peer,
)))
}

fn set_ttl(&mut self, ttl: u32) -> Result<()> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
match self.addr_local()?.ip() {
IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
IpAddr::V6(_) => socket
.set_unicast_hops_v6(ttl)
.map_err(io_err_into_net_error),
}
}

fn ttl(&self) -> Result<u32> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
match self.addr_local()?.ip() {
IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
}
}
}

#[derive(Debug)]
enum ConnectState {
Unknown,
Expand Down
Loading
Loading