Skip to content

Commit f75d13c

Browse files
authored
feat(io): add traits for reading/writing with ancillary data (#717)
* feat(io): add socket traits for cmsg send/recv * fix(net): drop socket trait impl for udp * fix(io): use associated type for addr in trait * fix(net): use socket2::SockAddr for uds * fix(io,net): change of design * fix(driver): allow using empty cmsg buffer to omit This is not done in a way of making control param optional because send_with_ancillary() semantically requires a control buffer, but at the same time, send_with_ancillary() with empty control buffer is not practically identical to just send() as the latter may use different impl like send_zc in iour. Therefore, I decided to implicitly allow empty cmsg buffers as something like "enforce sendmsg impl, but send without control buffer".
1 parent 38ac58b commit f75d13c

8 files changed

Lines changed: 382 additions & 29 deletions

File tree

compio-driver/src/sys/iocp/op.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
11221122
/// This function will panic if the control message buffer is misaligned.
11231123
pub fn new(fd: S, buffer: T, control: C, addr: Option<SockAddr>, flags: i32) -> Self {
11241124
assert!(
1125-
control.buf_ptr().cast::<CMSGHDR>().is_aligned(),
1125+
control.buf_len() == 0 || control.buf_ptr().cast::<CMSGHDR>().is_aligned(),
11261126
"misaligned control message buffer"
11271127
);
11281128
Self {
@@ -1151,7 +1151,11 @@ unsafe impl<T: IoVectoredBuf, C: IoBuf, S: AsFd> OpCode for SendMsg<T, C, S> {
11511151
let this = self.project();
11521152

11531153
*this.slices = this.buffer.as_ref().sys_slices();
1154-
let control = this.control.as_ref().sys_slice();
1154+
let control = if this.control.buf_len() == 0 {
1155+
SysSlice::null()
1156+
} else {
1157+
this.control.as_ref().sys_slice()
1158+
};
11551159
*this.msg = match this.addr.as_ref() {
11561160
Some(addr) => WSAMSG {
11571161
name: addr.as_ptr() as _,

compio-driver/src/sys/unix_op.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
815815
/// This function will panic if the control message buffer is misaligned.
816816
pub fn new(fd: S, buffer: T, control: C, addr: Option<SockAddr>, flags: i32) -> Self {
817817
assert!(
818-
control.buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
818+
control.buf_len() == 0 || control.buf_ptr().cast::<libc::cmsghdr>().is_aligned(),
819819
"misaligned control message buffer"
820820
);
821821
Self {
@@ -845,7 +845,11 @@ impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
845845
}
846846
this.msg.msg_iov = this.slices.as_ptr() as _;
847847
this.msg.msg_iovlen = this.slices.len() as _;
848-
this.msg.msg_control = this.control.buf_ptr() as _;
848+
this.msg.msg_control = if this.control.buf_len() == 0 {
849+
std::ptr::null_mut()
850+
} else {
851+
this.control.buf_ptr() as _
852+
};
849853
this.msg.msg_controllen = this.control.buf_len() as _;
850854
}
851855
}

compio-io/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ windows-sys = { workspace = true, optional = true, features = [
3131
] }
3232

3333
[dev-dependencies]
34+
compio-net = { workspace = true }
35+
compio-runtime = { workspace = true }
3436
tokio = { workspace = true, features = ["macros", "rt"] }
3537
serde = { version = "1.0.219", features = ["derive"] }
3638
futures-executor = "0.3.30"

compio-io/src/ancillary/io.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#[cfg(feature = "allocator_api")]
2+
use std::alloc::Allocator;
3+
4+
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut, t_alloc};
5+
6+
/// Trait for asynchronous read with ancillary (control) data.
7+
/// Intended for connected stream sockets (TCP, Unix streams) where no source
8+
/// address is needed.
9+
pub trait AsyncReadAncillary {
10+
/// Read data with ancillary data into an owned buffer.
11+
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
12+
&mut self,
13+
buffer: T,
14+
control: C,
15+
) -> BufResult<(usize, usize), (T, C)>;
16+
17+
/// Read data with ancillary data into a vectored buffer.
18+
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
19+
&mut self,
20+
buffer: T,
21+
control: C,
22+
) -> BufResult<(usize, usize), (T, C)>;
23+
}
24+
25+
impl<A: AsyncReadAncillary + ?Sized> AsyncReadAncillary for &mut A {
26+
#[inline]
27+
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
28+
&mut self,
29+
buffer: T,
30+
control: C,
31+
) -> BufResult<(usize, usize), (T, C)> {
32+
(**self).read_with_ancillary(buffer, control).await
33+
}
34+
35+
#[inline]
36+
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
37+
&mut self,
38+
buffer: T,
39+
control: C,
40+
) -> BufResult<(usize, usize), (T, C)> {
41+
(**self).read_vectored_with_ancillary(buffer, control).await
42+
}
43+
}
44+
45+
impl<A: AsyncReadAncillary + ?Sized, #[cfg(feature = "allocator_api")] Alloc: Allocator>
46+
AsyncReadAncillary for t_alloc!(Box, A, Alloc)
47+
{
48+
#[inline]
49+
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
50+
&mut self,
51+
buffer: T,
52+
control: C,
53+
) -> BufResult<(usize, usize), (T, C)> {
54+
(**self).read_with_ancillary(buffer, control).await
55+
}
56+
57+
#[inline]
58+
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
59+
&mut self,
60+
buffer: T,
61+
control: C,
62+
) -> BufResult<(usize, usize), (T, C)> {
63+
(**self).read_vectored_with_ancillary(buffer, control).await
64+
}
65+
}
66+
67+
/// Trait for asynchronous write with ancillary (control) data.
68+
/// Intended for connected stream sockets (TCP, Unix streams) where no
69+
/// destination address is needed.
70+
pub trait AsyncWriteAncillary {
71+
/// Write data with ancillary data from an owned buffer.
72+
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
73+
&mut self,
74+
buffer: T,
75+
control: C,
76+
) -> BufResult<usize, (T, C)>;
77+
78+
/// Write data with ancillary data from a vectored buffer.
79+
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
80+
&mut self,
81+
buffer: T,
82+
control: C,
83+
) -> BufResult<usize, (T, C)>;
84+
}
85+
86+
impl<A: AsyncWriteAncillary + ?Sized> AsyncWriteAncillary for &mut A {
87+
#[inline]
88+
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
89+
&mut self,
90+
buffer: T,
91+
control: C,
92+
) -> BufResult<usize, (T, C)> {
93+
(**self).write_with_ancillary(buffer, control).await
94+
}
95+
96+
#[inline]
97+
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
98+
&mut self,
99+
buffer: T,
100+
control: C,
101+
) -> BufResult<usize, (T, C)> {
102+
(**self)
103+
.write_vectored_with_ancillary(buffer, control)
104+
.await
105+
}
106+
}
107+
108+
impl<A: AsyncWriteAncillary + ?Sized, #[cfg(feature = "allocator_api")] Alloc: Allocator>
109+
AsyncWriteAncillary for t_alloc!(Box, A, Alloc)
110+
{
111+
#[inline]
112+
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
113+
&mut self,
114+
buffer: T,
115+
control: C,
116+
) -> BufResult<usize, (T, C)> {
117+
(**self).write_with_ancillary(buffer, control).await
118+
}
119+
120+
#[inline]
121+
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
122+
&mut self,
123+
buffer: T,
124+
control: C,
125+
) -> BufResult<usize, (T, C)> {
126+
(**self)
127+
.write_vectored_with_ancillary(buffer, control)
128+
.await
129+
}
130+
}

compio-io/src/ancillary/mod.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
//! ancillary data payloads.
1616
//! - [`CodecError`]: Error type for encoding/decoding operations.
1717
//!
18+
//! # Traits
19+
//!
20+
//! - [`AsyncReadAncillary`]: read data together with ancillary data
21+
//! - [`AsyncWriteAncillary`]: write data together with ancillary data
22+
//!
1823
//! # Functions
1924
//!
2025
//! - [`ancillary_space`]: Helper function to calculate ancillary message size
@@ -27,33 +32,52 @@
2732
//!
2833
//! # Example
2934
//!
35+
//! Send and receive a file descriptor over a Unix socket pair using
36+
//! `SCM_RIGHTS`:
37+
//!
3038
//! ```
31-
//! use compio_io::ancillary::{AncillaryBuf, AncillaryIter, CodecError, ancillary_space};
39+
//! # #[cfg(unix)] {
40+
//! use std::os::unix::io::RawFd;
41+
//!
42+
//! use compio_io::ancillary::*;
43+
//! use compio_net::UnixStream;
44+
//!
45+
//! const BUF_SIZE: usize = ancillary_space::<RawFd>();
3246
//!
33-
//! const LEVEL: i32 = 1;
34-
//! const TYPE: i32 = 2;
47+
//! # compio_runtime::Runtime::new().unwrap().block_on(async {
48+
//! // Create a socket pair.
49+
//! let (std_a, std_b) = std::os::unix::net::UnixStream::pair().unwrap();
50+
//! let mut a = UnixStream::from_std(std_a).unwrap();
51+
//! let mut b = UnixStream::from_std(std_b).unwrap();
3552
//!
36-
//! // Build a buffer containing two `u32` ancillary messages.
37-
//! let mut buf = AncillaryBuf::<{ ancillary_space::<u32>() * 2 }>::new();
38-
//! let mut builder = buf.builder();
39-
//! builder.push(LEVEL, TYPE, &42u32).unwrap();
40-
//! builder.push(LEVEL, TYPE, &43u32).unwrap();
41-
//! // Buffer is full, cannot add more messages.
42-
//! assert!(matches!(
43-
//! builder.push(LEVEL, TYPE, &44u32),
44-
//! Err(CodecError::BufferTooSmall)
45-
//! ));
53+
//! // Pass fd 0 (stdin) as ancillary data via SCM_RIGHTS.
54+
//! let mut ctrl_send = AncillaryBuf::<BUF_SIZE>::new();
55+
//! let mut builder = ctrl_send.builder();
56+
//! builder
57+
//! .push(libc::SOL_SOCKET, libc::SCM_RIGHTS, &(0 as RawFd))
58+
//! .unwrap();
4659
//!
47-
//! // Read back the messages.
48-
//! unsafe {
49-
//! let mut iter = AncillaryIter::new(&buf);
50-
//! let msg = iter.next().unwrap();
51-
//! assert_eq!(msg.level(), LEVEL);
52-
//! assert_eq!(msg.ty(), TYPE);
53-
//! assert_eq!(msg.data::<u32>().unwrap(), 42u32);
54-
//! assert_eq!(iter.next().unwrap().data::<u32>().unwrap(), 43u32);
55-
//! assert!(iter.next().is_none());
56-
//! }
60+
//! // Send the payload together with the ancillary data.
61+
//! a.write_with_ancillary(b"hello", ctrl_send).await.0.unwrap();
62+
//!
63+
//! // Receive on the other end.
64+
//! let payload = Vec::with_capacity(5);
65+
//! let ctrl_recv = AncillaryBuf::<BUF_SIZE>::new();
66+
//! let ((_, ctrl_len), (payload, ctrl_recv)) =
67+
//! b.read_with_ancillary(payload, ctrl_recv).await.unwrap();
68+
//!
69+
//! assert_eq!(&payload[..], b"hello");
70+
//!
71+
//! // Parse the received ancillary messages.
72+
//! let mut iter = unsafe { AncillaryIter::new(&ctrl_recv[..ctrl_len]) };
73+
//! let msg = iter.next().unwrap();
74+
//! assert_eq!(msg.level(), libc::SOL_SOCKET);
75+
//! assert_eq!(msg.ty(), libc::SCM_RIGHTS);
76+
//! // The kernel duplicates the fd, so the received value may differ.
77+
//! let _received_fd = unsafe { msg.data::<RawFd>() };
78+
//! assert!(iter.next().is_none());
79+
//! # });
80+
//! # }
5781
//! ```
5882
5983
use std::{
@@ -67,6 +91,10 @@ use compio_buf::{IoBuf, IoBufMut, SetLen};
6791
#[cfg(windows)]
6892
use windows_sys::Win32::Networking::WinSock;
6993

94+
mod io;
95+
96+
pub use self::io::*;
97+
7098
cfg_if::cfg_if! {
7199
if #[cfg(windows)] {
72100
#[path = "windows.rs"]

compio-net/src/tcp.rs

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ use std::{
88

99
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
1010
use compio_driver::impl_raw_fd;
11-
use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
11+
use compio_io::{
12+
AsyncRead, AsyncReadManaged, AsyncWrite,
13+
ancillary::{AsyncReadAncillary, AsyncWriteAncillary},
14+
util::Splittable,
15+
};
1216
use compio_runtime::{BorrowedBuffer, BufferPool, fd::PollFd};
1317
use futures_util::{Stream, StreamExt, stream::FusedStream};
1418
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
@@ -451,6 +455,52 @@ impl AsyncReadManaged for &TcpStream {
451455
}
452456
}
453457

458+
impl AsyncReadAncillary for TcpStream {
459+
#[inline]
460+
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
461+
&mut self,
462+
buffer: T,
463+
control: C,
464+
) -> BufResult<(usize, usize), (T, C)> {
465+
(&*self).read_with_ancillary(buffer, control).await
466+
}
467+
468+
#[inline]
469+
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
470+
&mut self,
471+
buffer: T,
472+
control: C,
473+
) -> BufResult<(usize, usize), (T, C)> {
474+
(&*self).read_vectored_with_ancillary(buffer, control).await
475+
}
476+
}
477+
478+
impl AsyncReadAncillary for &TcpStream {
479+
#[inline]
480+
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
481+
&mut self,
482+
buffer: T,
483+
control: C,
484+
) -> BufResult<(usize, usize), (T, C)> {
485+
self.inner
486+
.recv_msg(buffer, control, 0)
487+
.await
488+
.map_res(|(res, len, _addr)| (res, len))
489+
}
490+
491+
#[inline]
492+
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
493+
&mut self,
494+
buffer: T,
495+
control: C,
496+
) -> BufResult<(usize, usize), (T, C)> {
497+
self.inner
498+
.recv_msg_vectored(buffer, control, 0)
499+
.await
500+
.map_res(|(res, len, _addr)| (res, len))
501+
}
502+
}
503+
454504
impl AsyncWrite for TcpStream {
455505
#[inline]
456506
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
@@ -495,6 +545,48 @@ impl AsyncWrite for &TcpStream {
495545
}
496546
}
497547

548+
impl AsyncWriteAncillary for TcpStream {
549+
#[inline]
550+
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
551+
&mut self,
552+
buffer: T,
553+
control: C,
554+
) -> BufResult<usize, (T, C)> {
555+
(&*self).write_with_ancillary(buffer, control).await
556+
}
557+
558+
#[inline]
559+
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
560+
&mut self,
561+
buffer: T,
562+
control: C,
563+
) -> BufResult<usize, (T, C)> {
564+
(&*self)
565+
.write_vectored_with_ancillary(buffer, control)
566+
.await
567+
}
568+
}
569+
570+
impl AsyncWriteAncillary for &TcpStream {
571+
#[inline]
572+
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
573+
&mut self,
574+
buffer: T,
575+
control: C,
576+
) -> BufResult<usize, (T, C)> {
577+
self.inner.send_msg(buffer, control, None, 0).await
578+
}
579+
580+
#[inline]
581+
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
582+
&mut self,
583+
buffer: T,
584+
control: C,
585+
) -> BufResult<usize, (T, C)> {
586+
self.inner.send_msg_vectored(buffer, control, None, 0).await
587+
}
588+
}
589+
498590
impl Splittable for TcpStream {
499591
type ReadHalf = OwnedReadHalf<Self>;
500592
type WriteHalf = OwnedWriteHalf<Self>;

0 commit comments

Comments
 (0)