Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

121 changes: 118 additions & 3 deletions lib/virtual-net/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use crate::VirtualIoSource;
use crate::VirtualNetworking;
use crate::VirtualRawSocket;
use crate::VirtualSocket;
use crate::VirtualTcpBoundSocket;
use crate::VirtualTcpListener;
use crate::VirtualTcpSocket;
use crate::VirtualUdpSocket;
Expand Down Expand Up @@ -276,6 +277,7 @@ impl RemoteNetworkingClient {
buffer_accept: Default::default(),
buffer_recv_with_addr: Default::default(),
send_available: 0,
owns_socket_bindings: true,
}
}
}
Expand Down Expand Up @@ -760,6 +762,39 @@ impl VirtualNetworking for RemoteNetworkingClient {
}
}

async fn bind_tcp(
&self,
addr: SocketAddr,
only_v6: bool,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
let socket_id: SocketId = self
.common
.socket_seed
.fetch_add(1, Ordering::SeqCst)
.into();
match self
.common
.io_iface(RequestType::BindTcp {
socket_id,
addr,
only_v6,
reuse_port,
reuse_addr,
})
.await
{
ResponseType::Err(err) => Err(err),
ResponseType::None => Ok(Box::new(self.new_socket(socket_id))),
ResponseType::Socket(socket_id) => Ok(Box::new(self.new_socket(socket_id))),
res => {
tracing::debug!("invalid response to bind TCP request - {res:?}");
Err(NetworkError::IOError)
}
}
}

async fn bind_udp(
&self,
addr: SocketAddr,
Expand Down Expand Up @@ -880,19 +915,29 @@ struct RemoteSocket {
buffer_recv_with_addr: VecDeque<DataWithAddr>,
buffer_accept: VecDeque<SocketWithAddr>,
send_available: u64,
owns_socket_bindings: bool,
}
impl Drop for RemoteSocket {
fn drop(&mut self) {
if !self.owns_socket_bindings {
return;
}
let _ = self.io_socket_fire_and_forget(RequestType::Close);
self.release_socket_bindings();
}
}

impl RemoteSocket {
fn release_socket_bindings(&mut self) {
self.owns_socket_bindings = false;
self.common.recv_tx.lock().unwrap().remove(&self.socket_id);
self.common
.recv_with_addr_tx
.lock()
.unwrap()
.remove(&self.socket_id);
Comment thread
Arshia001 marked this conversation as resolved.
}
}

impl RemoteSocket {
async fn io_socket(&self, req: RequestType) -> ResponseType {
let req_id = self.common.request_seed.fetch_add(1, Ordering::SeqCst);
let mut req_rx = {
Expand Down Expand Up @@ -941,6 +986,31 @@ impl RemoteSocket {
self.pending_accept.replace((child_id, rx_recv));
Ok(())
}

fn transition_socket(&mut self) -> RemoteSocket {
let (_tx_recv, rx_recv) = tokio::sync::mpsc::channel(1);
let (_tx_recv_with_addr, rx_recv_with_addr) = tokio::sync::mpsc::channel(1);
let (_tx_accept, rx_accept) = tokio::sync::mpsc::channel(1);
let (_tx_sent, rx_sent) = tokio::sync::mpsc::channel(1);

self.owns_socket_bindings = false;

RemoteSocket {
socket_id: self.socket_id,
common: self.common.clone(),
rx_buffer: std::mem::take(&mut self.rx_buffer),
rx_recv: std::mem::replace(&mut self.rx_recv, rx_recv),
rx_recv_with_addr: std::mem::replace(&mut self.rx_recv_with_addr, rx_recv_with_addr),
tx_waker: self.tx_waker.clone(),
rx_accept: std::mem::replace(&mut self.rx_accept, rx_accept),
rx_sent: std::mem::replace(&mut self.rx_sent, rx_sent),
pending_accept: self.pending_accept.take(),
buffer_recv_with_addr: std::mem::take(&mut self.buffer_recv_with_addr),
buffer_accept: std::mem::take(&mut self.buffer_accept),
send_available: self.send_available,
owns_socket_bindings: true,
}
}
}

impl VirtualIoSource for RemoteSocket {
Expand Down Expand Up @@ -1121,6 +1191,7 @@ impl VirtualTcpListener for RemoteSocket {
buffer_accept: Default::default(),
buffer_recv_with_addr: Default::default(),
send_available: 0,
owns_socket_bindings: true,
};
Ok((Box::new(socket), accepted.addr))
}
Expand Down Expand Up @@ -1159,6 +1230,46 @@ impl VirtualTcpListener for RemoteSocket {
}
}

impl VirtualTcpBoundSocket for RemoteSocket {
fn addr_local(&self) -> Result<SocketAddr> {
VirtualSocket::addr_local(self)
}

fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
match block_on(self.io_socket(RequestType::ListenBound)) {
ResponseType::Err(err) => Err(err),
ResponseType::None => {
let mut socket = self.transition_socket();
socket.touch_begin_accept().ok();
Ok(Box::new(socket))
}
res => {
tracing::debug!("invalid response to listen bound request - {res:?}");
Err(NetworkError::IOError)
}
}
}

fn connect(&mut self, peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
match block_on(self.io_socket(RequestType::ConnectBound { peer })) {
ResponseType::Err(err) => Err(err),
ResponseType::None => Ok(Box::new(self.transition_socket())),
res => {
tracing::debug!("invalid response to connect bound request - {res:?}");
Err(NetworkError::IOError)
}
}
}

fn set_ttl(&mut self, ttl: u32) -> Result<()> {
VirtualSocket::set_ttl(self, ttl)
}

fn ttl(&self) -> Result<u32> {
VirtualSocket::ttl(self)
}
}

impl VirtualRawSocket for RemoteSocket {
fn try_send(&mut self, data: &[u8]) -> Result<usize> {
let mut cx = Context::from_waker(&self.tx_waker);
Expand Down Expand Up @@ -1431,7 +1542,11 @@ impl VirtualConnectedSocket for RemoteSocket {
}

fn close(&mut self) -> Result<()> {
self.io_socket_fire_and_forget(RequestType::Close)
let ret = self.io_socket_fire_and_forget(RequestType::Close);
if ret.is_ok() {
self.release_socket_bindings();
}
ret
}

fn try_recv(&mut self, buf: &mut [std::mem::MaybeUninit<u8>], peek: bool) -> Result<usize> {
Expand Down
165 changes: 150 additions & 15 deletions lib/virtual-net/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::ruleset::{Direction, Ruleset};
use crate::{
IpCidr, IpRoute, NetworkError, Result, SocketStatus, StreamSecurity, VirtualConnectedSocket,
VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket,
VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
VirtualSocket, VirtualTcpBoundSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket,
};
use crate::{VirtualIoSource, io_err_into_net_error};
use bytes::{Buf, BytesMut};
Expand Down Expand Up @@ -66,6 +66,41 @@ impl Default for LocalNetworking {
}
}

fn sock_addr_into_socket_addr(addr: socket2::SockAddr) -> Result<SocketAddr> {
addr.as_socket().ok_or(NetworkError::UnknownError)
}

fn tcp_socket_domain(addr: SocketAddr) -> socket2::Domain {
if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
}
}

#[allow(clippy::needless_bool)]
fn tcp_connect_in_progress(err: &io::Error) -> bool {
if matches!(
err.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::Interrupted
) {
true
} else {
#[cfg(all(target_family = "unix", feature = "libc"))]
{
matches!(
err.raw_os_error(),
Some(raw) if raw == libc::EINPROGRESS || raw == libc::EALREADY
)
}

#[cfg(not(all(target_family = "unix", feature = "libc")))]
{
false
}
}
}

#[async_trait::async_trait]
#[allow(unused_variables)]
impl VirtualNetworking for LocalNetworking {
Expand All @@ -83,21 +118,47 @@ impl VirtualNetworking for LocalNetworking {
return Err(NetworkError::PermissionDenied);
}

let listener = std::net::TcpListener::bind(addr)
.map(|sock| {
sock.set_nonblocking(true).ok();
Box::new(LocalTcpListener {
stream: mio::net::TcpListener::from_std(sock),
selector: self.selector.clone(),
handler_guard: HandlerGuardState::None,
no_delay: None,
keep_alive: None,
backlog: Default::default(),
ruleset: self.ruleset.clone(),
})
})
self.bind_tcp(addr, only_v6, reuse_port, reuse_addr)
.await?
.listen()
}

async fn bind_tcp(
&self,
addr: SocketAddr,
only_v6: bool,
reuse_port: bool,
reuse_addr: bool,
) -> Result<Box<dyn VirtualTcpBoundSocket + Sync>> {
if let Some(ruleset) = self.ruleset.as_ref()
&& !ruleset.allows_socket(addr, Direction::Inbound)
{
tracing::warn!(%addr, "bind_tcp blocked by firewall rule");
return Err(NetworkError::PermissionDenied);
}

let socket = socket2::Socket::new(tcp_socket_domain(addr), socket2::Type::STREAM, None)
.map_err(io_err_into_net_error)?;
Ok(listener)
socket
.set_nonblocking(true)
.map_err(io_err_into_net_error)?;
if addr.is_ipv6() {
socket.set_only_v6(only_v6).map_err(io_err_into_net_error)?;
}
socket
.set_reuse_address(reuse_addr)
.map_err(io_err_into_net_error)?;
#[cfg(not(windows))]
socket
.set_reuse_port(reuse_port)
.map_err(io_err_into_net_error)?;
socket.bind(&addr.into()).map_err(io_err_into_net_error)?;

Ok(Box::new(LocalTcpBoundSocket {
socket: Some(socket),
selector: self.selector.clone(),
ruleset: self.ruleset.clone(),
}))
}

async fn bind_udp(
Expand Down Expand Up @@ -377,6 +438,80 @@ impl VirtualIoSource for LocalTcpListener {
}
}

#[derive(Debug)]
pub struct LocalTcpBoundSocket {
socket: Option<socket2::Socket>,
selector: Arc<Selector>,
ruleset: Option<Ruleset>,
}

impl VirtualTcpBoundSocket for LocalTcpBoundSocket {
fn addr_local(&self) -> Result<SocketAddr> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
let addr = socket.local_addr().map_err(io_err_into_net_error)?;
sock_addr_into_socket_addr(addr)
}

fn listen(&mut self) -> Result<Box<dyn VirtualTcpListener + Sync>> {
let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
socket.listen(128).map_err(io_err_into_net_error)?;
let listener = mio::net::TcpListener::from_std(socket.into());
Ok(Box::new(LocalTcpListener {
stream: listener,
selector: self.selector.clone(),
handler_guard: HandlerGuardState::None,
no_delay: None,
keep_alive: None,
backlog: Default::default(),
ruleset: self.ruleset.clone(),
}))
}
Comment thread
Arshia001 marked this conversation as resolved.

fn connect(&mut self, mut peer: SocketAddr) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
if let Some(ruleset) = self.ruleset.as_ref()
&& !ruleset.allows_socket(peer, Direction::Outbound)
{
tracing::warn!(%peer, "bound connect_tcp blocked by firewall rule");
return Err(NetworkError::PermissionDenied);
}

let socket = self.socket.take().ok_or(NetworkError::InvalidFd)?;
if let Err(err) = socket.connect(&peer.into())
&& !tcp_connect_in_progress(&err)
{
return Err(io_err_into_net_error(err));
}

let stream = mio::net::TcpStream::from_std(socket.into());
if let Ok(p) = stream.peer_addr() {
peer = p;
}
Ok(Box::new(LocalTcpStream::new(
self.selector.clone(),
stream,
peer,
)))
}

fn set_ttl(&mut self, ttl: u32) -> Result<()> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
match self.addr_local()?.ip() {
IpAddr::V4(_) => socket.set_ttl_v4(ttl).map_err(io_err_into_net_error),
IpAddr::V6(_) => socket
.set_unicast_hops_v6(ttl)
.map_err(io_err_into_net_error),
}
}

fn ttl(&self) -> Result<u32> {
let socket = self.socket.as_ref().ok_or(NetworkError::InvalidFd)?;
match self.addr_local()?.ip() {
IpAddr::V4(_) => socket.ttl_v4().map_err(io_err_into_net_error),
IpAddr::V6(_) => socket.unicast_hops_v6().map_err(io_err_into_net_error),
}
}
}

#[derive(Debug)]
enum ConnectState {
Unknown,
Expand Down
Loading
Loading