Skip to content

Commit 43c6ef4

Browse files
feat(driver): add recv_from_managed operation support (#709)
* feat(driver): add recv_from_managed operation support * refactor: use mop to define the RecvFromManaged
1 parent 6d5d254 commit 43c6ef4

6 files changed

Lines changed: 157 additions & 5 deletions

File tree

compio-driver/src/op.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub use crate::sys::op::{
2525
WriteVectored, WriteVectoredAt,
2626
};
2727
#[cfg(io_uring)]
28-
pub use crate::sys::op::{ReadManaged, ReadManagedAt, RecvManaged};
28+
pub use crate::sys::op::{ReadManaged, ReadManagedAt, RecvFromManaged, RecvManaged};
2929
use crate::{Extra, OwnedFd, SharedFd, TakeBuffer, sys::aio::*};
3030

3131
/// Trait to update the buffer length inside the [`BufResult`].
@@ -471,9 +471,10 @@ pub(crate) mod managed {
471471

472472
use compio_buf::IntoInner;
473473
use pin_project_lite::pin_project;
474+
use socket2::SockAddr;
474475

475-
use super::{Read, ReadAt, Recv};
476-
use crate::{BorrowedBuffer, BufferPool, OwnedBuffer, TakeBuffer};
476+
use super::{Read, ReadAt, Recv, RecvFrom};
477+
use crate::{AsFd, BorrowedBuffer, BufferPool, OwnedBuffer, TakeBuffer};
477478

478479
pin_project! {
479480
/// Read a file at specified position into managed buffer.
@@ -600,6 +601,48 @@ pub(crate) mod managed {
600601
Ok(res)
601602
}
602603
}
604+
605+
pin_project! {
606+
/// Receive data and source address into managed buffer.
607+
pub struct RecvFromManaged<S: AsFd> {
608+
#[pin]
609+
pub(crate) op: RecvFrom<OwnedBuffer, S>,
610+
}
611+
}
612+
613+
impl<S: AsFd> RecvFromManaged<S> {
614+
/// Create [`RecvFromManaged`].
615+
pub fn new(fd: S, pool: &BufferPool, len: usize, flags: i32) -> io::Result<Self> {
616+
#[cfg(fusion)]
617+
let pool = pool.as_poll();
618+
Ok(Self {
619+
op: RecvFrom::new(fd, pool.get_buffer(len)?, flags),
620+
})
621+
}
622+
}
623+
624+
impl<S: AsFd> TakeBuffer for RecvFromManaged<S> {
625+
type Buffer<'a> = (BorrowedBuffer<'a>, SockAddr);
626+
type BufferPool = BufferPool;
627+
628+
fn take_buffer(
629+
self,
630+
buffer_pool: &Self::BufferPool,
631+
result: io::Result<usize>,
632+
_: u16,
633+
) -> io::Result<Self::Buffer<'_>> {
634+
let result = result?;
635+
#[cfg(fusion)]
636+
let buffer_pool = buffer_pool.as_poll();
637+
let (slice, addr_buffer, addr_size) = self.op.into_inner();
638+
let addr = unsafe { SockAddr::new(addr_buffer, addr_size) };
639+
// SAFETY: result is valid
640+
let res = unsafe { buffer_pool.create_proxy(slice, result) };
641+
#[cfg(fusion)]
642+
let res = BorrowedBuffer::new_poll(res);
643+
Ok((res, addr))
644+
}
645+
}
603646
}
604647

605648
#[cfg(not(io_uring))]

compio-driver/src/sys/fusion/op.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ op!(<S: AsFd> PathStat(dirfd: S, path: CString, follow_symlink: bool));
115115

116116
macro_rules! mop {
117117
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? ) with $pool:ident) => {
118+
mop!{ < $($ty: $trait),* > $name ( $( $arg: $arg_t ),* ) with $pool, buffer: crate::BorrowedBuffer<'a> }
119+
};
120+
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? ) with $pool:ident, buffer: $buffer:ty) => {
118121
::paste::paste!{
119122
enum [< $name Inner >] <$($ty: $trait),*> {
120123
Poll(crate::op::managed::$name<$($ty),*>),
@@ -159,7 +162,7 @@ macro_rules! mop {
159162

160163
impl<$($ty: $trait),*> crate::TakeBuffer for $name<$($ty),*> {
161164
type BufferPool = crate::BufferPool;
162-
type Buffer<'a> = crate::BorrowedBuffer<'a>;
165+
type Buffer<'a> = $buffer;
163166

164167
fn take_buffer(
165168
self,
@@ -206,3 +209,4 @@ macro_rules! mop {
206209
mop!(<S: AsFd> ReadManagedAt(fd: S, offset: u64, pool: &BufferPool, len: usize) with pool);
207210
mop!(<S: AsFd> ReadManaged(fd: S, pool: &BufferPool, len: usize) with pool);
208211
mop!(<S: AsFd> RecvManaged(fd: S, pool: &BufferPool, len: usize, flags: i32) with pool);
212+
mop!(<S: AsFd> RecvFromManaged(fd: S, pool: &BufferPool, len: usize, flags: i32) with pool, buffer: (crate::BorrowedBuffer<'a>, SockAddr));

compio-driver/src/sys/iocp/op.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,16 @@ unsafe impl<S: AsFd> OpCode for RecvManaged<S> {
578578
}
579579
}
580580

581+
unsafe impl<S: AsFd> OpCode for RecvFromManaged<S> {
582+
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
583+
unsafe { self.project().op.operate(optr) }
584+
}
585+
586+
fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
587+
self.project().op.cancel(optr)
588+
}
589+
}
590+
581591
pin_project! {
582592
/// Receive data from remote into vectored buffer.
583593
pub struct RecvVectored<T: IoVectoredBufMut, S> {

compio-driver/src/sys/iour/op.rs

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,8 @@ mod buf_ring {
917917
};
918918

919919
use io_uring::{opcode, squeue::Flags, types::Fd};
920+
use pin_project_lite::pin_project;
921+
use socket2::{SockAddr, SockAddrStorage, socklen_t};
920922

921923
use super::OpCode;
922924
use crate::{BorrowedBuffer, BufferPool, OpEntry, TakeBuffer};
@@ -1096,6 +1098,83 @@ mod buf_ring {
10961098
res
10971099
}
10981100
}
1101+
1102+
pin_project! {
1103+
/// Receive data and source address into managed buffer.
1104+
pub struct RecvFromManaged<S> {
1105+
fd: S,
1106+
buffer_group: u16,
1107+
flags: i32,
1108+
addr: SockAddrStorage,
1109+
addr_len: socklen_t,
1110+
iovec: libc::iovec,
1111+
msg: libc::msghdr,
1112+
_p: PhantomPinned,
1113+
}
1114+
}
1115+
1116+
impl<S> RecvFromManaged<S> {
1117+
/// Create [`RecvFromManaged`].
1118+
pub fn new(fd: S, buffer_pool: &BufferPool, len: usize, flags: i32) -> io::Result<Self> {
1119+
#[cfg(fusion)]
1120+
let buffer_pool = buffer_pool.as_io_uring();
1121+
let len: u32 = len.try_into().map_err(|_| {
1122+
io::Error::new(io::ErrorKind::InvalidInput, "required length too long")
1123+
})?;
1124+
let addr = SockAddrStorage::zeroed();
1125+
Ok(Self {
1126+
fd,
1127+
buffer_group: buffer_pool.buffer_group(),
1128+
flags,
1129+
addr_len: addr.size_of() as _,
1130+
addr,
1131+
iovec: libc::iovec {
1132+
iov_base: ptr::null_mut(),
1133+
iov_len: len as _,
1134+
},
1135+
msg: unsafe { std::mem::zeroed() },
1136+
_p: PhantomPinned,
1137+
})
1138+
}
1139+
}
1140+
1141+
unsafe impl<S: AsFd> OpCode for RecvFromManaged<S> {
1142+
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
1143+
let this = self.project();
1144+
this.msg.msg_name = this.addr as *mut _ as _;
1145+
this.msg.msg_namelen = *this.addr_len;
1146+
this.msg.msg_iov = this.iovec as *const _ as *mut _;
1147+
this.msg.msg_iovlen = 1;
1148+
opcode::RecvMsg::new(Fd(this.fd.as_fd().as_raw_fd()), this.msg)
1149+
.flags(*this.flags as _)
1150+
.buf_group(*this.buffer_group)
1151+
.build()
1152+
.flags(Flags::BUFFER_SELECT)
1153+
.into()
1154+
}
1155+
}
1156+
1157+
impl<S> TakeBuffer for RecvFromManaged<S> {
1158+
type Buffer<'a> = (BorrowedBuffer<'a>, SockAddr);
1159+
type BufferPool = BufferPool;
1160+
1161+
fn take_buffer(
1162+
self,
1163+
buffer_pool: &Self::BufferPool,
1164+
result: io::Result<usize>,
1165+
buffer_id: u16,
1166+
) -> io::Result<Self::Buffer<'_>> {
1167+
#[cfg(fusion)]
1168+
let buffer_pool = buffer_pool.as_io_uring();
1169+
let result = result.inspect_err(|_| buffer_pool.reuse_buffer(buffer_id))?;
1170+
let addr = unsafe { SockAddr::new(self.addr, self.addr_len) };
1171+
// SAFETY: result is valid
1172+
let buffer = unsafe { buffer_pool.get_buffer(buffer_id, result) }?;
1173+
#[cfg(fusion)]
1174+
let buffer = BorrowedBuffer::new_io_uring(buffer);
1175+
Ok((buffer, addr))
1176+
}
1177+
}
10991178
}
11001179

1101-
pub use buf_ring::{ReadManaged, ReadManagedAt, RecvManaged};
1180+
pub use buf_ring::{ReadManaged, ReadManagedAt, RecvFromManaged, RecvManaged};

compio-driver/src/sys/poll/op.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,20 @@ unsafe impl<S: AsFd> OpCode for crate::op::managed::RecvManaged<S> {
858858
}
859859
}
860860

861+
unsafe impl<S: AsFd> OpCode for crate::op::managed::RecvFromManaged<S> {
862+
fn pre_submit(self: Pin<&mut Self>) -> io::Result<Decision> {
863+
self.project().op.pre_submit()
864+
}
865+
866+
fn op_type(self: Pin<&mut Self>) -> Option<OpType> {
867+
self.project().op.op_type()
868+
}
869+
870+
fn operate(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
871+
self.project().op.operate()
872+
}
873+
}
874+
861875
impl<T: IoVectoredBufMut, S: AsFd> RecvVectored<T, S> {
862876
unsafe fn call(self: Pin<&mut Self>) -> libc::ssize_t {
863877
let this = self.project();

compio-driver/src/sys/stub/op.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,5 @@ impl<S: AsFd> OpCode for crate::op::managed::ReadManagedAt<S> {}
256256
impl<S: AsFd> OpCode for crate::op::managed::ReadManaged<S> {}
257257

258258
impl<S: AsFd> OpCode for crate::op::managed::RecvManaged<S> {}
259+
260+
impl<S: AsFd> OpCode for crate::op::managed::RecvFromManaged<S> {}

0 commit comments

Comments
 (0)