Skip to content

Commit f7ff305

Browse files
authored
Merge pull request astarte-platform#447 from joshuachp/fix/future-poll-mqtt
fix(mqtt): use a manual future to poll the retention tokens
2 parents a01365d + 0c3267b commit f7ff305

File tree

2 files changed

+39
-44
lines changed

2 files changed

+39
-44
lines changed

src/transport/mqtt/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub(crate) mod mock {
5555

5656
mock! {
5757
pub EventLoop{
58-
// If we don't return a future, the pool function will loop
58+
// If we don't return a future, the poll function will loop
5959
pub fn poll(&mut self) -> impl std::future::Future<Output = Result<Event, ConnectionError>> + Send + 'static;
6060
pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self;
6161
pub fn clean(&mut self);

src/transport/mqtt/retention.rs

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
//! When an interface major version is updated the retention cache must be invalidated. Since the
2626
//! payload will be publish on the new introspection.
2727
28+
use std::future::Future;
29+
use std::pin::Pin;
2830
use std::{collections::HashMap, future::IntoFuture, task::Poll};
2931

3032
use rumqttc::{AckOfPub, Token, TokenError};
31-
use tracing::trace;
33+
use tracing::{trace, warn};
3234

3335
use crate::retention::RetentionId;
3436

@@ -75,23 +77,6 @@ impl MqttRetention {
7577

7678
count
7779
}
78-
79-
fn next_received(&mut self) -> Option<Result<RetentionId, TokenError>> {
80-
let (id, res) = self
81-
.packets
82-
.iter_mut()
83-
.find_map(|(id, v)| match v.check() {
84-
Ok(_) => Some((*id, Ok(*id))),
85-
Err(TokenError::Waiting) => None,
86-
Err(TokenError::Disconnected) => Some((*id, Err(TokenError::Disconnected))),
87-
})?;
88-
89-
self.packets.remove(&id);
90-
91-
trace!("remove packet {id}");
92-
93-
Some(res)
94-
}
9580
}
9681

9782
impl<'a> IntoFuture for &'a mut MqttRetention {
@@ -104,29 +89,42 @@ impl<'a> IntoFuture for &'a mut MqttRetention {
10489
}
10590
}
10691

107-
impl Iterator for MqttRetention {
108-
type Item = Result<RetentionId, TokenError>;
109-
110-
fn next(&mut self) -> Option<Self::Item> {
111-
self.next_received()
112-
}
113-
}
114-
11592
pub(crate) struct MqttRetentionFuture<'a>(&'a mut MqttRetention);
11693

11794
impl std::future::Future for MqttRetentionFuture<'_> {
11895
type Output = Result<RetentionId, TokenError>;
11996

120-
fn poll(
121-
self: std::pin::Pin<&mut Self>,
122-
_cx: &mut std::task::Context<'_>,
123-
) -> Poll<Self::Output> {
124-
let this = self.get_mut();
125-
126-
this.0.queue();
127-
128-
match this.0.next() {
129-
Some(res) => Poll::Ready(res),
97+
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
98+
let this = &mut *self.get_mut().0;
99+
100+
this.queue();
101+
102+
let first = this.packets.iter_mut().find_map(|(id, token)| {
103+
let poll = <Token<AckOfPub> as Future>::poll(Pin::new(token), cx);
104+
105+
match poll {
106+
Poll::Pending => None,
107+
Poll::Ready(Ok(_)) => Some((*id, Ok(*id))),
108+
Poll::Ready(Err(TokenError::Waiting)) => {
109+
warn!(%id, "future returned Ready(Waiting), this should not happend and it could lead to errors on the next poll");
110+
111+
// NOTE: we could return None here, but after some consideration it's safer to
112+
// error and drop the token instead of risking a panic if we poll the
113+
// Future again
114+
Some((*id, Err(TokenError::Disconnected)))
115+
}
116+
Poll::Ready(Err(TokenError::Disconnected)) => {
117+
Some((*id, Err(TokenError::Disconnected)))
118+
}
119+
}
120+
});
121+
122+
match first {
123+
Some((id, res)) => {
124+
this.packets.remove(&id);
125+
126+
Poll::Ready(res)
127+
}
130128
None => Poll::Pending,
131129
}
132130
}
@@ -140,8 +138,8 @@ mod tests {
140138

141139
use super::*;
142140

143-
#[test]
144-
fn should_queue_and_get_next() {
141+
#[tokio::test]
142+
async fn should_queue_and_get_next() {
145143
let (tx, rx) = flume::unbounded();
146144

147145
let mut retention = MqttRetention::new(rx);
@@ -163,16 +161,13 @@ mod tests {
163161

164162
assert_eq!(retention.queue(), 3);
165163

166-
let n = retention.next();
167-
assert!(n.is_none());
168-
169164
t2.resolve(AckOfPub::None);
170165

171-
let n = retention.next().unwrap().unwrap();
166+
let n = retention.into_future().await.unwrap();
172167
assert_eq!(n, RetentionId::Stored(i2));
173168

174169
drop(t1);
175-
let res = retention.next().unwrap();
170+
let res = retention.into_future().await;
176171
assert!(res.is_err(), "expected error but got {:?}", res.unwrap());
177172
}
178173
}

0 commit comments

Comments
 (0)