diff --git a/compio-net/Cargo.toml b/compio-net/Cargo.toml index 1559582f2..677caf70e 100644 --- a/compio-net/Cargo.toml +++ b/compio-net/Cargo.toml @@ -22,9 +22,9 @@ compio-runtime = { workspace = true } cfg-if = { workspace = true } either = "1.9.0" +futures-util = { workspace = true } once_cell = { workspace = true } socket2 = { workspace = true } -futures-util = { workspace = true } [target.'cfg(windows)'.dependencies] widestring = { workspace = true } diff --git a/compio-net/src/incoming/mod.rs b/compio-net/src/incoming/mod.rs new file mode 100644 index 000000000..3e58cb864 --- /dev/null +++ b/compio-net/src/incoming/mod.rs @@ -0,0 +1,11 @@ +cfg_if::cfg_if! { + if #[cfg(windows)] { + #[path = "windows.rs"] + mod sys; + } else if #[cfg(unix)] { + #[path = "unix.rs"] + mod sys; + } +} + +pub use sys::*; diff --git a/compio-net/src/incoming/unix.rs b/compio-net/src/incoming/unix.rs new file mode 100644 index 000000000..9d6d9d89c --- /dev/null +++ b/compio-net/src/incoming/unix.rs @@ -0,0 +1,65 @@ +use std::{ + io, + os::fd::FromRawFd, + pin::Pin, + task::{Context, Poll, ready}, +}; + +use compio_buf::{BufResult, IntoInner}; +use compio_driver::{SharedFd, ToSharedFd, op::AcceptMulti}; +use compio_runtime::SubmitMulti; +use futures_util::{Stream, StreamExt, stream::FusedStream}; +use socket2::Socket as Socket2; + +use crate::Socket; + +pub struct Incoming<'a> { + listener: &'a Socket, + op: Option>>>, +} + +impl<'a> Incoming<'a> { + pub fn new(listener: &'a Socket) -> Self { + Self { listener, op: None } + } +} + +impl Stream for Incoming<'_> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + if let Some(op) = &mut this.op { + let res = ready!(op.poll_next_unpin(cx)); + if let Some(BufResult(res, _)) = res { + let socket = if op.is_terminated() && res.is_ok() { + let Some(op) = this.op.take() else { + // SAFETY: op is guaranteed to be Some at this point. + unsafe { std::hint::unreachable_unchecked() } + }; + op.try_take() + .map_err(|_| ()) + .expect("AcceptMulti has not completed") + .into_inner() + } else { + unsafe { Socket2::from_raw_fd(res? as _) } + }; + return Poll::Ready(Some(Socket::from_socket2(socket))); + } else { + this.op = None; + } + } else { + this.op = Some(compio_runtime::submit_multi(AcceptMulti::new( + this.listener.to_shared_fd(), + ))); + } + } + } +} + +impl FusedStream for Incoming<'_> { + fn is_terminated(&self) -> bool { + false + } +} diff --git a/compio-net/src/incoming/windows.rs b/compio-net/src/incoming/windows.rs new file mode 100644 index 000000000..156c3b854 --- /dev/null +++ b/compio-net/src/incoming/windows.rs @@ -0,0 +1,92 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll, ready}, +}; + +use compio_buf::BufResult; +use compio_driver::{SharedFd, ToSharedFd, op::Accept}; +use compio_runtime::{JoinHandle, Submit}; +use futures_util::{FutureExt, Stream, stream::FusedStream}; +use socket2::Socket as Socket2; + +use crate::Socket; + +#[allow(clippy::large_enum_variant)] +enum IncomingState { + Idle, + CreatingSocket(JoinHandle>), + Accepting(Submit>>), +} + +pub struct Incoming<'a> { + listener: &'a Socket, + state: IncomingState, +} + +impl<'a> Incoming<'a> { + pub fn new(listener: &'a Socket) -> Self { + Self { + listener, + state: IncomingState::Idle, + } + } +} + +impl Stream for Incoming<'_> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + match &mut this.state { + IncomingState::Idle => { + let domain = this.listener.local_addr().map(|addr| addr.domain())?; + let ty = this.listener.socket.r#type()?; + let protocol = this.listener.socket.protocol()?; + let handle = + compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol)); + this.state = IncomingState::CreatingSocket(handle); + } + IncomingState::CreatingSocket(handle) => match ready!(handle.poll_unpin(cx)) { + Ok(Ok(socket)) => { + let op = compio_runtime::submit(Accept::new( + this.listener.to_shared_fd(), + socket, + )); + this.state = IncomingState::Accepting(op); + } + Ok(Err(e)) => { + this.state = IncomingState::Idle; + return Poll::Ready(Some(Err(e))); + } + Err(e) => { + this.state = IncomingState::Idle; + std::panic::resume_unwind(e) + } + }, + IncomingState::Accepting(op) => { + let BufResult(res, op) = ready!(op.poll_unpin(cx)); + match res { + Ok(_) => { + this.state = IncomingState::Idle; + op.update_context()?; + let (accept_sock, _) = op.into_addr()?; + return Poll::Ready(Some(Ok(Socket::from_socket2(accept_sock)?))); + } + Err(e) => { + this.state = IncomingState::Idle; + return Poll::Ready(Some(Err(e))); + } + } + } + } + } + } +} + +impl FusedStream for Incoming<'_> { + fn is_terminated(&self) -> bool { + false + } +} diff --git a/compio-net/src/lib.rs b/compio-net/src/lib.rs index c50c675e8..a2d47d4cd 100644 --- a/compio-net/src/lib.rs +++ b/compio-net/src/lib.rs @@ -14,6 +14,7 @@ html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg" )] +mod incoming; mod opts; mod resolve; mod socket; @@ -46,6 +47,7 @@ pub type CMsgBuilder<'a> = compio_io::ancillary::AncillaryBuilder<'a>; /// Providing functionalities to wait for readiness. #[deprecated(since = "0.12.0", note = "Use `compio::runtime::fd::PollFd` instead")] pub type PollFd = compio_runtime::fd::PollFd; +pub(crate) use incoming::*; pub use opts::SocketOpts; pub use resolve::ToSocketAddrsAsync; pub(crate) use resolve::{each_addr, first_addr_buf, first_addr_buf_zerocopy}; diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 1577e4789..a44c91fee 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -21,6 +21,8 @@ use compio_runtime::{Attacher, BorrowedBuffer, BufferPool, fd::PollFd}; use futures_util::StreamExt; use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type}; +use crate::Incoming; + #[derive(Debug, Clone)] pub struct Socket { pub(crate) socket: Attacher, @@ -121,6 +123,10 @@ impl Socket { Ok((Self::from_socket2(accept_sock)?, addr)) } + pub fn incoming(&self) -> Incoming<'_> { + Incoming::new(self) + } + pub fn close(self) -> impl Future> { // Make sure that self won't be dropped after `close` called. // Users may call this method and drop the future immediately. In that way the diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index e98c70dba..3373dee4c 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -1,13 +1,21 @@ -use std::{future::Future, io, net::SocketAddr}; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::impl_raw_fd; use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable}; use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd}; +use futures_util::{Stream, StreamExt, stream::FusedStream}; use socket2::{Protocol, SockAddr, Socket as Socket2, Type}; use crate::{ - OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync, WriteHalf, + Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync, + WriteHalf, }; /// A TCP socket server, listening for connections. @@ -120,6 +128,20 @@ impl TcpListener { Ok((stream, addr.as_socket().expect("should be SocketAddr"))) } + /// Returns a stream of incoming connections to this listener. + pub fn incoming(&self) -> TcpIncoming<'_> { + self.incoming_with_options(&SocketOpts::default()) + } + + /// Returns a stream of incoming connections to this listener, and sets + /// options for each accepted connection. + pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> TcpIncoming<'a> { + TcpIncoming { + inner: self.inner.incoming(), + opts: *options, + } + } + /// Returns the local address that this listener is bound to. /// /// This can be useful, for example, when binding to port 0 to @@ -152,6 +174,33 @@ impl TcpListener { impl_raw_fd!(TcpListener, socket2::Socket, inner, socket); +/// A stream of incoming TCP connections. +pub struct TcpIncoming<'a> { + inner: Incoming<'a>, + opts: SocketOpts, +} + +impl Stream for TcpIncoming<'_> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.inner.poll_next_unpin(cx).map(|res| { + res.map(|res| { + let socket = res?; + this.opts.setup_socket(&socket)?; + Ok(TcpStream { inner: socket }) + }) + }) + } +} + +impl FusedStream for TcpIncoming<'_> { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} + /// A TCP stream between a local and a remote socket. /// /// A TCP stream can either be created by connecting to an endpoint, via the diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index cb02cd220..df7d655d3 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -1,12 +1,19 @@ -use std::{future::Future, io, path::Path}; +use std::{ + future::Future, + io, + path::Path, + pin::Pin, + task::{Context, Poll}, +}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::impl_raw_fd; use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable}; use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd}; +use futures_util::{Stream, StreamExt, stream::FusedStream}; use socket2::{SockAddr, Socket as Socket2, Type}; -use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf}; +use crate::{Incoming, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, SocketOpts, WriteHalf}; /// A Unix socket server, listening for connections. /// @@ -113,6 +120,20 @@ impl UnixListener { Ok((stream, addr)) } + /// Returns a stream of incoming connections to this listener. + pub fn incoming(&self) -> UnixIncoming<'_> { + self.incoming_with_options(&SocketOpts::default()) + } + + /// Returns a stream of incoming connections to this listener, and sets + /// options for each accepted connection. + pub fn incoming_with_options<'a>(&'a self, options: &SocketOpts) -> UnixIncoming<'a> { + UnixIncoming { + inner: self.inner.incoming(), + opts: *options, + } + } + /// Returns the local address that this listener is bound to. pub fn local_addr(&self) -> io::Result { self.inner.local_addr() @@ -121,6 +142,33 @@ impl UnixListener { impl_raw_fd!(UnixListener, socket2::Socket, inner, socket); +/// A stream of incoming Unix connections. +pub struct UnixIncoming<'a> { + inner: Incoming<'a>, + opts: SocketOpts, +} + +impl Stream for UnixIncoming<'_> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.inner.poll_next_unpin(cx).map(|res| { + res.map(|res| { + let socket = res?; + this.opts.setup_socket(&socket)?; + Ok(UnixStream { inner: socket }) + }) + }) + } +} + +impl FusedStream for UnixIncoming<'_> { + fn is_terminated(&self) -> bool { + self.inner.is_terminated() + } +} + /// A Unix stream between two local sockets on Windows & WSL. /// /// A Unix stream can either be created by connecting to an endpoint, via the diff --git a/compio-net/tests/incoming.rs b/compio-net/tests/incoming.rs new file mode 100644 index 000000000..12a2f1d09 --- /dev/null +++ b/compio-net/tests/incoming.rs @@ -0,0 +1,62 @@ +use compio_io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use compio_net::{TcpListener, TcpStream, UnixListener, UnixStream}; +use futures_util::StreamExt; + +#[compio_macros::test] +async fn incoming_tcp() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let task = compio_runtime::spawn(async move { + let mut incoming = listener.incoming(); + let mut i = 0usize; + while let Some(stream) = incoming.next().await { + let mut stream = stream.unwrap(); + stream.write_all(format!("Hello, {}", i)).await.unwrap(); + stream.shutdown().await.unwrap(); + i += 1; + if i >= 2 { + break; + } + } + }); + + for i in 0..2 { + let mut client = TcpStream::connect(&addr).await.unwrap(); + let (_, text) = client.read_to_string(String::new()).await.unwrap(); + assert_eq!(text, format!("Hello, {}", i)); + } + + task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)); +} + +#[compio_macros::test] +async fn incoming_unix() { + let dir = tempfile::Builder::new() + .prefix("compio-uds-incoming-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("connect.sock"); + + let listener = UnixListener::bind(&sock_path).await.unwrap(); + let task = compio_runtime::spawn(async move { + let mut incoming = listener.incoming(); + let mut i = 0usize; + while let Some(stream) = incoming.next().await { + let mut stream = stream.unwrap(); + stream.write_all(format!("Hello, {}", i)).await.unwrap(); + stream.shutdown().await.unwrap(); + i += 1; + if i >= 2 { + break; + } + } + }); + + for i in 0..2 { + let mut client = UnixStream::connect(&sock_path).await.unwrap(); + let (_, text) = client.read_to_string(String::new()).await.unwrap(); + assert_eq!(text, format!("Hello, {}", i)); + } + + task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)); +}