Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 13 additions & 128 deletions compio-net/src/split.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{error::Error, fmt, io};
use std::{io, ops::Deref};

use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::AsRawFd;
use compio_io::{
AsyncRead, AsyncWrite, AsyncWriteZerocopy, ancillary::AsyncWriteAncillaryZerocopy,
};
use compio_io::{AsyncRead, AsyncWrite};

pub(crate) fn split<T>(stream: &T) -> (ReadHalf<'_, T>, WriteHalf<'_, T>)
where
Expand All @@ -30,6 +27,14 @@ where
}
}

impl<T> Deref for ReadHalf<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
self.0
}
}

/// Borrowed write half.
#[derive(Debug)]
pub struct WriteHalf<'a, T>(&'a T);
Expand All @@ -55,130 +60,10 @@ where
}
}

pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>)
where
for<'a> &'a T: AsyncRead + AsyncWrite,
T: Clone,
{
(OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream))
}

/// Owned read half.
#[derive(Debug)]
pub struct OwnedReadHalf<T>(T);

impl<T: AsRawFd> OwnedReadHalf<T> {
/// Attempts to put the two halves of a `TcpStream` back together and
/// recover the original socket. Succeeds only if the two halves
/// originated from the same call to `into_split`.
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
if self.0.as_raw_fd() == w.0.as_raw_fd() {
drop(w);
Ok(self.0)
} else {
Err(ReuniteError(self, w))
}
}
}
impl<T> Deref for WriteHalf<'_, T> {
type Target = T;

impl<T> AsyncRead for OwnedReadHalf<T>
where
for<'a> &'a T: AsyncRead,
{
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
(&self.0).read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
(&self.0).read_vectored(buf).await
}
}

/// Owned write half.
#[derive(Debug)]
pub struct OwnedWriteHalf<T>(T);

impl<T> AsyncWrite for OwnedWriteHalf<T>
where
for<'a> &'a T: AsyncWrite,
{
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
(&self.0).write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
(&self.0).write_vectored(buf).await
}

async fn flush(&mut self) -> io::Result<()> {
(&self.0).flush().await
}

async fn shutdown(&mut self) -> io::Result<()> {
(&self.0).shutdown().await
}
}

impl<T> AsyncWriteZerocopy for OwnedWriteHalf<T>
where
T: AsyncWriteZerocopy,
{
type BufferReadyFuture<B: IoBuf> = T::BufferReadyFuture<B>;
type VectoredBufferReadyFuture<B: IoVectoredBuf> = T::VectoredBufferReadyFuture<B>;

async fn write_zerocopy<B: IoBuf>(
&mut self,
buf: B,
) -> BufResult<usize, Self::BufferReadyFuture<B>> {
self.0.write_zerocopy(buf).await
}

async fn write_zerocopy_vectored<B: IoVectoredBuf>(
&mut self,
buf: B,
) -> BufResult<usize, Self::VectoredBufferReadyFuture<B>> {
self.0.write_zerocopy_vectored(buf).await
}
}

impl<T> AsyncWriteAncillaryZerocopy for OwnedWriteHalf<T>
where
T: AsyncWriteAncillaryZerocopy,
{
type BufferReadyFuture<B: IoBuf, C: IoBuf> = T::BufferReadyFuture<B, C>;
type VectoredBufferReadyFuture<B: IoVectoredBuf, C: IoBuf> = T::VectoredBufferReadyFuture<B, C>;

async fn write_zerocopy_with_ancillary<B: IoBuf, C: IoBuf>(
&mut self,
buf: B,
control: C,
) -> BufResult<usize, Self::BufferReadyFuture<B, C>> {
self.0.write_zerocopy_with_ancillary(buf, control).await
}

async fn write_zerocopy_vectored_with_ancillary<B: IoVectoredBuf, C: IoBuf>(
&mut self,
buf: B,
control: C,
) -> BufResult<usize, Self::VectoredBufferReadyFuture<B, C>> {
fn deref(&self) -> &Self::Target {
self.0
.write_zerocopy_vectored_with_ancillary(buf, control)
.await
}
}

/// Error indicating that two halves were not from the same socket, and thus
/// could not be reunited.
#[derive(Debug)]
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);

impl<T> fmt::Display for ReuniteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}

impl<T: fmt::Debug> Error for ReuniteError<T> {}
15 changes: 7 additions & 8 deletions compio-net/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use futures_util::{Stream, StreamExt, stream::FusedStream};
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};

use crate::{
Extract, Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket,
ToSocketAddrsAsync, WriteHalf, Zerocopy,
Extract, Incoming, MSG_NOSIGNAL, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf, Zerocopy,
};

/// A TCP socket server, listening for connections.
Expand Down Expand Up @@ -324,9 +323,9 @@ impl TcpStream {
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](TcpStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
/// separate tasks.
pub fn into_split(self) -> (Self, Self) {
(self.clone(), self)
}

/// Create [`PollFd`] from inner socket.
Expand Down Expand Up @@ -774,11 +773,11 @@ impl AsyncWriteAncillaryZerocopy for &TcpStream {
}

impl Splittable for TcpStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
type ReadHalf = Self;
type WriteHalf = Self;

fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
self.into_split()
}
}

Expand Down
17 changes: 7 additions & 10 deletions compio-net/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ use compio_runtime::fd::PollFd;
use futures_util::{Stream, StreamExt, stream::FusedStream};
use socket2::{Domain, SockAddr, Socket as Socket2, Type};

use crate::{
Extract, Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf,
Zerocopy,
};
use crate::{Extract, Incoming, MSG_NOSIGNAL, ReadHalf, Socket, WriteHalf, Zerocopy};

/// A Unix socket server, listening for connections.
///
Expand Down Expand Up @@ -278,9 +275,9 @@ impl UnixStream {
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](UnixStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
/// separate tasks.
pub fn into_split(self) -> (Self, Self) {
(self.clone(), self)
}

/// Create [`PollFd`] from inner socket.
Expand Down Expand Up @@ -685,11 +682,11 @@ impl AsyncWriteAncillary for &UnixStream {
}

impl Splittable for UnixStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
type ReadHalf = Self;
type WriteHalf = Self;

fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
self.into_split()
}
}

Expand Down
26 changes: 0 additions & 26 deletions compio-net/tests/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,6 @@ async fn tcp_split() {
handle.await.resume_unwind();
}

#[compio_macros::test]
async fn tcp_unsplit() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();

let handle = compio_runtime::spawn_blocking(move || {
drop(listener.accept().unwrap());
drop(listener.accept().unwrap());
});

let stream1 = TcpStream::connect(&addr).await.unwrap();
let (read1, write1) = stream1.into_split();

let stream2 = TcpStream::connect(&addr).await.unwrap();
let (_, write2) = stream2.into_split();

let read1 = match read1.reunite(write2) {
Ok(_) => panic!("Reunite should not succeed"),
Err(err) => err.0,
};

read1.reunite(write1).expect("Reunite should succeed");

handle.await.resume_unwind();
}

#[compio_macros::test]
async fn unix_split() {
let dir = tempfile::Builder::new()
Expand Down
Loading