@@ -5,7 +5,10 @@ use std::{
55 io:: { self , Error , Result } ,
66 net:: SocketAddr ,
77 pin:: Pin ,
8- sync:: Arc ,
8+ sync:: {
9+ atomic:: { AtomicBool , Ordering } ,
10+ Arc , Mutex ,
11+ } ,
912 task:: { ready, Context , Poll } ,
1013} ;
1114use tokio:: {
@@ -40,11 +43,11 @@ impl TcpStream {
4043 let pair = Arc :: new ( pair) ;
4144 let read_half = ReadHalf {
4245 pair : pair. clone ( ) ,
43- rx : Rx {
46+ rx : Mutex :: new ( Rx {
4447 recv : receiver,
4548 buffer : None ,
46- } ,
47- is_closed : false ,
49+ } ) ,
50+ is_closed : AtomicBool :: new ( false ) ,
4851 } ;
4952
5053 let write_half = WriteHalf {
@@ -171,23 +174,23 @@ impl TcpStream {
171174 /// returns the number of bytes peeked.
172175 ///
173176 /// Successive calls return the same data.
174- pub async fn peek ( & mut self , buf : & mut [ u8 ] ) -> Result < usize > {
177+ pub async fn peek ( & self , buf : & mut [ u8 ] ) -> Result < usize > {
175178 self . read_half . peek ( buf) . await
176179 }
177180
178181 /// Attempts to receive data on the socket, without removing that data from
179182 /// the queue, registering the current task for wakeup if data is not yet
180183 /// available.
181- pub fn poll_peek ( & mut self , cx : & mut Context < ' _ > , buf : & mut ReadBuf ) -> Poll < Result < usize > > {
184+ pub fn poll_peek ( & self , cx : & mut Context < ' _ > , buf : & mut ReadBuf ) -> Poll < Result < usize > > {
182185 self . read_half . poll_peek ( cx, buf)
183186 }
184187}
185188
186189pub ( crate ) struct ReadHalf {
187190 pub ( crate ) pair : Arc < SocketPair > ,
188- rx : Rx ,
191+ rx : Mutex < Rx > ,
189192 /// FIN received, EOF for reads
190- is_closed : bool ,
193+ is_closed : AtomicBool ,
191194}
192195
193196struct Rx {
@@ -200,27 +203,33 @@ struct Rx {
200203}
201204
202205impl ReadHalf {
203- fn poll_read_priv ( & mut self , cx : & mut Context < ' _ > , buf : & mut ReadBuf ) -> Poll < Result < ( ) > > {
204- if self . is_closed || buf. capacity ( ) == 0 {
206+ fn is_closed ( & self ) -> bool {
207+ self . is_closed . load ( Ordering :: Acquire )
208+ }
209+
210+ fn poll_read_priv ( & self , cx : & mut Context < ' _ > , buf : & mut ReadBuf ) -> Poll < Result < ( ) > > {
211+ if self . is_closed ( ) || buf. capacity ( ) == 0 {
205212 return Poll :: Ready ( Ok ( ( ) ) ) ;
206213 }
207214
208- if let Some ( bytes) = self . rx . buffer . take ( ) {
209- self . rx . buffer = Self :: put_slice ( bytes, buf) ;
215+ let mut rx = self . rx . lock ( ) . unwrap ( ) ;
216+
217+ if let Some ( bytes) = rx. buffer . take ( ) {
218+ rx. buffer = Self :: put_slice ( bytes, buf) ;
210219
211220 return Poll :: Ready ( Ok ( ( ) ) ) ;
212221 }
213222
214- match ready ! ( self . rx. recv. poll_recv( cx) ) {
223+ match ready ! ( rx. recv. poll_recv( cx) ) {
215224 Some ( seg) => {
216225 tracing:: trace!( target: TRACING_TARGET , src = ?self . pair. remote, dst = ?self . pair. local, protocol = %seg, "Recv" ) ;
217226
218227 match seg {
219228 SequencedSegment :: Data ( bytes) => {
220- self . rx . buffer = Self :: put_slice ( bytes, buf) ;
229+ rx. buffer = Self :: put_slice ( bytes, buf) ;
221230 }
222231 SequencedSegment :: Fin => {
223- self . is_closed = true ;
232+ self . is_closed . store ( true , Ordering :: Release ) ;
224233 }
225234 }
226235
@@ -251,36 +260,34 @@ impl ReadHalf {
251260 }
252261 }
253262
254- pub ( crate ) fn poll_peek (
255- & mut self ,
256- cx : & mut Context < ' _ > ,
257- buf : & mut ReadBuf ,
258- ) -> Poll < Result < usize > > {
259- if self . is_closed || buf. capacity ( ) == 0 {
263+ pub ( crate ) fn poll_peek ( & self , cx : & mut Context < ' _ > , buf : & mut ReadBuf ) -> Poll < Result < usize > > {
264+ if self . is_closed ( ) || buf. capacity ( ) == 0 {
260265 return Poll :: Ready ( Ok ( 0 ) ) ;
261266 }
262267
268+ let mut rx = self . rx . lock ( ) . unwrap ( ) ;
269+
263270 // If we have buffered data, peek from it
264- if let Some ( bytes) = & self . rx . buffer {
271+ if let Some ( bytes) = & rx. buffer {
265272 let len = std:: cmp:: min ( bytes. len ( ) , buf. remaining ( ) ) ;
266273 buf. put_slice ( & bytes[ ..len] ) ;
267274 return Poll :: Ready ( Ok ( len) ) ;
268275 }
269276
270- match ready ! ( self . rx. recv. poll_recv( cx) ) {
277+ match ready ! ( rx. recv. poll_recv( cx) ) {
271278 Some ( seg) => {
272279 tracing:: trace!( target: TRACING_TARGET , src = ?self . pair. remote, dst = ?self . pair. local, protocol = %seg, "Peek" ) ;
273280
274281 match seg {
275282 SequencedSegment :: Data ( bytes) => {
276283 let len = std:: cmp:: min ( bytes. len ( ) , buf. remaining ( ) ) ;
277284 buf. put_slice ( & bytes[ ..len] ) ;
278- self . rx . buffer = Some ( bytes) ;
285+ rx. buffer = Some ( bytes) ;
279286
280287 Poll :: Ready ( Ok ( len) )
281288 }
282289 SequencedSegment :: Fin => {
283- self . is_closed = true ;
290+ self . is_closed . store ( true , Ordering :: Release ) ;
284291 Poll :: Ready ( Ok ( 0 ) )
285292 }
286293 }
@@ -292,7 +299,7 @@ impl ReadHalf {
292299 }
293300 }
294301
295- pub ( crate ) async fn peek ( & mut self , buf : & mut [ u8 ] ) -> Result < usize > {
302+ pub ( crate ) async fn peek ( & self , buf : & mut [ u8 ] ) -> Result < usize > {
296303 let mut buf = ReadBuf :: new ( buf) ;
297304 poll_fn ( |cx| self . poll_peek ( cx, & mut buf) ) . await
298305 }
@@ -302,7 +309,7 @@ impl Debug for ReadHalf {
302309 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
303310 f. debug_struct ( "ReadHalf" )
304311 . field ( "pair" , & self . pair )
305- . field ( "is_closed" , & self . is_closed )
312+ . field ( "is_closed" , & self . is_closed ( ) )
306313 . finish ( )
307314 }
308315}
@@ -422,7 +429,7 @@ impl Debug for WriteHalf {
422429
423430impl AsyncRead for ReadHalf {
424431 fn poll_read (
425- mut self : Pin < & mut Self > ,
432+ self : Pin < & mut Self > ,
426433 cx : & mut Context < ' _ > ,
427434 buf : & mut ReadBuf ,
428435 ) -> Poll < Result < ( ) > > {
0 commit comments