Skip to content

Commit c6a50af

Browse files
[runtime] Fix Network Flake (#3494)
1 parent 45f7d9f commit c6a50af

2 files changed

Lines changed: 113 additions & 59 deletions

File tree

runtime/src/network/mod.rs

Lines changed: 112 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ stability_scope!(ALPHA, cfg(all(not(target_arch = "wasm32"), feature = "iouring-
1717
#[cfg(test)]
1818
mod tests {
1919
use crate::{IoBuf, IoBufs, Listener, Sink, Stream};
20+
use commonware_utils::sync::Barrier;
2021
use futures::join;
21-
use std::net::SocketAddr;
22+
use std::{net::SocketAddr, sync::Arc};
23+
use tokio::task::JoinSet;
2224

2325
const CLIENT_SEND_DATA: &[u8] = b"client_send_data";
2426
const SERVER_SEND_DATA: &[u8] = b"server_send_data";
@@ -47,12 +49,11 @@ mod tests {
4749
// Get the local address of the listener
4850
let listener_addr = listener.local_addr().expect("Failed to get local address");
4951

50-
let runtime = tokio::runtime::Handle::current();
51-
52-
// Spawn server
53-
let server = runtime.spawn(async move {
52+
// Spawn server. Returning the socket halves keeps them alive until both
53+
// join handles are awaited below.
54+
let server = tokio::spawn(async move {
55+
// Server accepts a client, verifies the payload, and sends a reply.
5456
let (_, mut sink, mut stream) = listener.accept().await.expect("Failed to accept");
55-
5657
let received = stream
5758
.recv(CLIENT_SEND_DATA.len())
5859
.await
@@ -61,10 +62,14 @@ mod tests {
6162
sink.send(IoBuf::from(SERVER_SEND_DATA))
6263
.await
6364
.expect("Failed to send");
65+
(sink, stream)
6466
});
6567

66-
// Spawn client, connect to server, send and receive data over connection
67-
let client = runtime.spawn(async move {
68+
// Spawn client, connect to server, send and receive data over connection.
69+
// Returning the socket halves keeps them alive until both join handles
70+
// are awaited below.
71+
let client = tokio::spawn(async move {
72+
// Client connects to the server, sends a payload, and reads the reply.
6873
// Connect to the server
6974
let (mut sink, mut stream) = network
7075
.dial(listener_addr)
@@ -74,18 +79,18 @@ mod tests {
7479
sink.send(IoBuf::from(CLIENT_SEND_DATA))
7580
.await
7681
.expect("Failed to send data");
77-
7882
let received = stream
7983
.recv(SERVER_SEND_DATA.len())
8084
.await
8185
.expect("Failed to receive data");
8286
assert_eq!(received.coalesce(), SERVER_SEND_DATA);
87+
(sink, stream)
8388
});
8489

8590
// Wait for both tasks to complete
8691
let (server_result, client_result) = join!(server, client);
87-
assert!(server_result.is_ok());
88-
assert!(client_result.is_ok());
92+
server_result.expect("Server task failed");
93+
client_result.expect("Client task failed");
8994
}
9095

9196
// Test sending a multi-buffer payload.
@@ -99,8 +104,6 @@ mod tests {
99104
// Get the local address of the listener
100105
let listener_addr = listener.local_addr().expect("Failed to get local address");
101106

102-
let runtime = tokio::runtime::Handle::current();
103-
104107
// Build one logical message from multiple chunks so this test exercises
105108
// the `IoBufs` send path (instead of the single-buffer fast path).
106109
let message = IoBufs::from(vec![
@@ -112,36 +115,40 @@ mod tests {
112115

113116
// Spawn a server and read exactly the logical message size. The receive
114117
// side should observe the same byte stream regardless of send chunking.
115-
let server = runtime.spawn(async move {
116-
let (_, _sink, mut stream) = listener.accept().await.expect("Failed to accept");
118+
let server = tokio::spawn(async move {
119+
// Server receives the vectored payload as one logical byte stream.
120+
let (_, sink, mut stream) = listener.accept().await.expect("Failed to accept");
117121
let received = stream
118122
.recv(expected.len())
119123
.await
120124
.expect("Failed to receive");
121125
assert_eq!(received.coalesce(), expected.as_ref());
126+
(sink, stream)
122127
});
123128

124129
// Spawn client
125-
let client = runtime.spawn(async move {
130+
let client = tokio::spawn(async move {
131+
// Client connects and sends the pre-built vectored message.
126132
// Connect to the server
127-
let (mut sink, _stream) = network
133+
let (mut sink, stream) = network
128134
.dial(listener_addr)
129135
.await
130136
.expect("Failed to dial server");
131137

132138
// Send the pre-built vectored message.
133139
sink.send(message).await.expect("Failed to send data");
140+
(sink, stream)
134141
});
135142

136143
// Wait for both tasks to complete
137144
let (server_result, client_result) = join!(server, client);
138-
assert!(server_result.is_ok());
139-
assert!(client_result.is_ok());
145+
server_result.expect("Server task failed");
146+
client_result.expect("Client task failed");
140147
}
141148

142149
// Test handling multiple clients
143150
async fn test_network_multiple_clients<N: crate::Network>(network: N) {
144-
let runtime = tokio::runtime::Handle::current();
151+
const NUM_CLIENTS: usize = 3;
145152

146153
// Start a server
147154
let mut listener = network
@@ -150,27 +157,42 @@ mod tests {
150157
.expect("Failed to bind");
151158
let listener_addr = listener.local_addr().expect("Failed to get local address");
152159

160+
// Keep all sockets alive until every participant finishes.
161+
let barrier = Arc::new(Barrier::new(NUM_CLIENTS * 2));
162+
153163
// Server task
154-
let server = runtime.spawn(async move {
164+
let server_barrier = barrier.clone();
165+
let server = tokio::spawn(async move {
155166
// Handle multiple clients
156-
for _ in 0..3 {
167+
let mut set = JoinSet::new();
168+
for _ in 0..NUM_CLIENTS {
157169
let (_, mut sink, mut stream) = listener.accept().await.expect("Failed to accept");
158-
159-
let received = stream
160-
.recv(CLIENT_SEND_DATA.len())
161-
.await
162-
.expect("Failed to receive");
163-
assert_eq!(received.coalesce(), CLIENT_SEND_DATA);
164-
165-
sink.send(IoBuf::from(SERVER_SEND_DATA))
166-
.await
167-
.expect("Failed to send");
170+
let barrier = server_barrier.clone();
171+
set.spawn(async move {
172+
let received = stream
173+
.recv(CLIENT_SEND_DATA.len())
174+
.await
175+
.expect("Failed to receive");
176+
assert_eq!(received.coalesce(), CLIENT_SEND_DATA);
177+
sink.send(IoBuf::from(SERVER_SEND_DATA))
178+
.await
179+
.expect("Failed to send");
180+
181+
// Hold the connection open until every peer has finished.
182+
barrier.wait().await;
183+
});
184+
}
185+
while let Some(result) = set.join_next().await {
186+
result.expect("Server connection task failed");
168187
}
169188
});
170189

171190
// Start multiple clients
172-
let client = runtime.spawn(async move {
173-
for _ in 0..3 {
191+
let mut set = JoinSet::new();
192+
for _ in 0..NUM_CLIENTS {
193+
let network = network.clone();
194+
let barrier = barrier.clone();
195+
set.spawn(async move {
174196
// Connect to the server
175197
let (mut sink, mut stream) = network
176198
.dial(listener_addr)
@@ -187,14 +209,20 @@ mod tests {
187209
.recv(SERVER_SEND_DATA.len())
188210
.await
189211
.expect("Failed to receive data");
212+
190213
// Verify the received data
191214
assert_eq!(received.coalesce(), SERVER_SEND_DATA);
192-
}
193-
});
194215

195-
// Wait for server and all clients
216+
// Hold the connection open until every peer has finished.
217+
barrier.wait().await;
218+
});
219+
}
220+
221+
// Wait for all servers and clients to complete.
222+
while let Some(result) = set.join_next().await {
223+
result.expect("Client task failed");
224+
}
196225
server.await.expect("Server task failed");
197-
client.await.expect("Client task failed");
198226
}
199227

200228
// Test large data transfer
@@ -209,8 +237,9 @@ mod tests {
209237
.expect("Failed to bind");
210238
let listener_addr = listener.local_addr().expect("Failed to get local address");
211239

212-
let runtime = tokio::runtime::Handle::current();
213-
let server = runtime.spawn(async move {
240+
// Spawn server. Returning the socket halves keeps them alive until both
241+
// join handles are awaited below.
242+
let server = tokio::spawn(async move {
214243
let (_, mut sink, mut stream) = listener.accept().await.expect("Failed to accept");
215244

216245
// Receive and echo large data in chunks
@@ -221,10 +250,12 @@ mod tests {
221250
.expect("Failed to receive chunk");
222251
sink.send(received).await.expect("Failed to send chunk");
223252
}
253+
(sink, stream)
224254
});
225255

226-
// Client task
227-
let client = runtime.spawn(async move {
256+
// Client task. Returning the socket halves keeps them alive until both
257+
// join handles are awaited below.
258+
let client = tokio::spawn(async move {
228259
// Connect to the server
229260
let (mut sink, mut stream) = network
230261
.dial(listener_addr)
@@ -245,11 +276,13 @@ mod tests {
245276
.expect("Failed to receive chunk");
246277
assert_eq!(received.coalesce(), &pattern[..]);
247278
}
279+
(sink, stream)
248280
});
249281

250282
// Wait for both tasks to complete
251-
server.await.expect("Server task failed");
252-
client.await.expect("Client task failed");
283+
let (server_result, client_result) = join!(server, client);
284+
server_result.expect("Server task failed");
285+
client_result.expect("Client task failed");
253286
}
254287

255288
// Tests dialing and binding errors
@@ -281,17 +314,17 @@ mod tests {
281314
.expect("Failed to bind");
282315
let listener_addr = listener.local_addr().expect("Failed to get local address");
283316

284-
let runtime = tokio::runtime::Handle::current();
285-
286317
// Server sends data
287-
let server = runtime.spawn(async move {
288-
let (_, mut sink, _) = listener.accept().await.expect("Failed to accept");
318+
let server = tokio::spawn(async move {
319+
let (_, mut sink, stream) = listener.accept().await.expect("Failed to accept");
289320
sink.send(IoBuf::from(DATA)).await.expect("Failed to send");
321+
(sink, stream)
290322
});
291323

292324
// Client receives and tests peek
293-
let client = runtime.spawn(async move {
294-
let (_, mut stream) = network
325+
let client = tokio::spawn(async move {
326+
// Connect to the server
327+
let (sink, mut stream) = network
295328
.dial(listener_addr)
296329
.await
297330
.expect("Failed to dial server");
@@ -323,11 +356,13 @@ mod tests {
323356
// After consuming all data, peek should return empty
324357
let final_peek = stream.peek(100);
325358
assert!(final_peek.is_empty());
359+
(sink, stream)
326360
});
327361

362+
// Wait for both tasks to complete
328363
let (server_result, client_result) = join!(server, client);
329-
assert!(server_result.is_ok());
330-
assert!(client_result.is_ok());
364+
server_result.expect("Server task failed");
365+
client_result.expect("Client task failed");
331366
}
332367

333368
/// Network stress tests
@@ -352,36 +387,55 @@ mod tests {
352387
.unwrap();
353388
let addr = listener.local_addr().unwrap();
354389

390+
// Keep every connection alive until both the client and server halves finish.
391+
let barrier = Arc::new(Barrier::new(NUM_CLIENTS * 2));
392+
355393
// Spawn a server task that echoes messages from many clients.
394+
let server_barrier = barrier.clone();
356395
let server = tokio::spawn(async move {
396+
let mut set = JoinSet::new();
357397
for _ in 0..NUM_CLIENTS {
358398
let (_, mut sink, mut stream) = listener.accept().await.unwrap();
359-
tokio::spawn(async move {
399+
let barrier = server_barrier.clone();
400+
set.spawn(async move {
401+
// Echo every message back to the connected client.
360402
for _ in 0..NUM_MESSAGES {
361403
let received = stream.recv(MESSAGE_SIZE).await.unwrap();
362404
sink.send(received).await.unwrap();
363405
}
406+
407+
// Hold the connection open until every peer has finished.
408+
barrier.wait().await;
364409
});
365410
}
411+
while let Some(result) = set.join_next().await {
412+
result.unwrap();
413+
}
366414
});
367415

368416
// Spawn all clients.
369-
let mut clients = Vec::new();
417+
let mut set = JoinSet::new();
370418
for _ in 0..NUM_CLIENTS {
371419
let network = network.clone();
372-
clients.push(tokio::spawn(async move {
420+
let barrier = barrier.clone();
421+
set.spawn(async move {
422+
// Dial the server and repeatedly verify the echoed payload.
373423
let (mut sink, mut stream) = network.dial(addr).await.unwrap();
374424
let payload = vec![42u8; MESSAGE_SIZE];
375425
for _ in 0..NUM_MESSAGES {
376426
sink.send(payload.clone()).await.unwrap();
377427
let received = stream.recv(MESSAGE_SIZE).await.unwrap();
378428
assert_eq!(received.coalesce(), &payload[..]);
379429
}
380-
}));
430+
431+
// Hold the connection open until every peer has finished.
432+
barrier.wait().await;
433+
});
381434
}
382435

383-
for client in clients {
384-
client.await.unwrap();
436+
// Wait for all servers and clients to complete.
437+
while let Some(result) = set.join_next().await {
438+
result.unwrap();
385439
}
386440
server.await.unwrap();
387441
}

utils/src/sync/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub use parking_lot::{
2626
Condvar, Mutex, MutexGuard, Once, RwLock, RwLockReadGuard, RwLockWriteGuard,
2727
};
2828
pub use tokio::sync::{
29-
Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock,
29+
Barrier, Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock,
3030
RwLockReadGuard as AsyncRwLockReadGuard, RwLockWriteGuard as AsyncRwLockWriteGuard,
3131
};
3232

0 commit comments

Comments
 (0)