Skip to content

Commit 6068058

Browse files
committed
relax Sender and Receiver to not require &mut self
1 parent 119eb61 commit 6068058

2 files changed

Lines changed: 66 additions & 60 deletions

File tree

benches/throughput.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ impl<T: 'static + Send + Default> Routine<T> for Async {
7373
struct Block;
7474

7575
impl<T: 'static + Send + Default> Routine<T> for Block {
76-
fn produce(mut tx: RingSender<T>, limit: usize) -> JoinHandle<usize> {
76+
fn produce(tx: RingSender<T>, limit: usize) -> JoinHandle<usize> {
7777
let producer = iter::from_fn(move || tx.send(T::default()).ok());
7878
task::spawn_blocking(move || producer.take(limit).count())
7979
}
8080

81-
fn consume(mut rx: RingReceiver<T>, limit: usize) -> JoinHandle<usize> {
81+
fn consume(rx: RingReceiver<T>, limit: usize) -> JoinHandle<usize> {
8282
let consumer = iter::from_fn(move || loop {
8383
match rx.try_recv() {
8484
Ok(m) => return Some(m),

src/channel.rs

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl<T> RingSender<T> {
4141
/// * If the internal ring buffer is full, the oldest pending message is overwritten
4242
/// and returned as `Ok(Some(_))`, otherwise `Ok(None)` is returned.
4343
/// * If the channel is disconnected, [`SendError::Disconnected`] is returned.
44-
pub fn send(&mut self, message: T) -> Result<Option<T>, SendError<T>> {
44+
pub fn send(&self, message: T) -> Result<Option<T>, SendError<T>> {
4545
if self.handle.receivers.load(Ordering::Acquire) > 0 {
4646
let overwritten = self.handle.buffer.push(message);
4747

@@ -154,27 +154,13 @@ impl<T> RingReceiver<T> {
154154
}
155155
}
156156

157-
/// Receives a message through the channel (requires [feature] `"futures_api"`).
158-
///
159-
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
160-
/// * If the internal ring buffer is empty, the call blocks until a message is sent
161-
/// or the channel disconnects.
162-
/// * If the channel is disconnected and the internal ring buffer is empty,
163-
/// [`RecvError::Disconnected`] is returned.
164-
///
165-
/// [feature]: index.html#optional-features
166-
#[cfg(feature = "futures_api")]
167-
pub fn recv(&mut self) -> Result<T, RecvError> {
168-
futures::executor::block_on(futures::StreamExt::next(self)).ok_or(RecvError::Disconnected)
169-
}
170-
171157
/// Receives a message through the channel without blocking.
172158
///
173159
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
174160
/// * If the internal ring buffer is empty, [`TryRecvError::Empty`] is returned.
175161
/// * If the channel is disconnected and the internal ring buffer is empty,
176162
/// [`TryRecvError::Disconnected`] is returned.
177-
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
163+
pub fn try_recv(&self) -> Result<T, TryRecvError> {
178164
// We must check whether the channel is connected using acquire ordering before we look at
179165
// the buffer, in order to ensure that the loads associated with popping from the buffer
180166
// happen after the stores associated with a push into the buffer that may have happened
@@ -185,6 +171,49 @@ impl<T> RingReceiver<T> {
185171
self.handle.buffer.pop().ok_or(TryRecvError::Disconnected)
186172
}
187173
}
174+
175+
/// Receives a message through the channel (requires [feature] `"futures_api"`).
176+
///
177+
/// * If the internal ring buffer isn't empty, the oldest pending message is returned.
178+
/// * If the internal ring buffer is empty, the call blocks until a message is sent
179+
/// or the channel disconnects.
180+
/// * If the channel is disconnected and the internal ring buffer is empty,
181+
/// [`RecvError::Disconnected`] is returned.
182+
///
183+
/// [feature]: index.html#optional-features
184+
#[cfg(feature = "futures_api")]
185+
pub fn recv(&self) -> Result<T, RecvError> {
186+
futures::executor::block_on(futures::future::poll_fn(|ctx| self.poll(ctx)))
187+
.ok_or(RecvError::Disconnected)
188+
}
189+
190+
#[cfg(feature = "futures_api")]
191+
fn poll(&self, ctx: &mut Context<'_>) -> Poll<Option<T>> {
192+
match self.try_recv() {
193+
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
194+
self.handle.waitlist.remove(self.slot);
195+
Poll::Ready(result.ok())
196+
}
197+
198+
Err(TryRecvError::Empty) => {
199+
self.handle.waitlist.insert(self.slot, ctx.waker().clone());
200+
201+
// A full memory barrier is necessary to ensure that storing the waker
202+
// happens before attempting to retrieve a message from the buffer.
203+
fence(Ordering::SeqCst);
204+
205+
// Look at the buffer again in case a new message has been sent in the meantime.
206+
match self.try_recv() {
207+
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
208+
self.handle.waitlist.remove(self.slot);
209+
Poll::Ready(result.ok())
210+
}
211+
212+
Err(TryRecvError::Empty) => Poll::Pending,
213+
}
214+
}
215+
}
216+
}
188217
}
189218

190219
impl<T> Clone for RingReceiver<T> {
@@ -221,31 +250,8 @@ impl<T> Drop for RingReceiver<T> {
221250
impl<T> Stream for RingReceiver<T> {
222251
type Item = T;
223252

224-
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225-
match self.try_recv() {
226-
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
227-
self.handle.waitlist.remove(self.slot);
228-
Poll::Ready(result.ok())
229-
}
230-
231-
Err(TryRecvError::Empty) => {
232-
self.handle.waitlist.insert(self.slot, ctx.waker().clone());
233-
234-
// A full memory barrier is necessary to ensure that storing the waker
235-
// happens before attempting to retrieve a message from the buffer.
236-
fence(Ordering::SeqCst);
237-
238-
// Look at the buffer again in case a new message has been sent in the meantime.
239-
match self.try_recv() {
240-
result @ Ok(_) | result @ Err(TryRecvError::Disconnected) => {
241-
self.handle.waitlist.remove(self.slot);
242-
Poll::Ready(result.ok())
243-
}
244-
245-
Err(TryRecvError::Empty) => Poll::Pending,
246-
}
247-
}
248-
}
253+
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254+
self.poll(ctx)
249255
}
250256
}
251257

@@ -270,7 +276,7 @@ impl<T> Stream for RingReceiver<T> {
270276
///
271277
/// // Open a channel to transmit the time elapsed since the beginning of the countdown.
272278
/// // We only need a buffer of size 1, since we're only interested in the current value.
273-
/// let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(1)?);
279+
/// let (tx, rx) = ring_channel(NonZeroUsize::try_from(1)?);
274280
///
275281
/// thread::spawn(move || {
276282
/// let countdown = Instant::now() + Duration::from_secs(10);
@@ -472,7 +478,7 @@ mod tests {
472478
let (tx, _rx) = ring_channel(NonZeroUsize::try_from(cap)?);
473479

474480
rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| {
475-
let mut tx = tx.clone();
481+
let tx = tx.clone();
476482
spawn_blocking(move || assert!(tx.send(msg).is_ok()))
477483
}))?;
478484
}
@@ -486,7 +492,7 @@ mod tests {
486492
let (tx, _) = ring_channel(NonZeroUsize::try_from(cap)?);
487493

488494
rt.block_on(iter(msgs).map(Ok).try_for_each_concurrent(None, |msg| {
489-
let mut tx = tx.clone();
495+
let tx = tx.clone();
490496
spawn_blocking(move || assert_eq!(tx.send(msg), Err(SendError::Disconnected(msg))))
491497
}))?;
492498
}
@@ -496,7 +502,7 @@ mod tests {
496502
#[strategy(1..=10usize)] cap: usize,
497503
#[any(size_range(#cap..=10).lift())] msgs: Vec<u8>,
498504
) {
499-
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
505+
let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
500506
let overwritten = msgs.len() - min(msgs.len(), cap);
501507

502508
for &msg in &msgs[..cap] {
@@ -528,7 +534,7 @@ mod tests {
528534
try_join_all(
529535
iter::repeat(rx)
530536
.take(msgs.len())
531-
.map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())),
537+
.map(|rx| spawn_blocking(move || rx.try_recv().unwrap())),
532538
)
533539
.await
534540
})?;
@@ -552,7 +558,7 @@ mod tests {
552558
try_join_all(
553559
iter::repeat(rx)
554560
.take(msgs.len())
555-
.map(|mut rx| spawn_blocking(move || rx.try_recv().unwrap())),
561+
.map(|rx| spawn_blocking(move || rx.try_recv().unwrap())),
556562
)
557563
.await
558564
})?;
@@ -573,7 +579,7 @@ mod tests {
573579
repeat(rx)
574580
.take(n)
575581
.map(Ok)
576-
.try_for_each_concurrent(None, |mut rx| {
582+
.try_for_each_concurrent(None, |rx| {
577583
spawn_blocking(move || assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)))
578584
}),
579585
)?;
@@ -591,7 +597,7 @@ mod tests {
591597
repeat(rx)
592598
.take(n)
593599
.map(Ok)
594-
.try_for_each_concurrent(None, |mut rx| {
600+
.try_for_each_concurrent(None, |rx| {
595601
spawn_blocking(move || {
596602
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected))
597603
})
@@ -615,7 +621,7 @@ mod tests {
615621
try_join_all(
616622
iter::repeat(rx)
617623
.take(msgs.len())
618-
.map(|mut rx| spawn_blocking(move || rx.recv().unwrap())),
624+
.map(|rx| spawn_blocking(move || rx.recv().unwrap())),
619625
)
620626
.await
621627
})?;
@@ -640,7 +646,7 @@ mod tests {
640646
try_join_all(
641647
iter::repeat(rx)
642648
.take(msgs.len())
643-
.map(|mut rx| spawn_blocking(move || rx.recv().unwrap())),
649+
.map(|rx| spawn_blocking(move || rx.recv().unwrap())),
644650
)
645651
.await
646652
})?;
@@ -662,7 +668,7 @@ mod tests {
662668
repeat(rx)
663669
.take(n)
664670
.map(Ok)
665-
.try_for_each_concurrent(None, |mut rx| {
671+
.try_for_each_concurrent(None, |rx| {
666672
spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected)))
667673
}),
668674
)?;
@@ -685,7 +691,7 @@ mod tests {
685691
let consumer = repeat(rx)
686692
.take(n)
687693
.map(Ok)
688-
.try_for_each_concurrent(None, |mut rx| {
694+
.try_for_each_concurrent(None, |rx| {
689695
spawn_blocking(move || assert_eq!(rx.recv(), Err(RecvError::Disconnected)))
690696
});
691697

@@ -704,14 +710,14 @@ mod tests {
704710
let producer = repeat(tx)
705711
.take(n)
706712
.map(Ok)
707-
.try_for_each_concurrent(None, |mut tx| {
713+
.try_for_each_concurrent(None, |tx| {
708714
spawn_blocking(move || assert!(tx.send(()).is_ok()))
709715
});
710716

711717
let consumer = repeat(rx)
712718
.take(n)
713719
.map(Ok)
714-
.try_for_each_concurrent(None, |mut rx| {
720+
.try_for_each_concurrent(None, |rx| {
715721
spawn_blocking(move || assert_eq!(rx.recv(), Ok(())))
716722
});
717723

@@ -727,7 +733,7 @@ mod tests {
727733
#[any(size_range(1..=10).lift())] msgs: Vec<u8>,
728734
) {
729735
let rt = runtime::Builder::new_multi_thread().build()?;
730-
let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);
736+
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
731737
let overwritten = msgs.len() - min(msgs.len(), cap);
732738

733739
assert_eq!(rt.block_on(iter(&msgs).map(Ok).forward(&mut tx)), Ok(()));
@@ -773,7 +779,7 @@ mod tests {
773779
#[any(size_range(1..=10).lift())] msgs: Vec<u8>,
774780
) {
775781
let rt = runtime::Builder::new_multi_thread().build()?;
776-
let (mut tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
782+
let (tx, rx) = ring_channel(NonZeroUsize::try_from(cap)?);
777783
let overwritten = msgs.len() - min(msgs.len(), cap);
778784

779785
for &msg in &msgs {
@@ -809,7 +815,7 @@ mod tests {
809815
#[cfg(not(miri))] // will_wake sometimes returns false under miri
810816
#[proptest]
811817
fn receiver_withdraws_waker_if_channel_not_empty(#[strategy(1..=10usize)] cap: usize, msg: u8) {
812-
let (mut tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);
818+
let (tx, mut rx) = ring_channel(NonZeroUsize::try_from(cap)?);
813819

814820
let waker = Arc::new(MockWaker).into();
815821
let mut ctx = Context::from_waker(&waker);

0 commit comments

Comments
 (0)