diff --git a/compio-net/src/split.rs b/compio-net/src/split.rs index 855f3680..d2110397 100644 --- a/compio-net/src/split.rs +++ b/compio-net/src/split.rs @@ -1,10 +1,7 @@ -use std::{error::Error, fmt, io}; +use std::{io, ops::Deref}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; -use compio_driver::AsRawFd; -use compio_io::{ - AsyncRead, AsyncWrite, AsyncWriteZerocopy, ancillary::AsyncWriteAncillaryZerocopy, -}; +use compio_io::{AsyncRead, AsyncWrite}; pub(crate) fn split(stream: &T) -> (ReadHalf<'_, T>, WriteHalf<'_, T>) where @@ -30,6 +27,14 @@ where } } +impl Deref for ReadHalf<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + /// Borrowed write half. #[derive(Debug)] pub struct WriteHalf<'a, T>(&'a T); @@ -55,130 +60,10 @@ where } } -pub(crate) fn into_split(stream: T) -> (OwnedReadHalf, OwnedWriteHalf) -where - for<'a> &'a T: AsyncRead + AsyncWrite, - T: Clone, -{ - (OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream)) -} - -/// Owned read half. -#[derive(Debug)] -pub struct OwnedReadHalf(T); - -impl OwnedReadHalf { - /// Attempts to put the two halves of a `TcpStream` back together and - /// recover the original socket. Succeeds only if the two halves - /// originated from the same call to `into_split`. - pub fn reunite(self, w: OwnedWriteHalf) -> Result> { - if self.0.as_raw_fd() == w.0.as_raw_fd() { - drop(w); - Ok(self.0) - } else { - Err(ReuniteError(self, w)) - } - } -} +impl Deref for WriteHalf<'_, T> { + type Target = T; -impl AsyncRead for OwnedReadHalf -where - for<'a> &'a T: AsyncRead, -{ - async fn read(&mut self, buf: B) -> BufResult { - (&self.0).read(buf).await - } - - async fn read_vectored(&mut self, buf: V) -> BufResult { - (&self.0).read_vectored(buf).await - } -} - -/// Owned write half. -#[derive(Debug)] -pub struct OwnedWriteHalf(T); - -impl AsyncWrite for OwnedWriteHalf -where - for<'a> &'a T: AsyncWrite, -{ - async fn write(&mut self, buf: B) -> BufResult { - (&self.0).write(buf).await - } - - async fn write_vectored(&mut self, buf: B) -> BufResult { - (&self.0).write_vectored(buf).await - } - - async fn flush(&mut self) -> io::Result<()> { - (&self.0).flush().await - } - - async fn shutdown(&mut self) -> io::Result<()> { - (&self.0).shutdown().await - } -} - -impl AsyncWriteZerocopy for OwnedWriteHalf -where - T: AsyncWriteZerocopy, -{ - type BufferReadyFuture = T::BufferReadyFuture; - type VectoredBufferReadyFuture = T::VectoredBufferReadyFuture; - - async fn write_zerocopy( - &mut self, - buf: B, - ) -> BufResult> { - self.0.write_zerocopy(buf).await - } - - async fn write_zerocopy_vectored( - &mut self, - buf: B, - ) -> BufResult> { - self.0.write_zerocopy_vectored(buf).await - } -} - -impl AsyncWriteAncillaryZerocopy for OwnedWriteHalf -where - T: AsyncWriteAncillaryZerocopy, -{ - type BufferReadyFuture = T::BufferReadyFuture; - type VectoredBufferReadyFuture = T::VectoredBufferReadyFuture; - - async fn write_zerocopy_with_ancillary( - &mut self, - buf: B, - control: C, - ) -> BufResult> { - self.0.write_zerocopy_with_ancillary(buf, control).await - } - - async fn write_zerocopy_vectored_with_ancillary( - &mut self, - buf: B, - control: C, - ) -> BufResult> { + fn deref(&self) -> &Self::Target { self.0 - .write_zerocopy_vectored_with_ancillary(buf, control) - .await } } - -/// Error indicating that two halves were not from the same socket, and thus -/// could not be reunited. -#[derive(Debug)] -pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); - -impl fmt::Display for ReuniteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "tried to reunite halves that are not from the same socket" - ) - } -} - -impl Error for ReuniteError {} diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index ab86b4a1..1a99aad7 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -25,8 +25,7 @@ use futures_util::{Stream, StreamExt, stream::FusedStream}; use socket2::{Protocol, SockAddr, Socket as Socket2, Type}; use crate::{ - Extract, Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, - ToSocketAddrsAsync, WriteHalf, Zerocopy, + Extract, Incoming, MSG_NOSIGNAL, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf, Zerocopy, }; /// A TCP socket server, listening for connections. @@ -324,9 +323,9 @@ impl TcpStream { /// used to read and write the stream concurrently. /// /// Unlike [`split`](TcpStream::split), the owned halves can be moved to - /// separate tasks, however this comes at the cost of a heap allocation. - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { - crate::into_split(self) + /// separate tasks. + pub fn into_split(self) -> (Self, Self) { + (self.clone(), self) } /// Create [`PollFd`] from inner socket. @@ -774,11 +773,11 @@ impl AsyncWriteAncillaryZerocopy for &TcpStream { } impl Splittable for TcpStream { - type ReadHalf = OwnedReadHalf; - type WriteHalf = OwnedWriteHalf; + type ReadHalf = Self; + type WriteHalf = Self; fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { - crate::into_split(self) + self.into_split() } } diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index 28cd2102..62ba9637 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -23,10 +23,7 @@ use compio_runtime::fd::PollFd; use futures_util::{Stream, StreamExt, stream::FusedStream}; use socket2::{Domain, SockAddr, Socket as Socket2, Type}; -use crate::{ - Extract, Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf, - Zerocopy, -}; +use crate::{Extract, Incoming, MSG_NOSIGNAL, ReadHalf, Socket, WriteHalf, Zerocopy}; /// A Unix socket server, listening for connections. /// @@ -278,9 +275,9 @@ impl UnixStream { /// used to read and write the stream concurrently. /// /// Unlike [`split`](UnixStream::split), the owned halves can be moved to - /// separate tasks, however this comes at the cost of a heap allocation. - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { - crate::into_split(self) + /// separate tasks. + pub fn into_split(self) -> (Self, Self) { + (self.clone(), self) } /// Create [`PollFd`] from inner socket. @@ -685,11 +682,11 @@ impl AsyncWriteAncillary for &UnixStream { } impl Splittable for UnixStream { - type ReadHalf = OwnedReadHalf; - type WriteHalf = OwnedWriteHalf; + type ReadHalf = Self; + type WriteHalf = Self; fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { - crate::into_split(self) + self.into_split() } } diff --git a/compio-net/tests/split.rs b/compio-net/tests/split.rs index 6f5d952a..aea78abe 100644 --- a/compio-net/tests/split.rs +++ b/compio-net/tests/split.rs @@ -33,32 +33,6 @@ async fn tcp_split() { handle.await.resume_unwind(); } -#[compio_macros::test] -async fn tcp_unsplit() { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let handle = compio_runtime::spawn_blocking(move || { - drop(listener.accept().unwrap()); - drop(listener.accept().unwrap()); - }); - - let stream1 = TcpStream::connect(&addr).await.unwrap(); - let (read1, write1) = stream1.into_split(); - - let stream2 = TcpStream::connect(&addr).await.unwrap(); - let (_, write2) = stream2.into_split(); - - let read1 = match read1.reunite(write2) { - Ok(_) => panic!("Reunite should not succeed"), - Err(err) => err.0, - }; - - read1.reunite(write1).expect("Reunite should succeed"); - - handle.await.resume_unwind(); -} - #[compio_macros::test] async fn unix_split() { let dir = tempfile::Builder::new()