Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/target
.idea
80 changes: 69 additions & 11 deletions src/transport/webrtc/substream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ use crate::{
};

use bytes::{Buf, BufMut, BytesMut};
use futures::{Future, Stream};
use futures::{task::AtomicWaker, Future, Stream};
use parking_lot::Mutex;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::mpsc::{channel, error::TrySendError, Receiver, Sender};

use std::{
pin::Pin,
Expand Down Expand Up @@ -59,7 +59,7 @@ enum State {
SendClosed,
}

/// Channel-backedn substream.
/// Channel-backed substream. Must be owned and polled by exactly one task at a time.
pub struct Substream {
/// Substream state.
state: Arc<Mutex<State>>,
Expand All @@ -72,6 +72,9 @@ pub struct Substream {

/// RX channel for receiving messages from `peer`.
rx: Receiver<Event>,

/// Shared waker to notify when capacity on a previously full `tx` channel is available.
write_waker: Arc<AtomicWaker>,
}

impl Substream {
Expand All @@ -80,10 +83,13 @@ impl Substream {
let (outbound_tx, outbound_rx) = channel(256);
let (inbound_tx, inbound_rx) = channel(256);
let state = Arc::new(Mutex::new(State::Open));
let waker = Arc::new(AtomicWaker::new());

let handle = SubstreamHandle {
tx: inbound_tx,
rx: outbound_rx,
state: Arc::clone(&state),
write_waker: Arc::clone(&waker),
};

(
Expand All @@ -92,6 +98,7 @@ impl Substream {
tx: outbound_tx,
rx: inbound_rx,
read_buffer: BytesMut::new(),
write_waker: waker,
},
handle,
)
Expand All @@ -107,6 +114,9 @@ pub struct SubstreamHandle {

/// RX channel for receiving messages from `peer`.
rx: Receiver<Event>,

/// Shared waker to notify when capacity on a previously full `rx` channel is available.
write_waker: Arc<AtomicWaker>,
}

impl SubstreamHandle {
Expand Down Expand Up @@ -144,7 +154,9 @@ impl Stream for SubstreamHandle {
type Item = Event;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
let item = self.rx.poll_recv(cx);
self.write_waker.wake();
item
}
}

Expand Down Expand Up @@ -198,17 +210,20 @@ impl tokio::io::AsyncWrite for Substream {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}

// TODO: try to coalesce multiple calls to `poll_write()` into single `Event::Message`

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let future = self.tx.reserve();
futures::pin_mut!(future);
// Register in case channel is full. Do it before checking to avoid lost wakeups.
self.write_waker.register(cx.waker());

let permit = match futures::ready!(future.poll(cx)) {
Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
let permit = match self.tx.try_reserve() {
Ok(permit) => permit,
Err(err) =>
return match err {
TrySendError::Full(_) => Poll::Pending,
TrySendError::Closed(_) =>
Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
},
};

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let frame = buf[..num_bytes].to_vec();
permit.send(Event::Message(frame));

Expand Down Expand Up @@ -458,4 +473,47 @@ mod tests {
)
.await;
}

#[tokio::test]
async fn backpressure_released_wakes_blocked_writer() {
use tokio::time::{sleep, timeout, Duration};

let (mut substream, mut handle) = Substream::new();

// Fill the channel to capacity, same pattern as `backpressure_works`.
for _ in 0..128 {
substream.write_all(&vec![0u8; 2 * MAX_FRAME_SIZE]).await.unwrap();
}

// Spawn a writer task that will try to write once more. This should initially block
// because the channel is full and rely on the AtomicWaker to be woken later.
let writer = tokio::spawn(async move {
substream
.write_all(&vec![1u8; MAX_FRAME_SIZE])
.await
.expect("write should eventually succeed");
});

// Give the writer a short moment to reach the blocked (Pending) state.
sleep(Duration::from_millis(10)).await;
assert!(
!writer.is_finished(),
"writer should be blocked by backpressure"
);

// Now consume a single message from the receiving side. This will:
// - free capacity in the channel
// - call `write_waker.wake()` from `poll_next`
//
// That wake must cause the blocked writer to be polled again and complete its write.
let _ = handle.next().await.expect("expected at least one outbound message");

// The writer should now complete in a timely fashion, proving that:
// - registering the waker before `try_reserve` works (no lost wakeup)
// - the wake from `poll_next` correctly unblocks the writer.
timeout(Duration::from_secs(1), writer)
.await
.expect("writer task did not complete after capacity was freed")
.expect("writer task panicked");
}
}