Skip to content

Commit fa13f18

Browse files
committed
Let TCP peek methods take &self
This changes the `TcpStream::peek` and the various other peek methods to take `&self` instead of `&mut self`, ensuring tokio API compatibility and making them usable in more contexts. To achieve this, we wrap the `ReadHalf` internals in `Mutex` and `AtomicBool`, respectively. Note that we can't use `RefCell`/`Cell` here because we also require `ReadHalf` to be `Sync`.
1 parent af06e81 commit fa13f18

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

src/net/tcp/split_owned.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,16 @@ impl OwnedReadHalf {
4040
/// Attempts to receive data on the socket, without removing that data from
4141
/// the queue, registering the current task for wakeup if data is not yet
4242
/// available.
43-
pub fn poll_peek(
44-
mut self: Pin<&mut Self>,
45-
cx: &mut Context<'_>,
46-
buf: &mut ReadBuf,
47-
) -> Poll<io::Result<usize>> {
48-
Pin::new(&mut self.inner).poll_peek(cx, buf)
43+
pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<io::Result<usize>> {
44+
self.inner.poll_peek(cx, buf)
4945
}
5046

5147
/// Receives data on the socket from the remote address to which it is
5248
/// connected, without removing that data from the queue. On success,
5349
/// returns the number of bytes peeked.
5450
///
5551
/// Successive calls return the same data.
56-
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52+
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
5753
self.inner.peek(buf).await
5854
}
5955
}

src/net/tcp/stream.rs

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use std::{
55
io::{self, Error, Result},
66
net::SocketAddr,
77
pin::Pin,
8-
sync::Arc,
8+
sync::{
9+
atomic::{AtomicBool, Ordering},
10+
Arc, Mutex,
11+
},
912
task::{ready, Context, Poll},
1013
};
1114
use tokio::{
@@ -40,11 +43,11 @@ impl TcpStream {
4043
let pair = Arc::new(pair);
4144
let read_half = ReadHalf {
4245
pair: pair.clone(),
43-
rx: Rx {
46+
rx: Mutex::new(Rx {
4447
recv: receiver,
4548
buffer: None,
46-
},
47-
is_closed: false,
49+
}),
50+
is_closed: AtomicBool::new(false),
4851
};
4952

5053
let write_half = WriteHalf {
@@ -171,23 +174,23 @@ impl TcpStream {
171174
/// returns the number of bytes peeked.
172175
///
173176
/// Successive calls return the same data.
174-
pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
177+
pub async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
175178
self.read_half.peek(buf).await
176179
}
177180

178181
/// Attempts to receive data on the socket, without removing that data from
179182
/// the queue, registering the current task for wakeup if data is not yet
180183
/// available.
181-
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
184+
pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
182185
self.read_half.poll_peek(cx, buf)
183186
}
184187
}
185188

186189
pub(crate) struct ReadHalf {
187190
pub(crate) pair: Arc<SocketPair>,
188-
rx: Rx,
191+
rx: Mutex<Rx>,
189192
/// FIN received, EOF for reads
190-
is_closed: bool,
193+
is_closed: AtomicBool,
191194
}
192195

193196
struct Rx {
@@ -200,27 +203,33 @@ struct Rx {
200203
}
201204

202205
impl ReadHalf {
203-
fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
204-
if self.is_closed || buf.capacity() == 0 {
206+
fn is_closed(&self) -> bool {
207+
self.is_closed.load(Ordering::Acquire)
208+
}
209+
210+
fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
211+
if self.is_closed() || buf.capacity() == 0 {
205212
return Poll::Ready(Ok(()));
206213
}
207214

208-
if let Some(bytes) = self.rx.buffer.take() {
209-
self.rx.buffer = Self::put_slice(bytes, buf);
215+
let mut rx = self.rx.lock().unwrap();
216+
217+
if let Some(bytes) = rx.buffer.take() {
218+
rx.buffer = Self::put_slice(bytes, buf);
210219

211220
return Poll::Ready(Ok(()));
212221
}
213222

214-
match ready!(self.rx.recv.poll_recv(cx)) {
223+
match ready!(rx.recv.poll_recv(cx)) {
215224
Some(seg) => {
216225
tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
217226

218227
match seg {
219228
SequencedSegment::Data(bytes) => {
220-
self.rx.buffer = Self::put_slice(bytes, buf);
229+
rx.buffer = Self::put_slice(bytes, buf);
221230
}
222231
SequencedSegment::Fin => {
223-
self.is_closed = true;
232+
self.is_closed.store(true, Ordering::Release);
224233
}
225234
}
226235

@@ -251,36 +260,34 @@ impl ReadHalf {
251260
}
252261
}
253262

254-
pub(crate) fn poll_peek(
255-
&mut self,
256-
cx: &mut Context<'_>,
257-
buf: &mut ReadBuf,
258-
) -> Poll<Result<usize>> {
259-
if self.is_closed || buf.capacity() == 0 {
263+
pub(crate) fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
264+
if self.is_closed() || buf.capacity() == 0 {
260265
return Poll::Ready(Ok(0));
261266
}
262267

268+
let mut rx = self.rx.lock().unwrap();
269+
263270
// If we have buffered data, peek from it
264-
if let Some(bytes) = &self.rx.buffer {
271+
if let Some(bytes) = &rx.buffer {
265272
let len = std::cmp::min(bytes.len(), buf.remaining());
266273
buf.put_slice(&bytes[..len]);
267274
return Poll::Ready(Ok(len));
268275
}
269276

270-
match ready!(self.rx.recv.poll_recv(cx)) {
277+
match ready!(rx.recv.poll_recv(cx)) {
271278
Some(seg) => {
272279
tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
273280

274281
match seg {
275282
SequencedSegment::Data(bytes) => {
276283
let len = std::cmp::min(bytes.len(), buf.remaining());
277284
buf.put_slice(&bytes[..len]);
278-
self.rx.buffer = Some(bytes);
285+
rx.buffer = Some(bytes);
279286

280287
Poll::Ready(Ok(len))
281288
}
282289
SequencedSegment::Fin => {
283-
self.is_closed = true;
290+
self.is_closed.store(true, Ordering::Release);
284291
Poll::Ready(Ok(0))
285292
}
286293
}
@@ -292,7 +299,7 @@ impl ReadHalf {
292299
}
293300
}
294301

295-
pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
302+
pub(crate) async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
296303
let mut buf = ReadBuf::new(buf);
297304
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
298305
}
@@ -302,7 +309,7 @@ impl Debug for ReadHalf {
302309
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303310
f.debug_struct("ReadHalf")
304311
.field("pair", &self.pair)
305-
.field("is_closed", &self.is_closed)
312+
.field("is_closed", &self.is_closed())
306313
.finish()
307314
}
308315
}
@@ -422,7 +429,7 @@ impl Debug for WriteHalf {
422429

423430
impl AsyncRead for ReadHalf {
424431
fn poll_read(
425-
mut self: Pin<&mut Self>,
432+
self: Pin<&mut Self>,
426433
cx: &mut Context<'_>,
427434
buf: &mut ReadBuf,
428435
) -> Poll<Result<()>> {

tests/tcp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ fn peek_empty_buffer() -> Result {
821821
});
822822

823823
sim.client("client", async move {
824-
let mut s = TcpStream::connect(("server", PORT)).await?;
824+
let s = TcpStream::connect(("server", PORT)).await?;
825825

826826
// no-op peek with empty buffer
827827
let mut buf = [0; 0];

0 commit comments

Comments
 (0)