@@ -17,8 +17,10 @@ stability_scope!(ALPHA, cfg(all(not(target_arch = "wasm32"), feature = "iouring-
1717#[ cfg( test) ]
1818mod tests {
1919 use crate :: { IoBuf , IoBufs , Listener , Sink , Stream } ;
20+ use commonware_utils:: sync:: Barrier ;
2021 use futures:: join;
21- use std:: net:: SocketAddr ;
22+ use std:: { net:: SocketAddr , sync:: Arc } ;
23+ use tokio:: task:: JoinSet ;
2224
2325 const CLIENT_SEND_DATA : & [ u8 ] = b"client_send_data" ;
2426 const SERVER_SEND_DATA : & [ u8 ] = b"server_send_data" ;
@@ -47,12 +49,11 @@ mod tests {
4749 // Get the local address of the listener
4850 let listener_addr = listener. local_addr ( ) . expect ( "Failed to get local address" ) ;
4951
50- let runtime = tokio :: runtime :: Handle :: current ( ) ;
51-
52- // Spawn server
53- let server = runtime . spawn ( async move {
52+ // Spawn server. Returning the socket halves keeps them alive until both
53+ // join handles are awaited below.
54+ let server = tokio :: spawn ( async move {
55+ // Server accepts a client, verifies the payload, and sends a reply.
5456 let ( _, mut sink, mut stream) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
55-
5657 let received = stream
5758 . recv ( CLIENT_SEND_DATA . len ( ) )
5859 . await
@@ -61,10 +62,14 @@ mod tests {
6162 sink. send ( IoBuf :: from ( SERVER_SEND_DATA ) )
6263 . await
6364 . expect ( "Failed to send" ) ;
65+ ( sink, stream)
6466 } ) ;
6567
66- // Spawn client, connect to server, send and receive data over connection
67- let client = runtime. spawn ( async move {
68+ // Spawn client, connect to server, send and receive data over connection.
69+ // Returning the socket halves keeps them alive until both join handles
70+ // are awaited below.
71+ let client = tokio:: spawn ( async move {
72+ // Client connects to the server, sends a payload, and reads the reply.
6873 // Connect to the server
6974 let ( mut sink, mut stream) = network
7075 . dial ( listener_addr)
@@ -74,18 +79,18 @@ mod tests {
7479 sink. send ( IoBuf :: from ( CLIENT_SEND_DATA ) )
7580 . await
7681 . expect ( "Failed to send data" ) ;
77-
7882 let received = stream
7983 . recv ( SERVER_SEND_DATA . len ( ) )
8084 . await
8185 . expect ( "Failed to receive data" ) ;
8286 assert_eq ! ( received. coalesce( ) , SERVER_SEND_DATA ) ;
87+ ( sink, stream)
8388 } ) ;
8489
8590 // Wait for both tasks to complete
8691 let ( server_result, client_result) = join ! ( server, client) ;
87- assert ! ( server_result. is_ok ( ) ) ;
88- assert ! ( client_result. is_ok ( ) ) ;
92+ server_result. expect ( "Server task failed" ) ;
93+ client_result. expect ( "Client task failed" ) ;
8994 }
9095
9196 // Test sending a multi-buffer payload.
@@ -99,8 +104,6 @@ mod tests {
99104 // Get the local address of the listener
100105 let listener_addr = listener. local_addr ( ) . expect ( "Failed to get local address" ) ;
101106
102- let runtime = tokio:: runtime:: Handle :: current ( ) ;
103-
104107 // Build one logical message from multiple chunks so this test exercises
105108 // the `IoBufs` send path (instead of the single-buffer fast path).
106109 let message = IoBufs :: from ( vec ! [
@@ -112,36 +115,40 @@ mod tests {
112115
113116 // Spawn a server and read exactly the logical message size. The receive
114117 // side should observe the same byte stream regardless of send chunking.
115- let server = runtime. spawn ( async move {
116- let ( _, _sink, mut stream) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
118+ let server = tokio:: spawn ( async move {
119+ // Server receives the vectored payload as one logical byte stream.
120+ let ( _, sink, mut stream) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
117121 let received = stream
118122 . recv ( expected. len ( ) )
119123 . await
120124 . expect ( "Failed to receive" ) ;
121125 assert_eq ! ( received. coalesce( ) , expected. as_ref( ) ) ;
126+ ( sink, stream)
122127 } ) ;
123128
124129 // Spawn client
125- let client = runtime. spawn ( async move {
130+ let client = tokio:: spawn ( async move {
131+ // Client connects and sends the pre-built vectored message.
126132 // Connect to the server
127- let ( mut sink, _stream ) = network
133+ let ( mut sink, stream ) = network
128134 . dial ( listener_addr)
129135 . await
130136 . expect ( "Failed to dial server" ) ;
131137
132138 // Send the pre-built vectored message.
133139 sink. send ( message) . await . expect ( "Failed to send data" ) ;
140+ ( sink, stream)
134141 } ) ;
135142
136143 // Wait for both tasks to complete
137144 let ( server_result, client_result) = join ! ( server, client) ;
138- assert ! ( server_result. is_ok ( ) ) ;
139- assert ! ( client_result. is_ok ( ) ) ;
145+ server_result. expect ( "Server task failed" ) ;
146+ client_result. expect ( "Client task failed" ) ;
140147 }
141148
142149 // Test handling multiple clients
143150 async fn test_network_multiple_clients < N : crate :: Network > ( network : N ) {
144- let runtime = tokio :: runtime :: Handle :: current ( ) ;
151+ const NUM_CLIENTS : usize = 3 ;
145152
146153 // Start a server
147154 let mut listener = network
@@ -150,27 +157,42 @@ mod tests {
150157 . expect ( "Failed to bind" ) ;
151158 let listener_addr = listener. local_addr ( ) . expect ( "Failed to get local address" ) ;
152159
160+ // Keep all sockets alive until every participant finishes.
161+ let barrier = Arc :: new ( Barrier :: new ( NUM_CLIENTS * 2 ) ) ;
162+
153163 // Server task
154- let server = runtime. spawn ( async move {
164+ let server_barrier = barrier. clone ( ) ;
165+ let server = tokio:: spawn ( async move {
155166 // Handle multiple clients
156- for _ in 0 ..3 {
167+ let mut set = JoinSet :: new ( ) ;
168+ for _ in 0 ..NUM_CLIENTS {
157169 let ( _, mut sink, mut stream) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
158-
159- let received = stream
160- . recv ( CLIENT_SEND_DATA . len ( ) )
161- . await
162- . expect ( "Failed to receive" ) ;
163- assert_eq ! ( received. coalesce( ) , CLIENT_SEND_DATA ) ;
164-
165- sink. send ( IoBuf :: from ( SERVER_SEND_DATA ) )
166- . await
167- . expect ( "Failed to send" ) ;
170+ let barrier = server_barrier. clone ( ) ;
171+ set. spawn ( async move {
172+ let received = stream
173+ . recv ( CLIENT_SEND_DATA . len ( ) )
174+ . await
175+ . expect ( "Failed to receive" ) ;
176+ assert_eq ! ( received. coalesce( ) , CLIENT_SEND_DATA ) ;
177+ sink. send ( IoBuf :: from ( SERVER_SEND_DATA ) )
178+ . await
179+ . expect ( "Failed to send" ) ;
180+
181+ // Hold the connection open until every peer has finished.
182+ barrier. wait ( ) . await ;
183+ } ) ;
184+ }
185+ while let Some ( result) = set. join_next ( ) . await {
186+ result. expect ( "Server connection task failed" ) ;
168187 }
169188 } ) ;
170189
171190 // Start multiple clients
172- let client = runtime. spawn ( async move {
173- for _ in 0 ..3 {
191+ let mut set = JoinSet :: new ( ) ;
192+ for _ in 0 ..NUM_CLIENTS {
193+ let network = network. clone ( ) ;
194+ let barrier = barrier. clone ( ) ;
195+ set. spawn ( async move {
174196 // Connect to the server
175197 let ( mut sink, mut stream) = network
176198 . dial ( listener_addr)
@@ -187,14 +209,20 @@ mod tests {
187209 . recv ( SERVER_SEND_DATA . len ( ) )
188210 . await
189211 . expect ( "Failed to receive data" ) ;
212+
190213 // Verify the received data
191214 assert_eq ! ( received. coalesce( ) , SERVER_SEND_DATA ) ;
192- }
193- } ) ;
194215
195- // Wait for server and all clients
216+ // Hold the connection open until every peer has finished.
217+ barrier. wait ( ) . await ;
218+ } ) ;
219+ }
220+
221+ // Wait for all servers and clients to complete.
222+ while let Some ( result) = set. join_next ( ) . await {
223+ result. expect ( "Client task failed" ) ;
224+ }
196225 server. await . expect ( "Server task failed" ) ;
197- client. await . expect ( "Client task failed" ) ;
198226 }
199227
200228 // Test large data transfer
@@ -209,8 +237,9 @@ mod tests {
209237 . expect ( "Failed to bind" ) ;
210238 let listener_addr = listener. local_addr ( ) . expect ( "Failed to get local address" ) ;
211239
212- let runtime = tokio:: runtime:: Handle :: current ( ) ;
213- let server = runtime. spawn ( async move {
240+ // Spawn server. Returning the socket halves keeps them alive until both
241+ // join handles are awaited below.
242+ let server = tokio:: spawn ( async move {
214243 let ( _, mut sink, mut stream) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
215244
216245 // Receive and echo large data in chunks
@@ -221,10 +250,12 @@ mod tests {
221250 . expect ( "Failed to receive chunk" ) ;
222251 sink. send ( received) . await . expect ( "Failed to send chunk" ) ;
223252 }
253+ ( sink, stream)
224254 } ) ;
225255
226- // Client task
227- let client = runtime. spawn ( async move {
256+ // Client task. Returning the socket halves keeps them alive until both
257+ // join handles are awaited below.
258+ let client = tokio:: spawn ( async move {
228259 // Connect to the server
229260 let ( mut sink, mut stream) = network
230261 . dial ( listener_addr)
@@ -245,11 +276,13 @@ mod tests {
245276 . expect ( "Failed to receive chunk" ) ;
246277 assert_eq ! ( received. coalesce( ) , & pattern[ ..] ) ;
247278 }
279+ ( sink, stream)
248280 } ) ;
249281
250282 // Wait for both tasks to complete
251- server. await . expect ( "Server task failed" ) ;
252- client. await . expect ( "Client task failed" ) ;
283+ let ( server_result, client_result) = join ! ( server, client) ;
284+ server_result. expect ( "Server task failed" ) ;
285+ client_result. expect ( "Client task failed" ) ;
253286 }
254287
255288 // Tests dialing and binding errors
@@ -281,17 +314,17 @@ mod tests {
281314 . expect ( "Failed to bind" ) ;
282315 let listener_addr = listener. local_addr ( ) . expect ( "Failed to get local address" ) ;
283316
284- let runtime = tokio:: runtime:: Handle :: current ( ) ;
285-
286317 // Server sends data
287- let server = runtime . spawn ( async move {
288- let ( _, mut sink, _ ) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
318+ let server = tokio :: spawn ( async move {
319+ let ( _, mut sink, stream ) = listener. accept ( ) . await . expect ( "Failed to accept" ) ;
289320 sink. send ( IoBuf :: from ( DATA ) ) . await . expect ( "Failed to send" ) ;
321+ ( sink, stream)
290322 } ) ;
291323
292324 // Client receives and tests peek
293- let client = runtime. spawn ( async move {
294- let ( _, mut stream) = network
325+ let client = tokio:: spawn ( async move {
326+ // Connect to the server
327+ let ( sink, mut stream) = network
295328 . dial ( listener_addr)
296329 . await
297330 . expect ( "Failed to dial server" ) ;
@@ -323,11 +356,13 @@ mod tests {
323356 // After consuming all data, peek should return empty
324357 let final_peek = stream. peek ( 100 ) ;
325358 assert ! ( final_peek. is_empty( ) ) ;
359+ ( sink, stream)
326360 } ) ;
327361
362+ // Wait for both tasks to complete
328363 let ( server_result, client_result) = join ! ( server, client) ;
329- assert ! ( server_result. is_ok ( ) ) ;
330- assert ! ( client_result. is_ok ( ) ) ;
364+ server_result. expect ( "Server task failed" ) ;
365+ client_result. expect ( "Client task failed" ) ;
331366 }
332367
333368 /// Network stress tests
@@ -352,36 +387,55 @@ mod tests {
352387 . unwrap ( ) ;
353388 let addr = listener. local_addr ( ) . unwrap ( ) ;
354389
390+ // Keep every connection alive until both the client and server halves finish.
391+ let barrier = Arc :: new ( Barrier :: new ( NUM_CLIENTS * 2 ) ) ;
392+
355393 // Spawn a server task that echoes messages from many clients.
394+ let server_barrier = barrier. clone ( ) ;
356395 let server = tokio:: spawn ( async move {
396+ let mut set = JoinSet :: new ( ) ;
357397 for _ in 0 ..NUM_CLIENTS {
358398 let ( _, mut sink, mut stream) = listener. accept ( ) . await . unwrap ( ) ;
359- tokio:: spawn ( async move {
399+ let barrier = server_barrier. clone ( ) ;
400+ set. spawn ( async move {
401+ // Echo every message back to the connected client.
360402 for _ in 0 ..NUM_MESSAGES {
361403 let received = stream. recv ( MESSAGE_SIZE ) . await . unwrap ( ) ;
362404 sink. send ( received) . await . unwrap ( ) ;
363405 }
406+
407+ // Hold the connection open until every peer has finished.
408+ barrier. wait ( ) . await ;
364409 } ) ;
365410 }
411+ while let Some ( result) = set. join_next ( ) . await {
412+ result. unwrap ( ) ;
413+ }
366414 } ) ;
367415
368416 // Spawn all clients.
369- let mut clients = Vec :: new ( ) ;
417+ let mut set = JoinSet :: new ( ) ;
370418 for _ in 0 ..NUM_CLIENTS {
371419 let network = network. clone ( ) ;
372- clients. push ( tokio:: spawn ( async move {
420+ let barrier = barrier. clone ( ) ;
421+ set. spawn ( async move {
422+ // Dial the server and repeatedly verify the echoed payload.
373423 let ( mut sink, mut stream) = network. dial ( addr) . await . unwrap ( ) ;
374424 let payload = vec ! [ 42u8 ; MESSAGE_SIZE ] ;
375425 for _ in 0 ..NUM_MESSAGES {
376426 sink. send ( payload. clone ( ) ) . await . unwrap ( ) ;
377427 let received = stream. recv ( MESSAGE_SIZE ) . await . unwrap ( ) ;
378428 assert_eq ! ( received. coalesce( ) , & payload[ ..] ) ;
379429 }
380- } ) ) ;
430+
431+ // Hold the connection open until every peer has finished.
432+ barrier. wait ( ) . await ;
433+ } ) ;
381434 }
382435
383- for client in clients {
384- client. await . unwrap ( ) ;
436+ // Wait for all servers and clients to complete.
437+ while let Some ( result) = set. join_next ( ) . await {
438+ result. unwrap ( ) ;
385439 }
386440 server. await . unwrap ( ) ;
387441 }
0 commit comments