From 0c6ba39174c7ada9404c825fdc5e7e7e3084382b Mon Sep 17 00:00:00 2001 From: Abhinav Date: Thu, 23 Apr 2026 17:53:46 +0530 Subject: [PATCH] feat(driver, net): support recv_send_poll_first --- compio-driver/Cargo.toml | 2 + compio-driver/src/sys/op/managed/fallback.rs | 15 ++++++ compio-driver/src/sys/op/managed/fusion.rs | 33 +++++++++++++ compio-driver/src/sys/op/managed/iour.rs | 41 ++++++++++++---- compio-driver/src/sys/op/socket/iour.rs | 28 ++++++----- compio-driver/src/sys/op/socket/mod.rs | 50 +++++++++++++++++++- compio-driver/src/sys/pal/iour/mod.rs | 12 +++++ compio-net/src/socket/mod.rs | 50 +++++++++++++------- compio-net/src/tcp.rs | 11 ----- compio-net/src/udp.rs | 11 ----- compio-net/src/unix.rs | 11 ----- 11 files changed, 191 insertions(+), 73 deletions(-) diff --git a/compio-driver/Cargo.toml b/compio-driver/Cargo.toml index 027b4981..de5dea05 100644 --- a/compio-driver/Cargo.toml +++ b/compio-driver/Cargo.toml @@ -60,6 +60,7 @@ io-uring = { version = "0.7.12", optional = true } once_cell = { workspace = true, optional = true } polling = { version = "3.3.0", optional = true } rustix = { workspace = true, features = ["linux_5_11"] } +linux-raw-sys = { version = "0.12.1", optional = true } # Other platform dependencies [target.'cfg(all(unix, not(target_os = "linux")))'.dependencies] @@ -83,6 +84,7 @@ io-uring = [ "rustix/mm", "rustix/event", "rustix/system", + "linux-raw-sys/io_uring", "dep:io-uring", "dep:once_cell", ] diff --git a/compio-driver/src/sys/op/managed/fallback.rs b/compio-driver/src/sys/op/managed/fallback.rs index e2cd5839..b690b013 100644 --- a/compio-driver/src/sys/op/managed/fallback.rs +++ b/compio-driver/src/sys/op/managed/fallback.rs @@ -69,6 +69,11 @@ impl RecvManaged { op: Recv::new(fd, pool.pop()?.with_capacity(len), flags), }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + // This method has been added here for the sake of API compatibility. + pub fn poll_first(&mut self) {} } impl TakeBuffer for RecvManaged { @@ -91,6 +96,11 @@ impl RecvFromManaged { op: RecvFrom::new(fd, pool.pop()?.with_capacity(len), flags), }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + // This method has been added here for the sake of API compatibility. + pub fn poll_first(&mut self) {} } impl TakeBuffer for RecvFromManaged { @@ -119,6 +129,11 @@ impl RecvMsgManaged { op: RecvMsg::new(fd, [pool.pop()?.with_capacity(len)], control, flags), }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + // This method has been added here for the sake of API compatibility. + pub fn poll_first(&mut self) {} } impl TakeBuffer for RecvMsgManaged { diff --git a/compio-driver/src/sys/op/managed/fusion.rs b/compio-driver/src/sys/op/managed/fusion.rs index 733c8417..c511037a 100644 --- a/compio-driver/src/sys/op/managed/fusion.rs +++ b/compio-driver/src/sys/op/managed/fusion.rs @@ -131,6 +131,39 @@ mop!( RecvMulti(fd: S, pool: &BufferPool, len: usize, flags: RecvFlags) mop!( RecvFromMulti(fd: S, pool: &BufferPool, flags: RecvFlags) with pool; RecvFromMultiResult); mop!( RecvMsgMulti(fd: S, pool: &BufferPool, control_len: usize, flags: RecvFlags) with pool; RecvMsgMultiResult); +impl RecvManaged { + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + match self.inner { + RecvManagedInner::Poll(ref mut i) => i.poll_first(), + RecvManagedInner::IoUring(ref mut i) => i.poll_first(), + } + } +} + +impl RecvFromManaged { + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + match self.inner { + RecvFromManagedInner::Poll(ref mut i) => i.poll_first(), + RecvFromManagedInner::IoUring(ref mut i) => i.poll_first(), + } + } +} + +impl RecvMsgManaged { + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + match self.inner { + RecvMsgManagedInner::Poll(ref mut i) => i.poll_first(), + RecvMsgManagedInner::IoUring(ref mut i) => i.poll_first(), + } + } +} + enum RecvFromMultiResultInner { Poll(fallback::RecvFromMultiResult), IoUring(iour::RecvFromMultiResult), diff --git a/compio-driver/src/sys/op/managed/iour.rs b/compio-driver/src/sys/op/managed/iour.rs index ca679162..255c4812 100644 --- a/compio-driver/src/sys/op/managed/iour.rs +++ b/compio-driver/src/sys/op/managed/iour.rs @@ -12,8 +12,9 @@ use rustix::net::RecvFlags; use socket2::{SockAddr, SockAddrStorage, socklen_t}; use crate::{ - BufferPool, BufferRef, Extra, IourOpCode as OpCode, OpEntry, op::TakeBuffer, - sys::pal::is_kernel_at_least, + BufferPool, BufferRef, Extra, IourOpCode as OpCode, OpEntry, + op::TakeBuffer, + sys::pal::{is_kernel_at_least, set_poll_first}, }; /// Read a file at specified position into specified buffer. @@ -143,6 +144,7 @@ pub struct RecvManaged { buffer_group: u16, buffer_pool: BufferPool, buffer: Option, + poll_first: bool, } impl RecvManaged { @@ -157,8 +159,15 @@ impl RecvManaged { flags, buffer_pool: buffer_pool.clone(), buffer: None, + poll_first: false, }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.poll_first = true; + } } unsafe impl OpCode for RecvManaged { @@ -166,12 +175,13 @@ unsafe impl OpCode for RecvManaged { fn create_entry(&mut self, _: &mut Self::Control) -> OpEntry { let fd = self.fd.as_fd().as_raw_fd(); - opcode::Recv::new(Fd(fd), ptr::null_mut(), self.len) + let entry = opcode::Recv::new(Fd(fd), ptr::null_mut(), self.len) .flags(self.flags.bits() as _) .buf_group(self.buffer_group) .build() - .flags(Flags::BUFFER_SELECT) - .into() + .flags(Flags::BUFFER_SELECT); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } unsafe fn set_result(&mut self, _: &mut Self::Control, _: &io::Result, extra: &Extra) { @@ -205,6 +215,7 @@ pub struct RecvFromManaged { buffer_group: u16, buffer_pool: BufferPool, buffer: Option, + poll_first: bool, } #[doc(hidden)] @@ -236,8 +247,15 @@ impl RecvFromManaged { addr, buffer_pool: buffer_pool.clone(), buffer: None, + poll_first: false, }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.poll_first = true; + } } impl TakeBuffer for RecvFromManaged { @@ -262,12 +280,13 @@ unsafe impl OpCode for RecvFromManaged { } fn create_entry(&mut self, control: &mut Self::Control) -> OpEntry { - opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &raw mut control.msg) + let entry = opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &raw mut control.msg) .flags(self.flags.bits() as _) .buf_group(self.buffer_group) .build() - .flags(Flags::BUFFER_SELECT) - .into() + .flags(Flags::BUFFER_SELECT); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } unsafe fn set_result( @@ -311,6 +330,12 @@ impl RecvMsgManaged { control_len: 0, }) } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.op.poll_first(); + } } unsafe impl OpCode for RecvMsgManaged { diff --git a/compio-driver/src/sys/op/socket/iour.rs b/compio-driver/src/sys/op/socket/iour.rs index ccadff44..a2d3be86 100644 --- a/compio-driver/src/sys/op/socket/iour.rs +++ b/compio-driver/src/sys/op/socket/iour.rs @@ -238,14 +238,15 @@ unsafe impl OpCode for Recv { let fd = self.fd.as_fd().as_raw_fd(); let slice = self.buffer.sys_slice_mut(); - opcode::Recv::new( + let entry = opcode::Recv::new( Fd(fd), slice.ptr() as _, slice.len().try_into().unwrap_or(u32::MAX), ) .flags(self.flags.bits() as _) - .build() - .into() + .build(); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } fn call_blocking(&mut self, _: &mut Self::Control) -> io::Result { @@ -261,10 +262,11 @@ unsafe impl OpCode for RecvVectored { } fn create_entry(&mut self, control: &mut Self::Control) -> OpEntry { - opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &mut control.msg) + let entry = opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &mut control.msg) .flags(self.flags.bits() as _) - .build() - .into() + .build(); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } fn call_blocking(&mut self, control: &mut Self::Control) -> io::Result { @@ -286,10 +288,11 @@ impl RecvFromHeader { } pub fn create_entry(&mut self, control: &mut RecvMsgControl) -> OpEntry { - opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &mut control.msg) + let entry = opcode::RecvMsg::new(Fd(self.fd.as_fd().as_raw_fd()), &mut control.msg) .flags(self.flags.bits() as _) - .build() - .into() + .build(); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } pub fn set_result(&mut self, control: &mut RecvMsgControl) { @@ -357,10 +360,11 @@ unsafe impl OpCode for RecvMsg OpEntry { - opcode::RecvMsg::new(Fd(self.header.fd.as_fd().as_raw_fd()), &mut control.msg) + let entry = opcode::RecvMsg::new(Fd(self.header.fd.as_fd().as_raw_fd()), &mut control.msg) .flags(self.header.flags.bits() as _) - .build() - .into() + .build(); + let entry = set_poll_first(entry, self.poll_first); + entry.into() } unsafe fn set_result( diff --git a/compio-driver/src/sys/op/socket/mod.rs b/compio-driver/src/sys/op/socket/mod.rs index f1e08a1b..cc0ca5b6 100644 --- a/compio-driver/src/sys/op/socket/mod.rs +++ b/compio-driver/src/sys/op/socket/mod.rs @@ -83,6 +83,7 @@ pub struct Recv { pub(crate) fd: S, pub(crate) buffer: T, pub(crate) flags: RecvFlags, + poll_first: bool, } /// Receive data from remote into vectored buffer. @@ -90,6 +91,7 @@ pub struct RecvVectored { pub(crate) fd: S, pub(crate) buffer: T, pub(crate) flags: RecvFlags, + poll_first: bool, } pub(crate) struct RecvFromHeader { @@ -97,6 +99,7 @@ pub(crate) struct RecvFromHeader { pub(crate) flags: RecvFlags, pub(crate) addr: SockAddrStorage, pub(crate) addr_len: socklen_t, + poll_first: bool, } /// Receive data and source address. @@ -118,6 +121,7 @@ pub struct RecvMsg { pub(crate) buffer: T, pub(crate) control: C, pub(crate) control_len: usize, + poll_first: bool, } impl Connect { @@ -254,8 +258,15 @@ impl RecvMsg { buffer, control, control_len: 0, + poll_first: false, } } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.poll_first = true; + } } impl IntoInner for RecvMsg { @@ -273,7 +284,18 @@ impl IntoInner for RecvMsg { impl Recv { /// Create [`Recv`]. pub fn new(fd: S, buffer: T, flags: RecvFlags) -> Self { - Self { fd, buffer, flags } + Self { + fd, + buffer, + flags, + poll_first: false, + } + } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.poll_first = true; } } @@ -288,7 +310,18 @@ impl IntoInner for Recv { impl RecvVectored { /// Create [`RecvVectored`]. pub fn new(fd: S, buffer: T, flags: RecvFlags) -> Self { - Self { fd, buffer, flags } + Self { + fd, + buffer, + flags, + poll_first: false, + } + } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.poll_first = true; } } @@ -309,6 +342,7 @@ impl RecvFromHeader { addr, flags, addr_len: name_len, + poll_first: false, } } @@ -325,6 +359,12 @@ impl RecvFromVectored { buffer, } } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.header.poll_first = true; + } } impl IntoInner for RecvFromVectored { @@ -344,6 +384,12 @@ impl RecvFrom { buffer, } } + + /// This method sets the `IORING_RECVSEND_POLL_FIRST` flag in the `ioprio` + /// of the SQE on the IO_URING driver. + pub fn poll_first(&mut self) { + self.header.poll_first = true; + } } impl IntoInner for RecvFrom { diff --git a/compio-driver/src/sys/pal/iour/mod.rs b/compio-driver/src/sys/pal/iour/mod.rs index 87b4327f..967ea7e8 100644 --- a/compio-driver/src/sys/pal/iour/mod.rs +++ b/compio-driver/src/sys/pal/iour/mod.rs @@ -1,6 +1,8 @@ #[cfg(feature = "once_cell_try")] use std::sync::OnceLock; +use io_uring::squeue::Entry; +use linux_raw_sys::io_uring::{IORING_RECVSEND_POLL_FIRST, io_uring_sqe}; #[cfg(not(feature = "once_cell_try"))] use once_cell::sync::OnceCell as OnceLock; @@ -53,3 +55,13 @@ pub fn is_kernel_at_least(v: impl Into) -> bool { .map(|kv| kv >= v.into()) .unwrap_or_default() } + +pub(crate) fn set_poll_first(mut entry: Entry, flag: bool) -> Entry { + if flag && is_kernel_at_least((5, 19)) { + let sqe = &raw mut entry as *mut io_uring_sqe; + unsafe { + (*sqe).ioprio |= IORING_RECVSEND_POLL_FIRST as u16; + } + } + entry +} diff --git a/compio-net/src/socket/mod.rs b/compio-net/src/socket/mod.rs index 5f65d02e..53f29a78 100644 --- a/compio-net/src/socket/mod.rs +++ b/compio-net/src/socket/mod.rs @@ -243,19 +243,12 @@ impl Socket { } } - /// This method signifies whether the socket was non-empty after the last - /// receive operation. - /// - /// # Behavior - /// - /// It returns `Some(..)` only on the IO_URING driver and `None` on others. - pub fn sock_nonempty(&self) -> Option { - self.state.get() - } - pub async fn recv(&self, buffer: B, flags: RecvFlags) -> BufResult { let fd = self.to_shared_fd(); - let op = Recv::new(fd, buffer, flags); + let mut op = Recv::new(fd, buffer, flags); + if self.state.get() == Some(false) { + op.poll_first(); + } let (res, extra) = compio_runtime::submit(op).with_extra().await; self.state.set(&extra); let res = res.into_inner(); @@ -268,7 +261,10 @@ impl Socket { flags: RecvFlags, ) -> BufResult { let fd = self.to_shared_fd(); - let op = RecvVectored::new(fd, buffer, flags); + let mut op = RecvVectored::new(fd, buffer, flags); + if self.state.get() == Some(false) { + op.poll_first(); + } let (res, extra) = compio_runtime::submit(op).with_extra().await; self.state.set(&extra); let res = res.into_inner(); @@ -283,7 +279,10 @@ impl Socket { let fd = self.to_shared_fd(); let (res, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; - let op = RecvManaged::new(fd, &buffer_pool, len, flags)?; + let mut op = RecvManaged::new(fd, &buffer_pool, len, flags)?; + if self.state.get() == Some(false) { + op.poll_first(); + } io::Result::Ok(rt.submit(op).with_extra()) })? .await; @@ -346,7 +345,10 @@ impl Socket { flags: RecvFlags, ) -> BufResult<(usize, Option), T> { let fd = self.to_shared_fd(); - let op = RecvFrom::new(fd, buffer, flags); + let mut op = RecvFrom::new(fd, buffer, flags); + if self.state.get() == Some(false) { + op.poll_first(); + } let (res, extra) = compio_runtime::submit(op).with_extra().await; self.state.set(&extra); let res = res.into_inner().map_addr(); @@ -359,7 +361,10 @@ impl Socket { flags: RecvFlags, ) -> BufResult<(usize, Option), T> { let fd = self.to_shared_fd(); - let op = RecvFromVectored::new(fd, buffer, flags); + let mut op = RecvFromVectored::new(fd, buffer, flags); + if self.state.get() == Some(false) { + op.poll_first(); + } let (res, extra) = compio_runtime::submit(op).with_extra().await; self.state.set(&extra); let res = res.into_inner().map_addr(); @@ -374,7 +379,10 @@ impl Socket { let fd = self.to_shared_fd(); let (inner, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; - let op = RecvFromManaged::new(fd, &buffer_pool, len, flags)?; + let mut op = RecvFromManaged::new(fd, &buffer_pool, len, flags)?; + if self.state.get() == Some(false) { + op.poll_first(); + } io::Result::Ok(rt.submit(op).with_extra()) })? .await; @@ -426,7 +434,10 @@ impl Socket { flags: RecvFlags, ) -> BufResult<(usize, usize, Option), (T, C)> { let fd = self.to_shared_fd(); - let op = RecvMsg::new(fd, buffer, control, flags); + let mut op = RecvMsg::new(fd, buffer, control, flags); + if self.state.get() == Some(false) { + op.poll_first(); + } let (res, extra) = compio_runtime::submit(op).with_extra().await; self.state.set(&extra); let res = res.into_inner().map_addr(); @@ -442,7 +453,10 @@ impl Socket { let fd = self.to_shared_fd(); let (inner, extra) = Runtime::with_current(|rt| { let buffer_pool = rt.buffer_pool()?; - let op = RecvMsgManaged::new(fd, &buffer_pool, len, control, flags)?; + let mut op = RecvMsgManaged::new(fd, &buffer_pool, len, control, flags)?; + if self.state.get() == Some(false) { + op.poll_first(); + } io::Result::Ok(rt.submit(op).with_extra()) })? .await; diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index 2fdf8a3c..ab86b4a1 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -440,17 +440,6 @@ impl TcpStream { ) .await } - - /// Signifies whether the underlying socket was non-empty after the last - /// receive operation. - /// - /// # Behaviour - /// - /// Returns `Some(..)` only on the IO_URING driver and `None` on other - /// drivers. - pub fn sock_nonempty(&self) -> Option { - self.inner.sock_nonempty() - } } impl AsyncRead for TcpStream { diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index ccac3dea..d4870a84 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -196,17 +196,6 @@ impl UdpSocket { .map(|addr| addr.as_socket().expect("should be SocketAddr")) } - /// Signifies whether the underlying socket was non-empty after the last - /// receive operation. - /// - /// # Behaviour - /// - /// Returns `Some(..)` only on the IO-URING driver and `None` on other - /// drivers. - pub fn sock_nonempty(&self) -> Option { - self.inner.sock_nonempty() - } - /// Receives a packet of data from the socket into the buffer, returning the /// original buffer and quantity of data received. pub async fn recv(&self, buffer: T) -> BufResult { diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index 95f3b8bc..28cd2102 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -302,17 +302,6 @@ impl UnixStream { self.inner.disconnect().await?; Ok(UnixSocket { inner: self.inner }) } - - /// Signifies whether the underlying socket was non-empty after the last - /// receive operation. - /// - /// # Behaviour - /// - /// Returns `Some(..)` only on the IO-URING driver and `None` on other - /// drivers. - pub fn sock_nonempty(&self) -> Option { - self.inner.sock_nonempty() - } } impl AsyncRead for UnixStream {