Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 21 additions & 57 deletions tests/multi_packet.rs
Original file line number Diff line number Diff line change
@@ -1,80 +1,44 @@
//! Tests for multi-packet responses using channels.

use futures::TryStreamExt;
use rstest::rstest;
use tokio::sync::mpsc;
use wireframe::Response;
use wireframe_testing::collect_multi_packet;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[derive(PartialEq, Debug)]
struct TestMsg(u8);

/// Drain all messages from the stream.
async fn drain_all(stream: wireframe::FrameStream<TestMsg, ()>) -> Vec<TestMsg> {
stream.try_collect::<Vec<_>>().await.expect("stream error")
}

/// Verify that all messages sent through the channel are yielded via
/// `Response::into_stream()` for the `MultiPacket` variant.
#[tokio::test]
async fn multi_packet_yields_messages() {
let (tx, rx) = mpsc::channel(4);
tx.send(TestMsg(1)).await.expect("send");
tx.send(TestMsg(2)).await.expect("send");
drop(tx);

let resp: Response<TestMsg, ()> = Response::MultiPacket(rx);
let received = drain_all(resp.into_stream()).await;
assert_eq!(received, vec![TestMsg(1), TestMsg(2)]);
}

/// Yields no messages when the channel is immediately closed.
#[tokio::test]
async fn multi_packet_empty_channel() {
let (tx, rx) = mpsc::channel(4);
drop(tx);
let resp: Response<TestMsg, ()> = Response::MultiPacket(rx);
let received = drain_all(resp.into_stream()).await;
assert!(received.is_empty());
}
const CAPACITY: usize = 2;

/// Stops yielding when the sender is dropped before all messages are sent.
#[tokio::test]
async fn multi_packet_sender_dropped_before_all_messages() {
let (tx, rx) = mpsc::channel(4);
tx.send(TestMsg(1)).await.expect("send");
drop(tx);
let resp: Response<TestMsg, ()> = Response::MultiPacket(rx);
let received = drain_all(resp.into_stream()).await;
assert_eq!(received, vec![TestMsg(1)]);
}

/// Test that sending fails after the receiver is dropped.
#[tokio::test]
async fn multi_packet_send_fails_after_receiver_dropped() {
let (tx, rx) = mpsc::channel::<TestMsg>(2);
drop(rx);
let error = tx
.send(TestMsg(42))
.await
.expect_err("Send should fail when receiver is dropped");
let mpsc::error::SendError(msg) = error;
assert_eq!(msg, TestMsg(42));
/// Drain all messages from a `FrameStream` for non-channel response variants.
async fn drain_all<F, E: std::fmt::Debug>(stream: wireframe::FrameStream<F, E>) -> Vec<F> {
stream.try_collect::<Vec<_>>().await.expect("stream error")
}

/// Handles more messages than the channel capacity allows.
/// `collect_multi_packet` drains every frame regardless of channel state.
///
/// This covers empty channels, partial sends, and when senders outpace the
/// channel's capacity.
#[rstest(count, case(0), case(1), case(2), case(CAPACITY + 1))]
#[tokio::test]
async fn multi_packet_handles_channel_capacity() {
let (tx, rx) = mpsc::channel(2);
async fn multi_packet_drains_all_messages(count: usize) {
let (tx, rx) = mpsc::channel(CAPACITY);
let send_task = tokio::spawn(async move {
for i in 0..4u8 {
tx.send(TestMsg(i)).await.expect("send");
for i in 0..count {
tx.send(TestMsg(u8::try_from(i).expect("<= u8::MAX")))
.await
.expect("send");
}
});
let resp: Response<TestMsg, ()> = Response::MultiPacket(rx);
let received = drain_all(resp.into_stream()).await;
let received = collect_multi_packet(resp).await;
send_task.await.expect("sender join");
assert_eq!(
received,
vec![TestMsg(0), TestMsg(1), TestMsg(2), TestMsg(3)]
(0..count)
.map(|i| TestMsg(u8::try_from(i).expect("<= u8::MAX")))
.collect::<Vec<_>>()
);
}

Expand Down
14 changes: 2 additions & 12 deletions tests/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::{net::SocketAddr, sync::Arc};

use async_stream::try_stream;
use cucumber::World;
use futures::TryStreamExt;
use tokio::{net::TcpStream, sync::oneshot};
use tokio_util::sync::CancellationToken;
use wireframe::{
Expand All @@ -20,6 +19,7 @@ use wireframe::{
serializer::BincodeSerializer,
server::WireframeServer,
};
use wireframe_testing::collect_multi_packet;

type TestApp = wireframe::app::WireframeApp<BincodeSerializer, (), Envelope>;

Expand Down Expand Up @@ -216,28 +216,18 @@ pub struct MultiPacketWorld {
}

impl MultiPacketWorld {
async fn drain(&mut self, resp: wireframe::Response<u8, ()>) {
let frames = resp
.into_stream()
.try_collect::<Vec<_>>()
.await
.expect("stream error");
self.messages.extend(frames);
}

/// Helper method to process messages through a multi-packet response.
///
/// # Panics
/// Panics if sending to the channel fails.
async fn process_messages(&mut self, messages: &[u8]) {
self.messages.clear();
let (tx, ch_rx) = tokio::sync::mpsc::channel(4);
for &msg in messages {
tx.send(msg).await.expect("send");
}
drop(tx);
let resp: wireframe::Response<u8, ()> = wireframe::Response::MultiPacket(ch_rx);
self.drain(resp).await;
self.messages = collect_multi_packet(resp).await;
}

/// Send messages through a multi-packet response and record them.
Expand Down
3 changes: 3 additions & 0 deletions wireframe_testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

pub mod helpers;
pub mod logging;
pub mod multi_packet;

pub use helpers::{
TEST_MAX_FRAME,
Expand All @@ -39,3 +40,5 @@ pub use helpers::{
run_with_duplex_server,
};
pub use logging::{LoggerHandle, logger};
#[doc(inline)]
pub use multi_packet::collect_multi_packet;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
46 changes: 46 additions & 0 deletions wireframe_testing/src/multi_packet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//! Helpers for draining `Response::MultiPacket` values in tests.
//!
//! These utilities collect all frames from a [`Response::MultiPacket`] into a
//! `Vec`, enabling concise assertions in tests and Cucumber steps.

use wireframe::Response;

/// Collect all frames from a [`Response::MultiPacket`].
///
/// # Examples
///
/// ```rust
/// use tokio::sync::mpsc;
/// use wireframe::Response;
/// use wireframe_testing::collect_multi_packet;
///
/// # async fn demo() {
/// let (tx, rx) = mpsc::channel(4);
/// tx.send(1u8).await.expect("send");
/// drop(tx);
/// let frames = collect_multi_packet(Response::MultiPacket(rx)).await;
/// assert_eq!(frames, vec![1]);
/// # }
/// ```
Comment thread
coderabbitai[bot] marked this conversation as resolved.
///
/// # Panics
/// Panics if `resp` is not [`Response::MultiPacket`]; the panic message names
/// the received variant and is attributed to the caller.
#[must_use]
#[track_caller]
#[allow(ungated_async_fn_track_caller)] // track_caller on async is unstable
pub async fn collect_multi_packet<F, E>(resp: Response<F, E>) -> Vec<F> {
match resp {
Response::MultiPacket(mut rx) => {
let mut frames = Vec::new();
while let Some(frame) = rx.recv().await {
frames.push(frame);
}
frames
}
Response::Single(_) => panic!("collect_multi_packet received Response::Single"),
Response::Vec(_) => panic!("collect_multi_packet received Response::Vec"),
Response::Stream(_) => panic!("collect_multi_packet received Response::Stream"),
Response::Empty => panic!("collect_multi_packet received Response::Empty"),
}
}
Loading