Skip to content

Commit 0e0fe75

Browse files
authored
feat(socketio): return emitted data in case of emission error. (#262)
* feat(engineio/socket): add a `PermitIterator` + `reserve` fn * feat(socketio/error): hold back a value for `SocketError` * feat(socketio/socket): use permits for `emit` / `ack_emit`. * fix(socketio/socket): move emit with permit from `Socket` to `PermitIterator` * fix(socketio/errors): fix debug impl for `BroadcastError` * docs(example/private-messaging): fix error handling * feat(socketio/socket): refactor `Socket::send` fn to use `PermitIteratorExt` * fix(socketio): fix fmt * fix(socketio/socket): clippy lints * feat(socketio): splitted operators * feat(socketio/op): socket reference in `Operators` * fix(socketio/errors): add `Sync` bound on `adapter` errs. By adding a `Sync` bound on the `dyn Error` it is then possible to convert it to an `anyhow::error`. * doc(examples): fix `emit_with_ack` fn * fix(clippy): `field assignment outside of initializer for an instance created with Default::default()` * feat(socketio/socket): `emit_with_ack` return value in case of error * test(socketio/socket): fix tests for new `emit_with_ack` * fix(clippy): `useless conversion to the same type: errors::SocketError<()>` * doc(socketio/socket): improve doc to match new `emit_*` fns. * feat(engineio/socket): inline exposed emit fns * doc(socketio/op): fix doctest use after moved
1 parent a11edd6 commit 0e0fe75

File tree

14 files changed

+735
-208
lines changed

14 files changed

+735
-208
lines changed

e2e/socketioxide/socketioxide.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ fn on_connect(socket: SocketRef, Data(data): Data<Value>) {
3939
|s: SocketRef, Data::<Value>(data), Bin(bin)| async move {
4040
let ack = s
4141
.bin(bin)
42-
.emit_with_ack::<Value>("emit-with-ack", data)
42+
.emit_with_ack::<_, Value>("emit-with-ack", data)
4343
.unwrap()
4444
.await
4545
.unwrap();

engineioxide/src/socket.rs

+52
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,47 @@ impl From<&Error> for Option<DisconnectReason> {
115115
}
116116
}
117117

118+
/// A permit to emit a message to the client.
119+
/// A permit holds a place in the internal channel to send one packet to the client.
120+
pub struct Permit<'a> {
121+
inner: mpsc::Permit<'a, Packet>,
122+
}
123+
impl Permit<'_> {
124+
/// Consume the permit and emit a message to the client.
125+
#[inline]
126+
pub fn emit(self, msg: String) {
127+
self.inner.send(Packet::Message(msg));
128+
}
129+
/// Consume the permit and emit a binary message to the client.
130+
#[inline]
131+
pub fn emit_binary(self, data: Vec<u8>) {
132+
self.inner.send(Packet::Binary(data));
133+
}
134+
}
135+
136+
/// An [`Iterator`] over the permits returned by the [`reserve`](Socket::reserve) function
137+
#[derive(Debug)]
138+
pub struct PermitIterator<'a> {
139+
inner: mpsc::PermitIterator<'a, Packet>,
140+
}
141+
142+
impl<'a> Iterator for PermitIterator<'a> {
143+
type Item = Permit<'a>;
144+
145+
#[inline]
146+
fn next(&mut self) -> Option<Self::Item> {
147+
let inner = self.inner.next()?;
148+
Some(Permit { inner })
149+
}
150+
}
151+
impl ExactSizeIterator for PermitIterator<'_> {
152+
#[inline]
153+
fn len(&self) -> usize {
154+
self.inner.len()
155+
}
156+
}
157+
impl std::iter::FusedIterator for PermitIterator<'_> {}
158+
118159
/// A [`Socket`] represents a client connection to the server.
119160
/// It is agnostic to the [`TransportType`].
120161
///
@@ -348,6 +389,17 @@ where
348389
TransportType::from(self.transport.load(Ordering::Relaxed))
349390
}
350391

392+
/// Reserve `n` permits to emit multiple messages and ensure that there is enough
393+
/// space in the internal chan.
394+
///
395+
/// If the internal chan is full, the function will return a [`TrySendError::Full`] error.
396+
/// If the socket is closed, the function will return a [`TrySendError::Closed`] error.
397+
#[inline]
398+
pub fn reserve(&self, n: usize) -> Result<PermitIterator<'_>, TrySendError<()>> {
399+
let inner = self.internal_tx.try_reserve_many(n)?;
400+
Ok(PermitIterator { inner })
401+
}
402+
351403
/// Emits a message to the client.
352404
///
353405
/// If the transport is in websocket mode, the message is directly sent as a text frame.

examples/private-messaging/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ tracing-subscriber.workspace = true
1818
tracing.workspace = true
1919
serde.workspace = true
2020
serde_json.workspace = true
21+
anyhow = "1.0"

examples/private-messaging/src/handlers.rs

+7-18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::anyhow;
12
use serde::{Deserialize, Serialize};
23
use socketioxide::extract::{Data, SocketRef, State, TryData};
34
use tracing::error;
@@ -81,28 +82,20 @@ pub fn on_connection(
8182
});
8283
}
8384

84-
#[derive(Debug)]
85-
enum ConnectError {
86-
InvalidUsername,
87-
EncodeError(serde_json::Error),
88-
SocketError(socketioxide::SendError),
89-
BroadcastError(socketioxide::BroadcastError),
90-
}
91-
9285
/// Handles the connection of a new user
9386
fn session_connect(
9487
s: &SocketRef,
9588
auth: Result<Auth, serde_json::Error>,
9689
Sessions(session_state): &Sessions,
9790
Messages(msg_state): &Messages,
98-
) -> Result<(), ConnectError> {
99-
let auth = auth.map_err(ConnectError::EncodeError)?;
91+
) -> Result<(), anyhow::Error> {
92+
let auth = auth?;
10093
let mut sessions = session_state.write().unwrap();
10194
if let Some(session) = auth.session_id.and_then(|id| sessions.get_mut(&id)) {
10295
session.connected = true;
10396
s.extensions.insert(session.clone());
10497
} else {
105-
let username = auth.username.ok_or(ConnectError::InvalidUsername)?;
98+
let username = auth.username.ok_or(anyhow!("invalid username"))?;
10699
let session = Session::new(username);
107100
s.extensions.insert(session.clone());
108101

@@ -113,8 +106,7 @@ fn session_connect(
113106
let session = s.extensions.get::<Session>().unwrap();
114107

115108
s.join(session.user_id.to_string()).ok();
116-
s.emit("session", session.clone())
117-
.map_err(ConnectError::SocketError)?;
109+
s.emit("session", session.clone())?;
118110

119111
let users = session_state
120112
.read()
@@ -134,13 +126,10 @@ fn session_connect(
134126
})
135127
.collect::<Vec<_>>();
136128

137-
s.emit("users", [users])
138-
.map_err(ConnectError::SocketError)?;
129+
s.emit("users", [users])?;
139130

140131
let res = UserConnectedRes::new(&session, vec![]);
141132

142-
s.broadcast()
143-
.emit("user connected", res)
144-
.map_err(ConnectError::BroadcastError)?;
133+
s.broadcast().emit("user connected", res)?;
145134
Ok(())
146135
}

socketioxide/src/ack.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub struct AckResponse<T> {
3434
pub binary: Vec<Vec<u8>>,
3535
}
3636

37-
pub(crate) type AckResult<T = Value> = Result<AckResponse<T>, AckError>;
37+
pub(crate) type AckResult<T = Value> = Result<AckResponse<T>, AckError<()>>;
3838

3939
pin_project_lite::pin_project! {
4040
/// A [`Future`] of [`AckResponse`] received from the client with its corresponding [`Sid`].
@@ -56,7 +56,7 @@ impl<T> Future for AckResultWithId<T> {
5656
let v = match v {
5757
Ok(Ok(Ok(v))) => Ok(v),
5858
Ok(Ok(Err(e))) => Err(e),
59-
Ok(Err(_)) => Err(AckError::Socket(SocketError::Closed)),
59+
Ok(Err(_)) => Err(AckError::Socket(SocketError::Closed(()))),
6060
Err(_) => Err(AckError::Timeout),
6161
};
6262
Poll::Ready((*project.id, v))
@@ -102,7 +102,7 @@ pin_project_lite::pin_project! {
102102
/// let (svc, io) = SocketIo::new_svc();
103103
/// io.ns("/test", move |socket: SocketRef| async move {
104104
/// // We wait for the acknowledgement of the first emit (only one in this case)
105-
/// let ack = socket.emit_with_ack::<String>("test", "test").unwrap().await;
105+
/// let ack = socket.emit_with_ack::<_, String>("test", "test").unwrap().await;
106106
/// println!("Ack: {:?}", ack);
107107
///
108108
/// // We apply the `for_each` StreamExt fn to the AckStream
@@ -466,7 +466,7 @@ mod test {
466466
socket2.disconnect().unwrap();
467467
let (id, ack) = stream.next().await.unwrap();
468468
assert_eq!(id, sid);
469-
assert!(matches!(ack, Err(AckError::Socket(SocketError::Closed))));
469+
assert!(matches!(ack, Err(AckError::Socket(SocketError::Closed(_)))));
470470
assert!(stream.next().await.is_none());
471471
}
472472
#[tokio::test]
@@ -481,7 +481,7 @@ mod test {
481481

482482
assert!(matches!(
483483
stream.next().await.unwrap().1.unwrap_err(),
484-
AckError::Socket(SocketError::Closed)
484+
AckError::Socket(SocketError::Closed(_))
485485
));
486486
}
487487

@@ -495,7 +495,7 @@ mod test {
495495

496496
assert!(matches!(
497497
stream.await.unwrap_err(),
498-
AckError::Socket(SocketError::Closed)
498+
AckError::Socket(SocketError::Closed(_))
499499
));
500500
}
501501

socketioxide/src/adapter.rs

+48-33
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,12 @@ pub type Room = Cow<'static, str>;
3131
pub enum BroadcastFlags {
3232
/// Broadcast only to the current server
3333
Local,
34-
/// Broadcast to all servers
34+
/// Broadcast to all clients except the sender
3535
Broadcast,
36-
/// Add a custom timeout to the ack callback
37-
Timeout(Duration),
3836
}
3937

4038
/// Options that can be used to modify the behavior of the broadcast methods.
41-
#[derive(Clone, Debug)]
39+
#[derive(Clone, Debug, Default)]
4240
pub struct BroadcastOptions {
4341
/// The flags to apply to the broadcast.
4442
pub flags: HashSet<BroadcastFlags>,
@@ -49,24 +47,13 @@ pub struct BroadcastOptions {
4947
/// The socket id of the sender.
5048
pub sid: Option<Sid>,
5149
}
52-
impl BroadcastOptions {
53-
pub(crate) fn new(sid: Option<Sid>) -> Self {
54-
Self {
55-
flags: HashSet::new(),
56-
rooms: HashSet::new(),
57-
except: HashSet::new(),
58-
sid,
59-
}
60-
}
61-
}
62-
6350
//TODO: Make an AsyncAdapter trait
6451
/// An adapter is responsible for managing the state of the server.
6552
/// This adapter can be implemented to share the state between multiple servers.
6653
/// The default adapter is the [`LocalAdapter`], which stores the state in memory.
6754
pub trait Adapter: std::fmt::Debug + Send + Sync + 'static {
6855
/// An error that can occur when using the adapter. The default [`LocalAdapter`] has an [`Infallible`] error.
69-
type Error: std::error::Error + Into<AdapterError> + Send + 'static;
56+
type Error: std::error::Error + Into<AdapterError> + Send + Sync + 'static;
7057

7158
/// Create a new adapter and give the namespace ref to retrieve sockets.
7259
fn new(ns: Weak<Namespace<Self>>) -> Self
@@ -92,8 +79,12 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static {
9279
fn broadcast(&self, packet: Packet<'_>, opts: BroadcastOptions) -> Result<(), BroadcastError>;
9380

9481
/// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
95-
fn broadcast_with_ack(&self, packet: Packet<'static>, opts: BroadcastOptions)
96-
-> AckInnerStream;
82+
fn broadcast_with_ack(
83+
&self,
84+
packet: Packet<'static>,
85+
opts: BroadcastOptions,
86+
timeout: Option<Duration>,
87+
) -> AckInnerStream;
9788

9889
/// Returns the sockets ids that match the [`BroadcastOptions`].
9990
fn sockets(&self, rooms: impl RoomParam) -> Result<Vec<Sid>, Self::Error>;
@@ -211,23 +202,20 @@ impl Adapter for LocalAdapter {
211202
&self,
212203
packet: Packet<'static>,
213204
opts: BroadcastOptions,
205+
timeout: Option<Duration>,
214206
) -> AckInnerStream {
215-
let duration = opts.flags.iter().find_map(|flag| match flag {
216-
BroadcastFlags::Timeout(duration) => Some(*duration),
217-
_ => None,
218-
});
219207
let sockets = self.apply_opts(opts);
220208
#[cfg(feature = "tracing")]
221209
tracing::debug!(
222210
"broadcasting packet to {} sockets: {:?}",
223211
sockets.len(),
224212
sockets.iter().map(|s| s.id).collect::<Vec<_>>()
225213
);
226-
AckInnerStream::broadcast(packet, sockets, duration)
214+
AckInnerStream::broadcast(packet, sockets, timeout)
227215
}
228216

229217
fn sockets(&self, rooms: impl RoomParam) -> Result<Vec<Sid>, Infallible> {
230-
let mut opts = BroadcastOptions::new(None);
218+
let mut opts = BroadcastOptions::default();
231219
opts.rooms.extend(rooms.into_room_iter());
232220
Ok(self
233221
.apply_opts(opts)
@@ -421,7 +409,10 @@ mod test {
421409
let adapter = LocalAdapter::new(Arc::downgrade(&ns));
422410
adapter.add_all(socket, ["room1"]).unwrap();
423411

424-
let mut opts = BroadcastOptions::new(Some(socket));
412+
let mut opts = BroadcastOptions {
413+
sid: Some(socket),
414+
..Default::default()
415+
};
425416
opts.rooms = hash_set!["room1".into()];
426417
adapter.add_sockets(opts, "room2").unwrap();
427418
let rooms_map = adapter.rooms.read().unwrap();
@@ -438,7 +429,10 @@ mod test {
438429
let adapter = LocalAdapter::new(Arc::downgrade(&ns));
439430
adapter.add_all(socket, ["room1"]).unwrap();
440431

441-
let mut opts = BroadcastOptions::new(Some(socket));
432+
let mut opts = BroadcastOptions {
433+
sid: Some(socket),
434+
..Default::default()
435+
};
442436
opts.rooms = hash_set!["room1".into()];
443437
adapter.add_sockets(opts, "room2").unwrap();
444438

@@ -450,7 +444,10 @@ mod test {
450444
assert!(rooms_map.get("room2").unwrap().contains(&socket));
451445
}
452446

453-
let mut opts = BroadcastOptions::new(Some(socket));
447+
let mut opts = BroadcastOptions {
448+
sid: Some(socket),
449+
..Default::default()
450+
};
454451
opts.rooms = hash_set!["room1".into()];
455452
adapter.del_sockets(opts, "room2").unwrap();
456453

@@ -507,7 +504,10 @@ mod test {
507504
.add_all(socket2, ["room2", "room3", "room6"])
508505
.unwrap();
509506

510-
let mut opts = BroadcastOptions::new(Some(socket0));
507+
let mut opts = BroadcastOptions {
508+
sid: Some(socket0),
509+
..Default::default()
510+
};
511511
opts.rooms = hash_set!["room5".into()];
512512
adapter.disconnect_socket(opts).unwrap();
513513

@@ -533,33 +533,48 @@ mod test {
533533
.unwrap();
534534

535535
// socket 2 is the sender
536-
let mut opts = BroadcastOptions::new(Some(socket2));
536+
let mut opts = BroadcastOptions {
537+
sid: Some(socket2),
538+
..Default::default()
539+
};
537540
opts.rooms = hash_set!["room1".into()];
538541
opts.except = hash_set!["room2".into()];
539542
let sockets = adapter.fetch_sockets(opts).unwrap();
540543
assert_eq!(sockets.len(), 1);
541544
assert_eq!(sockets[0].id, socket1);
542545

543-
let mut opts = BroadcastOptions::new(Some(socket2));
546+
let mut opts = BroadcastOptions {
547+
sid: Some(socket2),
548+
..Default::default()
549+
};
544550
opts.flags.insert(BroadcastFlags::Broadcast);
545551
let sockets = adapter.fetch_sockets(opts).unwrap();
546552
assert_eq!(sockets.len(), 2);
547553
sockets.iter().for_each(|s| {
548554
assert!(s.id == socket0 || s.id == socket1);
549555
});
550556

551-
let mut opts = BroadcastOptions::new(Some(socket2));
557+
let mut opts = BroadcastOptions {
558+
sid: Some(socket2),
559+
..Default::default()
560+
};
552561
opts.flags.insert(BroadcastFlags::Broadcast);
553562
opts.except = hash_set!["room2".into()];
554563
let sockets = adapter.fetch_sockets(opts).unwrap();
555564
assert_eq!(sockets.len(), 1);
556565

557-
let opts = BroadcastOptions::new(Some(socket2));
566+
let opts = BroadcastOptions {
567+
sid: Some(socket2),
568+
..Default::default()
569+
};
558570
let sockets = adapter.fetch_sockets(opts).unwrap();
559571
assert_eq!(sockets.len(), 1);
560572
assert_eq!(sockets[0].id, socket2);
561573

562-
let opts = BroadcastOptions::new(Some(Sid::new()));
574+
let opts = BroadcastOptions {
575+
sid: Some(Sid::new()),
576+
..Default::default()
577+
};
563578
let sockets = adapter.fetch_sockets(opts).unwrap();
564579
assert_eq!(sockets.len(), 0);
565580
}

0 commit comments

Comments
 (0)