diff --git a/Cargo.toml b/Cargo.toml index 7947e34e7..1e01fd981 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ compio-quic = { path = "./compio-quic", version = "0.7.0", default-features = fa compio-ws = { path = "./compio-ws", version = "0.3.0", default-features = false } bytes = "1.7.1" +bytemuck = "1.25.0" cfg_aliases = "0.2.1" cfg-if = "1.0.0" compio-send-wrapper = "0.7.0" diff --git a/compio-io/Cargo.toml b/compio-io/Cargo.toml index c607851a5..019e3cf0a 100644 --- a/compio-io/Cargo.toml +++ b/compio-io/Cargo.toml @@ -20,6 +20,7 @@ cfg-if = { workspace = true, optional = true } thiserror = { workspace = true, optional = true } serde = { version = "1.0.219", optional = true } serde_json = { version = "1.0.140", optional = true } +bytemuck = { workspace = true, optional = true, features = ["min_const_generics"] } [target.'cfg(unix)'.dependencies] libc = { workspace = true, optional = true } @@ -30,7 +31,6 @@ windows-sys = { workspace = true, optional = true, features = [ ] } [dev-dependencies] -aligned-array = "1.0.1" tokio = { workspace = true, features = ["macros", "rt"] } serde = { version = "1.0.219", features = ["derive"] } futures-executor = "0.3.30" @@ -58,3 +58,7 @@ required-features = ["compat"] [[test]] name = "framed" required-features = ["codec-serde-json"] + +[[test]] +name = "ancillary" +required-features = ["ancillary"] diff --git a/compio-io/src/ancillary/bytemuck_ext.rs b/compio-io/src/ancillary/bytemuck_ext.rs new file mode 100644 index 000000000..82b0c095f --- /dev/null +++ b/compio-io/src/ancillary/bytemuck_ext.rs @@ -0,0 +1,74 @@ +//! Extension module for automatic [`AncillaryData`] implementation via +//! bytemuck. +//! +//! See [`BitwiseAncillaryData`] for details. + +use std::mem::MaybeUninit; + +pub use bytemuck::{Pod, Zeroable}; + +use super::{AncillaryData, CodecError, copy_from_bytes, copy_to_bytes}; + +/// Marker trait to enable automatic `AncillaryData` implementation via +/// bytemuck. +/// +/// Types that implement this trait (which requires [`bytemuck::Pod`]) will +/// automatically implement [`AncillaryData`] using a simple byte-wise +/// encoding/decoding. +/// +/// # Example +/// +/// ``` +/// use compio_io::ancillary::bytemuck_ext; +/// +/// #[derive(Clone, Copy)] +/// #[repr(C)] +/// struct MyType { +/// value: u32, +/// } +/// +/// unsafe impl bytemuck_ext::Zeroable for MyType {} +/// unsafe impl bytemuck_ext::Pod for MyType {} +/// impl bytemuck_ext::BitwiseAncillaryData for MyType {} +/// +/// // Now MyType automatically implements AncillaryData +/// ``` +pub trait BitwiseAncillaryData: Pod {} + +impl AncillaryData for T { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + unsafe { copy_to_bytes(self, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + unsafe { copy_from_bytes(buffer) } + } +} + +macro_rules! impl_bytemuck_marker { + ($($t:ty),* $(,)?) => { + $( + impl BitwiseAncillaryData for $t {} + )* + }; +} + +impl_bytemuck_marker!( + (), + u8, + u16, + u32, + u64, + u128, + usize, + i8, + i16, + i32, + i64, + i128, + isize, + f32, + f64, +); + +impl BitwiseAncillaryData for [T; N] {} diff --git a/compio-io/src/ancillary/mod.rs b/compio-io/src/ancillary/mod.rs index 95a281158..f4b9ce508 100644 --- a/compio-io/src/ancillary/mod.rs +++ b/compio-io/src/ancillary/mod.rs @@ -5,17 +5,62 @@ //! //! # Types //! -//! - [`AncillaryRef`]: A reference to a single ancillary data entry. -//! - [`AncillaryIter`]: An iterator over a buffer of ancillary messages. -//! - [`AncillaryBuilder`]: A builder for constructing ancillary messages into a -//! caller-supplied send buffer. //! - [`AncillaryBuf`]: A fixed-size, properly aligned stack buffer for -//! ancillary data +//! ancillary messages. +//! - [`AncillaryBuilder`]: A builder for constructing ancillary messages into a +//! [`AncillaryBuf`]. +//! - [`AncillaryIter`]: An iterator over a buffer of ancillary messages. +//! - [`AncillaryRef`]: A reference to a single ancillary data entry. +//! - [`AncillaryData`]: Trait for types that can be encoded/decoded as +//! ancillary data payloads. +//! - [`CodecError`]: Error type for encoding/decoding operations. +//! +//! # Functions +//! +//! - [`ancillary_space`]: Helper function to calculate ancillary message size +//! for a type. +//! +//! # Modules +//! +//! - [`bytemuck_ext`]: Extension module for automatic [`AncillaryData`] +//! implementation via bytemuck (requires `bytemuck` feature). +//! +//! # Example +//! +//! ``` +//! use compio_io::ancillary::{AncillaryBuf, AncillaryIter, CodecError, ancillary_space}; +//! +//! const LEVEL: i32 = 1; +//! const TYPE: i32 = 2; +//! +//! // Build a buffer containing two `u32` ancillary messages. +//! let mut buf = AncillaryBuf::<{ ancillary_space::() * 2 }>::new(); +//! let mut builder = buf.builder(); +//! builder.push(LEVEL, TYPE, &42u32).unwrap(); +//! builder.push(LEVEL, TYPE, &43u32).unwrap(); +//! // Buffer is full, cannot add more messages. +//! assert!(matches!( +//! builder.push(LEVEL, TYPE, &44u32), +//! Err(CodecError::BufferTooSmall) +//! )); +//! +//! // Read back the messages. +//! unsafe { +//! let mut iter = AncillaryIter::new(&buf); +//! let msg = iter.next().unwrap(); +//! assert_eq!(msg.level(), LEVEL); +//! assert_eq!(msg.ty(), TYPE); +//! assert_eq!(msg.data::().unwrap(), 42u32); +//! assert_eq!(iter.next().unwrap().data::().unwrap(), 43u32); +//! assert!(iter.next().is_none()); +//! } +//! ``` use std::{ marker::PhantomData, mem::MaybeUninit, ops::{Deref, DerefMut}, + ptr, }; use compio_buf::{IoBuf, IoBufMut, SetLen}; @@ -31,6 +76,8 @@ cfg_if::cfg_if! { mod sys; } } +#[cfg(feature = "bytemuck")] +pub mod bytemuck_ext; /// Reference to an ancillary (control) message. pub struct AncillaryRef<'a>(sys::CMsgRef<'a>); @@ -52,14 +99,9 @@ impl AncillaryRef<'_> { self.0.len() as _ } - /// Returns a reference to the data of the control message. - /// - /// # Safety - /// - /// The data part must be properly aligned and contains an initialized - /// instance of `T`. - pub unsafe fn data(&self) -> &T { - unsafe { self.0.data() } + /// Returns a copy of the data in the control message. + pub fn data(&self) -> Result { + self.0.decode_data() } } @@ -101,54 +143,41 @@ impl<'a> Iterator for AncillaryIter<'a> { } /// Helper to construct ancillary (control) messages. -pub struct AncillaryBuilder<'a> { +pub struct AncillaryBuilder<'a, const N: usize> { inner: sys::CMsgIter, - len: usize, - _p: PhantomData<&'a mut ()>, + buffer: &'a mut AncillaryBuf, } -impl<'a> AncillaryBuilder<'a> { - /// Create [`AncillaryBuilder`] with the given buffer. The buffer will be - /// zeroed on creation. - /// - /// # Panics - /// - /// This function will panic if the buffer is too short or not properly - /// aligned. - pub fn new(buffer: &'a mut [MaybeUninit]) -> Self { +impl<'a, const N: usize> AncillaryBuilder<'a, N> { + fn new(buffer: &'a mut AncillaryBuf) -> Self { // TODO: optimize zeroing - buffer.fill(MaybeUninit::new(0)); - Self { - inner: sys::CMsgIter::new(buffer.as_ptr().cast(), buffer.len()), - len: 0, - _p: PhantomData, - } - } - - /// Finishes building, returns length of the control message. - pub fn finish(self) -> usize { - self.len + buffer.as_uninit().fill(MaybeUninit::new(0)); + buffer.len = 0; + let inner = sys::CMsgIter::new(buffer.as_ptr(), buffer.buf_capacity()); + Self { inner, buffer } } - /// Try to append a control message entry into the buffer. If the buffer - /// does not have enough space or is not properly aligned with the value - /// type, returns `None`. - pub fn try_push(&mut self, level: i32, ty: i32, value: T) -> Option<()> { - if !self.inner.is_aligned::() || !self.inner.is_space_enough::() { - return None; + /// Append a control message into the buffer. + pub fn push( + &mut self, + level: i32, + ty: i32, + value: &T, + ) -> Result<(), CodecError> { + if !self.inner.is_space_enough(T::SIZE) { + return Err(CodecError::BufferTooSmall); } - // SAFETY: the buffer is zeroed and the pointer is valid and aligned - unsafe { - let mut cmsg = self.inner.current_mut()?; - cmsg.set_level(level); - cmsg.set_ty(ty); - self.len += cmsg.set_data(value); + // SAFETY: AncillaryBuf guarantees the buffer is zeroed and properly aligned, + // and we have checked the space. + let mut cmsg = unsafe { self.inner.current_mut() }.expect("sufficient space"); + cmsg.set_level(level); + cmsg.set_ty(ty); + self.buffer.len += cmsg.encode_data(value)?; - self.inner.next(); - } + unsafe { self.inner.next() }; - Some(()) + Ok(()) } } @@ -175,6 +204,16 @@ impl AncillaryBuf { _align: [], } } + + /// Create [`AncillaryBuilder`] with this buffer. The buffer will be zeroed + /// on creation. + /// + /// # Panics + /// + /// This function will panic if this buffer is too short. + pub fn builder(&mut self) -> AncillaryBuilder<'_, N> { + AncillaryBuilder::new(self) + } } impl Default for AncillaryBuf { @@ -215,3 +254,290 @@ impl DerefMut for AncillaryBuf { &mut self.inner[0..self.len] } } + +// Deprecated compio_net::CMsgRef +#[doc(hidden)] +pub struct CMsgRef<'a>(sys::CMsgRef<'a>); + +impl CMsgRef<'_> { + pub fn level(&self) -> i32 { + self.0.level() + } + + pub fn ty(&self) -> i32 { + self.0.ty() + } + + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.0.len() as _ + } + + /// Returns a reference to the data of the control message. + /// + /// # Safety + /// + /// The data part must be properly aligned and contains an initialized + /// instance of `T`. + pub unsafe fn data(&self) -> &T { + unsafe { self.0.data() } + } +} + +// Deprecated compio_net::CMsgIter +#[doc(hidden)] +pub struct CMsgIter<'a> { + inner: sys::CMsgIter, + _p: PhantomData<&'a ()>, +} + +impl<'a> CMsgIter<'a> { + /// Create [`CMsgIter`] with the given buffer. + /// + /// # Panics + /// + /// This function will panic if the buffer is too short or not properly + /// aligned. + /// + /// # Safety + /// + /// The buffer should contain valid control messages. + pub unsafe fn new(buffer: &'a [u8]) -> Self { + Self { + inner: sys::CMsgIter::new(buffer.as_ptr(), buffer.len()), + _p: PhantomData, + } + } +} + +impl<'a> Iterator for CMsgIter<'a> { + type Item = CMsgRef<'a>; + + fn next(&mut self) -> Option { + unsafe { + let cmsg = self.inner.current(); + self.inner.next(); + cmsg.map(CMsgRef) + } + } +} + +// Deprecated compio_net::CMsgBuilder +#[doc(hidden)] +pub struct CMsgBuilder<'a> { + inner: sys::CMsgIter, + len: usize, + _p: PhantomData<&'a mut ()>, +} + +impl<'a> CMsgBuilder<'a> { + pub fn new(buffer: &'a mut [MaybeUninit]) -> Self { + buffer.fill(MaybeUninit::new(0)); + Self { + inner: sys::CMsgIter::new(buffer.as_ptr().cast(), buffer.len()), + len: 0, + _p: PhantomData, + } + } + + pub fn finish(self) -> usize { + self.len + } + + pub fn try_push(&mut self, level: i32, ty: i32, value: T) -> Option<()> { + if !self.inner.is_aligned::() || !self.inner.is_space_enough(std::mem::size_of::()) { + return None; + } + + // SAFETY: the buffer is zeroed and the pointer is valid and aligned + unsafe { + let mut cmsg = self.inner.current_mut()?; + cmsg.set_level(level); + cmsg.set_ty(ty); + self.len += cmsg.set_data(value); + + self.inner.next(); + } + + Some(()) + } +} + +/// Returns the buffer size required to hold one ancillary message carrying a +/// value of type `T`. +/// +/// This is the platform-appropriate equivalent of `CMSG_SPACE(T::SIZE)` on +/// Unix or `WSA_CMSG_SPACE(T::SIZE)` on Windows, and can be used as a const +/// generic argument for [`AncillaryBuf`]. +pub const fn ancillary_space() -> usize { + #[cfg(unix)] + // SAFETY: CMSG_SPACE is always safe + unsafe { + libc::CMSG_SPACE(T::SIZE as libc::c_uint) as usize + } + + #[cfg(windows)] + sys::wsa_cmsg_space(T::SIZE) +} + +/// Error that can occur when encoding or decoding ancillary data. +#[derive(Debug)] +pub enum CodecError { + /// The provided buffer is too small to hold the encoded data. + BufferTooSmall, + /// Another error occurred during encoding or decoding. + Other(Box), +} + +impl CodecError { + /// Create a new [`CodecError::Other`] from any error type. + pub fn other(error: impl Into>) -> Self { + Self::Other(error.into()) + } + + /// Attempt to downcast the error to a concrete type. + /// + /// Returns `Some(&T)` if the error is of type `T`, otherwise `None`. + pub fn downcast_ref(&self) -> Option<&T> { + match self { + Self::Other(e) => e.downcast_ref(), + _ => None, + } + } + + /// Attempt to downcast the error to a concrete type. + /// + /// Returns `Some(&mut T)` if the error is of type `T`, otherwise `None`. + pub fn downcast_mut(&mut self) -> Option<&mut T> { + match self { + Self::Other(e) => e.downcast_mut(), + _ => None, + } + } +} + +impl std::fmt::Display for CodecError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BufferTooSmall => write!(f, "buffer too small for encoding/decoding"), + Self::Other(e) => write!(f, "codec error: {}", e), + } + } +} + +impl std::error::Error for CodecError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Other(e) => Some(e.as_ref()), + _ => None, + } + } +} + +/// Trait for types that can be encoded and decoded as ancillary data payloads. +/// +/// This trait enables a type to be used as the data payload in control messages +/// (ancillary data). Types implementing this trait can be passed to +/// [`AncillaryBuilder::push`] and retrieved via [`AncillaryRef::data`]. +/// +/// # Built-in Implementations +/// +/// This trait is implemented for the following platform-specific types: +/// +/// - Unix: `libc::in_addr`, `libc::in_pktinfo`, `libc::in6_pktinfo` +/// - Windows: `IN_PKTINFO`, `IN6_PKTINFO` +/// +/// When the `bytemuck` feature is enabled, this trait is also automatically +/// implemented for types that implement [`bytemuck_ext::BitwiseAncillaryData`]: +/// +/// - Primitive types: `()`, `u8`, `u16`, `u32`, `u64`, `u128`, `usize`, `i8`, +/// `i16`, `i32`, `i64`, `i128`, `isize`, `f32`, `f64` +/// - Fixed-size arrays of the above types (up to size 512) +/// +/// For custom types with the `bytemuck` feature enabled, you can implement +/// [`bytemuck_ext::BitwiseAncillaryData`] to automatically get +/// [`AncillaryData`] (see [`bytemuck_ext`] for details). Otherwise, you must +/// manually implement this trait with custom encoding/decoding logic. +/// +/// # Example +/// +/// ``` +/// use std::mem::MaybeUninit; +/// +/// use compio_io::ancillary::{AncillaryData, CodecError}; +/// +/// struct MyData { +/// value: u32, +/// } +/// +/// impl AncillaryData for MyData { +/// const SIZE: usize = std::mem::size_of::(); +/// +/// fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { +/// if buffer.len() < Self::SIZE { +/// return Err(CodecError::BufferTooSmall); +/// } +/// let bytes = self.value.to_ne_bytes(); +/// for (i, &byte) in bytes.iter().enumerate() { +/// buffer[i] = MaybeUninit::new(byte); +/// } +/// Ok(()) +/// } +/// +/// fn decode(buffer: &[u8]) -> Result { +/// if buffer.len() < Self::SIZE { +/// return Err(CodecError::BufferTooSmall); +/// } +/// let mut bytes = [0u8; 4]; +/// bytes.copy_from_slice(&buffer[..4]); +/// Ok(MyData { +/// value: u32::from_ne_bytes(bytes), +/// }) +/// } +/// } +/// ``` +pub trait AncillaryData: Sized { + /// The size in bytes of the encoded representation. + /// + /// This defaults to `std::mem::size_of::()` but can be overridden + /// for types with custom encoding. + const SIZE: usize = std::mem::size_of::(); + + /// Encode this value into the provided buffer. + /// + /// # Errors + /// + /// Returns [`CodecError::BufferTooSmall`] if the buffer is too small to + /// hold the encoded data, or [`CodecError::Other`] for other encoding + /// errors. + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError>; + + /// Decode a value from the provided buffer. + /// + /// # Errors + /// + /// Returns [`CodecError::BufferTooSmall`] if the buffer is too small, + /// or [`CodecError::Other`] for other decoding errors. + fn decode(buffer: &[u8]) -> Result; +} + +unsafe fn copy_to_bytes( + src: &T, + dest: &mut [MaybeUninit], +) -> Result<(), CodecError> { + if dest.len() < T::SIZE { + return Err(CodecError::BufferTooSmall); + } + unsafe { + ptr::copy_nonoverlapping::(src as *const T as _, dest.as_mut_ptr() as _, T::SIZE); + } + Ok(()) +} + +unsafe fn copy_from_bytes(src: &[u8]) -> Result { + if src.len() < T::SIZE { + return Err(CodecError::BufferTooSmall); + } + let src_ptr = src.as_ptr() as *const T; + unsafe { Ok(ptr::read_unaligned(src_ptr)) } +} diff --git a/compio-io/src/ancillary/unix.rs b/compio-io/src/ancillary/unix.rs index 4cf11502c..0563b494d 100644 --- a/compio-io/src/ancillary/unix.rs +++ b/compio-io/src/ancillary/unix.rs @@ -1,5 +1,9 @@ +use std::{mem::MaybeUninit, slice}; + use libc::{CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE, c_int, cmsghdr, msghdr}; +use super::{AncillaryData, CodecError, copy_from_bytes, copy_to_bytes}; + pub(crate) struct CMsgRef<'a>(&'a cmsghdr); impl CMsgRef<'_> { @@ -15,6 +19,12 @@ impl CMsgRef<'_> { self.0.cmsg_len as _ } + pub(crate) fn decode_data(&self) -> Result { + let data_ptr = unsafe { CMSG_DATA(self.0) } as *const u8; + let buffer = unsafe { slice::from_raw_parts(data_ptr, self.len()) }; + T::decode(buffer) + } + pub(crate) unsafe fn data(&self) -> &T { unsafe { let data_ptr = CMSG_DATA(self.0); @@ -34,6 +44,14 @@ impl CMsgMut<'_> { self.0.cmsg_type = ty; } + pub(crate) fn encode_data(&mut self, value: &T) -> Result { + self.0.cmsg_len = unsafe { CMSG_LEN(T::SIZE as _) } as _; + let data_ptr = unsafe { CMSG_DATA(self.0) } as *mut MaybeUninit; + let buffer = unsafe { slice::from_raw_parts_mut(data_ptr, T::SIZE) }; + value.encode(buffer)?; + Ok(unsafe { CMSG_SPACE(T::SIZE as _) } as _) + } + pub(crate) unsafe fn set_data(&mut self, data: T) -> usize { unsafe { self.0.cmsg_len = CMSG_LEN(std::mem::size_of::() as _) as _; @@ -83,9 +101,9 @@ impl CMsgIter { self.msg.msg_control.cast::().is_aligned() } - pub(crate) fn is_space_enough(&self) -> bool { + pub(crate) fn is_space_enough(&self, space: usize) -> bool { if !self.cmsg.is_null() { - let space = unsafe { CMSG_SPACE(std::mem::size_of::() as _) as usize }; + let space = unsafe { CMSG_SPACE(space as _) as usize }; #[allow(clippy::unnecessary_cast)] let max = self.msg.msg_control as usize + self.msg.msg_controllen as usize; self.cmsg as usize + space <= max @@ -94,3 +112,56 @@ impl CMsgIter { } } } + +impl AncillaryData for libc::in_addr { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + unsafe { copy_to_bytes(self, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + unsafe { copy_from_bytes(buffer) } + } +} + +#[cfg(any(target_os = "linux", target_os = "android"))] +impl AncillaryData for libc::in_pktinfo { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() }; + pktinfo.ipi_ifindex = self.ipi_ifindex; + pktinfo.ipi_spec_dst.s_addr = self.ipi_spec_dst.s_addr; + pktinfo.ipi_addr.s_addr = self.ipi_addr.s_addr; + unsafe { copy_to_bytes(&pktinfo, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + let pktinfo: libc::in_pktinfo = unsafe { copy_from_bytes(buffer) }?; + Ok(libc::in_pktinfo { + ipi_ifindex: pktinfo.ipi_ifindex, + ipi_spec_dst: libc::in_addr { + s_addr: pktinfo.ipi_spec_dst.s_addr, + }, + ipi_addr: libc::in_addr { + s_addr: pktinfo.ipi_addr.s_addr, + }, + }) + } +} + +impl AncillaryData for libc::in6_pktinfo { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() }; + pktinfo.ipi6_ifindex = self.ipi6_ifindex; + pktinfo.ipi6_addr.s6_addr = self.ipi6_addr.s6_addr; + unsafe { copy_to_bytes(&pktinfo, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + let pktinfo: libc::in6_pktinfo = unsafe { copy_from_bytes(buffer) }?; + Ok(libc::in6_pktinfo { + ipi6_ifindex: pktinfo.ipi6_ifindex, + ipi6_addr: libc::in6_addr { + s6_addr: pktinfo.ipi6_addr.s6_addr, + }, + }) + } +} diff --git a/compio-io/src/ancillary/windows.rs b/compio-io/src/ancillary/windows.rs index 26120a485..81dc0a444 100644 --- a/compio-io/src/ancillary/windows.rs +++ b/compio-io/src/ancillary/windows.rs @@ -1,9 +1,14 @@ use std::{ - mem::{align_of, size_of}, + mem::{MaybeUninit, align_of, size_of}, ptr::null_mut, + slice, }; -use windows_sys::Win32::Networking::WinSock::{CMSGHDR, WSABUF, WSAMSG}; +use windows_sys::Win32::Networking::WinSock::{ + self, CMSGHDR, IN_PKTINFO, IN6_PKTINFO, WSABUF, WSAMSG, +}; + +use super::{AncillaryData, CodecError, copy_from_bytes, copy_to_bytes}; // Macros from https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h #[inline] @@ -50,7 +55,7 @@ unsafe fn wsa_cmsg_data(cmsg: *const CMSGHDR) -> *mut u8 { } #[inline] -const fn wsa_cmsg_space(length: usize) -> usize { +pub(crate) const fn wsa_cmsg_space(length: usize) -> usize { WSA_CMSGDATA_OFFSET + wsa_cmsghdr_align(length) } @@ -74,6 +79,12 @@ impl CMsgRef<'_> { self.0.cmsg_len } + pub fn decode_data(&self) -> Result { + let data_ptr = unsafe { wsa_cmsg_data(self.0) } as *const u8; + let buffer = unsafe { slice::from_raw_parts(data_ptr, self.len()) }; + T::decode(buffer) + } + pub unsafe fn data(&self) -> &T { unsafe { let data_ptr = wsa_cmsg_data(self.0); @@ -93,6 +104,14 @@ impl CMsgMut<'_> { self.0.cmsg_type = ty; } + pub(crate) fn encode_data(&mut self, value: &T) -> Result { + let data_ptr = unsafe { wsa_cmsg_data(self.0) } as *mut MaybeUninit; + let buffer = unsafe { slice::from_raw_parts_mut(data_ptr, T::SIZE) }; + value.encode(buffer)?; + self.0.cmsg_len = wsa_cmsg_len(T::SIZE as _) as _; + Ok(wsa_cmsg_space(T::SIZE as _)) + } + pub(crate) unsafe fn set_data(&mut self, data: T) -> usize { self.0.cmsg_len = wsa_cmsg_len(size_of::() as _) as _; unsafe { @@ -144,9 +163,9 @@ impl CMsgIter { self.msg.Control.buf.cast::().is_aligned() } - pub(crate) fn is_space_enough(&self) -> bool { + pub(crate) fn is_space_enough(&self, space: usize) -> bool { if !self.cmsg.is_null() { - let space = wsa_cmsg_space(size_of::() as _); + let space = wsa_cmsg_space(space as _); let max = self.msg.Control.buf as usize + self.msg.Control.len as usize; self.cmsg as usize + space <= max } else { @@ -154,3 +173,49 @@ impl CMsgIter { } } } + +impl AncillaryData for IN_PKTINFO { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + let mut pktinfo: IN_PKTINFO = unsafe { std::mem::zeroed() }; + unsafe { + pktinfo.ipi_addr.S_un.S_addr = self.ipi_addr.S_un.S_addr; + } + pktinfo.ipi_ifindex = self.ipi_ifindex; + unsafe { copy_to_bytes(&pktinfo, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + let pktinfo: IN_PKTINFO = unsafe { copy_from_bytes(buffer) }?; + Ok(IN_PKTINFO { + ipi_addr: WinSock::IN_ADDR { + S_un: WinSock::IN_ADDR_0 { + S_addr: unsafe { pktinfo.ipi_addr.S_un.S_addr }, + }, + }, + ipi_ifindex: pktinfo.ipi_ifindex, + }) + } +} + +impl AncillaryData for IN6_PKTINFO { + fn encode(&self, buffer: &mut [MaybeUninit]) -> Result<(), CodecError> { + let mut pktinfo: IN6_PKTINFO = unsafe { std::mem::zeroed() }; + unsafe { + pktinfo.ipi6_addr.u.Byte = self.ipi6_addr.u.Byte; + } + pktinfo.ipi6_ifindex = self.ipi6_ifindex; + unsafe { copy_to_bytes(&pktinfo, buffer) } + } + + fn decode(buffer: &[u8]) -> Result { + let pktinfo: IN6_PKTINFO = unsafe { copy_from_bytes(buffer) }?; + Ok(IN6_PKTINFO { + ipi6_addr: WinSock::IN6_ADDR { + u: WinSock::IN6_ADDR_0 { + Byte: unsafe { pktinfo.ipi6_addr.u.Byte }, + }, + }, + ipi6_ifindex: pktinfo.ipi6_ifindex, + }) + } +} diff --git a/compio-io/tests/ancillary.rs b/compio-io/tests/ancillary.rs index 21389ecbc..ff8521f98 100644 --- a/compio-io/tests/ancillary.rs +++ b/compio-io/tests/ancillary.rs @@ -1,35 +1,47 @@ use std::mem::MaybeUninit; -use aligned_array::{A8, Aligned}; -use compio_buf::{IoBuf, IoBufMut}; -use compio_io::ancillary::{AncillaryBuilder, AncillaryIter}; +use compio_buf::IoBuf; +use compio_io::ancillary::{AncillaryBuf, AncillaryIter, CMsgBuilder}; #[test] fn test_cmsg() { - let mut buf: Aligned = Aligned([0u8; 64]); - let mut builder = AncillaryBuilder::new(buf.as_uninit()); + let mut buf = AncillaryBuf::<128>::new(); + let mut builder = buf.builder(); - builder.try_push(0, 0, ()).unwrap(); // 16 / 12 - builder.try_push(1, 1, u32::MAX).unwrap(); // 16 + 4 + 4 / 12 + 4 - builder.try_push(2, 2, i64::MIN).unwrap(); // 16 + 8 / 12 + 8 - let len = builder.finish(); - assert!(len == 64 || len == 48); + builder.push(0, 0, &()).unwrap(); // 16 / 12 + builder.push(1, 1, &u8::MAX).unwrap(); // 16 + 1 + 7 / 12 + 1 + 3 + builder.push(2, 2, &u32::MAX).unwrap(); // 16 + 4 + 4 / 12 + 4 + builder.push(3, 3, &i64::MIN).unwrap(); // 16 + 8 / 12 + 8 + builder.push(4, 4, &[0; 1]).unwrap(); // 16 + 1 + 7 / 12 + 1 + 3 + assert!(buf.buf_len() == 112 || buf.buf_len() == 80); unsafe { - let buf = buf.slice(..len); let mut iter = AncillaryIter::new(&buf); let cmsg = iter.next().unwrap(); - assert_eq!((cmsg.level(), cmsg.ty(), cmsg.data::<()>()), (0, 0, &())); + assert_eq!( + (cmsg.level(), cmsg.ty(), cmsg.data::<()>().unwrap()), + (0, 0, ()) + ); + let cmsg = iter.next().unwrap(); + assert_eq!( + (cmsg.level(), cmsg.ty(), cmsg.data::().unwrap()), + (1, 1, u8::MAX) + ); + let cmsg = iter.next().unwrap(); + assert_eq!( + (cmsg.level(), cmsg.ty(), cmsg.data::().unwrap()), + (2, 2, u32::MAX) + ); let cmsg = iter.next().unwrap(); assert_eq!( - (cmsg.level(), cmsg.ty(), cmsg.data::()), - (1, 1, &u32::MAX) + (cmsg.level(), cmsg.ty(), cmsg.data::().unwrap()), + (3, 3, i64::MIN) ); let cmsg = iter.next().unwrap(); assert_eq!( - (cmsg.level(), cmsg.ty(), cmsg.data::()), - (2, 2, &i64::MIN) + (cmsg.level(), cmsg.ty(), cmsg.data().unwrap()), + (4, 4, [0; 1]) ); assert!(iter.next().is_none()); } @@ -38,13 +50,12 @@ fn test_cmsg() { #[test] #[should_panic] fn invalid_buffer_length() { - let mut buf = [MaybeUninit::new(0u8); 1]; - AncillaryBuilder::new(&mut buf); + AncillaryBuf::<1>::new().builder(); } #[test] #[should_panic] fn invalid_buffer_alignment() { let mut buf = [MaybeUninit::new(0u8); 64]; - AncillaryBuilder::new(&mut buf[1..]); + CMsgBuilder::new(&mut buf[1..]); } diff --git a/compio-net/src/lib.rs b/compio-net/src/lib.rs index a2d47d4cd..e9f8ae721 100644 --- a/compio-net/src/lib.rs +++ b/compio-net/src/lib.rs @@ -28,21 +28,21 @@ mod unix; since = "0.12.0", note = "use `compio_io::ancillary::AncillaryRef` instead" )] -pub type CMsgRef<'a> = compio_io::ancillary::AncillaryRef<'a>; +pub type CMsgRef<'a> = compio_io::ancillary::CMsgRef<'a>; /// An iterator for control messages. #[deprecated( since = "0.12.0", note = "use `compio_io::ancillary::AncillaryIter` instead" )] -pub type CMsgIter<'a> = compio_io::ancillary::AncillaryIter<'a>; +pub type CMsgIter<'a> = compio_io::ancillary::CMsgIter<'a>; /// Helper to construct control message. #[deprecated( since = "0.12.0", - note = "use `compio_io::ancillary::AncillaryBuilder` instead" + note = "use `compio_io::ancillary::AncillaryBuf::builder()` instead" )] -pub type CMsgBuilder<'a> = compio_io::ancillary::AncillaryBuilder<'a>; +pub type CMsgBuilder<'a> = compio_io::ancillary::CMsgBuilder<'a>; /// Providing functionalities to wait for readiness. #[deprecated(since = "0.12.0", note = "Use `compio::runtime::fd::PollFd` instead")] diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 6de9d18e1..b5baf28c3 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -15,7 +15,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] # Workspace dependencies -compio-io = { workspace = true, features = ["ancillary"] } +compio-io = { workspace = true, features = ["ancillary", "bytemuck"] } compio-buf = { workspace = true, features = ["bytes"] } compio-log = { workspace = true } compio-net = { workspace = true } diff --git a/compio-quic/src/socket.rs b/compio-quic/src/socket.rs index 34f01ed57..fd3c4995c 100644 --- a/compio-quic/src/socket.rs +++ b/compio-quic/src/socket.rs @@ -15,8 +15,8 @@ use std::{ sync::atomic::Ordering, }; -use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetLen, buf_try}; -use compio_io::ancillary::{AncillaryBuf, AncillaryBuilder, AncillaryIter}; +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, buf_try}; +use compio_io::ancillary::{AncillaryBuf, AncillaryIter, CodecError}; use compio_net::UdpSocket; use quinn_proto::{EcnCodepoint, Transmit}; #[cfg(windows)] @@ -290,68 +290,83 @@ impl Socket { #[allow(unused_mut)] let mut stride = len; - // SAFETY: `control` contains valid data - unsafe { - for cmsg in AncillaryIter::new(&control) { + let res = (|| { + // SAFETY: `control` contains valid data + for cmsg in unsafe { AncillaryIter::new(&control) } { #[cfg(windows)] const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32; match (cmsg.level(), cmsg.ty()) { // ECN #[cfg(unix)] - (libc::IPPROTO_IP, libc::IP_TOS) => ecn_bits = *cmsg.data::(), + (libc::IPPROTO_IP, libc::IP_TOS) => ecn_bits = cmsg.data::()?, #[cfg(all(unix, not(any(non_freebsd, solarish))))] - (libc::IPPROTO_IP, libc::IP_RECVTOS) => ecn_bits = *cmsg.data::(), + (libc::IPPROTO_IP, libc::IP_RECVTOS) => ecn_bits = cmsg.data::()?, #[cfg(unix)] (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => { // NOTE: It's OK to use `c_int` instead of `u8` on Apple systems - ecn_bits = *cmsg.data::() as u8 + ecn_bits = cmsg.data::()? as u8 } #[cfg(windows)] (WinSock::IPPROTO_IP, WinSock::IP_ECN) | (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { - ecn_bits = *cmsg.data::() as u8 + ecn_bits = cmsg.data::()? as u8 } // pktinfo / destination address #[cfg(linux_all)] (libc::IPPROTO_IP, libc::IP_PKTINFO) => { - let pktinfo = cmsg.data::(); + let pktinfo = cmsg.data::()?; local_ip = Some(IpAddr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes())); } #[cfg(any(bsd, solarish, apple))] (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { - let in_addr = cmsg.data::(); + let in_addr = cmsg.data::()?; local_ip = Some(IpAddr::from(in_addr.s_addr.to_ne_bytes())); } #[cfg(windows)] (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { - let pktinfo = cmsg.data::(); - local_ip = Some(IpAddr::from(pktinfo.ipi_addr.S_un.S_addr.to_ne_bytes())); + let pktinfo = cmsg.data::()?; + local_ip = Some(IpAddr::from( + // SAFETY: S_addr is a valid representation of the union for IPv4 + // addresses + unsafe { pktinfo.ipi_addr.S_un.S_addr }.to_ne_bytes(), + )); } #[cfg(unix)] (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { - let pktinfo = cmsg.data::(); + let pktinfo = cmsg.data::()?; local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.s6_addr)); } #[cfg(windows)] (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { - let pktinfo = cmsg.data::(); - local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.u.Byte)); + let pktinfo = cmsg.data::()?; + // SAFETY: Byte is a valid representation of the union for IPv6 addresses + local_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte })); } // GRO #[cfg(linux_all)] - (libc::SOL_UDP, libc::UDP_GRO) => stride = *cmsg.data::() as usize, + (libc::SOL_UDP, libc::UDP_GRO) => stride = cmsg.data::()? as usize, #[cfg(windows)] (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => { - stride = *cmsg.data::() as usize + stride = cmsg.data::()? as usize } _ => {} } } - } + Ok::<(), CodecError>(()) + })(); + let ((), buffer) = buf_try!(BufResult( + res.map_err(|e| match e { + CodecError::BufferTooSmall => { + io::Error::new(io::ErrorKind::InvalidData, "cmsg_len is too small") + } + CodecError::Other(e) => io::Error::other(e), + }), + buffer + )); let meta = RecvMeta { remote, @@ -363,26 +378,29 @@ impl Socket { BufResult(Ok(meta), buffer) } - pub async fn send(&self, buffer: T, transmit: &Transmit) -> T { + fn construct_control_message( + &self, + transmit: &Transmit, + ) -> Result, CodecError> { let is_ipv4 = transmit.destination.ip().to_canonical().is_ipv4(); let ecn = transmit.ecn.map_or(0, |x| x as u8); - let mut control = AncillaryBuf::::new(); - let mut builder = AncillaryBuilder::new(control.as_uninit()); + let mut control = AncillaryBuf::new(); + let mut builder = control.builder(); // ECN if is_ipv4 { #[cfg(all(unix, not(any(freebsd, netbsd))))] - builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_int); + builder.push(libc::IPPROTO_IP, libc::IP_TOS, &(ecn as libc::c_int))?; #[cfg(freebsd)] - builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_uchar); + builder.push(libc::IPPROTO_IP, libc::IP_TOS, &(ecn as libc::c_uchar))?; #[cfg(windows)] - builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn as i32); + builder.push(WinSock::IPPROTO_IP, WinSock::IP_ECN, &(ecn as i32))?; } else { #[cfg(unix)] - builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn as libc::c_int); + builder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, &(ecn as libc::c_int))?; #[cfg(windows)] - builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn as i32); + builder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, &(ecn as i32))?; } // pktinfo / destination address @@ -396,7 +414,7 @@ impl Socket { ipi_spec_dst: libc::in_addr { s_addr: addr }, ipi_addr: libc::in_addr { s_addr: 0 }, }; - builder.try_push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo); + builder.push(libc::IPPROTO_IP, libc::IP_PKTINFO, &pktinfo)?; } #[cfg(any(bsd, solarish, apple))] { @@ -407,7 +425,7 @@ impl Socket { if encode_src_ip_v4 { let addr = libc::in_addr { s_addr: addr }; - builder.try_push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr); + builder.push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, &addr)?; } } #[cfg(windows)] @@ -418,7 +436,7 @@ impl Socket { }, ipi_ifindex: 0, }; - builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + builder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, &pktinfo)?; } } Some(IpAddr::V6(ip)) => { @@ -430,7 +448,7 @@ impl Socket { s6_addr: ip.octets(), }, }; - builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + builder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, &pktinfo)?; } #[cfg(windows)] { @@ -440,7 +458,7 @@ impl Socket { }, ipi6_ifindex: 0, }; - builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + builder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, &pktinfo)?; } } None => {} @@ -451,21 +469,24 @@ impl Socket { && segment_size < transmit.size { #[cfg(linux_all)] - builder.try_push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size as u16); + builder.push(libc::SOL_UDP, libc::UDP_SEGMENT, &(segment_size as u16))?; #[cfg(windows)] - builder.try_push( + builder.push( WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE, - segment_size as u32, - ); + &(segment_size as u32), + )?; #[cfg(not(any(linux_all, windows)))] let _ = segment_size; } - let len = builder.finish(); - // SAFETY: AncillaryBuilder ensures the buffer is initialized within len - unsafe { control.set_len(len) }; + Ok(control) + } + pub async fn send(&self, buffer: T, transmit: &Transmit) -> T { + let mut control = self + .construct_control_message(transmit) + .expect("CMSG_LEN should be large enough"); let mut buffer = buffer.slice(0..transmit.size); loop { diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 06996c8e0..1dfa50fe5 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -118,6 +118,7 @@ memmap2 = ["compio-buf/memmap2"] criterion = ["compio-runtime?/criterion"] sync = ["compio-driver/sync", "compio-quic?/sync", "compio-io?/sync"] +bytemuck = ["compio-io?/bytemuck"] enable_log = ["compio-log/enable_log"]