Skip to content

Commit 4155015

Browse files
committed
Add unit tests
1 parent 9b3f916 commit 4155015

File tree

3 files changed

+321
-11
lines changed

3 files changed

+321
-11
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@ serde = { version = "1.0", features = ["derive"] }
2121

2222
[dev-dependencies]
2323
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"]}
24-
rand = "0.8"
24+
rand = "0.8"
25+
futures-util = { version = "0.3", features = ["io"] }
26+
tokio-util = "0.7"

src/lib.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ use std::{
1414
};
1515
use tokio::sync::{mpsc, oneshot, Mutex};
1616

17+
#[cfg(test)]
18+
mod tests;
19+
1720
#[derive(Deserialize, Serialize)]
1821
struct InternalMessage<T> {
1922
user_message: T,
@@ -153,11 +156,9 @@ pub struct AsyncReadConverse<
153156
pending_reply: Vec<ReplySender<T>>,
154157
}
155158

156-
impl<
157-
R: AsyncRead + Unpin,
158-
W: AsyncWrite + Unpin,
159-
T: Serialize + DeserializeOwned + Unpin,
160-
> AsyncReadConverse<R, W, T> {
159+
impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
160+
AsyncReadConverse<R, W, T>
161+
{
161162
pub fn inner(&self) -> &R {
162163
self.raw.inner()
163164
}
@@ -226,7 +227,7 @@ impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwne
226227
} else {
227228
continue;
228229
}
229-
},
230+
}
230231
None => return Poll::Ready(None),
231232
}
232233
}
@@ -243,10 +244,7 @@ pub struct AsyncWriteConverse<W: AsyncWrite + Unpin, T: Serialize + DeserializeO
243244
next_id: u64,
244245
}
245246

246-
impl<
247-
W: AsyncWrite + Unpin,
248-
T: Serialize + DeserializeOwned + Unpin,
249-
> AsyncWriteConverse<W, T> {
247+
impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
250248
pub async fn with_inner<F: FnOnce(&W) -> R, R>(&self, f: F) -> R {
251249
f(self.raw.lock().await.inner())
252250
}

src/tests.rs

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
use std::{
2+
io,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
time::Duration,
6+
};
7+
8+
use futures_io::{AsyncRead, AsyncWrite};
9+
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
10+
use tokio::sync::mpsc::{self, Receiver};
11+
use tokio_util::sync::PollSender;
12+
13+
// What follows is an intentionally obnoxious `AsyncRead` and `AsyncWrite` implementation. Please don't use this outside of tests.
14+
struct BasicChannelSender {
15+
sender: PollSender<Vec<u8>>,
16+
}
17+
18+
impl AsyncWrite for BasicChannelSender {
19+
fn poll_write(
20+
mut self: Pin<&mut Self>,
21+
cx: &mut Context<'_>,
22+
buf: &[u8],
23+
) -> Poll<futures_io::Result<usize>> {
24+
if futures_core::ready!(self.sender.poll_reserve(cx)).is_err() {
25+
return Poll::Ready(Err(io::Error::new(
26+
io::ErrorKind::ConnectionAborted,
27+
"remote hung up",
28+
)));
29+
}
30+
let write_len = buf.len();
31+
self.sender
32+
.send_item(buf.to_vec())
33+
.expect("receiver hung up!");
34+
Poll::Ready(Ok(write_len))
35+
}
36+
37+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
38+
Poll::Ready(Ok(()))
39+
}
40+
41+
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
42+
self.sender.close();
43+
Poll::Ready(Ok(()))
44+
}
45+
}
46+
47+
struct BasicChannelReceiver {
48+
receiver: Receiver<Vec<u8>>,
49+
last_received: Vec<u8>,
50+
}
51+
52+
impl AsyncRead for BasicChannelReceiver {
53+
fn poll_read(
54+
mut self: Pin<&mut Self>,
55+
cx: &mut Context<'_>,
56+
buf: &mut [u8],
57+
) -> Poll<futures_io::Result<usize>> {
58+
let mut len_written = 0;
59+
loop {
60+
if self.last_received.len() > 0 {
61+
let copy_len = self.last_received.len().min(buf.len() - len_written);
62+
buf[len_written..(len_written + copy_len)]
63+
.copy_from_slice(&self.last_received[..copy_len]);
64+
self.last_received = self.last_received.split_off(copy_len);
65+
len_written += copy_len;
66+
if len_written == buf.len() {
67+
return Poll::Ready(Ok(buf.len()));
68+
}
69+
} else {
70+
self.last_received = match self.receiver.poll_recv(cx) {
71+
Poll::Ready(Some(v)) => v,
72+
Poll::Ready(None) => {
73+
return if len_written > 0 {
74+
Poll::Ready(Ok(len_written))
75+
} else {
76+
Poll::Pending
77+
}
78+
}
79+
Poll::Pending => {
80+
return if len_written > 0 {
81+
Poll::Ready(Ok(len_written))
82+
} else {
83+
Poll::Pending
84+
}
85+
}
86+
}
87+
}
88+
}
89+
}
90+
}
91+
92+
fn basic_channel() -> (BasicChannelSender, BasicChannelReceiver) {
93+
let (sender, receiver) = mpsc::channel(32);
94+
(
95+
BasicChannelSender {
96+
sender: PollSender::new(sender),
97+
},
98+
BasicChannelReceiver {
99+
receiver,
100+
last_received: Vec::new(),
101+
},
102+
)
103+
}
104+
105+
// This tests our testing equipment, just makes sure the above implementations are correct.
106+
#[tokio::test(flavor = "multi_thread")]
107+
async fn basic_channel_test() {
108+
{
109+
let (mut sender, mut receiver) = basic_channel();
110+
let message = "Hello World!".as_bytes();
111+
let mut read_buf = vec![0; message.len()];
112+
let write = tokio::spawn(async move { sender.write_all(message).await });
113+
tokio::time::timeout(Duration::from_secs(2), receiver.read_exact(&mut read_buf))
114+
.await
115+
.unwrap()
116+
.unwrap();
117+
write.await.unwrap().unwrap();
118+
assert_eq!(message, read_buf);
119+
}
120+
{
121+
let (sender, mut receiver) = basic_channel();
122+
let mut sender = Some(sender);
123+
for _ in 0..10 {
124+
let message = (0..255).collect::<Vec<u8>>();
125+
let mut read_buf = vec![0; message.len()];
126+
let message_clone = message.clone();
127+
let mut sender_inner = sender.take().unwrap();
128+
let write = tokio::spawn(async move {
129+
sender_inner.write_all(&message_clone).await.unwrap();
130+
sender_inner
131+
});
132+
tokio::time::timeout(Duration::from_secs(2), receiver.read_exact(&mut read_buf))
133+
.await
134+
.unwrap()
135+
.unwrap();
136+
sender = Some(write.await.unwrap());
137+
assert_eq!(message, read_buf);
138+
}
139+
}
140+
}
141+
142+
use std::{
143+
sync::{
144+
atomic::{AtomicU32, Ordering},
145+
Arc,
146+
},
147+
time::Instant,
148+
};
149+
150+
use serde::{Deserialize, Serialize};
151+
152+
use futures_util::StreamExt;
153+
use rand::Rng;
154+
155+
use super::*;
156+
157+
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
158+
pub enum TestMessage {
159+
HelloThere,
160+
GeneralKenobiYouAreABoldOne,
161+
}
162+
163+
#[tokio::test]
164+
async fn basic_dialogue() {
165+
let (server_write, client_read) = basic_channel();
166+
let (client_write, server_read) = basic_channel();
167+
let (server_read, mut server_write) = new_duplex_connection(server_read, server_write);
168+
let (mut client_read, _client_write) = new_duplex_connection(client_read, client_write);
169+
server_read.drive_forever();
170+
tokio::spawn(async move {
171+
while let Some(message) = client_read.next().await {
172+
let mut received_message = message.unwrap();
173+
let message = received_message.take_message();
174+
match message {
175+
TestMessage::HelloThere => received_message
176+
.reply(TestMessage::GeneralKenobiYouAreABoldOne)
177+
.await
178+
.unwrap(),
179+
TestMessage::GeneralKenobiYouAreABoldOne => panic!("Wait, that's my line!"),
180+
}
181+
}
182+
});
183+
assert_eq!(
184+
server_write.ask(TestMessage::HelloThere).await.unwrap(),
185+
TestMessage::GeneralKenobiYouAreABoldOne
186+
);
187+
}
188+
189+
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
190+
pub enum IdentifiableMessage {
191+
FromServer(u32),
192+
FromClient(u32),
193+
}
194+
195+
#[tokio::test(flavor = "multi_thread")]
196+
async fn flurry_of_communication() {
197+
const TEST_DURATION: Duration = Duration::from_secs(1);
198+
const TEST_COUNT: u32 = 5;
199+
let tests_complete = Arc::new(AtomicU32::new(0));
200+
let start = Instant::now();
201+
for _ in 0..TEST_COUNT {
202+
let tests_complete = Arc::clone(&tests_complete);
203+
tokio::spawn(async move {
204+
let (server_write, client_read) = basic_channel();
205+
let (client_write, server_read) = basic_channel();
206+
let (mut server_read, mut server_write) =
207+
new_duplex_connection(server_read, server_write);
208+
let (mut client_read, mut client_write) =
209+
new_duplex_connection(client_read, client_write);
210+
tokio::spawn(async move {
211+
while let Some(message) = client_read.next().await {
212+
let mut received_message = message.unwrap();
213+
let message = received_message.take_message();
214+
match message {
215+
IdentifiableMessage::FromServer(u) => received_message
216+
.reply(IdentifiableMessage::FromClient(u))
217+
.await
218+
.unwrap(),
219+
IdentifiableMessage::FromClient(_) => panic!(
220+
"Received message from client as client, this should never happen"
221+
),
222+
}
223+
}
224+
});
225+
tokio::spawn(async move {
226+
while let Some(message) = server_read.next().await {
227+
let mut received_message = message.unwrap();
228+
let message = received_message.take_message();
229+
match message {
230+
IdentifiableMessage::FromClient(u) => received_message
231+
.reply(IdentifiableMessage::FromServer(u))
232+
.await
233+
.unwrap(),
234+
IdentifiableMessage::FromServer(_) => panic!(
235+
"Received message from server as server, this should never happen"
236+
),
237+
}
238+
}
239+
});
240+
let start = Instant::now();
241+
while start.elapsed() < TEST_DURATION {
242+
let code = rand::thread_rng().gen::<u32>();
243+
if rand::thread_rng().gen::<bool>() {
244+
assert_eq!(
245+
server_write
246+
.ask(IdentifiableMessage::FromServer(code))
247+
.await
248+
.unwrap(),
249+
IdentifiableMessage::FromClient(code)
250+
);
251+
} else {
252+
assert_eq!(
253+
client_write
254+
.ask(IdentifiableMessage::FromClient(code))
255+
.await
256+
.unwrap(),
257+
IdentifiableMessage::FromServer(code)
258+
);
259+
}
260+
}
261+
tests_complete.fetch_add(1, Ordering::Relaxed);
262+
});
263+
}
264+
while tests_complete.load(Ordering::Relaxed) < TEST_COUNT && start.elapsed() < TEST_DURATION * 2
265+
{
266+
}
267+
assert!(start.elapsed() >= TEST_DURATION);
268+
assert!(start.elapsed() < TEST_DURATION * 2);
269+
}
270+
271+
#[tokio::test]
272+
async fn timeout_check() {
273+
let (server_write, client_read) = basic_channel();
274+
let (client_write, server_read) = basic_channel();
275+
let (server_read, mut server_write) = new_duplex_connection(server_read, server_write);
276+
let (mut client_read, _client_write) = new_duplex_connection(client_read, client_write);
277+
server_read.drive_forever();
278+
tokio::spawn(async move {
279+
while let Some(message) = client_read.next().await {
280+
let mut received_message = message.unwrap();
281+
let message = received_message.take_message();
282+
match message {
283+
TestMessage::HelloThere => received_message
284+
.reply(TestMessage::GeneralKenobiYouAreABoldOne)
285+
.await
286+
.unwrap(),
287+
TestMessage::GeneralKenobiYouAreABoldOne => {
288+
// Do nothing.
289+
}
290+
}
291+
}
292+
});
293+
let start = Instant::now();
294+
let timeout = Duration::from_secs(1);
295+
assert!(matches!(
296+
server_write
297+
.ask_timeout(timeout, TestMessage::GeneralKenobiYouAreABoldOne)
298+
.await,
299+
Err(Error::Timeout)
300+
));
301+
let elapsed = start.elapsed();
302+
assert!(elapsed < timeout * 2);
303+
assert!(elapsed >= timeout);
304+
assert!(matches!(
305+
server_write
306+
.ask_timeout(timeout, TestMessage::HelloThere)
307+
.await,
308+
Ok(TestMessage::GeneralKenobiYouAreABoldOne)
309+
));
310+
}

0 commit comments

Comments
 (0)