Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/target
.idea
37 changes: 26 additions & 11 deletions src/transport/webrtc/substream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ use crate::{
};

use bytes::{Buf, BufMut, BytesMut};
use futures::{Future, Stream};
use futures::{task::AtomicWaker, Future, Stream};
use parking_lot::Mutex;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::mpsc::{channel, error::TrySendError, Receiver, Sender};

use std::{
pin::Pin,
Expand Down Expand Up @@ -59,7 +59,7 @@ enum State {
SendClosed,
}

/// Channel-backedn substream.
/// Channel-backed substream. Must be owned and polled by exactly one task at a time.
pub struct Substream {
/// Substream state.
state: Arc<Mutex<State>>,
Expand All @@ -72,6 +72,9 @@ pub struct Substream {

/// RX channel for receiving messages from `peer`.
rx: Receiver<Event>,

/// Shared waker to notify when capacity on a previously full `tx` channel is available.
write_waker: Arc<AtomicWaker>,
}

impl Substream {
Expand All @@ -80,10 +83,13 @@ impl Substream {
let (outbound_tx, outbound_rx) = channel(256);
let (inbound_tx, inbound_rx) = channel(256);
let state = Arc::new(Mutex::new(State::Open));
let waker = Arc::new(AtomicWaker::new());

let handle = SubstreamHandle {
tx: inbound_tx,
rx: outbound_rx,
state: Arc::clone(&state),
write_waker: Arc::clone(&waker),
};

(
Expand All @@ -92,6 +98,7 @@ impl Substream {
tx: outbound_tx,
rx: inbound_rx,
read_buffer: BytesMut::new(),
write_waker: waker,
},
handle,
)
Expand All @@ -107,6 +114,9 @@ pub struct SubstreamHandle {

/// RX channel for receiving messages from `peer`.
rx: Receiver<Event>,

/// Shared waker to notify when capacity on a previously full `rx` channel is available.
write_waker: Arc<AtomicWaker>,
}

impl SubstreamHandle {
Expand Down Expand Up @@ -144,7 +154,9 @@ impl Stream for SubstreamHandle {
type Item = Event;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
let item = self.rx.poll_recv(cx);
self.write_waker.wake();
item
}
}

Expand Down Expand Up @@ -198,17 +210,20 @@ impl tokio::io::AsyncWrite for Substream {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}

// TODO: try to coalesce multiple calls to `poll_write()` into single `Event::Message`

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let future = self.tx.reserve();
futures::pin_mut!(future);
// Register in case channel is full. Do it before checking to avoid lost wakeups.
self.write_waker.register(cx.waker());

let permit = match futures::ready!(future.poll(cx)) {
Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
let permit = match self.tx.try_reserve() {
Ok(permit) => permit,
Err(err) =>
return match err {
TrySendError::Full(_) => Poll::Pending,
TrySendError::Closed(_) =>
Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
},
};

let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
let frame = buf[..num_bytes].to_vec();
permit.send(Event::Message(frame));

Expand Down