Skip to content

Commit 73f76ec

Browse files
committed
webrtc: Wake blocked writers and close both sides on RESET_STREAM
1 parent a4effe2 commit 73f76ec

File tree

1 file changed

+125
-2
lines changed

1 file changed

+125
-2
lines changed

src/transport/webrtc/substream.rs

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ pub struct Substream {
9898
/// Waker to notify when shutdown completes (FIN_ACK received).
9999
shutdown_waker: Arc<AtomicWaker>,
100100

101+
/// Waker to notify when write state changes (e.g., STOP_SENDING received).
102+
write_waker: Arc<AtomicWaker>,
103+
101104
/// Timeout for waiting on FIN_ACK after sending FIN.
102105
/// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled.
103106
fin_ack_timeout: Option<Pin<Box<tokio::time::Sleep>>>,
@@ -110,13 +113,15 @@ impl Substream {
110113
let (inbound_tx, inbound_rx) = channel(256);
111114
let state = Arc::new(Mutex::new(State::Open));
112115
let shutdown_waker = Arc::new(AtomicWaker::new());
116+
let write_waker = Arc::new(AtomicWaker::new());
113117

114118
let handle = SubstreamHandle {
115119
inbound_tx,
116120
outbound_tx: outbound_tx.clone(),
117121
rx: outbound_rx,
118122
state: Arc::clone(&state),
119123
shutdown_waker: Arc::clone(&shutdown_waker),
124+
write_waker: Arc::clone(&write_waker),
120125
};
121126

122127
(
@@ -126,6 +131,7 @@ impl Substream {
126131
rx: inbound_rx,
127132
read_buffer: BytesMut::new(),
128133
shutdown_waker,
134+
write_waker,
129135
fin_ack_timeout: None,
130136
},
131137
handle,
@@ -148,6 +154,9 @@ pub struct SubstreamHandle {
148154

149155
/// Waker to notify when shutdown completes (FIN_ACK received).
150156
shutdown_waker: Arc<AtomicWaker>,
157+
158+
/// Waker to notify when write state changes (e.g., STOP_SENDING received).
159+
write_waker: Arc<AtomicWaker>,
151160
}
152161

153162
impl SubstreamHandle {
@@ -187,9 +196,19 @@ impl SubstreamHandle {
187196
}
188197
Flag::StopSending => {
189198
*self.state.lock() = State::SendClosed;
199+
// Wake any blocked poll_write so it can see the state change
200+
self.write_waker.wake();
190201
return Ok(());
191202
}
192203
Flag::ResetStream => {
204+
// RESET_STREAM abruptly terminates both sides of the stream
205+
// (matching go-libp2p behavior)
206+
// Close the read side
207+
let _ = self.inbound_tx.try_send(Event::RecvClosed);
208+
// Close the write side
209+
*self.state.lock() = State::SendClosed;
210+
// Wake any blocked poll_write so it can see the state change
211+
self.write_waker.wake();
193212
return Err(Error::ConnectionClosed);
194213
}
195214
}
@@ -286,6 +305,9 @@ impl tokio::io::AsyncWrite for Substream {
286305
cx: &mut Context<'_>,
287306
buf: &[u8],
288307
) -> Poll<Result<usize, std::io::Error>> {
308+
// Register waker so we get notified on state changes (e.g., STOP_SENDING)
309+
self.write_waker.register(cx.waker());
310+
289311
// Reject writes if we're closing or closed
290312
match *self.state.lock() {
291313
State::SendClosed | State::Closing | State::FinSent | State::FinAcked => {
@@ -299,6 +321,14 @@ impl tokio::io::AsyncWrite for Substream {
299321
Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
300322
};
301323

324+
// Re-check state after poll_reserve - it may have changed while we were waiting
325+
match *self.state.lock() {
326+
State::SendClosed | State::Closing | State::FinSent | State::FinAcked => {
327+
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
328+
}
329+
State::Open => {}
330+
}
331+
302332
let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len());
303333
let frame = buf[..num_bytes].to_vec();
304334

@@ -901,8 +931,9 @@ mod tests {
901931
}
902932

903933
#[tokio::test]
904-
async fn reset_stream_flag_returns_error() {
905-
let (_substream, handle) = Substream::new();
934+
async fn reset_stream_flag_closes_both_sides() {
935+
use tokio::io::AsyncWriteExt;
936+
let (mut substream, handle) = Substream::new();
906937

907938
// Simulate receiving RESET_STREAM
908939
let result = handle
@@ -914,6 +945,19 @@ mod tests {
914945

915946
// Should return connection closed error
916947
assert!(matches!(result, Err(Error::ConnectionClosed)));
948+
949+
// Write side should be closed (state = SendClosed)
950+
assert!(matches!(*handle.state.lock(), State::SendClosed));
951+
952+
// Attempting to write should fail
953+
match substream.write_all(&vec![0u8; 100]).await {
954+
Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe),
955+
_ => panic!("write should have failed"),
956+
}
957+
958+
// Read side should also be closed (RecvClosed event was sent)
959+
// The substream's rx channel should have RecvClosed
960+
assert!(matches!(substream.rx.try_recv(), Ok(Event::RecvClosed)));
917961
}
918962

919963
#[tokio::test]
@@ -1007,6 +1051,85 @@ mod tests {
10071051
shutdown_task3.await.unwrap();
10081052
}
10091053

1054+
#[tokio::test]
1055+
async fn stop_sending_wakes_blocked_writer() {
1056+
use tokio::io::AsyncWriteExt;
1057+
let (mut substream, handle) = Substream::new();
1058+
1059+
// Fill up the channel to cause poll_write to return Pending
1060+
// Channel capacity is 256
1061+
for _ in 0..256 {
1062+
substream.write_all(&[1u8; 100]).await.unwrap();
1063+
}
1064+
1065+
// Now the next write should block waiting for channel capacity
1066+
let write_task = tokio::spawn(async move {
1067+
// This write will block because channel is full
1068+
let result = substream.write_all(&[2u8; 100]).await;
1069+
// Should fail because STOP_SENDING was received
1070+
assert!(result.is_err());
1071+
});
1072+
1073+
// Give the writer time to block on poll_reserve
1074+
tokio::time::sleep(Duration::from_millis(10)).await;
1075+
assert!(!write_task.is_finished(), "write should be blocked");
1076+
1077+
// Simulate receiving STOP_SENDING from remote
1078+
handle
1079+
.on_message(WebRtcMessage {
1080+
payload: None,
1081+
flag: Some(Flag::StopSending),
1082+
})
1083+
.await
1084+
.unwrap();
1085+
1086+
// The write task should wake up and see the state change
1087+
tokio::time::timeout(Duration::from_secs(1), write_task)
1088+
.await
1089+
.expect("write task should complete after STOP_SENDING")
1090+
.unwrap();
1091+
}
1092+
1093+
#[tokio::test]
1094+
async fn reset_stream_wakes_blocked_writer() {
1095+
use tokio::io::AsyncWriteExt;
1096+
let (mut substream, handle) = Substream::new();
1097+
1098+
// Fill up the channel to cause poll_write to return Pending
1099+
// Channel capacity is 256
1100+
for _ in 0..256 {
1101+
substream.write_all(&[1u8; 100]).await.unwrap();
1102+
}
1103+
1104+
// Now the next write should block waiting for channel capacity
1105+
let write_task = tokio::spawn(async move {
1106+
// This write will block because channel is full
1107+
let result = substream.write_all(&[2u8; 100]).await;
1108+
// Should fail because RESET_STREAM was received
1109+
assert!(result.is_err());
1110+
});
1111+
1112+
// Give the writer time to block on poll_reserve
1113+
tokio::time::sleep(Duration::from_millis(10)).await;
1114+
assert!(!write_task.is_finished(), "write should be blocked");
1115+
1116+
// Simulate receiving RESET_STREAM from remote
1117+
let result = handle
1118+
.on_message(WebRtcMessage {
1119+
payload: None,
1120+
flag: Some(Flag::ResetStream),
1121+
})
1122+
.await;
1123+
// RESET_STREAM returns an error
1124+
assert!(result.is_err());
1125+
1126+
// The write task should wake up and see the state change
1127+
tokio::time::timeout(Duration::from_secs(1), write_task)
1128+
.await
1129+
.expect("write task should complete after RESET_STREAM")
1130+
.unwrap();
1131+
}
1132+
10101133
#[tokio::test]
10111134
async fn shutdown_rejects_new_writes() {
10121135
use tokio::io::AsyncWriteExt;

0 commit comments

Comments
 (0)