22
33use crate :: { Error , Sink as SinkTrait , Stream as StreamTrait } ;
44use bytes:: Bytes ;
5+ use commonware_utils:: { StableBuf , StableBufMut } ;
56use futures:: channel:: oneshot;
67use std:: {
78 collections:: VecDeque ,
@@ -41,10 +42,10 @@ pub struct Sink {
4142}
4243
4344impl SinkTrait for Sink {
44- async fn send ( & mut self , msg : & [ u8 ] ) -> Result < ( ) , Error > {
45+ async fn send < B : StableBuf > ( & mut self , msg : B ) -> Result < ( ) , Error > {
4546 let ( os_send, data) = {
4647 let mut channel = self . channel . lock ( ) . unwrap ( ) ;
47- channel. buffer . extend ( msg) ;
48+ channel. buffer . extend ( msg. as_ref ( ) ) ;
4849
4950 // If there is a waiter and the buffer is large enough,
5051 // return the waiter (while clearing the waiter field).
@@ -74,16 +75,16 @@ pub struct Stream {
7475}
7576
7677impl StreamTrait for Stream {
77- async fn recv ( & mut self , buf : & mut [ u8 ] ) -> Result < ( ) , Error > {
78+ async fn recv < B : StableBufMut > ( & mut self , mut buf : B ) -> Result < B , Error > {
7879 let os_recv = {
7980 let mut channel = self . channel . lock ( ) . unwrap ( ) ;
8081
8182 // If the message is fully available in the buffer,
8283 // drain the value into buf and return.
8384 if channel. buffer . len ( ) >= buf. len ( ) {
8485 let b: Vec < u8 > = channel. buffer . drain ( 0 ..buf. len ( ) ) . collect ( ) ;
85- buf. copy_from_slice ( & b) ;
86- return Ok ( ( ) ) ;
86+ buf. put_slice ( & b) ;
87+ return Ok ( buf ) ;
8788 }
8889
8990 // Otherwise, populate the waiter.
@@ -95,8 +96,9 @@ impl StreamTrait for Stream {
9596
9697 // Wait for the waiter to be resolved.
9798 let data = os_recv. await . map_err ( |_| Error :: RecvFailed ) ?;
98- buf. copy_from_slice ( & data) ;
99- Ok ( ( ) )
99+ assert_eq ! ( data. len( ) , buf. len( ) ) ;
100+ buf. put_slice ( & data) ;
101+ Ok ( buf)
100102 }
101103}
102104
@@ -111,52 +113,45 @@ mod tests {
111113 #[ test]
112114 fn test_send_recv ( ) {
113115 let ( mut sink, mut stream) = Channel :: init ( ) ;
114-
115- let data = b"hello world" ;
116- let mut buf = vec ! [ 0 ; data. len( ) ] ;
116+ let data = b"hello world" . to_vec ( ) ;
117117
118118 block_on ( async {
119- sink. send ( data) . await . unwrap ( ) ;
120- stream. recv ( & mut buf) . await . unwrap ( ) ;
119+ sink. send ( data. clone ( ) ) . await . unwrap ( ) ;
120+ let buf = stream. recv ( vec ! [ 0 ; data. len( ) ] ) . await . unwrap ( ) ;
121+ assert_eq ! ( buf, data) ;
121122 } ) ;
122-
123- assert_eq ! ( buf, data) ;
124123 }
125124
126125 #[ test]
127126 fn test_send_recv_partial_multiple ( ) {
128127 let ( mut sink, mut stream) = Channel :: init ( ) ;
129-
130- let data1 = b"hello" ;
131- let data2 = b"world" ;
132- let mut buf1 = vec ! [ 0 ; data1. len( ) ] ;
133- let mut buf2 = vec ! [ 0 ; data2. len( ) ] ;
128+ let data = b"hello" . to_vec ( ) ;
129+ let data2 = b" world" . to_vec ( ) ;
134130
135131 block_on ( async {
136- sink. send ( data1 ) . await . unwrap ( ) ;
132+ sink. send ( data ) . await . unwrap ( ) ;
137133 sink. send ( data2) . await . unwrap ( ) ;
138- stream. recv ( & mut buf1[ 0 ..3 ] ) . await . unwrap ( ) ;
139- stream. recv ( & mut buf1[ 3 ..] ) . await . unwrap ( ) ;
140- stream. recv ( & mut buf2) . await . unwrap ( ) ;
134+ let buf = stream. recv ( vec ! [ 0 ; 5 ] ) . await . unwrap ( ) ;
135+ assert_eq ! ( buf, b"hello" ) ;
136+ let buf = stream. recv ( buf) . await . unwrap ( ) ;
137+ assert_eq ! ( buf, b" worl" ) ;
138+ let buf = stream. recv ( vec ! [ 0 ; 1 ] ) . await . unwrap ( ) ;
139+ assert_eq ! ( buf, b"d" ) ;
141140 } ) ;
142-
143- assert_eq ! ( buf1, data1) ;
144- assert_eq ! ( buf2, data2) ;
145141 }
146142
147143 #[ test]
148144 fn test_send_recv_async ( ) {
149145 let ( mut sink, mut stream) = Channel :: init ( ) ;
150146
151147 let data = b"hello world" ;
152- let mut buf = vec ! [ 0 ; data. len( ) ] ;
153-
154- block_on ( async {
155- futures:: try_join!( stream. recv( & mut buf) , async {
148+ let buf = block_on ( async {
149+ futures:: try_join!( stream. recv( vec![ 0 ; data. len( ) ] ) , async {
156150 sleep( Duration :: from_millis( 10_000 ) ) ;
157- sink. send( data) . await
151+ sink. send( data. to_vec ( ) ) . await
158152 } , )
159- . unwrap ( ) ;
153+ . unwrap ( )
154+ . 0
160155 } ) ;
161156
162157 assert_eq ! ( buf, data) ;
@@ -170,8 +165,7 @@ mod tests {
170165 // If the oneshot sender is dropped before the oneshot receiver is resolved,
171166 // the recv function should return an error.
172167 executor. start ( |_| async move {
173- let mut buf = vec ! [ 0 ; 5 ] ;
174- let ( v, _) = join ! ( stream. recv( & mut buf) , async {
168+ let ( v, _) = join ! ( stream. recv( vec![ 0 ; 5 ] ) , async {
175169 // Take the waiter and drop it.
176170 sink. channel. lock( ) . unwrap( ) . waiter. take( ) ;
177171 } , ) ;
@@ -187,12 +181,10 @@ mod tests {
187181 // If the waiter value has a min, but the oneshot receiver is dropped,
188182 // the send function should return an error when attempting to send the data.
189183 executor. start ( |context| async move {
190- let mut buf = vec ! [ 0 ; 5 ] ;
191-
192184 // Create a waiter using a recv call.
193185 // But then drop the receiver.
194186 select ! {
195- v = stream. recv( & mut buf ) => {
187+ v = stream. recv( vec! [ 0 ; 5 ] ) => {
196188 panic!( "unexpected value: {:?}" , v) ;
197189 } ,
198190 _ = context. sleep( Duration :: from_millis( 100 ) ) => {
@@ -202,7 +194,7 @@ mod tests {
202194 drop ( stream) ;
203195
204196 // Try to send a message (longer than the requested amount), but the receiver is dropped.
205- let result = sink. send ( b"hello world" ) . await ;
197+ let result = sink. send ( b"hello world" . to_vec ( ) ) . await ;
206198 assert ! ( matches!( result, Err ( Error :: SendFailed ) ) ) ;
207199 } ) ;
208200 }
@@ -214,9 +206,8 @@ mod tests {
214206
215207 // If there is no data to read, test that the recv function just blocks. A timeout should return first.
216208 executor. start ( |context| async move {
217- let mut buf = vec ! [ 0 ; 5 ] ;
218209 select ! {
219- v = stream. recv( & mut buf ) => {
210+ v = stream. recv( vec! [ 0 ; 5 ] ) => {
220211 panic!( "unexpected value: {:?}" , v) ;
221212 } ,
222213 _ = context. sleep( Duration :: from_millis( 100 ) ) => {
0 commit comments