Skip to content

Commit 4ffd473

Browse files
authored
feat(socketio/connect): Add middlewares to namespace (#280)
* feat(socketio/handler): add a `with` middleware fn to any connect handler. * feat(socketio/packet): custom `connect_error` packet * feat(socketio/handler): return error from middleware. * fix(clippy): async fn call rather than impl `Future` * fix(fmt) * chore: bump MSRV to 1.75.0 * test(socketio/connect): add middleware tests * feat(socketio/connect): connect to ns only after connect handler result. * fix(fmt) * fix(clippy) * test(socketio/ack): fix ack test for new `connect` handler behaviour * test(io): unused `Result` lint * feat(socketio/conect): correct behaviour with connect after middleware * Revert "test(socketio/ack): fix ack test for new `connect` handler behaviour" This reverts commit 53ab208. * test(socketio/connect): fix ws message assertion * test: minor improvements * doc(socketio/connect): improve doc and code readability * feat(socketio/connect): emit extractor errors in middlewares * chore(bench/heaptrack): add middleware to bench * doc(socketio/connect): wip doc * doc(socketio/connect): wip doc * doc(socketio/connect): wip doc * test(socket): set connected for dummy socket * test(socket): add test for connect status * feat(socketio/connect): block emission if socket is not connected * feat(socketio/socket): disconnect status before calling `disconnect` handler. * doc(socketio/connect): document middlewares * chore(bench): remove middleware from bench * doc(socketio/connect): specify middleware behavior for `Data` extractor * doc(socketio/connect): specify middleware behavior for `Data` extractor
1 parent 39d700a commit 4ffd473

File tree

15 files changed

+581
-78
lines changed

15 files changed

+581
-78
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[workspace.package]
22
version = "0.11.0"
33
edition = "2021"
4-
rust-version = "1.67.0"
4+
rust-version = "1.75.0"
55
authors = ["Théodore Prévot <"]
66
repository = "https://github.com/totodore/socketioxide"
77
homepage = "https://github.com/totodore/socketioxide"

examples/private-messaging/src/handlers.rs

+6-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use anyhow::anyhow;
22
use serde::{Deserialize, Serialize};
33
use socketioxide::extract::{Data, SocketRef, State, TryData};
4-
use tracing::error;
54
use uuid::Uuid;
65

76
use crate::store::{Message, Messages, Session, Sessions};
@@ -39,18 +38,7 @@ struct PrivateMessageReq {
3938
content: String,
4039
}
4140

42-
pub fn on_connection(
43-
s: SocketRef,
44-
TryData(auth): TryData<Auth>,
45-
sessions: State<Sessions>,
46-
msgs: State<Messages>,
47-
) {
48-
if let Err(e) = session_connect(&s, auth, sessions.0, msgs.0) {
49-
error!("Failed to connect: {:?}", e);
50-
s.disconnect().ok();
51-
return;
52-
}
53-
41+
pub fn on_connection(s: SocketRef) {
5442
s.on(
5543
"private message",
5644
|s: SocketRef, Data(PrivateMessageReq { to, content }), State(Messages(msg))| {
@@ -83,11 +71,11 @@ pub fn on_connection(
8371
}
8472

8573
/// Handles the connection of a new user
86-
fn session_connect(
87-
s: &SocketRef,
88-
auth: Result<Auth, serde_json::Error>,
89-
Sessions(session_state): &Sessions,
90-
Messages(msg_state): &Messages,
74+
pub fn authenticate_middleware(
75+
s: SocketRef,
76+
TryData(auth): TryData<Auth>,
77+
State(Sessions(session_state)): State<Sessions>,
78+
State(Messages(msg_state)): State<Messages>,
9179
) -> Result<(), anyhow::Error> {
9280
let auth = auth?;
9381
let mut sessions = session_state.write().unwrap();

examples/private-messaging/src/main.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use socketioxide::SocketIo;
1+
use socketioxide::{handler::ConnectHandler, SocketIo};
22
use tower::ServiceBuilder;
33
use tower_http::{cors::CorsLayer, services::ServeDir};
44
use tracing::info;
@@ -22,7 +22,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2222
.with_state(Messages::default())
2323
.build_layer();
2424

25-
io.ns("/", handlers::on_connection);
25+
io.ns(
26+
"/",
27+
handlers::on_connection.with(handlers::authenticate_middleware),
28+
);
2629

2730
let app = axum::Router::new()
2831
.nest_service("/", ServeDir::new("dist"))

socketioxide/src/client.rs

+15-9
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,21 @@ impl<A: Adapter> Client<A> {
4646
#[cfg(feature = "tracing")]
4747
tracing::debug!("auth: {:?}", auth);
4848

49-
let sid = esocket.id;
5049
if let Some(ns) = self.get_ns(ns_path) {
51-
ns.connect(sid, esocket.clone(), auth, self.config.clone())?;
52-
53-
// cancel the connect timeout task for v5
54-
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
55-
tx.send(()).ok();
56-
}
57-
50+
let esocket = esocket.clone();
51+
let config = self.config.clone();
52+
tokio::spawn(async move {
53+
if ns
54+
.connect(esocket.id, esocket.clone(), auth, config)
55+
.await
56+
.is_ok()
57+
{
58+
// cancel the connect timeout task for v5
59+
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
60+
tx.send(()).ok();
61+
}
62+
}
63+
});
5864
Ok(())
5965
} else if ProtocolVersion::from(esocket.protocol) == ProtocolVersion::V4 && ns_path == "/" {
6066
#[cfg(feature = "tracing")]
@@ -64,7 +70,7 @@ impl<A: Adapter> Client<A> {
6470
esocket.close(EIoDisconnectReason::TransportClose);
6571
Ok(())
6672
} else {
67-
let packet = Packet::invalid_namespace(ns_path).into();
73+
let packet = Packet::connect_error(ns_path, "Invalid namespace").into();
6874
if let Err(_e) = esocket.emit(packet) {
6975
#[cfg(feature = "tracing")]
7076
tracing::error!("error while sending invalid namespace packet: {}", _e);

socketioxide/src/errors.rs

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub enum Error {
2424
Adapter(#[from] AdapterError),
2525
}
2626

27+
pub(crate) struct ConnectFail;
28+
2729
/// Error type for ack operations.
2830
#[derive(thiserror::Error, Debug)]
2931
pub enum AckError<T> {

0 commit comments

Comments
 (0)