Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions compio-io/src/compat/async_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
fmt::Debug,
io::{self, BufRead},
io,
marker::PhantomPinned,
mem::MaybeUninit,
pin::Pin,
Expand All @@ -10,13 +10,19 @@ use std::{

use pin_project_lite::pin_project;

use crate::{AsyncRead, AsyncWrite, PinBoxFuture, compat::SyncStream, util::DEFAULT_BUF_SIZE};
use crate::{
AsyncRead, AsyncWrite, PinBoxFuture,
compat::{SyncStream, SyncStreamReadHalf, SyncStreamWriteHalf},
util::{DEFAULT_BUF_SIZE, Splittable},
};

pin_project! {
/// A stream wrapper for [`futures_util::io`] traits.
pub struct AsyncStream<S> {
pub struct AsyncStream<S: Splittable> {
#[pin]
inner: SyncStream<S>,
read_inner: SyncStreamReadHalf<S::ReadHalf>,
#[pin]
write_inner: SyncStreamWriteHalf<S::WriteHalf>,
read_future: Option<PinBoxFuture<io::Result<usize>>>,
write_future: Option<PinBoxFuture<io::Result<usize>>>,
shutdown_future: Option<PinBoxFuture<io::Result<()>>>,
Expand All @@ -31,7 +37,7 @@ pin_project! {
}
}

impl<S> AsyncStream<S> {
impl<S: Splittable> AsyncStream<S> {
/// Create [`AsyncStream`] with the stream and default buffer size.
pub fn new(stream: S) -> Self {
Self::new_impl(SyncStream::new(stream))
Expand All @@ -43,8 +49,10 @@ impl<S> AsyncStream<S> {
}

fn new_impl(inner: SyncStream<S>) -> Self {
let (read_inner, write_inner) = inner.split();
Self {
inner,
read_inner,
write_inner,
read_future: None,
write_future: None,
shutdown_future: None,
Expand All @@ -57,20 +65,25 @@ impl<S> AsyncStream<S> {
_p: PhantomPinned,
}
}
}

impl<S> AsyncStream<S>
Comment thread
Berrysoft marked this conversation as resolved.
Outdated
where
S: Splittable<ReadHalf = S, WriteHalf = S>,
{
/// Get the reference of the inner stream.
pub fn get_ref(&self) -> &S {
self.inner.get_ref()
self.read_inner.get_ref()
}

/// Returns a mutable reference to the underlying stream.
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut()
self.read_inner.get_mut()
}

/// Consumes the `AsyncStream`, returning the underlying stream.
pub fn into_inner(self) -> S {
self.inner.into_inner()
self.read_inner.into_inner()
}
}

Expand Down Expand Up @@ -246,9 +259,9 @@ fn replace_waker(waker_slot: &mut Option<Waker>, waker: &Waker) {
}
}

impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
impl<S: Splittable + 'static> AsyncStream<S>
where
for<'a> &'a S: AsyncRead,
S::ReadHalf: AsyncRead + Unpin,
{
fn poll_read_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
let this = self.project();
Expand All @@ -257,7 +270,7 @@ where
// - The stream is internally mutable.
// - The future only accesses the corresponding buffer and fields.
// - No access overlap between the futures.
let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
let inner = unsafe { extend_lifetime_mut(this.read_inner.get_mut()) };
let arr = WakerArray([
this.read_waker.as_ref().cloned(),
this.read_uninit_waker.as_ref().cloned(),
Expand All @@ -270,9 +283,9 @@ where
}
}

impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncRead for AsyncStream<S>
impl<S: Splittable + 'static> futures_util::AsyncRead for AsyncStream<S>
where
for<'a> &'a S: AsyncRead,
S::ReadHalf: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
Expand All @@ -285,16 +298,16 @@ where
poll_future_would_block!(
cx,
this.read_waker,
io::Read::read(this.inner.get_mut(), buf),
io::Read::read(this.read_inner.get_mut(), buf),
self.as_mut().poll_read_impl()
)
}
}
}

impl<S: AsyncRead + Unpin + 'static> AsyncStream<S>
impl<S: Splittable + 'static> AsyncStream<S>
where
for<'a> &'a S: AsyncRead,
S::ReadHalf: AsyncRead + Unpin,
{
/// Attempt to read from the `AsyncRead` into `buf`.
///
Expand All @@ -310,16 +323,16 @@ where
poll_future_would_block!(
cx,
this.read_uninit_waker,
this.inner.get_mut().read_buf_uninit(buf),
this.read_inner.get_mut().read_buf_uninit(buf),
self.as_mut().poll_read_impl()
)
}
}
}

impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncStream<S>
impl<S: Splittable + 'static> futures_util::AsyncBufRead for AsyncStream<S>
where
for<'a> &'a S: AsyncRead,
S::ReadHalf: AsyncRead + Unpin,
{
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
replace_waker(self.as_mut().project().read_buf_waker, cx.waker());
Expand All @@ -330,14 +343,15 @@ where
this.read_buf_waker,
// SAFETY: The buffer won't be accessed after the future is ready, and the future
// won't live longer than the stream.
io::BufRead::fill_buf(this.inner.get_mut()).map(|s| unsafe { extend_lifetime(s) }),
io::BufRead::fill_buf(this.read_inner.get_mut())
.map(|s| unsafe { extend_lifetime(s) }),
self.as_mut().poll_read_impl()
)
}
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
io::BufRead::consume(self.project().read_inner.get_mut(), amt)
}
}

Expand Down Expand Up @@ -417,13 +431,13 @@ impl<S: AsyncRead + Unpin + 'static> futures_util::AsyncBufRead for AsyncReadStr
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
io::BufRead::consume(self.project().inner.get_mut(), amt)
}
}

impl<S: AsyncWrite + Unpin + 'static> AsyncStream<S>
impl<S: Splittable + 'static> AsyncStream<S>
where
for<'a> &'a S: AsyncWrite,
S::WriteHalf: AsyncWrite + Unpin,
{
fn poll_flush_impl(self: Pin<&mut Self>) -> Poll<io::Result<usize>> {
let this = self.project();
Expand All @@ -432,7 +446,7 @@ where
// - The stream is internally mutable.
// - The future only accesses the corresponding buffer and fields.
// - No access overlap between the futures.
let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
let inner = unsafe { extend_lifetime_mut(this.write_inner.get_mut()) };
let arr = WakerArray([
this.write_waker.as_ref().cloned(),
this.flush_waker.as_ref().cloned(),
Expand All @@ -451,7 +465,7 @@ where
// - The stream is internally mutable.
// - The future only accesses the corresponding buffer and fields.
// - No access overlap between the futures.
let inner = unsafe { extend_lifetime_mut(this.inner.get_mut()) };
let inner = unsafe { extend_lifetime_mut(this.write_inner.get_mut()) };
let arr = WakerArray([
this.write_waker.as_ref().cloned(),
this.flush_waker.as_ref().cloned(),
Expand All @@ -464,9 +478,9 @@ where
}
}

impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncStream<S>
impl<S: Splittable + 'static> futures_util::AsyncWrite for AsyncStream<S>
where
for<'a> &'a S: AsyncWrite,
S::WriteHalf: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
Expand All @@ -483,7 +497,7 @@ where
poll_future_would_block!(
cx,
this.write_waker,
io::Write::write(this.inner.get_mut(), buf),
io::Write::write(this.write_inner.get_mut(), buf),
self.as_mut().poll_flush_impl()
)
}
Expand All @@ -504,7 +518,7 @@ where
replace_waker(self.as_mut().project().close_waker, cx.waker());
// Avoid shutdown on flush because the inner buffer might be passed to the
// driver.
if self.write_future.is_some() || self.inner.has_pending_write() {
if self.write_future.is_some() || self.write_inner.has_pending_write() {
debug_assert!(self.shutdown_future.is_none());
ready!(self.as_mut().poll_flush_impl())?;
}
Expand Down Expand Up @@ -597,11 +611,9 @@ impl<S: AsyncWrite + Unpin + 'static> futures_util::AsyncWrite for AsyncWriteStr
}
}

impl<S: Debug> Debug for AsyncStream<S> {
impl<S: Splittable> Debug for AsyncStream<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncStream")
.field("inner", &self.inner)
.finish_non_exhaustive()
f.debug_struct("AsyncStream").finish_non_exhaustive()
}
}

Expand Down
12 changes: 4 additions & 8 deletions compio-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ futures-rustls = { workspace = true, default-features = false, optional = true,
"logging",
"tls12",
] }
futures-util = { workspace = true, optional = true }
futures-util = { workspace = true }
pin-project-lite = { workspace = true, optional = true }

[dev-dependencies]
Expand All @@ -51,14 +51,10 @@ futures-rustls = { workspace = true, default-features = false, features = [
[features]
default = ["native-tls"]
all = ["native-tls", "rustls"]
rustls = ["dep:rustls", "dep:futures-rustls", "dep:futures-util"]
native-tls = ["dep:native-tls", "dep:futures-util", "dep:pin-project-lite"]
rustls = ["dep:rustls", "dep:futures-rustls"]
native-tls = ["dep:native-tls", "dep:pin-project-lite"]
native-tls-vendored = ["native-tls/vendored"]
py-dynamic-openssl = [
"dep:compio-py-dynamic-openssl",
"dep:futures-util",
"dep:pin-project-lite",
]
py-dynamic-openssl = ["dep:compio-py-dynamic-openssl", "dep:pin-project-lite"]

ring = ["rustls", "rustls/ring", "futures-rustls/ring"]

Expand Down
12 changes: 7 additions & 5 deletions compio-tls/src/adapter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fmt::Debug, io};

use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream};
use compio_io::{AsyncRead, AsyncWrite, compat::AsyncStream, util::Splittable};

use crate::TlsStream;

Expand Down Expand Up @@ -79,13 +79,14 @@ impl TlsConnector {
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
pub async fn connect<S: AsyncRead + AsyncWrite + Unpin + 'static>(
pub async fn connect<S: AsyncRead + AsyncWrite + Splittable + 'static>(
&self,
domain: &str,
stream: S,
) -> io::Result<TlsStream<S>>
where
for<'a> &'a S: AsyncRead + AsyncWrite,
S::ReadHalf: AsyncRead + Unpin,
S::WriteHalf: AsyncWrite + Unpin,
{
Comment thread
Berrysoft marked this conversation as resolved.
Outdated
match &self.0 {
#[cfg(feature = "native-tls")]
Expand Down Expand Up @@ -178,12 +179,13 @@ impl TlsAcceptor {
/// This is typically used after a new socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of accepting a client connection.
pub async fn accept<S: AsyncRead + AsyncWrite + Unpin + 'static>(
pub async fn accept<S: AsyncRead + AsyncWrite + Splittable + 'static>(
&self,
stream: S,
) -> io::Result<TlsStream<S>>
where
for<'a> &'a S: AsyncRead + AsyncWrite,
S::ReadHalf: AsyncRead + Unpin,
S::WriteHalf: AsyncWrite + Unpin,
Comment thread
Berrysoft marked this conversation as resolved.
Outdated
{
match &self.0 {
#[cfg(feature = "native-tls")]
Expand Down
Loading
Loading