Skip to content

Commit 266ff72

Browse files
[utils] Add Reservation (#3683)
1 parent f63e0d0 commit 266ff72

3 files changed

Lines changed: 272 additions & 16 deletions

File tree

utils/src/channel/fallible.rs

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
//! let result = sender.request(|tx| Message::Query { responder: tx }).await;
1717
//! ```
1818
19-
use super::{mpsc, oneshot};
19+
use super::{
20+
mpsc, oneshot,
21+
reservation::{Reservation, ReservationExt},
22+
};
23+
use std::future::Future;
2024

2125
/// Extension trait for channel operations that may fail due to disconnection.
2226
///
@@ -44,7 +48,7 @@ pub trait FallibleExt<T> {
4448
/// .request(|tx| Message::Dialable { responder: tx })
4549
/// .await;
4650
/// ```
47-
fn request<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = Option<R>> + Send
51+
fn request<R, F>(&self, make_msg: F) -> impl Future<Output = Option<R>> + Send
4852
where
4953
R: Send,
5054
F: FnOnce(oneshot::Sender<R>) -> T + Send;
@@ -53,11 +57,7 @@ pub trait FallibleExt<T> {
5357
///
5458
/// This is a convenience wrapper around [`request`](Self::request) for cases
5559
/// where you have a sensible default value.
56-
fn request_or<R, F>(
57-
&self,
58-
make_msg: F,
59-
default: R,
60-
) -> impl std::future::Future<Output = R> + Send
60+
fn request_or<R, F>(&self, make_msg: F, default: R) -> impl Future<Output = R> + Send
6161
where
6262
R: Send,
6363
F: FnOnce(oneshot::Sender<R>) -> T + Send;
@@ -66,7 +66,7 @@ pub trait FallibleExt<T> {
6666
///
6767
/// This is a convenience wrapper around [`request`](Self::request) for types
6868
/// that implement [`Default`].
69-
fn request_or_default<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = R> + Send
69+
fn request_or_default<R, F>(&self, make_msg: F) -> impl Future<Output = R> + Send
7070
where
7171
R: Default + Send,
7272
F: FnOnce(oneshot::Sender<R>) -> T + Send;
@@ -116,7 +116,7 @@ pub trait AsyncFallibleExt<T> {
116116
/// may have been dropped during shutdown. The return value can
117117
/// be ignored if the caller doesn't need to know whether the
118118
/// send succeeded.
119-
fn send_lossy(&self, msg: T) -> impl std::future::Future<Output = bool> + Send;
119+
fn send_lossy(&self, msg: T) -> impl Future<Output = bool> + Send;
120120

121121
/// Try to send a message without blocking, returning `true` if successful.
122122
///
@@ -125,28 +125,32 @@ pub trait AsyncFallibleExt<T> {
125125
/// disconnected.
126126
fn try_send_lossy(&self, msg: T) -> bool;
127127

128+
/// Attempts to send immediately, reserving the message when the channel is full.
129+
///
130+
/// Returns `None` if the value was sent immediately or the receiver has been dropped.
131+
#[must_use = "await and send any reservation"]
132+
fn send_or_reserve_lossy(&self, msg: T) -> Option<Reservation<T>>
133+
where
134+
T: 'static;
135+
128136
/// Send a request message containing a oneshot responder and await the response.
129137
///
130138
/// Returns `None` if:
131139
/// - The receiver has been dropped (send fails)
132140
/// - The responder is dropped without sending (receive fails)
133-
fn request<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = Option<R>> + Send
141+
fn request<R, F>(&self, make_msg: F) -> impl Future<Output = Option<R>> + Send
134142
where
135143
R: Send,
136144
F: FnOnce(oneshot::Sender<R>) -> T + Send;
137145

138146
/// Send a request and return the provided default on failure.
139-
fn request_or<R, F>(
140-
&self,
141-
make_msg: F,
142-
default: R,
143-
) -> impl std::future::Future<Output = R> + Send
147+
fn request_or<R, F>(&self, make_msg: F, default: R) -> impl Future<Output = R> + Send
144148
where
145149
R: Send,
146150
F: FnOnce(oneshot::Sender<R>) -> T + Send;
147151

148152
/// Send a request and return `R::default()` on failure.
149-
fn request_or_default<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = R> + Send
153+
fn request_or_default<R, F>(&self, make_msg: F) -> impl Future<Output = R> + Send
150154
where
151155
R: Default + Send,
152156
F: FnOnce(oneshot::Sender<R>) -> T + Send;
@@ -161,6 +165,13 @@ impl<T: Send> AsyncFallibleExt<T> for mpsc::Sender<T> {
161165
self.try_send(msg).is_ok()
162166
}
163167

168+
fn send_or_reserve_lossy(&self, msg: T) -> Option<Reservation<T>>
169+
where
170+
T: 'static,
171+
{
172+
self.send_or_reserve(msg).ok().flatten()
173+
}
174+
164175
async fn request<R, F>(&self, make_msg: F) -> Option<R>
165176
where
166177
R: Send,
@@ -365,6 +376,61 @@ mod tests {
365376
assert!(!tx.try_send_lossy(TestMessage::FireAndForget(42)));
366377
}
367378

379+
// send_or_reserve_lossy tests
380+
381+
#[test]
382+
fn test_send_or_reserve_lossy_success() {
383+
let (tx, mut rx) = mpsc::channel(1);
384+
385+
assert!(tx
386+
.send_or_reserve_lossy(TestMessage::FireAndForget(42))
387+
.is_none());
388+
assert!(matches!(rx.try_recv(), Ok(TestMessage::FireAndForget(42))));
389+
}
390+
391+
#[test]
392+
fn test_send_or_reserve_lossy_disconnected() {
393+
let (tx, rx) = mpsc::channel::<TestMessage>(1);
394+
drop(rx);
395+
396+
assert!(tx
397+
.send_or_reserve_lossy(TestMessage::FireAndForget(42))
398+
.is_none());
399+
}
400+
401+
#[test_async]
402+
async fn test_send_or_reserve_lossy_reserves_when_full() {
403+
let (tx, mut rx) = mpsc::channel(1);
404+
tx.try_send(TestMessage::FireAndForget(1)).unwrap();
405+
406+
let reservation = tx
407+
.send_or_reserve_lossy(TestMessage::FireAndForget(2))
408+
.expect("receiver should be open");
409+
410+
assert!(matches!(
411+
rx.recv().await,
412+
Some(TestMessage::FireAndForget(1))
413+
));
414+
reservation.await.unwrap().send();
415+
assert!(matches!(
416+
rx.recv().await,
417+
Some(TestMessage::FireAndForget(2))
418+
));
419+
}
420+
421+
#[test_async]
422+
async fn test_send_or_reserve_lossy_reserved_disconnected() {
423+
let (tx, rx) = mpsc::channel(1);
424+
tx.try_send(TestMessage::FireAndForget(1)).unwrap();
425+
426+
let reservation = tx
427+
.send_or_reserve_lossy(TestMessage::FireAndForget(2))
428+
.expect("receiver should be open");
429+
drop(rx);
430+
431+
assert!(reservation.await.is_err());
432+
}
433+
368434
// OneshotExt tests
369435

370436
#[test]

utils/src/channel/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Utilities for working with channels.
22
33
pub mod fallible;
4+
pub mod reservation;
45
pub mod ring;
56
pub mod tracked;
67

utils/src/channel/reservation.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
//! Channel reservation helpers.
2+
3+
use super::mpsc::{
4+
self,
5+
error::{SendError, TrySendError},
6+
OwnedPermit,
7+
};
8+
use std::{
9+
future::Future,
10+
pin::Pin,
11+
task::{Context, Poll},
12+
};
13+
14+
// The reserve future only reports channel closure; the message value is stored separately.
15+
type ReserveResult<T> = Result<OwnedPermit<T>, SendError<()>>;
16+
17+
// Tokio's `reserve_owned` future is not nameable, so box it instead of exposing a future parameter.
18+
type ReserveFuture<T> = Pin<Box<dyn Future<Output = ReserveResult<T>> + Send>>;
19+
20+
/// A reserved channel slot bundled with the value to send.
21+
#[must_use = "call send to deliver the reserved message"]
22+
pub struct Reserved<T> {
23+
permit: OwnedPermit<T>,
24+
value: T,
25+
}
26+
27+
impl<T> Reserved<T> {
28+
/// Sends the buffered value through the reserved slot.
29+
pub fn send(self) -> mpsc::Sender<T> {
30+
self.permit.send(self.value)
31+
}
32+
}
33+
34+
/// A future that waits for a channel slot and keeps ownership of the value.
35+
#[must_use = "await the reservation to acquire a channel slot"]
36+
pub struct Reservation<T> {
37+
future: ReserveFuture<T>,
38+
value: Option<T>,
39+
}
40+
41+
impl<T> Reservation<T> {
42+
fn new(future: impl Future<Output = ReserveResult<T>> + Send + 'static, value: T) -> Self {
43+
Self {
44+
future: Box::pin(future),
45+
value: Some(value),
46+
}
47+
}
48+
}
49+
50+
impl<T> Unpin for Reservation<T> {}
51+
52+
impl<T> Future for Reservation<T> {
53+
type Output = Result<Reserved<T>, SendError<T>>;
54+
55+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
56+
let permit = match self.future.as_mut().poll(cx) {
57+
Poll::Pending => return Poll::Pending,
58+
Poll::Ready(permit) => permit,
59+
};
60+
let value = self
61+
.value
62+
.take()
63+
.expect("reservation polled after completion");
64+
Poll::Ready(match permit {
65+
Ok(permit) => Ok(Reserved { permit, value }),
66+
Err(SendError(())) => Err(SendError(value)),
67+
})
68+
}
69+
}
70+
71+
/// Extension trait for bounded channel sends that can reserve capacity.
72+
pub trait ReservationExt<T> {
73+
/// Attempts to send immediately, reserving the message when the channel is full.
74+
///
75+
/// Returns:
76+
/// - `Ok(None)` when the value was sent immediately.
77+
/// - `Ok(Some(_))` when the channel was full. Await the reservation and call
78+
/// [`Reserved::send`] to deliver the value.
79+
/// - `Err(_)` when the receiver has been dropped.
80+
#[must_use = "await and send any reservation"]
81+
fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
82+
where
83+
T: 'static;
84+
}
85+
86+
impl<T: Send> ReservationExt<T> for mpsc::Sender<T> {
87+
fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
88+
where
89+
T: 'static,
90+
{
91+
match self.try_send(value) {
92+
Ok(()) => Ok(None),
93+
Err(TrySendError::Full(value)) => {
94+
Ok(Some(Reservation::new(self.clone().reserve_owned(), value)))
95+
}
96+
Err(TrySendError::Closed(value)) => Err(SendError(value)),
97+
}
98+
}
99+
}
100+
101+
#[cfg(test)]
102+
mod tests {
103+
use super::*;
104+
use commonware_macros::test_async;
105+
use std::collections::BTreeMap;
106+
107+
#[test]
108+
fn test_send_or_reserve_sends_immediately() {
109+
let (sender, mut receiver) = mpsc::channel(1);
110+
assert!(sender.send_or_reserve(1).unwrap().is_none());
111+
assert_eq!(receiver.try_recv(), Ok(1));
112+
}
113+
114+
#[test]
115+
fn test_send_or_reserve_closed_returns_value() {
116+
let (sender, receiver) = mpsc::channel(1);
117+
drop(receiver);
118+
119+
match sender.send_or_reserve(1) {
120+
Ok(_) => panic!("send should fail"),
121+
Err(SendError(value)) => assert_eq!(value, 1),
122+
}
123+
}
124+
125+
#[test_async]
126+
async fn test_send_or_reserve_waits_for_capacity() {
127+
let (sender, mut receiver) = mpsc::channel(1);
128+
sender.try_send(1).unwrap();
129+
130+
let reservation = sender
131+
.send_or_reserve(2)
132+
.unwrap()
133+
.expect("channel should be full");
134+
assert_eq!(receiver.recv().await, Some(1));
135+
reservation.await.unwrap().send();
136+
assert_eq!(receiver.recv().await, Some(2));
137+
}
138+
139+
#[test_async]
140+
async fn test_send_or_reserve_returns_value_when_closed_while_waiting() {
141+
let (sender, receiver) = mpsc::channel(1);
142+
sender.try_send(1).unwrap();
143+
144+
let reservation = sender
145+
.send_or_reserve(2)
146+
.unwrap()
147+
.expect("channel should be full");
148+
drop(receiver);
149+
150+
match reservation.await {
151+
Ok(_) => panic!("reservation should fail"),
152+
Err(SendError(value)) => assert_eq!(value, 2),
153+
}
154+
}
155+
156+
#[test_async]
157+
async fn test_send_or_reserve_reservations_can_be_stored() {
158+
let (sender, mut receiver) = mpsc::channel(1);
159+
sender.try_send(0).unwrap();
160+
161+
let mut reservations = Vec::new();
162+
reservations.push(
163+
sender
164+
.send_or_reserve(1)
165+
.unwrap()
166+
.expect("channel should be full"),
167+
);
168+
169+
let mut reservation_map = BTreeMap::new();
170+
reservation_map.insert(
171+
"next",
172+
sender
173+
.send_or_reserve(2)
174+
.unwrap()
175+
.expect("channel should be full"),
176+
);
177+
178+
assert_eq!(receiver.recv().await, Some(0));
179+
reservations.pop().unwrap().await.unwrap().send();
180+
assert_eq!(receiver.recv().await, Some(1));
181+
reservation_map
182+
.remove("next")
183+
.unwrap()
184+
.await
185+
.unwrap()
186+
.send();
187+
assert_eq!(receiver.recv().await, Some(2));
188+
}
189+
}

0 commit comments

Comments
 (0)