@@ -98,6 +98,9 @@ pub struct Substream {
9898 /// Waker to notify when shutdown completes (FIN_ACK received).
9999 shutdown_waker : Arc < AtomicWaker > ,
100100
101+ /// Waker to notify when write state changes (e.g., STOP_SENDING received).
102+ write_waker : Arc < AtomicWaker > ,
103+
101104 /// Timeout for waiting on FIN_ACK after sending FIN.
102105 /// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled.
103106 fin_ack_timeout : Option < Pin < Box < tokio:: time:: Sleep > > > ,
@@ -110,13 +113,15 @@ impl Substream {
110113 let ( inbound_tx, inbound_rx) = channel ( 256 ) ;
111114 let state = Arc :: new ( Mutex :: new ( State :: Open ) ) ;
112115 let shutdown_waker = Arc :: new ( AtomicWaker :: new ( ) ) ;
116+ let write_waker = Arc :: new ( AtomicWaker :: new ( ) ) ;
113117
114118 let handle = SubstreamHandle {
115119 inbound_tx,
116120 outbound_tx : outbound_tx. clone ( ) ,
117121 rx : outbound_rx,
118122 state : Arc :: clone ( & state) ,
119123 shutdown_waker : Arc :: clone ( & shutdown_waker) ,
124+ write_waker : Arc :: clone ( & write_waker) ,
120125 } ;
121126
122127 (
@@ -126,6 +131,7 @@ impl Substream {
126131 rx : inbound_rx,
127132 read_buffer : BytesMut :: new ( ) ,
128133 shutdown_waker,
134+ write_waker,
129135 fin_ack_timeout : None ,
130136 } ,
131137 handle,
@@ -148,6 +154,9 @@ pub struct SubstreamHandle {
148154
149155 /// Waker to notify when shutdown completes (FIN_ACK received).
150156 shutdown_waker : Arc < AtomicWaker > ,
157+
158+ /// Waker to notify when write state changes (e.g., STOP_SENDING received).
159+ write_waker : Arc < AtomicWaker > ,
151160}
152161
153162impl SubstreamHandle {
@@ -187,9 +196,19 @@ impl SubstreamHandle {
187196 }
188197 Flag :: StopSending => {
189198 * self . state . lock ( ) = State :: SendClosed ;
199+ // Wake any blocked poll_write so it can see the state change
200+ self . write_waker . wake ( ) ;
190201 return Ok ( ( ) ) ;
191202 }
192203 Flag :: ResetStream => {
204+ // RESET_STREAM abruptly terminates both sides of the stream
205+ // (matching go-libp2p behavior)
206+ // Close the read side
207+ let _ = self . inbound_tx . try_send ( Event :: RecvClosed ) ;
208+ // Close the write side
209+ * self . state . lock ( ) = State :: SendClosed ;
210+ // Wake any blocked poll_write so it can see the state change
211+ self . write_waker . wake ( ) ;
193212 return Err ( Error :: ConnectionClosed ) ;
194213 }
195214 }
@@ -286,6 +305,9 @@ impl tokio::io::AsyncWrite for Substream {
286305 cx : & mut Context < ' _ > ,
287306 buf : & [ u8 ] ,
288307 ) -> Poll < Result < usize , std:: io:: Error > > {
308+ // Register waker so we get notified on state changes (e.g., STOP_SENDING)
309+ self . write_waker . register ( cx. waker ( ) ) ;
310+
289311 // Reject writes if we're closing or closed
290312 match * self . state . lock ( ) {
291313 State :: SendClosed | State :: Closing | State :: FinSent | State :: FinAcked => {
@@ -299,6 +321,14 @@ impl tokio::io::AsyncWrite for Substream {
299321 Err ( _) => return Poll :: Ready ( Err ( std:: io:: ErrorKind :: BrokenPipe . into ( ) ) ) ,
300322 } ;
301323
324+ // Re-check state after poll_reserve - it may have changed while we were waiting
325+ match * self . state . lock ( ) {
326+ State :: SendClosed | State :: Closing | State :: FinSent | State :: FinAcked => {
327+ return Poll :: Ready ( Err ( std:: io:: ErrorKind :: BrokenPipe . into ( ) ) ) ;
328+ }
329+ State :: Open => { }
330+ }
331+
302332 let num_bytes = std:: cmp:: min ( MAX_FRAME_SIZE , buf. len ( ) ) ;
303333 let frame = buf[ ..num_bytes] . to_vec ( ) ;
304334
@@ -901,8 +931,9 @@ mod tests {
901931 }
902932
903933 #[ tokio:: test]
904- async fn reset_stream_flag_returns_error ( ) {
905- let ( _substream, handle) = Substream :: new ( ) ;
934+ async fn reset_stream_flag_closes_both_sides ( ) {
935+ use tokio:: io:: AsyncWriteExt ;
936+ let ( mut substream, handle) = Substream :: new ( ) ;
906937
907938 // Simulate receiving RESET_STREAM
908939 let result = handle
@@ -914,6 +945,19 @@ mod tests {
914945
915946 // Should return connection closed error
916947 assert ! ( matches!( result, Err ( Error :: ConnectionClosed ) ) ) ;
948+
949+ // Write side should be closed (state = SendClosed)
950+ assert ! ( matches!( * handle. state. lock( ) , State :: SendClosed ) ) ;
951+
952+ // Attempting to write should fail
953+ match substream. write_all ( & vec ! [ 0u8 ; 100 ] ) . await {
954+ Err ( error) => assert_eq ! ( error. kind( ) , std:: io:: ErrorKind :: BrokenPipe ) ,
955+ _ => panic ! ( "write should have failed" ) ,
956+ }
957+
958+ // Read side should also be closed (RecvClosed event was sent)
959+ // The substream's rx channel should have RecvClosed
960+ assert ! ( matches!( substream. rx. try_recv( ) , Ok ( Event :: RecvClosed ) ) ) ;
917961 }
918962
919963 #[ tokio:: test]
@@ -1007,6 +1051,85 @@ mod tests {
10071051 shutdown_task3. await . unwrap ( ) ;
10081052 }
10091053
1054+ #[ tokio:: test]
1055+ async fn stop_sending_wakes_blocked_writer ( ) {
1056+ use tokio:: io:: AsyncWriteExt ;
1057+ let ( mut substream, handle) = Substream :: new ( ) ;
1058+
1059+ // Fill up the channel to cause poll_write to return Pending
1060+ // Channel capacity is 256
1061+ for _ in 0 ..256 {
1062+ substream. write_all ( & [ 1u8 ; 100 ] ) . await . unwrap ( ) ;
1063+ }
1064+
1065+ // Now the next write should block waiting for channel capacity
1066+ let write_task = tokio:: spawn ( async move {
1067+ // This write will block because channel is full
1068+ let result = substream. write_all ( & [ 2u8 ; 100 ] ) . await ;
1069+ // Should fail because STOP_SENDING was received
1070+ assert ! ( result. is_err( ) ) ;
1071+ } ) ;
1072+
1073+ // Give the writer time to block on poll_reserve
1074+ tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
1075+ assert ! ( !write_task. is_finished( ) , "write should be blocked" ) ;
1076+
1077+ // Simulate receiving STOP_SENDING from remote
1078+ handle
1079+ . on_message ( WebRtcMessage {
1080+ payload : None ,
1081+ flag : Some ( Flag :: StopSending ) ,
1082+ } )
1083+ . await
1084+ . unwrap ( ) ;
1085+
1086+ // The write task should wake up and see the state change
1087+ tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , write_task)
1088+ . await
1089+ . expect ( "write task should complete after STOP_SENDING" )
1090+ . unwrap ( ) ;
1091+ }
1092+
1093+ #[ tokio:: test]
1094+ async fn reset_stream_wakes_blocked_writer ( ) {
1095+ use tokio:: io:: AsyncWriteExt ;
1096+ let ( mut substream, handle) = Substream :: new ( ) ;
1097+
1098+ // Fill up the channel to cause poll_write to return Pending
1099+ // Channel capacity is 256
1100+ for _ in 0 ..256 {
1101+ substream. write_all ( & [ 1u8 ; 100 ] ) . await . unwrap ( ) ;
1102+ }
1103+
1104+ // Now the next write should block waiting for channel capacity
1105+ let write_task = tokio:: spawn ( async move {
1106+ // This write will block because channel is full
1107+ let result = substream. write_all ( & [ 2u8 ; 100 ] ) . await ;
1108+ // Should fail because RESET_STREAM was received
1109+ assert ! ( result. is_err( ) ) ;
1110+ } ) ;
1111+
1112+ // Give the writer time to block on poll_reserve
1113+ tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
1114+ assert ! ( !write_task. is_finished( ) , "write should be blocked" ) ;
1115+
1116+ // Simulate receiving RESET_STREAM from remote
1117+ let result = handle
1118+ . on_message ( WebRtcMessage {
1119+ payload : None ,
1120+ flag : Some ( Flag :: ResetStream ) ,
1121+ } )
1122+ . await ;
1123+ // RESET_STREAM returns an error
1124+ assert ! ( result. is_err( ) ) ;
1125+
1126+ // The write task should wake up and see the state change
1127+ tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , write_task)
1128+ . await
1129+ . expect ( "write task should complete after RESET_STREAM" )
1130+ . unwrap ( ) ;
1131+ }
1132+
10101133 #[ tokio:: test]
10111134 async fn shutdown_rejects_new_writes ( ) {
10121135 use tokio:: io:: AsyncWriteExt ;
0 commit comments