diff --git a/.gitignore b/.gitignore index 19c012f2..3a8cabc9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ /target .idea - diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index 2ef3f293..f839fd83 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -24,9 +24,10 @@ use crate::{ }; use bytes::{Buf, BufMut, BytesMut}; -use futures::{Future, Stream}; +use futures::Stream; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio_util::sync::PollSender; use std::{ pin::Pin, @@ -59,7 +60,7 @@ enum State { SendClosed, } -/// Channel-backedn substream. +/// Channel-backed substream. Must be owned and polled by exactly one task at a time. pub struct Substream { /// Substream state. state: Arc>, @@ -67,8 +68,9 @@ pub struct Substream { /// Read buffer. read_buffer: BytesMut, - /// TX channel for sending messages to `peer`. - tx: Sender, + /// TX channel for sending messages to `peer`, wrapped in a [`PollSender`] + /// so that backpressure is driven by the caller's waker. + tx: PollSender, /// RX channel for receiving messages from `peer`. rx: Receiver, @@ -80,6 +82,7 @@ impl Substream { let (outbound_tx, outbound_rx) = channel(256); let (inbound_tx, inbound_rx) = channel(256); let state = Arc::new(Mutex::new(State::Open)); + let handle = SubstreamHandle { tx: inbound_tx, rx: outbound_rx, @@ -89,7 +92,7 @@ impl Substream { ( Self { state, - tx: outbound_tx, + tx: PollSender::new(outbound_tx), rx: inbound_rx, read_buffer: BytesMut::new(), }, @@ -98,14 +101,14 @@ impl Substream { } } -/// Substream handle that is given to the transport backend. +/// Substream handle that is given to the WebRTC transport backend. pub struct SubstreamHandle { state: Arc>, - /// TX channel for sending messages to `peer`. + /// TX channel for sending inbound messages from `peer` to the associated `Substream`. tx: Sender, - /// RX channel for receiving messages from `peer`. + /// RX channel for receiving outbound messages to `peer` from the associated `Substream`. rx: Receiver, } @@ -190,7 +193,7 @@ impl tokio::io::AsyncRead for Substream { impl tokio::io::AsyncWrite for Substream { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { @@ -198,21 +201,18 @@ impl tokio::io::AsyncWrite for Substream { return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); } - // TODO: try to coalesce multiple calls to `poll_write()` into single `Event::Message` - - let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); - let future = self.tx.reserve(); - futures::pin_mut!(future); - - let permit = match futures::ready!(future.poll(cx)) { + match futures::ready!(self.tx.poll_reserve(cx)) { + Ok(()) => {} Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - Ok(permit) => permit, }; + let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); let frame = buf[..num_bytes].to_vec(); - permit.send(Event::Message(frame)); - Poll::Ready(Ok(num_bytes)) + match self.tx.send_item(Event::Message(frame)) { + Ok(()) => Poll::Ready(Ok(num_bytes)), + Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), + } } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -220,19 +220,18 @@ impl tokio::io::AsyncWrite for Substream { } fn poll_shutdown( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let future = self.tx.reserve(); - futures::pin_mut!(future); - - let permit = match futures::ready!(future.poll(cx)) { + match futures::ready!(self.tx.poll_reserve(cx)) { + Ok(()) => {} Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - Ok(permit) => permit, }; - permit.send(Event::Close); - Poll::Ready(Ok(())) + match self.tx.send_item(Event::Close) { + Ok(()) => Poll::Ready(Ok(())), + Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), + } } } @@ -458,4 +457,47 @@ mod tests { ) .await; } + + #[tokio::test] + async fn backpressure_released_wakes_blocked_writer() { + use tokio::time::{sleep, timeout, Duration}; + + let (mut substream, mut handle) = Substream::new(); + + // Fill the channel to capacity, same pattern as `backpressure_works`. + for _ in 0..128 { + substream.write_all(&vec![0u8; 2 * MAX_FRAME_SIZE]).await.unwrap(); + } + + // Spawn a writer task that will try to write once more. This should initially block + // because the channel is full and rely on the AtomicWaker to be woken later. + let writer = tokio::spawn(async move { + substream + .write_all(&vec![1u8; MAX_FRAME_SIZE]) + .await + .expect("write should eventually succeed"); + }); + + // Give the writer a short moment to reach the blocked (Pending) state. + sleep(Duration::from_millis(10)).await; + assert!( + !writer.is_finished(), + "writer should be blocked by backpressure" + ); + + // Now consume a single message from the receiving side. This will: + // - free capacity in the channel + // - call `write_waker.wake()` from `poll_next` + // + // That wake must cause the blocked writer to be polled again and complete its write. + let _ = handle.next().await.expect("expected at least one outbound message"); + + // The writer should now complete in a timely fashion, proving that: + // - registering the waker before `try_reserve` works (no lost wakeup) + // - the wake from `poll_next` correctly unblocks the writer. + timeout(Duration::from_secs(1), writer) + .await + .expect("writer task did not complete after capacity was freed") + .expect("writer task panicked"); + } }