diff --git a/compio-net/Cargo.toml b/compio-net/Cargo.toml index e4f24bd4b..1559582f2 100644 --- a/compio-net/Cargo.toml +++ b/compio-net/Cargo.toml @@ -24,6 +24,7 @@ cfg-if = { workspace = true } either = "1.9.0" once_cell = { workspace = true } socket2 = { workspace = true } +futures-util = { workspace = true } [target.'cfg(windows)'.dependencies] widestring = { workspace = true } @@ -40,7 +41,6 @@ libc = { workspace = true } # Shared dev dependencies for all platforms [dev-dependencies] compio-macros = { workspace = true } -futures-util = { workspace = true } tempfile = { workspace = true } [features] diff --git a/compio-net/src/lib.rs b/compio-net/src/lib.rs index d24b0e9f0..c50c675e8 100644 --- a/compio-net/src/lib.rs +++ b/compio-net/src/lib.rs @@ -48,7 +48,7 @@ pub type CMsgBuilder<'a> = compio_io::ancillary::AncillaryBuilder<'a>; pub type PollFd = compio_runtime::fd::PollFd; pub use opts::SocketOpts; pub use resolve::ToSocketAddrsAsync; -pub(crate) use resolve::{each_addr, first_addr_buf}; +pub(crate) use resolve::{each_addr, first_addr_buf, first_addr_buf_zerocopy}; pub(crate) use socket::*; pub use split::*; pub use tcp::*; diff --git a/compio-net/src/resolve/mod.rs b/compio-net/src/resolve/mod.rs index 623898e77..312b8e3f3 100644 --- a/compio-net/src/resolve/mod.rs +++ b/compio-net/src/resolve/mod.rs @@ -9,7 +9,7 @@ cfg_if::cfg_if! { } use std::{ - future::Future, + future::{Future, Ready, ready}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, }; @@ -169,3 +169,34 @@ pub async fn first_addr_buf>>( ) } } + +pub async fn first_addr_buf_zerocopy( + addr: impl ToSocketAddrsAsync, + buffer: B, + f: impl FnOnce(SocketAddr, B) -> F1, +) -> BufResult, F2>> +where + F1: Future>, + F2: Future, +{ + fn ret(fut: T) -> Either, F> { + Either::Left(ready(fut)) + } + + let mut addrs = match addr.to_socket_addrs_async().await { + Ok(addrs) => addrs, + Err(e) => return BufResult(Err(e), ret(buffer)), + }; + if let Some(addr) = addrs.next() { + let BufResult(res, fut) = f(addr, buffer).await; + BufResult(res, Either::Right(fut)) + } else { + BufResult( + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "could not operate on first address", + )), + ret(buffer), + ) + } +} diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index ede37a2a7..1577e4789 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -8,15 +8,17 @@ use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectore #[cfg(unix)] use compio_driver::op::CreateSocket; use compio_driver::{ - AsRawFd, ToSharedFd, impl_raw_fd, + AsRawFd, OpCode, ToSharedFd, impl_raw_fd, op::{ Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromManaged, RecvFromVectored, RecvManaged, RecvMsg, RecvResultExt, RecvVectored, ResultTakeBuffer, - Send, SendMsg, SendTo, SendToVectored, SendVectored, ShutdownSocket, VecBufResultExt, + Send, SendMsg, SendMsgZc, SendTo, SendToVectored, SendToVectoredZc, SendToZc, SendVectored, + SendVectoredZc, SendZc, ShutdownSocket, VecBufResultExt, }, syscall, }; use compio_runtime::{Attacher, BorrowedBuffer, BufferPool, fd::PollFd}; +use futures_util::StreamExt; use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type}; #[derive(Debug, Clone)] @@ -208,6 +210,22 @@ impl Socket { compio_runtime::submit(op).await.into_inner() } + pub async fn send_zerocopy( + &self, + buf: T, + flags: i32, + ) -> BufResult + use> { + submit_zerocopy(SendZc::new(self.to_shared_fd(), buf, flags)).await + } + + pub async fn send_zerocopy_vectored( + &self, + buf: T, + flags: i32, + ) -> BufResult + use> { + submit_zerocopy(SendVectoredZc::new(self.to_shared_fd(), buf, flags)).await + } + pub async fn recv_from( &self, buffer: T, @@ -275,6 +293,26 @@ impl Socket { compio_runtime::submit(op).await.into_inner() } + pub async fn send_to_zerocopy( + &self, + buffer: T, + addr: &SockAddr, + flags: i32, + ) -> BufResult + use> { + let op = SendToZc::new(self.to_shared_fd(), buffer, addr.clone(), flags); + submit_zerocopy(op).await + } + + pub async fn send_to_zerocopy_vectored( + &self, + buffer: T, + addr: &SockAddr, + flags: i32, + ) -> BufResult + use> { + let op = SendToVectoredZc::new(self.to_shared_fd(), buffer, addr.clone(), flags); + submit_zerocopy(op).await + } + pub async fn send_msg( &self, buffer: T, @@ -299,6 +337,33 @@ impl Socket { compio_runtime::submit(op).await.into_inner() } + pub async fn send_msg_zerocopy( + &self, + buffer: T, + control: C, + addr: Option<&SockAddr>, + flags: i32, + ) -> BufResult + use> { + self.send_msg_zerocopy_vectored([buffer], control, addr, flags) + .await + .map_buffer(|fut| async move { + let ([buffer], control) = fut.await; + (buffer, control) + }) + } + + pub async fn send_msg_zerocopy_vectored( + &self, + buffer: T, + control: C, + addr: Option<&SockAddr>, + flags: i32, + ) -> BufResult + use> { + let fd = self.to_shared_fd(); + let op = SendMsgZc::new(fd, buffer, control, addr.cloned(), flags); + submit_zerocopy(op).await + } + #[cfg(unix)] pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { let mut value: MaybeUninit = MaybeUninit::uninit(); @@ -377,3 +442,27 @@ impl Socket { } impl_raw_fd!(Socket, Socket2, socket, socket); + +async fn submit_zerocopy( + op: T, +) -> BufResult + use> { + let mut stream = compio_runtime::submit_multi(op); + let res = stream + .next() + .await + .expect("SubmitMulti should yield at least one item") + .0; + + let fut = async move { + // we don't need 2nd CQE's result + _ = stream.next().await; + + stream + .try_take() + .map_err(|_| ()) + .expect("Cannot retrieve buffer") + .into_inner() + }; + + BufResult(res, fut) +} diff --git a/compio-net/src/split.rs b/compio-net/src/split.rs index e61f7b01a..fe15294b6 100644 --- a/compio-net/src/split.rs +++ b/compio-net/src/split.rs @@ -37,19 +37,19 @@ where for<'a> &'a T: AsyncWrite, { async fn write(&mut self, buf: B) -> BufResult { - self.0.write(buf).await + (self.0).write(buf).await } async fn write_vectored(&mut self, buf: B) -> BufResult { - self.0.write_vectored(buf).await + (self.0).write_vectored(buf).await } async fn flush(&mut self) -> io::Result<()> { - self.0.flush().await + (self.0).flush().await } async fn shutdown(&mut self) -> io::Result<()> { - self.0.shutdown().await + (self.0).shutdown().await } } diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index aa0d36830..e98c70dba 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -328,6 +328,30 @@ impl TcpStream { self.inner.send(buf, MSG_OOB).await } + + /// Sends data using [zero-copy send](https://man7.org/linux/man-pages/man3/io_uring_prep_send_zc.3.html). + /// + /// If the underlying platform doesn't support zero-copy send, it will fall + /// back to normal send. + pub async fn send_zerocopy( + &self, + buf: T, + flags: i32, + ) -> BufResult + use> { + self.inner.send_zerocopy(buf, flags).await + } + + /// Sends vectorized data using [zero-copy send](https://man7.org/linux/man-pages/man3/io_uring_prep_send_zc.3.html). + /// + /// If the underlying platform doesn't support zero-copy send, it will fall + /// back to normal send. + pub async fn send_zerocopy_vectored( + &self, + buf: T, + flags: i32, + ) -> BufResult + use> { + self.inner.send_zerocopy_vectored(buf, flags).await + } } impl AsyncRead for TcpStream { @@ -405,12 +429,14 @@ impl AsyncWrite for TcpStream { impl AsyncWrite for &TcpStream { #[inline] async fn write(&mut self, buf: T) -> BufResult { - self.inner.send(buf, 0).await + let BufResult(res, fut) = self.send_zerocopy(buf, 0).await; + BufResult(res, fut.await) } #[inline] async fn write_vectored(&mut self, buf: T) -> BufResult { - self.inner.send_vectored(buf, 0).await + let BufResult(res, fut) = self.send_zerocopy_vectored(buf, 0).await; + BufResult(res, fut.await) } #[inline] diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index d5bca496e..1f797a458 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -389,6 +389,76 @@ impl UdpSocket { .await } + /// Sends data on the socket to the given address with zero copy. + /// + /// Returns the result of send and a future that resolves to the + /// original buffer when the send is complete. + pub async fn send_to_zerocopy( + &self, + buffer: T, + addr: A, + ) -> BufResult + use> { + super::first_addr_buf_zerocopy(addr, buffer, |addr, buffer| async move { + self.inner.send_to_zerocopy(buffer, &addr.into(), 0).await + }) + .await + } + + /// Sends vectored data on the socket to the given address with zero copy. + /// + /// Returns the result of send and a future that resolves to the + /// original buffer when the send is complete. + pub async fn send_to_zerocopy_vectored( + &self, + buffer: T, + addr: A, + ) -> BufResult + use> { + super::first_addr_buf_zerocopy(addr, buffer, |addr, buffer| async move { + self.inner + .send_to_zerocopy_vectored(buffer, &addr.into(), 0) + .await + }) + .await + } + + /// Sends data with control message on the socket to the given address with + /// zero copy. + /// + /// Returns the result of send and a future that resolves to the + /// original buffer when the send is complete. + pub async fn send_msg_zerocopy( + &self, + buffer: T, + control: C, + addr: A, + ) -> BufResult + use> { + super::first_addr_buf_zerocopy(addr, (buffer, control), |addr, (b, c)| async move { + self.inner + .send_msg_zerocopy(b, c, Some(&addr.into()), 0) + .await + }) + .await + } + + /// Sends vectored data with control message on the socket to the given + /// address with zero copy. + /// + /// Returns the result of send and a future that resolves to the + /// original buffer when the send is complete. + pub async fn send_msg_zerocopy_vectored( + &self, + buffer: T, + control: C, + addr: A, + ) -> BufResult + use> { + super::first_addr_buf_zerocopy(addr, (buffer, control), |addr, (b, c)| async move { + self.inner + .send_msg_zerocopy_vectored(b, c, Some(&addr.into()), 0) + .await + }) + .await + } + /// Gets a socket option. /// /// # Safety diff --git a/compio-net/tests/zero_copy.rs b/compio-net/tests/zero_copy.rs new file mode 100644 index 000000000..f4a1e3215 --- /dev/null +++ b/compio-net/tests/zero_copy.rs @@ -0,0 +1,93 @@ +use compio_buf::BufResult; +use compio_io::AsyncReadExt; +use compio_net::{TcpListener, TcpStream, UdpSocket}; + +#[compio_macros::test] +async fn tcp_zerocopy() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let task = compio_runtime::spawn(async move { listener.accept().await.unwrap() }); + + let tx = TcpStream::connect(&addr).await.unwrap(); + let (mut rx, _) = task.await.unwrap(); + + let buffer = Vec::from(b"hello world" as &[u8]); + let BufResult(res, fut) = tx.send_zerocopy(buffer, 0).await; + assert_eq!(res.unwrap(), 11); + let buffer = fut.await; + assert_eq!(buffer, b"hello world"); + + let buf = Vec::with_capacity(11); + let (_, buf) = rx.read_exact(buf).await.unwrap(); + assert_eq!(buf, b"hello world"); +} + +#[compio_macros::test] +async fn tcp_zerocopy_vectored() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let task = compio_runtime::spawn(async move { listener.accept().await.unwrap() }); + + let tx = TcpStream::connect(&addr).await.unwrap(); + let (mut rx, _) = task.await.unwrap(); + + let buffer = [ + Vec::from(b"hello" as &[u8]), + Vec::from(b" " as &[u8]), + Vec::from(b"world" as &[u8]), + ]; + let BufResult(res, fut) = tx.send_zerocopy_vectored(buffer, 0).await; + assert_eq!(res.unwrap(), 11); + let buffer = fut.await; + assert_eq!(buffer[0], b"hello"); + assert_eq!(buffer[1], b" "); + assert_eq!(buffer[2], b"world"); + + let buf = Vec::with_capacity(11); + let (_, buf) = rx.read_exact(buf).await.unwrap(); + assert_eq!(buf, b"hello world"); +} + +#[compio_macros::test] +async fn udp_zerocopy() { + let receiver = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = receiver.local_addr().unwrap(); + + let sender = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let buffer = Vec::from(b"hello world" as &[u8]); + let BufResult(res, fut) = sender.send_to_zerocopy(buffer, addr).await; + assert_eq!(res.unwrap(), 11); + let buffer = fut.await; + assert_eq!(buffer, b"hello world"); + + let (len, buf) = receiver.recv(Vec::with_capacity(11)).await.unwrap(); + assert_eq!(len, 11); + assert_eq!(buf, b"hello world"); +} + +#[compio_macros::test] +async fn udp_zerocopy_vectored() { + let receiver = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let addr = receiver.local_addr().unwrap(); + + let sender = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let buffer = [ + Vec::from(b"hello" as &[u8]), + Vec::from(b" " as &[u8]), + Vec::from(b"world" as &[u8]), + ]; + let BufResult(res, fut) = sender.send_to_zerocopy_vectored(buffer, addr).await; + assert_eq!(res.unwrap(), 11); + let buffer = fut.await; + assert_eq!(buffer[0], b"hello"); + assert_eq!(buffer[1], b" "); + assert_eq!(buffer[2], b"world"); + + let (len, buf) = receiver.recv(Vec::with_capacity(11)).await.unwrap(); + assert_eq!(len, 11); + assert_eq!(buf, b"hello world"); +}