Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compio-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ repository = { workspace = true }
compio-buf = { workspace = true, features = ["arrayvec", "bytes"] }
futures-util = { workspace = true, features = ["sink"] }
paste = { workspace = true }
pin-project-lite = { workspace = true, optional = true }
synchrony = { workspace = true, features = ["bilock"] }
thiserror = { workspace = true, optional = true }
serde = { version = "1.0.219", optional = true }
Expand All @@ -25,7 +26,7 @@ futures-executor = "0.3.30"

[features]
default = []
compat = ["futures-util/io"]
compat = ["futures-util/io", "dep:pin-project-lite"]
sync = []

# Codecs
Expand Down
99 changes: 58 additions & 41 deletions compio-io/src/compat/async_stream.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
use std::{
fmt::Debug,
io::{self, BufRead},
marker::PhantomPinned,
mem::MaybeUninit,
pin::Pin,
task::{Context, Poll},
};

use pin_project_lite::pin_project;

use crate::{PinBoxFuture, compat::SyncStream};

/// A stream wrapper for [`futures_util::io`] traits.
pub struct AsyncStream<S> {
// The futures keep the reference to the inner stream, so we need to pin
// the inner stream to make sure the reference is valid.
inner: Pin<Box<SyncStream<S>>>,
read_future: Option<PinBoxFuture<io::Result<usize>>>,
write_future: Option<PinBoxFuture<io::Result<usize>>>,
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
pin_project! {
/// A stream wrapper for [`futures_util::io`] traits.
pub struct AsyncStream<S> {
// The futures keep the reference to the inner stream, so we need to pin
// the inner stream to make sure the reference is valid.
#[pin]
inner: SyncStream<S>,
read_future: Option<PinBoxFuture<io::Result<usize>>>,
write_future: Option<PinBoxFuture<io::Result<usize>>>,
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
#[pin]
_p: PhantomPinned,
}
}

impl<S> AsyncStream<S> {
Expand All @@ -31,10 +39,11 @@ impl<S> AsyncStream<S> {

fn new_impl(inner: SyncStream<S>) -> Self {
Self {
inner: Box::pin(inner),
inner,
read_future: None,
write_future: None,
shutdown_future: None,
_p: PhantomPinned,
}
}

Expand Down Expand Up @@ -82,68 +91,74 @@ macro_rules! poll_future_would_block {
}};
}

impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> {
unsafe fn extend_lifetime<T>(t: &mut T) -> &'static mut T {
unsafe { &mut *(t as *mut T) }
}

impl<S: crate::AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
// SAFETY:
// - The futures won't live longer than the stream.
// - The inner stream is pinned.
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
// - The future won't live longer than the stream.
// - The stream is `Unpin`, and is internally mutable.
// - The future only accesses the corresponding buffer and fields.
// - No access overlap between the futures.
let inner: &'static mut SyncStream<S> = unsafe { extend_lifetime(this.inner.get_mut()) };
Comment thread
Berrysoft marked this conversation as resolved.
Outdated

poll_future_would_block!(
self.read_future,
this.read_future,
cx,
inner.fill_read_buf(),
io::Read::read(inner, buf)
)
}
}

impl<S: crate::AsyncRead + 'static> AsyncStream<S> {
impl<S: crate::AsyncRead + Unpin + 'static> AsyncStream<S> {
/// Attempt to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
pub fn poll_read_uninit(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [MaybeUninit<u8>],
) -> Poll<io::Result<usize>> {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let this = self.project();
let inner: &'static mut SyncStream<S> = unsafe { extend_lifetime(this.inner.get_mut()) };
poll_future_would_block!(
self.read_future,
this.read_future,
cx,
inner.fill_read_buf(),
inner.read_buf_uninit(buf)
)
}
}

impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
impl<S: crate::AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncStream<S> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.project();
let inner: &'static mut SyncStream<S> = unsafe { extend_lifetime(this.inner.get_mut()) };
poll_future_would_block!(
self.read_future,
this.read_future,
cx,
inner.fill_read_buf(),
// SAFETY: anyway the slice won't be used after free.
io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) })
)
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) }
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
}
}

impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> {
impl<S: crate::AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Expand All @@ -152,38 +167,39 @@ impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S>
return Poll::Pending;
}

let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let this = self.project();
let inner: &'static mut SyncStream<S> = unsafe { extend_lifetime(this.inner.get_mut()) };
poll_future_would_block!(
self.write_future,
this.write_future,
cx,
inner.flush_write_buf(),
io::Write::write(inner, buf)
)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.shutdown_future.is_some() {
debug_assert!(self.write_future.is_none());
return Poll::Pending;
}

let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let res = poll_future!(self.write_future, cx, inner.flush_write_buf());
let this = self.project();
let inner: &'static mut SyncStream<S> = unsafe { extend_lifetime(this.inner.get_mut()) };
let res = poll_future!(this.write_future, cx, inner.flush_write_buf());
Poll::Ready(res.map(|_| ()))
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// Avoid shutdown on flush because the inner buffer might be passed to the
// driver.
if self.write_future.is_some() || self.inner.has_pending_write() {
debug_assert!(self.shutdown_future.is_none());
self.poll_flush(cx)
} else {
let this = self.project();
let inner: &'static mut SyncStream<S> =
unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) };
let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown());
unsafe { extend_lifetime(this.inner.get_mut()) };
let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown());
Poll::Ready(res)
}
}
Expand All @@ -207,7 +223,8 @@ mod test {
#[test]
fn close() {
block_on(async {
let mut stream = AsyncStream::new(Vec::<u8>::new());
let stream = AsyncStream::new(Vec::<u8>::new());
let mut stream = std::pin::pin!(stream);
let n = stream.write(b"hello").await.unwrap();
assert_eq!(n, 5);
stream.close().await.unwrap();
Expand Down
16 changes: 11 additions & 5 deletions compio-io/tests/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
fn async_compat_read() {
block_on(async {
let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..];
let mut stream = AsyncStream::new(src);
let stream = AsyncStream::new(src);
let mut stream = std::pin::pin!(stream);

let mut buf = [0; 6];
let len = stream.read(&mut buf).await.unwrap();
Expand All @@ -27,7 +28,8 @@ fn async_compat_read() {
fn async_compat_bufread() {
block_on(async {
let src = &[1u8, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0][..];
let mut stream = AsyncStream::new(src);
let stream = AsyncStream::new(src);
let mut stream = std::pin::pin!(stream);

let slice = stream.fill_buf().await.unwrap();
assert_eq!(slice, [1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]);
Expand All @@ -45,7 +47,8 @@ fn async_compat_bufread() {
fn async_compat_write() {
block_on(async {
let dst = Cursor::new([0u8; 10]);
let mut stream = AsyncStream::new(dst);
let stream = AsyncStream::new(dst);
let mut stream = std::pin::pin!(stream);

let len = stream.write(&[1, 1, 4, 5, 1, 4]).await.unwrap();
stream.flush().await.unwrap();
Expand All @@ -55,7 +58,9 @@ fn async_compat_write() {
assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 0, 0, 0, 0]);

let dst = Cursor::new([0u8; 10]);
let mut stream = AsyncStream::with_capacity(10, dst);
let stream = AsyncStream::with_capacity(10, dst);
let mut stream = std::pin::pin!(stream);

let len = stream
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
.await
Expand All @@ -71,7 +76,8 @@ fn async_compat_write() {
fn async_compat_flush_fail() {
block_on(async {
let dst = Cursor::new([0u8; 10]);
let mut stream = AsyncStream::new(dst);
let stream = AsyncStream::new(dst);
let mut stream = std::pin::pin!(stream);
let len = stream
.write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0])
.await
Expand Down
8 changes: 4 additions & 4 deletions compio-tls/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl TlsConnector {
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
pub async fn connect<S: AsyncRead + AsyncWrite + 'static>(
pub async fn connect<S: AsyncRead + AsyncWrite + Unpin + 'static>(
&self,
domain: &str,
stream: S,
Expand All @@ -97,7 +97,7 @@ impl TlsConnector {
let client = c
.connect(
domain.to_string().try_into().map_err(io::Error::other)?,
AsyncStream::new(stream),
Box::pin(AsyncStream::new(stream)),
)
.await?;
Ok(TlsStream::from(client))
Expand Down Expand Up @@ -172,7 +172,7 @@ impl TlsAcceptor {
/// This is typically used after a new socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of accepting a client connection.
pub async fn accept<S: AsyncRead + AsyncWrite + 'static>(
pub async fn accept<S: AsyncRead + AsyncWrite + Unpin + 'static>(
&self,
stream: S,
) -> io::Result<TlsStream<S>> {
Expand All @@ -183,7 +183,7 @@ impl TlsAcceptor {
}
#[cfg(feature = "rustls")]
TlsAcceptorInner::Rustls(c) => {
let server = c.accept(AsyncStream::new(stream)).await?;
let server = c.accept(Box::pin(AsyncStream::new(stream))).await?;
Ok(TlsStream::from(server))
}
#[cfg(feature = "py-dynamic-openssl")]
Expand Down
12 changes: 6 additions & 6 deletions compio-tls/src/rtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ use crate::TlsStream;

/// A lazy TLS acceptor that performs the initial handshake and allows access to
/// the [`ClientHello`] message before completing the handshake.
pub struct LazyConfigAcceptor<S>(futures_rustls::LazyConfigAcceptor<AsyncStream<S>>);
pub struct LazyConfigAcceptor<S>(futures_rustls::LazyConfigAcceptor<Pin<Box<AsyncStream<S>>>>);

impl<S: AsyncRead + AsyncWrite + 'static> LazyConfigAcceptor<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + 'static> LazyConfigAcceptor<S> {
/// Create a new [`LazyConfigAcceptor`] with the given acceptor and stream.
pub fn new(acceptor: Acceptor, s: S) -> Self {
Self(futures_rustls::LazyConfigAcceptor::new(
acceptor,
AsyncStream::new(s),
Box::pin(AsyncStream::new(s)),
))
}
}

impl<S: AsyncRead + AsyncWrite + 'static> Future for LazyConfigAcceptor<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + 'static> Future for LazyConfigAcceptor<S> {
type Output = Result<StartHandshake<S>, io::Error>;

Comment thread
Berrysoft marked this conversation as resolved.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Expand All @@ -38,9 +38,9 @@ impl<S: AsyncRead + AsyncWrite + 'static> Future for LazyConfigAcceptor<S> {

/// A TLS acceptor that has completed the initial handshake and allows access to
/// the [`ClientHello`] message.
pub struct StartHandshake<S>(futures_rustls::StartHandshake<AsyncStream<S>>);
pub struct StartHandshake<S>(futures_rustls::StartHandshake<Pin<Box<AsyncStream<S>>>>);

impl<S: AsyncRead + AsyncWrite + 'static> StartHandshake<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + 'static> StartHandshake<S> {
/// Get the [`ClientHello`] message from the initial handshake.
pub fn client_hello(&self) -> ClientHello<'_> {
self.0.client_hello()
Expand Down
Loading