diff --git a/src/schema/webrtc.proto b/src/schema/webrtc.proto index f36e04f94..852f3c6c1 100644 --- a/src/schema/webrtc.proto +++ b/src/schema/webrtc.proto @@ -12,6 +12,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver MAY discard any data that it already received on that stream. RESET_STREAM = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } optional Flag flag = 1; diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index c47e0024d..65e7d80a0 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -27,6 +27,7 @@ use crate::{ substream::Substream, transport::{ webrtc::{ + schema::webrtc::message::Flag, substream::{Event as SubstreamEvent, Substream as WebRtcSubstream, SubstreamHandle}, util::WebRtcMessage, }, @@ -263,7 +264,7 @@ impl WebRtcConnection { let fallback_names = std::mem::take(&mut context.fallback_names); let (dialer_state, message) = WebRtcDialerState::propose(context.protocol.clone(), fallback_names)?; - let message = WebRtcMessage::encode(message); + let message = WebRtcMessage::encode(message, None); self.rtc .channel(channel_id) @@ -330,7 +331,10 @@ impl WebRtcConnection { self.rtc .channel(channel_id) .ok_or(Error::ChannelDoesntExist)? - .write(true, WebRtcMessage::encode(response.to_vec()).as_ref()) + .write( + true, + WebRtcMessage::encode(response.to_vec(), None).as_ref(), + ) .map_err(Error::WebRtc)?; let protocol = negotiated.ok_or(Error::SubstreamDoesntExist)?; @@ -452,7 +456,7 @@ impl WebRtcConnection { target: LOG_TARGET, peer = ?self.peer, ?channel_id, - flags = message.flags, + flag = ?message.flag, data_len = message.payload.as_ref().map_or(0usize, |payload| payload.len()), "handle inbound message", ); @@ -598,20 +602,26 @@ impl WebRtcConnection { Ok(()) } - /// Handle outbound data. - fn on_outbound_data(&mut self, channel_id: ChannelId, data: Vec) -> crate::Result<()> { + /// Handle outbound data with optional flag. + fn on_outbound_data( + &mut self, + channel_id: ChannelId, + data: Vec, + flag: Option, + ) -> crate::Result<()> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, ?channel_id, data_len = ?data.len(), + ?flag, "send data", ); self.rtc .channel(channel_id) .ok_or(Error::ChannelDoesntExist)? - .write(true, WebRtcMessage::encode(data).as_ref()) + .write(true, WebRtcMessage::encode(data, flag).as_ref()) .map_err(Error::WebRtc) .map(|_| ()) } @@ -788,7 +798,7 @@ impl WebRtcConnection { }, event = self.handles.next() => match event { None => unreachable!(), - Some((channel_id, None | Some(SubstreamEvent::Close))) => { + Some((channel_id, None)) => { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -800,11 +810,12 @@ impl WebRtcConnection { self.channels.insert(channel_id, ChannelState::Closing); self.handles.remove(&channel_id); } - Some((channel_id, Some(SubstreamEvent::Message(data)))) => { - if let Err(error) = self.on_outbound_data(channel_id, data) { + Some((channel_id, Some(SubstreamEvent::Message { payload, flag }))) => { + if let Err(error) = self.on_outbound_data(channel_id, payload, flag) { tracing::debug!( target: LOG_TARGET, ?channel_id, + ?flag, ?error, "failed to send data to remote peer", ); diff --git a/src/transport/webrtc/opening.rs b/src/transport/webrtc/opening.rs index ba6e2af7a..c2a0827e8 100644 --- a/src/transport/webrtc/opening.rs +++ b/src/transport/webrtc/opening.rs @@ -207,7 +207,7 @@ impl OpeningWebRtcConnection { }; // create first noise handshake and send it to remote peer - let payload = WebRtcMessage::encode(context.first_message(Role::Dialer)?); + let payload = WebRtcMessage::encode(context.first_message(Role::Dialer)?, None); self.rtc .channel(self.noise_channel_id) @@ -300,7 +300,7 @@ impl OpeningWebRtcConnection { }; // create second noise handshake message and send it to remote - let payload = WebRtcMessage::encode(context.second_message()?); + let payload = WebRtcMessage::encode(context.second_message()?, None); let mut channel = self.rtc.channel(self.noise_channel_id).ok_or(Error::ChannelDoesntExist)?; diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index f839fd83e..ce3333d7e 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -24,7 +24,7 @@ use crate::{ }; use bytes::{Buf, BufMut, BytesMut}; -use futures::Stream; +use futures::{task::AtomicWaker, Future, Stream}; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_util::sync::PollSender; @@ -33,31 +33,51 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, + time::Duration, }; /// Maximum frame size. const MAX_FRAME_SIZE: usize = 16384; +/// Timeout for waiting on FIN_ACK after sending FIN. +/// Matches go-libp2p's 5 second stream close timeout. +#[cfg(not(test))] +const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(5); + +/// Shorter timeout for tests. +#[cfg(test)] +const FIN_ACK_TIMEOUT: Duration = Duration::from_secs(2); + /// Substream event. #[derive(Debug, PartialEq, Eq)] pub enum Event { /// Receiver closed. RecvClosed, - /// Send/receive message. - Message(Vec), - - /// Close substream. - Close, + /// Send/receive message with optional flag. + Message { + payload: Vec, + flag: Option, + }, } /// Substream stream. +#[derive(Debug, Clone, Copy)] enum State { /// Substream is fully open. Open, /// Remote is no longer interested in receiving anything. SendClosed, + + /// Shutdown initiated, flushing pending data before sending FIN. + Closing, + + /// We sent FIN, waiting for FIN_ACK. + FinSent, + + /// We received FIN_ACK, write half is closed. + FinAcked, } /// Channel-backed substream. Must be owned and polled by exactly one task at a time. @@ -74,6 +94,16 @@ pub struct Substream { /// RX channel for receiving messages from `peer`. rx: Receiver, + + /// Waker to notify when shutdown completes (FIN_ACK received). + shutdown_waker: Arc, + + /// Waker to notify when write state changes (e.g., STOP_SENDING received). + write_waker: Arc, + + /// Timeout for waiting on FIN_ACK after sending FIN. + /// Boxed to maintain Unpin for Substream while allowing the Sleep to be polled. + fin_ack_timeout: Option>>, } impl Substream { @@ -82,11 +112,17 @@ impl Substream { let (outbound_tx, outbound_rx) = channel(256); let (inbound_tx, inbound_rx) = channel(256); let state = Arc::new(Mutex::new(State::Open)); + let shutdown_waker = Arc::new(AtomicWaker::new()); + let write_waker = Arc::new(AtomicWaker::new()); let handle = SubstreamHandle { - tx: inbound_tx, + inbound_tx, + outbound_tx: outbound_tx.clone(), rx: outbound_rx, state: Arc::clone(&state), + shutdown_waker: Arc::clone(&shutdown_waker), + write_waker: Arc::clone(&write_waker), + read_closed: std::sync::atomic::AtomicBool::new(false), }; ( @@ -95,6 +131,9 @@ impl Substream { tx: PollSender::new(outbound_tx), rx: inbound_rx, read_buffer: BytesMut::new(), + shutdown_waker, + write_waker, + fin_ack_timeout: None, }, handle, ) @@ -106,36 +145,112 @@ pub struct SubstreamHandle { state: Arc>, /// TX channel for sending inbound messages from `peer` to the associated `Substream`. - tx: Sender, + inbound_tx: Sender, + + /// TX channel for sending outbound messages to `peer` (e.g., FIN_ACK responses). + outbound_tx: Sender, /// RX channel for receiving outbound messages to `peer` from the associated `Substream`. rx: Receiver, + + /// Waker to notify when shutdown completes (FIN_ACK received). + shutdown_waker: Arc, + + /// Waker to notify when write state changes (e.g., STOP_SENDING received). + write_waker: Arc, + + /// Whether we've already sent RecvClosed to the inbound channel. + /// Prevents duplicate RecvClosed events if multiple FIN messages are received. + read_closed: std::sync::atomic::AtomicBool, } impl SubstreamHandle { /// Handle message received from a remote peer. /// - /// If the message contains any flags, handle them first and appropriately close the correct - /// side of the substream. If the message contained any payload, send it to the protocol for - /// further processing. + /// Process an incoming WebRTC message, handling any payload and flags. + /// + /// Payload is processed first (if present), then flags are handled. This ensures that + /// a FIN message containing final data will deliver that data before signaling closure. pub async fn on_message(&self, message: WebRtcMessage) -> crate::Result<()> { - if let Some(flags) = message.flags { - if flags == Flag::Fin as i32 { - self.tx.send(Event::RecvClosed).await?; - } - - if flags & 1 == Flag::StopSending as i32 { - *self.state.lock() = State::SendClosed; - } - - if flags & 2 == Flag::ResetStream as i32 { - return Err(Error::ConnectionClosed); + // Process payload first, before handling flags. + // This ensures that if a FIN message contains data, we deliver it before closing. + if let Some(payload) = message.payload { + if !payload.is_empty() { + self.inbound_tx + .send(Event::Message { + payload, + flag: None, + }) + .await?; } } - if let Some(payload) = message.payload { - if !payload.is_empty() { - return self.tx.send(Event::Message(payload)).await.map_err(From::from); + // Now handle flags + if let Some(flag) = message.flag { + match flag { + Flag::Fin => { + // Guard against duplicate FIN messages - only send RecvClosed once + if self.read_closed.swap(true, std::sync::atomic::Ordering::SeqCst) { + // Already processed FIN, ignore duplicate + tracing::debug!( + target: "litep2p::webrtc::substream", + "received duplicate FIN, ignoring" + ); + return Ok(()); + } + + // Received FIN from remote, close our read half + self.inbound_tx.send(Event::RecvClosed).await?; + + // Send FIN_ACK back to remote using try_send to avoid blocking. + // If the channel is full, the remote will timeout waiting for FIN_ACK + // and handle it gracefully. This prevents deadlock if the outbound + // channel is blocked due to backpressure. + if let Err(e) = self.outbound_tx.try_send(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck), + }) { + tracing::warn!( + target: "litep2p::webrtc::substream", + ?e, + "failed to send FIN_ACK, remote will timeout" + ); + } + return Ok(()); + } + Flag::FinAck => { + // Received FIN_ACK, we can now fully close our write half + let mut state = self.state.lock(); + if matches!(*state, State::FinSent) { + *state = State::FinAcked; + // Wake up any task waiting on shutdown + self.shutdown_waker.wake(); + } else { + tracing::warn!( + target: "litep2p::webrtc::substream", + ?state, + "received FIN_ACK in unexpected state, ignoring" + ); + } + return Ok(()); + } + Flag::StopSending => { + *self.state.lock() = State::SendClosed; + // Wake any blocked poll_write so it can see the state change + self.write_waker.wake(); + return Ok(()); + } + Flag::ResetStream => { + // RESET_STREAM abruptly terminates both sides of the stream + // (matching go-libp2p behavior) + // Close the read side + let _ = self.inbound_tx.try_send(Event::RecvClosed); + // Close the write side + *self.state.lock() = State::SendClosed; + // Wake any blocked poll_write so it can see the state change + self.write_waker.wake(); + return Err(Error::ConnectionClosed); + } } } @@ -147,7 +262,27 @@ impl Stream for SubstreamHandle { type Item = Event; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) + // First, try to drain any pending outbound messages + match self.rx.poll_recv(cx) { + Poll::Ready(Some(event)) => return Poll::Ready(Some(event)), + Poll::Ready(None) => { + // Outbound channel closed (all senders dropped) + return Poll::Ready(None); + } + Poll::Pending => { + // No messages available, check if we should signal closure + } + } + + // Check if Substream has been dropped (inbound channel closed) + // When Substream is dropped, there will be no more outbound messages + // Since we've already tried to recv above and got Pending, we know the queue is empty + // Therefore, it's safe to signal closure + if self.inbound_tx.is_closed() { + return Poll::Ready(None); + } + + Poll::Pending } } @@ -169,19 +304,19 @@ impl tokio::io::AsyncRead for Substream { } match futures::ready!(self.rx.poll_recv(cx)) { - None | Some(Event::Close) | Some(Event::RecvClosed) => + None | Some(Event::RecvClosed) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), - Some(Event::Message(message)) => { - if message.len() > MAX_FRAME_SIZE { + Some(Event::Message { payload, flag: _ }) => { + if payload.len() > MAX_FRAME_SIZE { return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())); } - match buf.remaining() >= message.len() { - true => buf.put_slice(&message), + match buf.remaining() >= payload.len() { + true => buf.put_slice(&payload), false => { let remaining = buf.remaining(); - buf.put_slice(&message[..remaining]); - self.read_buffer.put_slice(&message[remaining..]); + buf.put_slice(&payload[..remaining]); + self.read_buffer.put_slice(&payload[remaining..]); } } @@ -197,8 +332,15 @@ impl tokio::io::AsyncWrite for Substream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if let State::SendClosed = *self.state.lock() { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + // Register waker so we get notified on state changes (e.g., STOP_SENDING) + self.write_waker.register(cx.waker()); + + // Reject writes if we're closing or closed + match *self.state.lock() { + State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + State::Open => {} } match futures::ready!(self.tx.poll_reserve(cx)) { @@ -206,10 +348,21 @@ impl tokio::io::AsyncWrite for Substream { Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), }; + // Re-check state after poll_reserve - it may have changed while we were waiting + match *self.state.lock() { + State::SendClosed | State::Closing | State::FinSent | State::FinAcked => { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + State::Open => {} + } + let num_bytes = std::cmp::min(MAX_FRAME_SIZE, buf.len()); let frame = buf[..num_bytes].to_vec(); - match self.tx.send_item(Event::Message(frame)) { + match self.tx.send_item(Event::Message { + payload: frame, + flag: None, + }) { Ok(()) => Poll::Ready(Ok(num_bytes)), Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), } @@ -223,13 +376,105 @@ impl tokio::io::AsyncWrite for Substream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + // State machine for proper shutdown: + // 1. Transition to Closing (stops accepting new writes) + // 2. Flush pending data + // 3. Send FIN flag + // 4. Transition to FinSent + // 5. Wait for FIN_ACK + // 6. Transition to FinAcked and complete + + let current_state = *self.state.lock(); + + match current_state { + // Already received FIN_ACK, shutdown complete + State::FinAcked => return Poll::Ready(Ok(())), + + // Sent FIN, waiting for FIN_ACK - poll timeout and return Pending + State::FinSent => { + // Register waker FIRST to avoid race condition with on_message + self.shutdown_waker.register(cx.waker()); + + // Re-check state after waker registration in case FIN_ACK arrived + // between the initial state check and waker registration + if matches!(*self.state.lock(), State::FinAcked) { + return Poll::Ready(Ok(())); + } + + // Poll the timeout - if it fires, force shutdown completion + if let Some(timeout) = self.fin_ack_timeout.as_mut() { + if timeout.as_mut().poll(cx).is_ready() { + tracing::debug!( + target: "litep2p::webrtc::substream", + "FIN_ACK timeout exceeded, forcing shutdown completion" + ); + *self.state.lock() = State::FinAcked; + return Poll::Ready(Ok(())); + } + } + + return Poll::Pending; + } + + // First call to shutdown - transition to Closing + State::Open => { + *self.state.lock() = State::Closing; + } + + State::Closing => { + // Already in closing state, continue with shutdown process. + // Guard against duplicate FIN sends: if timeout is already set, we've + // already sent FIN and are waiting for FIN_ACK. This shouldn't happen + // with correct AsyncWrite usage (&mut self), but provides defense in depth. + if self.fin_ack_timeout.is_some() { + self.shutdown_waker.register(cx.waker()); + return Poll::Pending; + } + } + + State::SendClosed => { + // Remote closed send, we can still send FIN + } + } + + // Flush any pending data + // Note: Currently poll_flush is a no-op, but the channel backpressure + // provides implicit flushing since we wait for poll_reserve below + futures::ready!(self.as_mut().poll_flush(cx))?; + + // Reserve space to send FIN match futures::ready!(self.tx.poll_reserve(cx)) { Ok(()) => {} Err(_) => return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), }; - match self.tx.send_item(Event::Close) { - Ok(()) => Poll::Ready(Ok(())), + // Send message with FIN flag + match self.tx.send_item(Event::Message { + payload: vec![], + flag: Some(Flag::Fin), + }) { + Ok(()) => { + // Race condition mitigation strategy: + // 1. Transition to FinSent FIRST so on_message can recognize FIN_ACK (if waker + // registered first, FIN_ACK would be ignored since state != FinSent) + // 2. Register waker so we'll be notified on future FIN_ACK arrivals + // 3. Re-check state to catch FIN_ACK that arrived between steps 1 and 2 (wake() + // called before waker registered has no effect, but state changed) + *self.state.lock() = State::FinSent; + self.shutdown_waker.register(cx.waker()); + if matches!(*self.state.lock(), State::FinAcked) { + return Poll::Ready(Ok(())); + } + + // Initialize the timeout for FIN_ACK + let mut timeout = Box::pin(tokio::time::sleep(FIN_ACK_TIMEOUT)); + // Poll the timeout once to register it with tokio's timer + // This ensures we'll be woken when it expires + let _ = timeout.as_mut().poll(cx); + self.fin_ack_timeout = Some(timeout); + + Poll::Pending + } Err(_) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())), } } @@ -247,7 +492,13 @@ mod tests { substream.write_all(&vec![0u8; 1337]).await.unwrap(); - assert_eq!(handle.next().await, Some(Event::Message(vec![0u8; 1337]))); + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![0u8; 1337], + flag: None + }) + ); futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { Poll::Pending => Poll::Ready(()), @@ -264,13 +515,25 @@ mod tests { assert_eq!( handle.rx.recv().await, - Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + Some(Event::Message { + payload: vec![0u8; MAX_FRAME_SIZE], + flag: None, + }) ); assert_eq!( handle.rx.recv().await, - Some(Event::Message(vec![0u8; MAX_FRAME_SIZE])) + Some(Event::Message { + payload: vec![0u8; MAX_FRAME_SIZE], + flag: None, + }) + ); + assert_eq!( + handle.rx.recv().await, + Some(Event::Message { + payload: vec![0u8; 1], + flag: None, + }) ); - assert_eq!(handle.rx.recv().await, Some(Event::Message(vec![0u8; 1]))); futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { Poll::Pending => Poll::Ready(()), @@ -295,10 +558,38 @@ mod tests { let (mut substream, mut handle) = Substream::new(); substream.write_all(&vec![1u8; 1337]).await.unwrap(); - substream.shutdown().await.unwrap(); - assert_eq!(handle.next().await, Some(Event::Message(vec![1u8; 1337]))); - assert_eq!(handle.next().await, Some(Event::Close)); + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 1337], + flag: None, + }) + ); + // After shutdown, should send FIN flag + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Send FIN_ACK to complete shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + shutdown_task.await.unwrap(); } #[tokio::test] @@ -307,7 +598,7 @@ mod tests { handle .on_message(WebRtcMessage { payload: None, - flags: Some(0i32), + flag: Some(Flag::Fin), }) .await .unwrap(); @@ -321,7 +612,14 @@ mod tests { #[tokio::test] async fn read_small_frame() { let (mut substream, handle) = Substream::new(); - handle.tx.send(Event::Message(vec![1u8; 256])).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: vec![1u8; 256], + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 2048]; @@ -349,7 +647,14 @@ mod tests { let mut first = vec![1u8; 256]; first.extend_from_slice(&vec![2u8; 256]); - handle.tx.send(Event::Message(first)).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: first, + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 256]; @@ -385,8 +690,22 @@ mod tests { let mut first = vec![1u8; 256]; first.extend_from_slice(&vec![2u8; 256]); - handle.tx.send(Event::Message(first)).await.unwrap(); - handle.tx.send(Event::Message(vec![4u8; 2048])).await.unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: first, + flag: None, + }) + .await + .unwrap(); + handle + .inbound_tx + .send(Event::Message { + payload: vec![4u8; 2048], + flag: None, + }) + .await + .unwrap(); let mut buf = vec![0u8; 256]; @@ -500,4 +819,697 @@ mod tests { .expect("writer task did not complete after capacity was freed") .expect("writer task panicked"); } + + #[tokio::test] + async fn fin_flag_sent_on_shutdown() { + let (mut substream, mut handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Should receive FIN flag + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify state is FinSent + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete shutdown cleanly (avoids waiting for timeout) + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Wait for shutdown to complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn fin_ack_response_on_receiving_fin() { + let (mut substream, mut handle) = Substream::new(); + + // Spawn task to consume inbound events sent to the substream + let consumer_task = tokio::spawn(async move { + // Substream should receive RecvClosed + let mut buf = vec![0u8; 1024]; + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed + } + other => panic!("Unexpected result: {:?}", other), + } + }); + + // Simulate receiving FIN from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Wait for consumer task to complete + consumer_task.await.unwrap(); + + // Verify FIN_ACK was sent outbound to network + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + } + + #[tokio::test] + async fn fin_ack_received_transitions_to_fin_acked() { + let (mut substream, handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Simulate receiving FIN_ACK from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should transition to FinAcked + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should now complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn full_fin_handshake() { + let (mut substream, mut handle) = Substream::new(); + + // Write some data + substream.write_all(&vec![1u8; 100]).await.unwrap(); + + // Spawn shutdown in background since it will wait for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Verify data was sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 100], + flag: None, + }) + ); + + // Verify FIN was sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Simulate receiving FIN_ACK + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should be in FinAcked state + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should now complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn stop_sending_flag_closes_send_half() { + let (mut substream, handle) = Substream::new(); + + // Simulate receiving STOP_SENDING + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + // Should transition to SendClosed + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Attempting to write should fail + match substream.write_all(&vec![0u8; 100]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("write should have failed"), + } + } + + #[tokio::test] + async fn reset_stream_flag_closes_both_sides() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Simulate receiving RESET_STREAM + let result = handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + + // Should return connection closed error + assert!(matches!(result, Err(Error::ConnectionClosed))); + + // Write side should be closed (state = SendClosed) + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Attempting to write should fail + match substream.write_all(&vec![0u8; 100]).await { + Err(error) => assert_eq!(error.kind(), std::io::ErrorKind::BrokenPipe), + _ => panic!("write should have failed"), + } + + // Read side should also be closed (RecvClosed event was sent) + // The substream's rx channel should have RecvClosed + assert!(matches!(substream.rx.try_recv(), Ok(Event::RecvClosed))); + } + + #[tokio::test] + async fn fin_ack_does_not_trigger_other_flag() { + let (mut substream, handle) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now simulate receiving FIN_ACK (value = 3) + // This should NOT trigger STOP_SENDING (value = 1) or RESET_STREAM (value = 2) + // even though 3 & 1 == 1 and 3 & 2 == 2 + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Should transition to FinAcked, not SendClosed + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should complete + shutdown_task.await.unwrap(); + + // Writing should still work (not closed by STOP_SENDING) + // Note: We already sent FIN, so write won't actually work, but the state check happens + // first + } + + #[tokio::test] + async fn flags_are_mutually_exclusive() { + let (_substream, handle) = Substream::new(); + + // Test that STOP_SENDING (1) is handled correctly + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + assert!(matches!(*handle.state.lock(), State::SendClosed)); + + // Create a new substream for RESET_STREAM test + let (_substream2, handle2) = Substream::new(); + + // Test that RESET_STREAM (2) is handled correctly + let result = handle2 + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + + assert!(matches!(result, Err(Error::ConnectionClosed))); + + // Create a new substream for FIN test + let (mut substream3, handle3) = Substream::new(); + + // Spawn shutdown since it waits for FIN_ACK + let shutdown_task3 = tokio::spawn(async move { + substream3.shutdown().await.unwrap(); + }); + + // Wait a bit for FIN to be sent + tokio::task::yield_now().await; + + // Test that FIN_ACK (3) is handled correctly + handle3 + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + assert!(matches!(*handle3.state.lock(), State::FinAcked)); + + // Shutdown should complete + shutdown_task3.await.unwrap(); + } + + #[tokio::test] + async fn stop_sending_wakes_blocked_writer() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Fill up the channel to cause poll_write to return Pending + // Channel capacity is 256 + for _ in 0..256 { + substream.write_all(&[1u8; 100]).await.unwrap(); + } + + // Now the next write should block waiting for channel capacity + let write_task = tokio::spawn(async move { + // This write will block because channel is full + let result = substream.write_all(&[2u8; 100]).await; + // Should fail because STOP_SENDING was received + assert!(result.is_err()); + }); + + // Give the writer time to block on poll_reserve + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(!write_task.is_finished(), "write should be blocked"); + + // Simulate receiving STOP_SENDING from remote + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::StopSending), + }) + .await + .unwrap(); + + // The write task should wake up and see the state change + tokio::time::timeout(Duration::from_secs(1), write_task) + .await + .expect("write task should complete after STOP_SENDING") + .unwrap(); + } + + #[tokio::test] + async fn reset_stream_wakes_blocked_writer() { + use tokio::io::AsyncWriteExt; + let (mut substream, handle) = Substream::new(); + + // Fill up the channel to cause poll_write to return Pending + // Channel capacity is 256 + for _ in 0..256 { + substream.write_all(&[1u8; 100]).await.unwrap(); + } + + // Now the next write should block waiting for channel capacity + let write_task = tokio::spawn(async move { + // This write will block because channel is full + let result = substream.write_all(&[2u8; 100]).await; + // Should fail because RESET_STREAM was received + assert!(result.is_err()); + }); + + // Give the writer time to block on poll_reserve + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(!write_task.is_finished(), "write should be blocked"); + + // Simulate receiving RESET_STREAM from remote + let result = handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::ResetStream), + }) + .await; + // RESET_STREAM returns an error + assert!(result.is_err()); + + // The write task should wake up and see the state change + tokio::time::timeout(Duration::from_secs(1), write_task) + .await + .expect("write task should complete after RESET_STREAM") + .unwrap(); + } + + #[tokio::test] + async fn shutdown_rejects_new_writes() { + use tokio::io::AsyncWriteExt; + let (mut substream, mut handle) = Substream::new(); + + // Write some data + substream.write_all(&vec![1u8; 100]).await.unwrap(); + + // Spawn shutdown in background + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for data and FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![1u8; 100], + flag: None, + }) + ); + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify we transitioned through Closing to FinSent + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Shutdown should complete + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn shutdown_idempotent() { + use tokio::io::AsyncWriteExt; + let (mut substream, mut handle) = Substream::new(); + + // Spawn first shutdown + let shutdown_task1 = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + substream + }); + + // Wait for FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Send FIN_ACK to complete first shutdown + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // First shutdown should complete + let mut substream = shutdown_task1.await.unwrap(); + + // Second shutdown should succeed without error (already in FinAcked state) + substream.shutdown().await.unwrap(); + assert!(matches!(*handle.state.lock(), State::FinAcked)); + } + + #[tokio::test] + async fn shutdown_timeout_without_fin_ack() { + use tokio::time::{timeout, Duration}; + + let (mut substream, mut handle) = Substream::new(); + + // Spawn shutdown in background + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify we're in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // DON'T send FIN_ACK - let it timeout + // The shutdown should complete after FIN_ACK_TIMEOUT (2 seconds in tests) + // Add a bit of buffer to the timeout + let result = timeout(Duration::from_secs(4), shutdown_task).await; + + assert!(result.is_ok(), "Shutdown should complete after timeout"); + assert!( + result.unwrap().is_ok(), + "Shutdown should succeed after timeout" + ); + + // Should have transitioned to FinAcked after timeout + assert!(matches!(*handle.state.lock(), State::FinAcked)); + } + + #[tokio::test] + async fn closing_state_blocks_writes() { + use tokio::io::AsyncWriteExt; + + let (mut substream, handle) = Substream::new(); + + // Manually transition to Closing state + *handle.state.lock() = State::Closing; + + // Attempt to write should fail + let result = substream.write_all(&vec![1u8; 100]).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::BrokenPipe); + } + + #[tokio::test] + async fn handle_signals_closure_after_substream_dropped() { + use futures::StreamExt; + + let (mut substream, mut handle) = Substream::new(); + + // Complete shutdown handshake (client-initiated) + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + // Substream will be dropped here + }); + + // Receive FIN + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Send FIN_ACK + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Wait for shutdown to complete and Substream to drop + shutdown_task.await.unwrap(); + + // Verify handle signals closure (returns None) + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should signal closure after Substream is dropped" + ); + } + + #[tokio::test] + async fn server_side_closure_after_receiving_fin() { + use futures::StreamExt; + + let (mut substream, mut handle) = Substream::new(); + + // Spawn task to consume from substream (server side) + let server_task = tokio::spawn(async move { + let mut buf = vec![0u8; 1024]; + // This should fail because we receive RecvClosed + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed by FIN + } + other => panic!("Unexpected result: {:?}", other), + } + // Substream dropped here (server closes after receiving FIN) + }); + + // Remote (client) sends FIN + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Verify FIN_ACK was sent back + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + + // Wait for server to close substream + server_task.await.unwrap(); + + // Verify handle signals closure (returns None) - this is the key fix! + assert_eq!( + handle.next().await, + None, + "SubstreamHandle should signal closure after server receives FIN and drops Substream" + ); + } + + #[tokio::test] + async fn simultaneous_close() { + // Test simultaneous close where both sides send FIN at the same time. + // This verifies that: + // 1. Both sides can be in FinSent state simultaneously + // 2. Both sides correctly respond to FIN with FIN_ACK even when in FinSent state + // 3. Both sides eventually transition to FinAcked + + let (mut substream, mut handle) = Substream::new(); + + // Local side initiates shutdown (sends FIN, transitions to FinSent) + let shutdown_task = tokio::spawn(async move { + substream.shutdown().await.unwrap(); + }); + + // Wait for local FIN to be sent + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::Fin) + }) + ); + + // Verify local is in FinSent state + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now simulate remote also sending FIN (simultaneous close) + // This should trigger FIN_ACK response even though we're in FinSent state + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // Local should send FIN_ACK in response to remote's FIN + assert_eq!( + handle.next().await, + Some(Event::Message { + payload: vec![], + flag: Some(Flag::FinAck) + }) + ); + + // Local should still be in FinSent (waiting for FIN_ACK from remote) + assert!(matches!(*handle.state.lock(), State::FinSent)); + + // Now remote sends FIN_ACK (completing their side of the handshake) + handle + .on_message(WebRtcMessage { + payload: None, + flag: Some(Flag::FinAck), + }) + .await + .unwrap(); + + // Local should now transition to FinAcked + assert!(matches!(*handle.state.lock(), State::FinAcked)); + + // Shutdown should complete successfully + shutdown_task.await.unwrap(); + } + + #[tokio::test] + async fn fin_with_payload_delivers_data_before_close() { + // Test that when a FIN message contains payload data, the data is delivered + // to the substream before the RecvClosed event. This is important because + // the spec allows a FIN message to contain final data. + + let (mut substream, handle) = Substream::new(); + + // Simulate receiving FIN with payload from remote + handle + .on_message(WebRtcMessage { + payload: Some(b"final data".to_vec()), + flag: Some(Flag::Fin), + }) + .await + .unwrap(); + + // First, we should receive the payload data + let mut buf = vec![0u8; 1024]; + let n = substream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"final data"); + + // Then, subsequent read should fail with BrokenPipe (RecvClosed) + match substream.read(&mut buf).await { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + // Expected - read half closed after FIN + } + other => panic!("Expected BrokenPipe error, got: {:?}", other), + } + } } diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index 55917afc6..ae050d50d 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -18,74 +18,97 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{codec::unsigned_varint::UnsignedVarint, error::ParseError, transport::webrtc::schema}; +use crate::{ + error::ParseError, + transport::webrtc::schema::{self, webrtc::message::Flag}, +}; use prost::Message; -use tokio_util::codec::{Decoder, Encoder}; -/// WebRTC mesage. +/// WebRTC message. #[derive(Debug)] pub struct WebRtcMessage { /// Payload. pub payload: Option>, - // Flags. - pub flags: Option, + /// Flag. + pub flag: Option, } impl WebRtcMessage { - /// Encode WebRTC message. - pub fn encode(payload: Vec) -> Vec { + /// Encode WebRTC message with optional flag. + /// + /// Uses a single allocation by pre-calculating the total size and encoding + /// the varint length prefix and protobuf message directly into the output buffer. + pub fn encode(payload: Vec, flag: Option) -> Vec { let protobuf_payload = schema::webrtc::Message { message: (!payload.is_empty()).then_some(payload), - flag: None, + flag: flag.map(|f| f as i32), }; - let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); - protobuf_payload - .encode(&mut payload) - .expect("Vec to provide needed capacity"); - let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - let mut codec = UnsignedVarint::new(None); - let _result = codec.encode(payload.into(), &mut out_buf); + // Calculate sizes upfront for single allocation with exact capacity + let protobuf_len = protobuf_payload.encoded_len(); + // Varint uses 7 bits per byte, so calculate exact length needed + // ilog2 gives the position of the highest set bit (0-indexed), divide by 7 for varint bytes + let varint_len = if protobuf_len == 0 { + 1 + } else { + (protobuf_len.ilog2() as usize / 7) + 1 + }; - out_buf.into() - } + // Single allocation for the entire output with exact size + let mut out_buf = Vec::with_capacity(varint_len + protobuf_len); - /// Encode WebRTC message with flags. - #[allow(unused)] - pub fn encode_with_flags(payload: Vec, flags: i32) -> Vec { - let protobuf_payload = schema::webrtc::Message { - message: (!payload.is_empty()).then_some(payload), - flag: Some(flags), - }; - let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); + // Encode varint length prefix directly + let mut varint_buf = unsigned_varint::encode::usize_buffer(); + let varint_slice = unsigned_varint::encode::usize(protobuf_len, &mut varint_buf); + out_buf.extend_from_slice(varint_slice); + + // Encode protobuf directly into output buffer protobuf_payload - .encode(&mut payload) + .encode(&mut out_buf) .expect("Vec to provide needed capacity"); - let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - let mut codec = UnsignedVarint::new(None); - let _result = codec.encode(payload.into(), &mut out_buf); - - out_buf.into() + out_buf } /// Decode payload into [`WebRtcMessage`]. + /// + /// Decodes the varint length prefix directly from the slice without allocations, + /// then decodes the protobuf message from the remaining bytes. + /// + /// # Flag handling + /// + /// Unknown flag values (e.g., from a newer protocol version) are logged as warnings + /// and treated as `None` for forward compatibility. This allows the message payload + /// to still be processed even if the flag is not recognized. pub fn decode(payload: &[u8]) -> Result { - // TODO: https://github.com/paritytech/litep2p/issues/352 set correct size - let mut codec = UnsignedVarint::new(None); - let mut data = bytes::BytesMut::from(payload); - let result = codec - .decode(&mut data) - .map_err(|_| ParseError::InvalidData)? - .ok_or(ParseError::InvalidData)?; - - match schema::webrtc::Message::decode(result) { - Ok(message) => Ok(Self { - payload: message.message, - flags: message.flag, - }), + // Decode varint length prefix directly from slice (no allocation) + // Returns (decoded_length, remaining_bytes_after_varint) + let (len, remaining) = + unsigned_varint::decode::usize(payload).map_err(|_| ParseError::InvalidData)?; + + // Get exactly `len` bytes of protobuf data (no allocation) + let protobuf_data = remaining.get(..len).ok_or(ParseError::InvalidData)?; + + match schema::webrtc::Message::decode(protobuf_data) { + Ok(message) => { + let flag = message.flag.and_then(|f| match Flag::try_from(f) { + Ok(flag) => Some(flag), + Err(_) => { + tracing::warn!( + target: "litep2p::webrtc", + ?f, + "received message with unknown flag value, ignoring flag" + ); + None + } + }); + Ok(Self { + payload: message.message, + flag, + }) + } Err(_) => Err(ParseError::InvalidData), } } @@ -96,29 +119,30 @@ mod tests { use super::*; #[test] - fn with_payload_no_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec()); + fn with_payload_no_flag() { + let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, None); + assert_eq!(decoded.flag, None); } #[test] - fn with_payload_and_flags() { - let message = WebRtcMessage::encode_with_flags("Hello, world!".as_bytes().to_vec(), 1i32); + fn with_payload_and_flag() { + let message = + WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(Flag::StopSending)); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, Some(1i32)); + assert_eq!(decoded.flag, Some(Flag::StopSending)); } #[test] - fn no_payload_with_flags() { - let message = WebRtcMessage::encode_with_flags(vec![], 2i32); + fn no_payload_with_flag() { + let message = WebRtcMessage::encode(vec![], Some(Flag::ResetStream)); let decoded = WebRtcMessage::decode(&message).unwrap(); assert_eq!(decoded.payload, None); - assert_eq!(decoded.flags, Some(2i32)); + assert_eq!(decoded.flag, Some(Flag::ResetStream)); } }