Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
/target
.idea

94 changes: 68 additions & 26 deletions src/transport/webrtc/substream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use crate::{
};

use bytes::{Buf, BufMut, BytesMut};
use futures::{Future, Stream};
use futures::Stream;
use parking_lot::Mutex;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_util::sync::PollSender;

use std::{
pin::Pin,
Expand Down Expand Up @@ -59,16 +60,17 @@ enum State {
SendClosed,
}

/// Channel-backedn substream.
/// Channel-backed substream. Must be owned and polled by exactly one task at a time.
pub struct Substream {
/// Substream state.
state: Arc<Mutex<State>>,

/// Read buffer.
read_buffer: BytesMut,

/// TX channel for sending messages to `peer`.
tx: Sender<Event>,
/// TX channel for sending messages to `peer`, wrapped in a [`PollSender`]
/// so that backpressure is driven by the caller's waker.
tx: PollSender<Event>,

/// RX channel for receiving messages from `peer`.
rx: Receiver<Event>,
Expand All @@ -80,6 +82,7 @@ impl Substream {
let (outbound_tx, outbound_rx) = channel(256);
let (inbound_tx, inbound_rx) = channel(256);
let state = Arc::new(Mutex::new(State::Open));

let handle = SubstreamHandle {
tx: inbound_tx,
rx: outbound_rx,
Expand All @@ -89,7 +92,7 @@ impl Substream {
(
Self {
state,
tx: outbound_tx,
tx: PollSender::new(outbound_tx),
rx: inbound_rx,
read_buffer: BytesMut::new(),
},
Expand All @@ -102,10 +105,10 @@ impl Substream {
pub struct SubstreamHandle {
state: Arc<Mutex<State>>,

/// TX channel for sending messages to `peer`.
/// TX channel for sending inbound messages from `peer` to the associated `Substream`.
tx: Sender<Event>,

/// RX channel for receiving messages from `peer`.
/// RX channel for receiving outbound messages to `peer` from the associated `Substream`.
rx: Receiver<Event>,
}

Expand Down Expand Up @@ -190,49 +193,45 @@ impl tokio::io::AsyncRead for Substream {

impl tokio::io::AsyncWrite for Substream {
fn poll_write(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
if let State::SendClosed = *self.state.lock() {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}

// TODO: try to coalesce multiple calls to `poll_write()` into single `Event::Message`

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let future = self.tx.reserve();
futures::pin_mut!(future);

let permit = match futures::ready!(future.poll(cx)) {
match futures::ready!(self.tx.poll_reserve(cx)) {
Ok(()) => {}
Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
Ok(permit) => permit,
};

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let frame = buf[..num_bytes].to_vec();
permit.send(Event::Message(frame));

Poll::Ready(Ok(num_bytes))
match self.tx.send_item(Event::Message(frame)) {
Ok(()) => Poll::Ready(Ok(num_bytes)),
Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
}
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let future = self.tx.reserve();
futures::pin_mut!(future);

let permit = match futures::ready!(future.poll(cx)) {
match futures::ready!(self.tx.poll_reserve(cx)) {
Ok(()) => {}
Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
Ok(permit) => permit,
};
permit.send(Event::Close);

Poll::Ready(Ok(()))
match self.tx.send_item(Event::Close) {
Ok(()) => Poll::Ready(Ok(())),
Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
}
}
}

Expand Down Expand Up @@ -458,4 +457,47 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn backpressure_released_wakes_blocked_writer() {
use tokio::time::{sleep, timeout, Duration};

let (mut substream, mut handle) = Substream::new();

// Fill the channel to capacity, same pattern as `backpressure_works`.
for _ in 0..128 {
substream.write_all(&vec![0u8; 2 * MAX_FRAME_SIZE]).await.unwrap();
}

// Spawn a writer task that will try to write once more. This should initially block
// because the channel is full and rely on the AtomicWaker to be woken later.
let writer = tokio::spawn(async move {
substream
.write_all(&vec![1u8; MAX_FRAME_SIZE])
.await
.expect("write should eventually succeed");
});

// Give the writer a short moment to reach the blocked (Pending) state.
sleep(Duration::from_millis(10)).await;
assert!(
!writer.is_finished(),
"writer should be blocked by backpressure"
);

// Now consume a single message from the receiving side. This will:
// - free capacity in the channel
// - call `write_waker.wake()` from `poll_next`
//
// That wake must cause the blocked writer to be polled again and complete its write.
let _ = handle.next().await.expect("expected at least one outbound message");

// The writer should now complete in a timely fashion, proving that:
// - registering the waker before `try_reserve` works (no lost wakeup)
// - the wake from `poll_next` correctly unblocks the writer.
timeout(Duration::from_secs(1), writer)
.await
.expect("writer task did not complete after capacity was freed")
.expect("writer task panicked");
}
}
Loading