diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index 7fc282b92..46aada4da 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -16,14 +16,14 @@ pub use crate::sys::op::{ Accept, Recv, RecvFrom, RecvFromVectored, RecvMsg, RecvVectored, Send, SendMsg, SendTo, SendToVectored, SendVectored, }; -#[cfg(windows)] -pub use crate::sys::op::{ConnectNamedPipe, DeviceIoControl}; #[cfg(unix)] pub use crate::sys::op::{ - CreateDir, CreateSocket, CurrentDir, FileStat, HardLink, Interest, OpenFile, PathStat, - PollOnce, ReadVectored, ReadVectoredAt, Rename, Stat, Symlink, TruncateFile, Unlink, + AcceptMulti, CreateDir, CreateSocket, CurrentDir, FileStat, HardLink, Interest, OpenFile, + PathStat, PollOnce, ReadVectored, ReadVectoredAt, Rename, Stat, Symlink, TruncateFile, Unlink, WriteVectored, WriteVectoredAt, }; +#[cfg(windows)] +pub use crate::sys::op::{ConnectNamedPipe, DeviceIoControl}; #[cfg(io_uring)] pub use crate::sys::op::{ ReadManaged, ReadManagedAt, ReadMulti, ReadMultiAt, RecvFromManaged, RecvManaged, RecvMulti, @@ -665,7 +665,9 @@ pub(crate) mod managed { } #[cfg(not(io_uring))] -pub use managed::*; +pub use managed::{ + ReadManaged, ReadManagedAt, ReadMulti, ReadMultiAt, RecvFromManaged, RecvManaged, RecvMulti, +}; bitflags::bitflags! { /// Flags for operations. diff --git a/compio-driver/src/sys/fusion/op.rs b/compio-driver/src/sys/fusion/op.rs index 15c3f6b59..b8800e694 100644 --- a/compio-driver/src/sys/fusion/op.rs +++ b/compio-driver/src/sys/fusion/op.rs @@ -126,6 +126,7 @@ mod iour { pub use crate::sys::iour::{op::*, OpCode}; } #[rustfmt::skip] mod poll { pub use crate::sys::poll::{op::*, OpCode}; } +op!( AcceptMulti(fd: S)); op!( RecvFrom(fd: S, buffer: T, flags: i32)); op!( SendTo(fd: S, buffer: T, addr: SockAddr, flags: i32)); op!( RecvFromVectored(fd: S, buffer: T, flags: i32)); diff --git a/compio-driver/src/sys/iour/op.rs b/compio-driver/src/sys/iour/op.rs index 3470a3aa0..dffaa7499 100644 --- a/compio-driver/src/sys/iour/op.rs +++ b/compio-driver/src/sys/iour/op.rs @@ -1,8 +1,9 @@ use std::{ + collections::VecDeque, ffi::CString, io, marker::PhantomPinned, - os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd}, + os::fd::{AsFd, AsRawFd, FromRawFd, IntoRawFd, OwnedFd}, pin::Pin, }; @@ -557,6 +558,82 @@ unsafe impl OpCode for Accept { } } +struct AcceptMultishotResult { + res: io::Result, + extra: crate::Extra, +} + +impl AcceptMultishotResult { + pub unsafe fn new(res: io::Result, extra: crate::Extra) -> Self { + Self { + res: res.map(|fd| unsafe { Socket2::from_raw_fd(fd as _) }), + extra, + } + } + + pub fn into_result(self) -> BufResult { + BufResult(self.res.map(|fd| fd.into_raw_fd() as _), self.extra) + } +} + +pin_project! { + /// Accept multiple connections. + pub struct AcceptMulti { + #[pin] + pub(crate) op: Accept, + multishots: VecDeque + } +} + +impl AcceptMulti { + /// Create [`AcceptMulti`]. + pub fn new(fd: S) -> Self { + Self { + op: Accept::new(fd), + multishots: VecDeque::new(), + } + } +} + +unsafe impl OpCode for AcceptMulti { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + let this = self.project(); + opcode::AcceptMulti::new(Fd(this.op.fd.as_fd().as_raw_fd())) + .flags(libc::SOCK_CLOEXEC) + .build() + .into() + } + + fn create_entry_fallback(self: Pin<&mut Self>) -> OpEntry { + self.project().op.create_entry() + } + + unsafe fn set_result(self: Pin<&mut Self>, res: &io::Result, extra: &crate::Extra) { + unsafe { self.project().op.set_result(res, extra) } + } + + unsafe fn push_multishot(self: Pin<&mut Self>, res: io::Result, extra: crate::Extra) { + self.project() + .multishots + .push_back(unsafe { AcceptMultishotResult::new(res, extra) }); + } + + fn pop_multishot(self: Pin<&mut Self>) -> Option> { + self.project() + .multishots + .pop_front() + .map(|res| res.into_result()) + } +} + +impl IntoInner for AcceptMulti { + type Inner = Socket2; + + fn into_inner(self) -> Self::Inner { + self.op.into_inner().0 + } +} + unsafe impl OpCode for Connect { fn create_entry(self: Pin<&mut Self>) -> OpEntry { opcode::Connect::new( diff --git a/compio-driver/src/sys/poll/op.rs b/compio-driver/src/sys/poll/op.rs index be99dc513..fec16aea5 100644 --- a/compio-driver/src/sys/poll/op.rs +++ b/compio-driver/src/sys/poll/op.rs @@ -768,6 +768,44 @@ unsafe impl OpCode for Accept { } } +pin_project! { + /// Accept multiple connections. + pub struct AcceptMulti { + #[pin] + pub(crate) op: Accept, + } +} + +impl AcceptMulti { + /// Create [`AcceptMulti`]. + pub fn new(fd: S) -> Self { + Self { + op: Accept::new(fd), + } + } +} + +unsafe impl OpCode for AcceptMulti { + fn pre_submit(self: Pin<&mut Self>) -> io::Result { + self.project().op.pre_submit() + } + + fn op_type(self: Pin<&mut Self>) -> Option { + self.project().op.op_type() + } + + fn operate(self: Pin<&mut Self>) -> Poll> { + self.project().op.operate() + } +} + +impl IntoInner for AcceptMulti { + type Inner = Socket2; + + fn into_inner(self) -> Self::Inner { + self.op.into_inner().0 + } +} unsafe impl OpCode for Connect { fn pre_submit(self: Pin<&mut Self>) -> io::Result { syscall!( diff --git a/compio-driver/src/sys/stub/op.rs b/compio-driver/src/sys/stub/op.rs index 29f71f8e3..e190562ad 100644 --- a/compio-driver/src/sys/stub/op.rs +++ b/compio-driver/src/sys/stub/op.rs @@ -3,7 +3,7 @@ use std::ffi::CString; use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; -use socket2::{SockAddr, SockAddrStorage, socklen_t}; +use socket2::{SockAddr, SockAddrStorage, Socket as Socket2, socklen_t}; use super::{OpCode, stub_unimpl}; pub use crate::sys::unix_op::*; @@ -125,6 +125,28 @@ impl OpCode for CloseSocket {} impl OpCode for Accept {} +/// Accept multiple connections. +pub struct AcceptMulti { + fd: S, +} + +impl AcceptMulti { + /// Create [`AcceptMulti`]. + pub fn new(fd: S) -> Self { + Self { fd } + } +} + +impl IntoInner for AcceptMulti { + type Inner = Socket2; + + fn into_inner(self) -> Self::Inner { + stub_unimpl() + } +} + +impl OpCode for AcceptMulti {} + impl OpCode for Connect {} impl OpCode for Recv {} diff --git a/compio-driver/src/sys/unix_op.rs b/compio-driver/src/sys/unix_op.rs index 115a7ca89..7a142b16a 100644 --- a/compio-driver/src/sys/unix_op.rs +++ b/compio-driver/src/sys/unix_op.rs @@ -548,9 +548,12 @@ impl Accept { _p: PhantomPinned, } } +} + +impl IntoInner for Accept { + type Inner = (Socket2, SockAddr); - /// Get the remote address from the inner buffer. - pub fn into_addr(mut self) -> (Socket2, SockAddr) { + fn into_inner(mut self) -> Self::Inner { let socket = self.accepted_fd.take().expect("socket not accepted"); (socket, unsafe { SockAddr::new(self.buffer, self.addr_len) }) } diff --git a/compio-driver/tests/file.rs b/compio-driver/tests/op.rs similarity index 67% rename from compio-driver/tests/file.rs rename to compio-driver/tests/op.rs index a9043b341..45b29e9b2 100644 --- a/compio-driver/tests/file.rs +++ b/compio-driver/tests/op.rs @@ -1,13 +1,12 @@ use std::{ io::{self, Write}, net::{TcpListener, TcpStream}, - ops::Deref, time::Duration, }; use compio_buf::BufResult; use compio_driver::{ - AsRawFd, BufferPool, Extra, OpCode, OwnedFd, Proactor, PushEntry, SharedFd, TakeBuffer, + AsRawFd, Extra, OpCode, OwnedFd, Proactor, PushEntry, SharedFd, TakeBuffer, op::{ Asyncify, CloseFile, CloseSocket, ReadAt, ReadManagedAt, ReadMultiAt, RecvMulti, ResultTakeBuffer, @@ -15,6 +14,9 @@ use compio_driver::{ }; mod pipe2; +#[cfg(unix)] +use compio_driver::op::AcceptMulti; + #[cfg(unix)] #[test] fn truncate_file_poll() { @@ -104,42 +106,46 @@ fn push_and_wait(driver: &mut Proactor, op: O) -> BufResult } } -fn push_and_wait_multi + 'static>( +fn push_and_wait_multi( driver: &mut Proactor, op: O, - pool: &BufferPool, -) -> Vec -where - for<'a> O::Buffer<'a>: Deref, -{ - match driver.push(op) { - PushEntry::Ready(res) => match (res, driver.default_extra()).take_buffer(pool) { - Ok(slice) => slice.to_vec(), - Err(_) => vec![], - }, - PushEntry::Pending(mut user_data) => { - let mut buffer = vec![]; - loop { - driver.poll(None).unwrap(); - while let Some(res) = driver.pop_multishot(&user_data) { - match res.take_buffer(pool) { - Ok(slice) => buffer.extend_from_slice(&slice), - Err(_) => break, - } +) -> impl Iterator)>> + '_ { + let mut op = Some(op); + let mut user_data = None; + let mut finished = false; + + std::iter::from_fn(move || { + if finished { + return None; + } + + if user_data.is_none() { + match driver.push(op.take().expect("operation should be pushed once")) { + PushEntry::Ready(BufResult(res, op)) => { + finished = true; + return Some(BufResult(res, (driver.default_extra(), Some(op)))); } - match driver.pop_with_extra(user_data) { - PushEntry::Pending(k) => user_data = k, - PushEntry::Ready(res) => { - if let Ok(slice) = res.take_buffer(pool) { - buffer.extend_from_slice(&slice) - } - break; - } + PushEntry::Pending(k) => user_data = Some(k), + } + } + + loop { + if let Some(res) = user_data.as_ref().and_then(|key| driver.pop_multishot(key)) { + return Some(res.map_buffer(|extra| (extra, None))); + } + + let key = user_data.take().expect("pending key should exist"); + match driver.pop_with_extra(key) { + PushEntry::Pending(k) => user_data = Some(k), + PushEntry::Ready((BufResult(res, op), extra)) => { + finished = true; + return Some(BufResult(res, (extra, Some(op)))); } } - buffer + + driver.poll(None).unwrap(); } - } + }) } #[test] @@ -275,7 +281,20 @@ fn read_multi() { let pool = driver.create_buffer_pool(4, 1024).unwrap(); let op = ReadMultiAt::new(fd.clone(), 0, &pool, 1024).unwrap(); - let buffer = push_and_wait_multi(&mut driver, op, &pool); + let buffer = push_and_wait_multi(&mut driver, op) + .map(|BufResult(res, (extra, op))| { + if let Some(op) = op { + (BufResult(res, op), extra).take_buffer(&pool) + } else { + BufResult(res, extra).take_buffer(&pool) + } + .map(|buf| buf.to_vec()) + .unwrap_or_default() + }) + .collect::>() + .into_iter() + .flatten() + .collect::>(); println!("{}", std::str::from_utf8(&buffer).unwrap()); @@ -310,7 +329,20 @@ fn recv_multi() { let mut buffer = vec![]; loop { let op = RecvMulti::new(stream.clone(), &pool, 0, 0).unwrap(); - let slice = push_and_wait_multi(&mut driver, op, &pool); + let slice = push_and_wait_multi(&mut driver, op) + .map(|BufResult(res, (extra, op))| { + if let Some(op) = op { + (BufResult(res, op), extra).take_buffer(&pool) + } else { + BufResult(res, extra).take_buffer(&pool) + } + .map(|buf| buf.to_vec()) + .unwrap_or_default() + }) + .collect::>() + .into_iter() + .flatten() + .collect::>(); if slice.is_empty() { break; } @@ -323,6 +355,61 @@ fn recv_multi() { push_and_wait(&mut driver, op).unwrap(); } +#[cfg(unix)] +#[test] +fn accept_multi() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + + let handle = std::thread::spawn(move || { + let mut driver = Proactor::new().unwrap(); + + let server = socket2::Socket::from(server); + if driver.driver_type().is_polling() { + server.set_nonblocking(true).unwrap(); + } + let server = SharedFd::new(server); + + driver.attach(server.as_raw_fd()).unwrap(); + + let mut i = 0; + loop { + let op = AcceptMulti::new(server.clone()); + for BufResult(res, (_, op)) in push_and_wait_multi(&mut driver, op) { + let mut client = if let Some(op) = op { + use compio_buf::IntoInner; + + op.into_inner() + } else { + unsafe { + use std::os::fd::FromRawFd; + socket2::Socket::from_raw_fd(res.unwrap() as _) + } + }; + client + .write_all(format!("Hello, {}", i).as_bytes()) + .unwrap(); + client.shutdown(std::net::Shutdown::Both).unwrap(); + i += 1; + if i >= 2 { + return; + } + } + } + }); + for i in 0..2 { + use std::io::Read; + + let mut client = TcpStream::connect(addr).unwrap(); + let mut s = String::new(); + client.read_to_string(&mut s).unwrap(); + assert_eq!(s, format!("Hello, {}", i)); + } + if let Err(e) = handle.join() { + std::panic::resume_unwind(e) + } +} + #[test] #[cfg(all(target_pointer_width = "64", any(io_uring, target_os = "windows")))] fn read_len_over_u32() { diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index eea8de0b1..ede37a2a7 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -95,7 +95,7 @@ impl Socket { pub async fn accept(&self) -> io::Result<(Self, SockAddr)> { let op = Accept::new(self.to_shared_fd()); let (_, op) = buf_try!(@try compio_runtime::submit(op).await); - let (accept_sock, addr) = op.into_addr(); + let (accept_sock, addr) = op.into_inner(); let accept_sock = Self::from_socket2(accept_sock)?; Ok((accept_sock, addr)) }