diff --git a/Cargo.lock b/Cargo.lock index ea44b9ce..9e32c66e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2319,6 +2319,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.16" @@ -2985,6 +2996,7 @@ dependencies = [ "static_assertions", "thiserror 2.0.16", "tokio", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 17ed773c..3c8972f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ tokio = { version = "1.47.1", default-features = false, features = [ "io-util", ] } tokio-util = { version = "0.7.16", features = ["rt", "codec"] } +tokio-stream = "0.1.17" futures = "0.3.31" async-trait = "0.1.89" bytes = "1.10.1" diff --git a/docs/multi-packet-and-streaming-responses-design.md b/docs/multi-packet-and-streaming-responses-design.md index 4481f46d..016675f3 100644 --- a/docs/multi-packet-and-streaming-responses-design.md +++ b/docs/multi-packet-and-streaming-responses-design.md @@ -110,6 +110,11 @@ This design allows simple, single-frame handlers to remain unchanged (`Ok(my_frame.into())`) while providing powerful and efficient options for more complex cases. +To simplify consumption, `Response::into_stream` converts any `Response` +variant into a `FrameStream`. Downstream code can iterate over frames without +matching `MultiPacket` or wiring channels. Both `Response::Vec` with an empty +vector and `Response::Empty` yield an empty stream. + ### 4.2 The `WireframeError` Enum To enable more robust error handling, a generic error enum will be introduced. diff --git a/src/response.rs b/src/response.rs index 58c3b205..7ef2df3e 100644 --- a/src/response.rs +++ b/src/response.rs @@ -31,8 +31,9 @@ use std::pin::Pin; -use futures::stream::Stream; +use futures::{Stream, StreamExt, stream}; use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; /// A type alias for a type-erased, dynamically dispatched stream of frames. /// @@ -113,6 +114,45 @@ impl From> for Response { fn from(v: Vec) -> Self { Response::Vec(v) } } +impl Response { + /// Convert this response into a stream of frames. + /// + /// `Response::Vec` with no frames and `Response::Empty` produce an empty + /// stream. + /// + /// # Examples + /// + /// ``` + /// use futures::TryStreamExt; + /// use wireframe::Response; + /// + /// # async fn demo() { + /// let (tx, rx) = tokio::sync::mpsc::channel(1); + /// tx.send(1u8).await.expect("send"); + /// drop(tx); + /// let resp: Response = Response::MultiPacket(rx); + /// let frames: Vec = resp + /// .into_stream() + /// .try_collect() + /// .await + /// .expect("stream error"); + /// assert_eq!(frames, vec![1]); + /// # } + /// ``` + #[must_use] + pub fn into_stream(self) -> FrameStream { + match self { + Response::Single(f) => { + stream::once(async move { Ok::>(f) }).boxed() + } + Response::Vec(frames) => stream::iter(frames.into_iter().map(Ok)).boxed(), + Response::Stream(s) => s, + Response::MultiPacket(rx) => ReceiverStream::new(rx).map(Ok).boxed(), + Response::Empty => stream::empty().boxed(), + } + } +} + /// A generic error type for wireframe operations. /// /// # Examples diff --git a/tests/multi_packet.rs b/tests/multi_packet.rs index 73cd97d4..beed6178 100644 --- a/tests/multi_packet.rs +++ b/tests/multi_packet.rs @@ -1,21 +1,19 @@ //! Tests for multi-packet responses using channels. +use futures::TryStreamExt; use tokio::sync::mpsc; use wireframe::Response; #[derive(PartialEq, Debug)] struct TestMsg(u8); -/// Drain all messages from the receiver. -async fn drain_all(mut rx: mpsc::Receiver) -> Vec { - let mut messages = Vec::new(); - while let Some(msg) = rx.recv().await { - messages.push(msg); - } - messages +/// Drain all messages from the stream. +async fn drain_all(stream: wireframe::FrameStream) -> Vec { + stream.try_collect::>().await.expect("stream error") } -/// Verifies that all messages sent through the channel are yielded by `Response::MultiPacket`. +/// Verify that all messages sent through the channel are yielded via +/// `Response::into_stream()` for the `MultiPacket` variant. #[tokio::test] async fn multi_packet_yields_messages() { let (tx, rx) = mpsc::channel(4); @@ -24,11 +22,7 @@ async fn multi_packet_yields_messages() { drop(tx); let resp: Response = Response::MultiPacket(rx); - let received = if let Response::MultiPacket(rx) = resp { - drain_all(rx).await - } else { - unreachable!() - }; + let received = drain_all(resp.into_stream()).await; assert_eq!(received, vec![TestMsg(1), TestMsg(2)]); } @@ -38,11 +32,7 @@ async fn multi_packet_empty_channel() { let (tx, rx) = mpsc::channel(4); drop(tx); let resp: Response = Response::MultiPacket(rx); - let received = if let Response::MultiPacket(rx) = resp { - drain_all(rx).await - } else { - unreachable!() - }; + let received = drain_all(resp.into_stream()).await; assert!(received.is_empty()); } @@ -53,11 +43,7 @@ async fn multi_packet_sender_dropped_before_all_messages() { tx.send(TestMsg(1)).await.expect("send"); drop(tx); let resp: Response = Response::MultiPacket(rx); - let received = if let Response::MultiPacket(rx) = resp { - drain_all(rx).await - } else { - unreachable!() - }; + let received = drain_all(resp.into_stream()).await; assert_eq!(received, vec![TestMsg(1)]); } @@ -71,14 +57,26 @@ async fn multi_packet_handles_channel_capacity() { } }); let resp: Response = Response::MultiPacket(rx); - let received = if let Response::MultiPacket(rx) = resp { - drain_all(rx).await - } else { - unreachable!() - }; + let received = drain_all(resp.into_stream()).await; send_task.await.expect("sender join"); assert_eq!( received, vec![TestMsg(0), TestMsg(1), TestMsg(2), TestMsg(3)] ); } + +/// Returns an empty stream for an empty vector response. +#[tokio::test] +async fn vec_empty_returns_empty_stream() { + let resp: Response = Response::Vec(Vec::new()); + let received = drain_all(resp.into_stream()).await; + assert!(received.is_empty()); +} + +/// `Response::Empty` yields no frames. +#[tokio::test] +async fn empty_returns_empty_stream() { + let resp: Response = Response::Empty; + let received = drain_all(resp.into_stream()).await; + assert!(received.is_empty()); +} diff --git a/tests/world.rs b/tests/world.rs index 54b12114..816b17de 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -8,6 +8,7 @@ use std::{net::SocketAddr, sync::Arc}; use async_stream::try_stream; use cucumber::World; +use futures::TryStreamExt; use tokio::{net::TcpStream, sync::oneshot}; use tokio_util::sync::CancellationToken; use wireframe::{ @@ -216,11 +217,12 @@ pub struct MultiPacketWorld { impl MultiPacketWorld { async fn drain(&mut self, resp: wireframe::Response) { - if let wireframe::Response::MultiPacket(mut mp_rx) = resp { - while let Some(msg) = mp_rx.recv().await { - self.messages.push(msg); - } - } + let frames = resp + .into_stream() + .try_collect::>() + .await + .expect("stream error"); + self.messages.extend(frames); } /// Helper method to process messages through a multi-packet response.