diff --git a/compio-io/Cargo.toml b/compio-io/Cargo.toml index 280926adc..fed0e1e7e 100644 --- a/compio-io/Cargo.toml +++ b/compio-io/Cargo.toml @@ -13,6 +13,7 @@ repository = { workspace = true } compio-buf = { workspace = true, features = ["arrayvec", "bytes"] } futures-util = { workspace = true, features = ["sink"] } paste = { workspace = true } +pin-project-lite = { workspace = true, optional = true } synchrony = { workspace = true, features = ["bilock"] } thiserror = { workspace = true, optional = true } serde = { version = "1.0.219", optional = true } @@ -25,7 +26,7 @@ futures-executor = "0.3.30" [features] default = [] -compat = ["futures-util/io"] +compat = ["futures-util/io", "dep:pin-project-lite"] sync = [] # Codecs diff --git a/compio-io/src/buffer.rs b/compio-io/src/buffer.rs index 1c77ef54b..3b9f0d999 100644 --- a/compio-io/src/buffer.rs +++ b/compio-io/src/buffer.rs @@ -97,6 +97,10 @@ impl Buffer { fn buf_mut(&mut self) -> &mut B { self.inner_mut().as_inner_mut() } + + pub(crate) fn has_inner(&self) -> bool { + self.0.is_some() + } } impl Buffer { diff --git a/compio-io/src/compat/async_stream.rs b/compio-io/src/compat/async_stream.rs index f25f31fdf..540a27225 100644 --- a/compio-io/src/compat/async_stream.rs +++ b/compio-io/src/compat/async_stream.rs @@ -1,21 +1,34 @@ use std::{ fmt::Debug, io::{self, BufRead}, + marker::PhantomPinned, mem::MaybeUninit, pin::Pin, - task::{Context, Poll}, + sync::Arc, + task::{Context, Poll, Wake, Waker, ready}, }; -use crate::{PinBoxFuture, compat::SyncStream}; +use pin_project_lite::pin_project; -/// A stream wrapper for [`futures_util::io`] traits. -pub struct AsyncStream { - // The futures keep the reference to the inner stream, so we need to pin - // the inner stream to make sure the reference is valid. - inner: Pin>>, - read_future: Option>>, - write_future: Option>>, - shutdown_future: Option>>, +use crate::{AsyncRead, AsyncWrite, PinBoxFuture, compat::SyncStream, util::DEFAULT_BUF_SIZE}; + +pin_project! { + /// A stream wrapper for [`futures_util::io`] traits. + pub struct AsyncStream { + #[pin] + inner: SyncStream, + read_future: Option>>, + write_future: Option>>, + shutdown_future: Option>>, + read_waker: Option, + read_uninit_waker: Option, + read_buf_waker: Option, + write_waker: Option, + flush_waker: Option, + close_waker: Option, + #[pin] + _p: PhantomPinned, + } } impl AsyncStream { @@ -31,10 +44,78 @@ impl AsyncStream { fn new_impl(inner: SyncStream) -> Self { Self { - inner: Box::pin(inner), + inner, read_future: None, write_future: None, shutdown_future: None, + read_waker: None, + read_uninit_waker: None, + read_buf_waker: None, + write_waker: None, + flush_waker: None, + close_waker: None, + _p: PhantomPinned, + } + } + + /// Get the reference of the inner stream. + pub fn get_ref(&self) -> &S { + self.inner.get_ref() + } + + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + self.inner.get_mut() + } + + /// Consumes the `SyncStream`, returning the underlying stream. + pub fn into_inner(self) -> S { + self.inner.into_inner() + } +} + +pin_project! { + /// A read stream wrapper for [`futures_util::io`]. + /// + /// It doesn't support write and shutdown operations, making looser + /// requirements on the inner stream. + pub struct AsyncReadStream { + #[pin] + inner: SyncStream, + read_future: Option>>, + read_waker: Option, + read_uninit_waker: Option, + read_buf_waker: Option, + #[pin] + _p: PhantomPinned, + } +} + +impl AsyncReadStream { + /// Create [`AsyncReadStream`] with the stream and default buffer size. + pub fn new(stream: S) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, stream) + } + + /// Create [`AsyncReadStream`] with the stream and buffer size. + pub fn with_capacity(cap: usize, stream: S) -> Self { + Self::new_impl(SyncStream::with_limits2( + cap, + 0, + cap, + SyncStream::::DEFAULT_MAX_BUFFER, + stream, + )) + } + + fn new_impl(inner: SyncStream) -> Self { + Self { + inner, + read_future: None, + read_waker: None, + read_uninit_waker: None, + read_buf_waker: None, + _p: PhantomPinned, } } @@ -42,6 +123,78 @@ impl AsyncStream { pub fn get_ref(&self) -> &S { self.inner.get_ref() } + + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + self.inner.get_mut() + } + + /// Consumes the `SyncStream`, returning the underlying stream. + pub fn into_inner(self) -> S { + self.inner.into_inner() + } +} + +pin_project! { + /// A write stream wrapper for [`futures_util::io`]. + /// + /// It doesn't support read operations, making looser requirements on the inner stream. + pub struct AsyncWriteStream { + #[pin] + inner: SyncStream, + write_future: Option>>, + shutdown_future: Option>>, + write_waker: Option, + flush_waker: Option, + close_waker: Option, + #[pin] + _p: PhantomPinned, + } +} + +impl AsyncWriteStream { + /// Create [`AsyncWriteStream`] with the stream and default buffer size. + pub fn new(stream: S) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, stream) + } + + /// Create [`AsyncWriteStream`] with the stream and buffer size. + pub fn with_capacity(cap: usize, stream: S) -> Self { + Self::new_impl(SyncStream::with_limits2( + 0, + cap, + cap, + SyncStream::::DEFAULT_MAX_BUFFER, + stream, + )) + } + + fn new_impl(inner: SyncStream) -> Self { + Self { + inner, + write_future: None, + shutdown_future: None, + write_waker: None, + flush_waker: None, + close_waker: None, + _p: PhantomPinned, + } + } + + /// Get the reference of the inner stream. + pub fn get_ref(&self) -> &S { + self.inner.get_ref() + } + + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + self.inner.get_mut() + } + + /// Consumes the `SyncStream`, returning the underlying stream. + pub fn into_inner(self) -> S { + self.inner.into_inner() + } } macro_rules! poll_future { @@ -62,48 +215,171 @@ macro_rules! poll_future { } macro_rules! poll_future_would_block { - ($f:expr, $cx:expr, $e:expr, $io:expr) => {{ - if let Some(mut f) = $f.take() { - if f.as_mut().poll($cx).is_pending() { - $f.replace(f); - return Poll::Pending; - } - } - + ($cx:expr, $w:expr, $io:expr, $f:expr) => {{ match $io { - Ok(len) => Poll::Ready(Ok(len)), + Ok(res) => { + $w.take(); + return Poll::Ready(Ok(res)); + } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - $f.replace(Box::pin($e)); - $cx.waker().wake_by_ref(); - Poll::Pending + ready!($f)?; + } + Err(e) => { + $w.take(); + return Poll::Ready(Err(e)); } - Err(e) => Poll::Ready(Err(e)), } }}; } -impl futures_util::AsyncRead for AsyncStream { +unsafe fn extend_lifetime_mut(t: &mut T) -> &'static mut T { + unsafe { &mut *(t as *mut T) } +} + +unsafe fn extend_lifetime(t: &T) -> &'static T { + unsafe { &*(t as *const T) } +} + +fn replace_waker(waker_slot: &mut Option, waker: &Waker) { + if !waker_slot.as_ref().is_some_and(|w| w.will_wake(waker)) { + waker_slot.replace(waker.clone()); + } +} + +impl AsyncStream +where + for<'a> &'a S: AsyncRead, +{ + fn poll_read_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); + // SAFETY: + // - The future won't live longer than the stream. + // - The stream is internally mutable. + // - The future only accesses the corresponding buffer and fields. + // - No access overlap between the futures. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.read_waker.as_ref().cloned(), + this.read_uninit_waker.as_ref().cloned(), + this.read_buf_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.read_future, cx, inner.fill_read_buf()); + Poll::Ready(res) + } +} + +impl futures_util::AsyncRead for AsyncStream +where + for<'a> &'a S: AsyncRead, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { + replace_waker(self.as_mut().project().read_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_waker, + io::Read::read(this.inner.get_mut(), buf), + self.as_mut().poll_read_impl() + ) + } + } +} + +impl AsyncStream +where + for<'a> &'a S: AsyncRead, +{ + /// Attempt to read from the `AsyncRead` into `buf`. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + pub fn poll_read_uninit( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [MaybeUninit], + ) -> Poll> { + replace_waker(self.as_mut().project().read_uninit_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_uninit_waker, + this.inner.get_mut().read_buf_uninit(buf), + self.as_mut().poll_read_impl() + ) + } + } +} + +impl futures_util::AsyncBufRead for AsyncStream +where + for<'a> &'a S: AsyncRead, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + replace_waker(self.as_mut().project().read_buf_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_buf_waker, + // SAFETY: The buffer won't be accessed after the future is ready, and the future + // won't live longer than the stream. + io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }), + self.as_mut().poll_read_impl() + ) + } + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().inner.consume(amt) + } +} + +impl AsyncReadStream { + fn poll_read_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); // SAFETY: - // - The futures won't live longer than the stream. - // - The inner stream is pinned. - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; + // - The future won't live longer than the stream. + // - The stream is `Unpin`. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.read_waker.as_ref().cloned(), + this.read_uninit_waker.as_ref().cloned(), + this.read_buf_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.read_future, cx, inner.fill_read_buf()); + Poll::Ready(res) + } +} - poll_future_would_block!( - self.read_future, - cx, - inner.fill_read_buf(), - io::Read::read(inner, buf) - ) +impl futures_util::AsyncRead for AsyncReadStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + replace_waker(self.as_mut().project().read_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_waker, + io::Read::read(this.inner.get_mut(), buf), + self.as_mut().poll_read_impl() + ) + } } } -impl AsyncStream { +impl AsyncReadStream { /// Attempt to read from the `AsyncRead` into `buf`. /// /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. @@ -112,80 +388,212 @@ impl AsyncStream { cx: &mut Context<'_>, buf: &mut [MaybeUninit], ) -> Poll> { - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; - poll_future_would_block!( - self.read_future, - cx, - inner.fill_read_buf(), - inner.read_buf_uninit(buf) - ) + replace_waker(self.as_mut().project().read_uninit_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_uninit_waker, + this.inner.get_mut().read_buf_uninit(buf), + self.as_mut().poll_read_impl() + ) + } } } - -impl futures_util::AsyncBufRead for AsyncStream { +impl futures_util::AsyncBufRead for AsyncReadStream { fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; - poll_future_would_block!( - self.read_future, - cx, - inner.fill_read_buf(), - // SAFETY: anyway the slice won't be used after free. - io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) }) - ) + replace_waker(self.as_mut().project().read_buf_waker, cx.waker()); + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.read_buf_waker, + // SAFETY: The buffer won't be accessed after the future is ready, and the future + // won't live longer than the stream. + io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }), + self.as_mut().poll_read_impl() + ) + } } - fn consume(mut self: Pin<&mut Self>, amt: usize) { - unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) } + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().inner.consume(amt) } } -impl futures_util::AsyncWrite for AsyncStream { +impl AsyncStream +where + for<'a> &'a S: AsyncWrite, +{ + fn poll_flush_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); + // SAFETY: + // - The future won't live longer than the stream. + // - The stream is internally mutable. + // - The future only accesses the corresponding buffer and fields. + // - No access overlap between the futures. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.write_waker.as_ref().cloned(), + this.flush_waker.as_ref().cloned(), + this.close_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.write_future, cx, inner.flush_write_buf()); + Poll::Ready(res) + } + + fn poll_close_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); + // SAFETY: + // - The future won't live longer than the stream. + // - The stream is internally mutable. + // - The future only accesses the corresponding buffer and fields. + // - No access overlap between the futures. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.write_waker.as_ref().cloned(), + this.flush_waker.as_ref().cloned(), + this.close_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown()); + Poll::Ready(res) + } +} + +impl futures_util::AsyncWrite for AsyncStream +where + for<'a> &'a S: AsyncWrite, +{ fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + replace_waker(self.as_mut().project().write_waker, cx.waker()); if self.shutdown_future.is_some() { debug_assert!(self.write_future.is_none()); - return Poll::Pending; + ready!(self.as_mut().poll_close_impl())?; + } + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.write_waker, + io::Write::write(this.inner.get_mut(), buf), + self.as_mut().poll_flush_impl() + ) } - - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; - poll_future_would_block!( - self.write_future, - cx, - inner.flush_write_buf(), - io::Write::write(inner, buf) - ) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + replace_waker(self.as_mut().project().flush_waker, cx.waker()); if self.shutdown_future.is_some() { debug_assert!(self.write_future.is_none()); - return Poll::Pending; + ready!(self.as_mut().poll_close_impl())?; + } + let res = ready!(self.as_mut().poll_flush_impl()); + self.project().flush_waker.take(); + Poll::Ready(res.map(|_| ())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + replace_waker(self.as_mut().project().close_waker, cx.waker()); + // Avoid shutdown on flush because the inner buffer might be passed to the + // driver. + if self.write_future.is_some() || self.inner.has_pending_write() { + debug_assert!(self.shutdown_future.is_none()); + ready!(self.as_mut().poll_flush_impl())?; } + let res = ready!(self.as_mut().poll_close_impl()); + self.project().close_waker.take(); + Poll::Ready(res) + } +} + +impl AsyncWriteStream { + fn poll_flush_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); + // SAFETY: + // - The future won't live longer than the stream. + // - The stream is `Unpin`. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.write_waker.as_ref().cloned(), + this.flush_waker.as_ref().cloned(), + this.close_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.write_future, cx, inner.flush_write_buf()); + Poll::Ready(res) + } - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; - let res = poll_future!(self.write_future, cx, inner.flush_write_buf()); + fn poll_close_impl(self: Pin<&mut Self>) -> Poll> { + let this = self.project(); + // SAFETY: + // - The future won't live longer than the stream. + // - The stream is `Unpin`. + let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) }; + let arr = WakerArray([ + this.write_waker.as_ref().cloned(), + this.flush_waker.as_ref().cloned(), + this.close_waker.as_ref().cloned(), + ]); + let waker = Waker::from(Arc::new(arr)); + let cx = &mut Context::from_waker(&waker); + let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown()); + Poll::Ready(res) + } +} + +impl futures_util::AsyncWrite for AsyncWriteStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + replace_waker(self.as_mut().project().write_waker, cx.waker()); + if self.shutdown_future.is_some() { + debug_assert!(self.write_future.is_none()); + ready!(self.as_mut().poll_close_impl())?; + } + loop { + let this = self.as_mut().project(); + poll_future_would_block!( + cx, + this.write_waker, + io::Write::write(this.inner.get_mut(), buf), + self.as_mut().poll_flush_impl() + ) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + replace_waker(self.as_mut().project().flush_waker, cx.waker()); + if self.shutdown_future.is_some() { + debug_assert!(self.write_future.is_none()); + ready!(self.as_mut().poll_close_impl())?; + } + let res = ready!(self.as_mut().poll_flush_impl()); + self.project().flush_waker.take(); Poll::Ready(res.map(|_| ())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + replace_waker(self.as_mut().project().close_waker, cx.waker()); // Avoid shutdown on flush because the inner buffer might be passed to the // driver. if self.write_future.is_some() || self.inner.has_pending_write() { debug_assert!(self.shutdown_future.is_none()); - self.poll_flush(cx) - } else { - let inner: &'static mut SyncStream = - unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; - let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown()); - Poll::Ready(res) + ready!(self.as_mut().poll_flush_impl())?; } + let res = ready!(self.as_mut().poll_close_impl()); + self.project().close_waker.take(); + Poll::Ready(res) } } @@ -197,17 +605,30 @@ impl Debug for AsyncStream { } } +struct WakerArray([Option; N]); + +impl Wake for WakerArray { + fn wake(self: Arc) { + self.0.iter().for_each(|w| { + if let Some(w) = w { + w.wake_by_ref() + } + }); + } +} + #[cfg(test)] mod test { use futures_executor::block_on; use futures_util::AsyncWriteExt; - use super::AsyncStream; + use super::AsyncWriteStream; #[test] fn close() { block_on(async { - let mut stream = AsyncStream::new(Vec::::new()); + let stream = AsyncWriteStream::new(Vec::::new()); + let mut stream = std::pin::pin!(stream); let n = stream.write(b"hello").await.unwrap(); assert_eq!(n, 5); stream.close().await.unwrap(); diff --git a/compio-io/src/compat/sync_stream.rs b/compio-io/src/compat/sync_stream.rs index 3f37c487c..906f4e3c5 100644 --- a/compio-io/src/compat/sync_stream.rs +++ b/compio-io/src/compat/sync_stream.rs @@ -44,7 +44,7 @@ pub struct SyncStream { impl SyncStream { // 64MiB max - const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; + pub(crate) const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; /// Creates a new `SyncStream` with default buffer sizes. /// @@ -74,6 +74,23 @@ impl SyncStream { } } + pub(crate) fn with_limits2( + read_capacity: usize, + write_capacity: usize, + base_capacity: usize, + max_buffer_size: usize, + stream: S, + ) -> Self { + Self { + inner: stream, + read_buf: Buffer::with_capacity(read_capacity), + write_buf: Buffer::with_capacity(write_capacity), + eof: false, + base_capacity, + max_buffer_size, + } + } + /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.inner @@ -95,8 +112,12 @@ impl SyncStream { } /// Returns the available bytes in the read buffer. - fn available_read(&self) -> &[u8] { - self.read_buf.buffer() + fn available_read(&self) -> io::Result<&[u8]> { + if self.read_buf.has_inner() { + Ok(self.read_buf.buffer()) + } else { + Err(would_block("the read buffer is in use")) + } } /// Marks `amt` bytes as consumed from the read buffer. @@ -158,7 +179,7 @@ impl Read for SyncStream { impl BufRead for SyncStream { fn fill_buf(&mut self) -> io::Result<&[u8]> { - let available = self.available_read(); + let available = self.available_read()?; if available.is_empty() && !self.eof { return Err(would_block("need to fill read buffer")); @@ -179,6 +200,9 @@ impl Write for SyncStream { /// capacity. In the latter case, it may write partial data before /// returning `WouldBlock`. fn write(&mut self, buf: &[u8]) -> io::Result { + if !self.write_buf.has_inner() { + return Err(would_block("the write buffer is in use")); + } // Check if we should flush first if self.write_buf.need_flush() && !self.write_buf.is_empty() { return Err(would_block("need to flush write buffer")); diff --git a/compio-io/tests/compat.rs b/compio-io/tests/compat.rs index f7cb77501..c00387416 100644 --- a/compio-io/tests/compat.rs +++ b/compio-io/tests/compat.rs @@ -1,6 +1,6 @@ use std::io::Cursor; -use compio_io::compat::AsyncStream; +use compio_io::compat::{AsyncReadStream, AsyncWriteStream}; use futures_executor::block_on; use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; @@ -8,7 +8,8 @@ use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; fn async_compat_read() { block_on(async { let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..]; - let mut stream = AsyncStream::new(src); + let stream = AsyncReadStream::new(src); + let mut stream = std::pin::pin!(stream); let mut buf = [0; 6]; let len = stream.read(&mut buf).await.unwrap(); @@ -27,7 +28,8 @@ fn async_compat_read() { fn async_compat_bufread() { block_on(async { let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..]; - let mut stream = AsyncStream::new(src); + let stream = AsyncReadStream::new(src); + let mut stream = std::pin::pin!(stream); let slice = stream.fill_buf().await.unwrap(); assert_eq!(slice, [1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]); @@ -45,7 +47,8 @@ fn async_compat_bufread() { fn async_compat_write() { block_on(async { let dst = Cursor::new([0u8; 10]); - let mut stream = AsyncStream::new(dst); + let stream = AsyncWriteStream::new(dst); + let mut stream = std::pin::pin!(stream); let len = stream.write(&[1, 1, 4, 5, 1, 4]).await.unwrap(); stream.flush().await.unwrap(); @@ -55,7 +58,9 @@ fn async_compat_write() { assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 0, 0, 0, 0]); let dst = Cursor::new([0u8; 10]); - let mut stream = AsyncStream::with_capacity(10, dst); + let stream = AsyncWriteStream::with_capacity(10, dst); + let mut stream = std::pin::pin!(stream); + let len = stream .write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]) .await @@ -71,7 +76,8 @@ fn async_compat_write() { fn async_compat_flush_fail() { block_on(async { let dst = Cursor::new([0u8; 10]); - let mut stream = AsyncStream::new(dst); + let stream = AsyncWriteStream::new(dst); + let mut stream = std::pin::pin!(stream); let len = stream .write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]) .await diff --git a/compio-tls/src/adapter.rs b/compio-tls/src/adapter.rs index 098e4cb33..1e4bddcb9 100644 --- a/compio-tls/src/adapter.rs +++ b/compio-tls/src/adapter.rs @@ -82,11 +82,14 @@ impl TlsConnector { /// example, a TCP connection to a remote server. That stream is then /// provided here to perform the client half of a connection to a /// TLS-powered server. - pub async fn connect( + pub async fn connect( &self, domain: &str, stream: S, - ) -> io::Result> { + ) -> io::Result> + where + for<'a> &'a S: AsyncRead + AsyncWrite, + { match &self.0 { #[cfg(feature = "native-tls")] TlsConnectorInner::NativeTls(c) => { @@ -97,7 +100,7 @@ impl TlsConnector { let client = c .connect( domain.to_string().try_into().map_err(io::Error::other)?, - AsyncStream::new(stream), + Box::pin(AsyncStream::new(stream)), ) .await?; Ok(TlsStream::from(client)) @@ -172,10 +175,13 @@ impl TlsAcceptor { /// This is typically used after a new socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform /// the server half of accepting a client connection. - pub async fn accept( + pub async fn accept( &self, stream: S, - ) -> io::Result> { + ) -> io::Result> + where + for<'a> &'a S: AsyncRead + AsyncWrite, + { match &self.0 { #[cfg(feature = "native-tls")] TlsAcceptorInner::NativeTls(c) => { @@ -183,7 +189,7 @@ impl TlsAcceptor { } #[cfg(feature = "rustls")] TlsAcceptorInner::Rustls(c) => { - let server = c.accept(AsyncStream::new(stream)).await?; + let server = c.accept(Box::pin(AsyncStream::new(stream))).await?; Ok(TlsStream::from(server)) } #[cfg(feature = "py-dynamic-openssl")] diff --git a/compio-tls/src/maybe.rs b/compio-tls/src/maybe.rs index 6a037990e..4e2c409d7 100644 --- a/compio-tls/src/maybe.rs +++ b/compio-tls/src/maybe.rs @@ -46,6 +46,7 @@ impl MaybeTlsStream { impl AsyncRead for MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin + 'static, + for<'a> &'a S: AsyncRead + AsyncWrite, { async fn read(&mut self, buf: B) -> BufResult { match &mut self.0 { @@ -58,6 +59,7 @@ where impl AsyncWrite for MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin + 'static, + for<'a> &'a S: AsyncRead + AsyncWrite, { async fn write(&mut self, buf: B) -> BufResult { match &mut self.0 { diff --git a/compio-tls/src/rtls.rs b/compio-tls/src/rtls.rs index db6bcc579..381dff6ea 100644 --- a/compio-tls/src/rtls.rs +++ b/compio-tls/src/rtls.rs @@ -16,19 +16,25 @@ use crate::TlsStream; /// A lazy TLS acceptor that performs the initial handshake and allows access to /// the [`ClientHello`] message before completing the handshake. -pub struct LazyConfigAcceptor(futures_rustls::LazyConfigAcceptor>); +pub struct LazyConfigAcceptor(futures_rustls::LazyConfigAcceptor>>>); -impl LazyConfigAcceptor { +impl LazyConfigAcceptor +where + for<'a> &'a S: AsyncRead + AsyncWrite, +{ /// Create a new [`LazyConfigAcceptor`] with the given acceptor and stream. pub fn new(acceptor: Acceptor, s: S) -> Self { Self(futures_rustls::LazyConfigAcceptor::new( acceptor, - AsyncStream::new(s), + Box::pin(AsyncStream::new(s)), )) } } -impl Future for LazyConfigAcceptor { +impl Future for LazyConfigAcceptor +where + for<'a> &'a S: AsyncRead + AsyncWrite, +{ type Output = Result, io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -38,9 +44,12 @@ impl Future for LazyConfigAcceptor { /// A TLS acceptor that has completed the initial handshake and allows access to /// the [`ClientHello`] message. -pub struct StartHandshake(futures_rustls::StartHandshake>); +pub struct StartHandshake(futures_rustls::StartHandshake>>>); -impl StartHandshake { +impl StartHandshake +where + for<'a> &'a S: AsyncRead + AsyncWrite, +{ /// Get the [`ClientHello`] message from the initial handshake. pub fn client_hello(&self) -> ClientHello<'_> { self.0.client_hello() diff --git a/compio-tls/src/stream.rs b/compio-tls/src/stream.rs index f02923131..9b68b27a9 100644 --- a/compio-tls/src/stream.rs +++ b/compio-tls/src/stream.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "rustls")] +use std::pin::Pin; use std::{borrow::Cow, io, mem::MaybeUninit}; use compio_buf::{BufResult, IoBuf, IoBufMut}; @@ -12,7 +14,7 @@ enum TlsStreamInner { #[cfg(feature = "native-tls")] NativeTls(native_tls::TlsStream>), #[cfg(feature = "rustls")] - Rustls(futures_rustls::TlsStream>), + Rustls(futures_rustls::TlsStream>>>), #[cfg(feature = "py-dynamic-openssl")] PyDynamicOpenSsl(compio_py_dynamic_openssl::ssl::SslStream>), #[cfg(not(any( @@ -69,8 +71,8 @@ impl From>> for TlsStream { #[cfg(feature = "rustls")] #[doc(hidden)] -impl From>> for TlsStream { - fn from(value: futures_rustls::client::TlsStream>) -> Self { +impl From>>>> for TlsStream { + fn from(value: futures_rustls::client::TlsStream>>>) -> Self { Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client( value, ))) @@ -79,8 +81,8 @@ impl From>> for TlsStream #[cfg(feature = "rustls")] #[doc(hidden)] -impl From>> for TlsStream { - fn from(value: futures_rustls::server::TlsStream>) -> Self { +impl From>>>> for TlsStream { + fn from(value: futures_rustls::server::TlsStream>>>) -> Self { Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server( value, ))) @@ -124,7 +126,10 @@ where } } -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream +where + for<'a> &'a S: AsyncRead + AsyncWrite, +{ async fn read(&mut self, mut buf: B) -> BufResult { let slice = buf.as_uninit(); slice.fill(MaybeUninit::new(0)); @@ -188,7 +193,10 @@ async fn flush_impl(s: &mut native_tls::TlsStream>) Ok(()) } -impl AsyncWrite for TlsStream { +impl AsyncWrite for TlsStream +where + for<'a> &'a S: AsyncRead + AsyncWrite, +{ async fn write(&mut self, buf: T) -> BufResult { let slice = buf.as_init(); match &mut self.0 { diff --git a/compio-ws/src/tls.rs b/compio-ws/src/tls.rs index 0719cc320..ca41e2a1c 100644 --- a/compio-ws/src/tls.rs +++ b/compio-ws/src/tls.rs @@ -118,7 +118,8 @@ async fn wrap_stream( mode: Mode, ) -> Result, Error> where - S: AsyncRead + AsyncWrite + 'static, + S: AsyncRead + AsyncWrite + Unpin + 'static, + for<'a> &'a S: AsyncRead + AsyncWrite, { match mode { Mode::Plain => Ok(MaybeTlsStream::new_plain(socket)), @@ -176,6 +177,7 @@ pub async fn client_async_tls( where R: IntoClientRequest, S: AsyncRead + AsyncWrite + Unpin + 'static, + for<'a> &'a S: AsyncRead + AsyncWrite, { client_async_tls_with_config(request, stream, None, None).await } @@ -191,6 +193,7 @@ pub async fn client_async_tls_with_config( where R: IntoClientRequest, S: AsyncRead + AsyncWrite + Unpin + 'static, + for<'a> &'a S: AsyncRead + AsyncWrite, { let request: Request = request.into_client_request()?;