diff --git a/actix-ws/CHANGELOG.md b/actix-ws/CHANGELOG.md index e151e1317..9ac282b7a 100644 --- a/actix-ws/CHANGELOG.md +++ b/actix-ws/CHANGELOG.md @@ -2,12 +2,14 @@ ## Unreleased +- feat: Add `handle_with_protocols()` for `Sec-WebSocket-Protocol` negotiation [#479] - feat: Add optional typed message codecs with serde_json support. - feat: Implement `Sink` for `Session` - fix: Ignore empty continuation chunks [#660] - fix: Truncate oversized control-frame payloads to avoid emitting invalid frames [#508] - fix: Fix continuation overflow handling +[#479]: https://github.com/actix/actix-extras/issues/479 [#660]: https://github.com/actix/actix-extras/pull/660 [#508]: https://github.com/actix/actix-extras/issues/508 diff --git a/actix-ws/README.md b/actix-ws/README.md index bb370361f..2b3e70709 100644 --- a/actix-ws/README.md +++ b/actix-ws/README.md @@ -68,6 +68,22 @@ See `examples/json.rs` and run it with: cargo run -p actix-ws --features serde-json --example json ``` +## WebSocket Sub-Protocols + +Use `handle_with_protocols` when your server supports one or more +`Sec-WebSocket-Protocol` values. + +```rust +let (response, session, msg_stream) = actix_ws::handle_with_protocols( + &req, + body, + &["graphql-transport-ws", "graphql-ws"], +)?; +``` + +When there is an overlap, the first protocol offered by the client that the server supports is +returned in the handshake response. + ## Resources - [API Documentation](https://docs.rs/actix-ws) diff --git a/actix-ws/src/lib.rs b/actix-ws/src/lib.rs index 2ede4ef29..2d359dc58 100644 --- a/actix-ws/src/lib.rs +++ b/actix-ws/src/lib.rs @@ -1,6 +1,6 @@ //! WebSockets for Actix Web, without actors. //! -//! For usage, see documentation on [`handle()`]. +//! For usage, see documentation on [`handle()`] and [`handle_with_protocols()`]. #![warn(missing_docs)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] @@ -12,7 +12,7 @@ use actix_http::{ body::{BodyStream, MessageBody}, ws::handshake, }; -use actix_web::{web, HttpRequest, HttpResponse}; +use actix_web::{http::header, web, HttpRequest, HttpResponse}; use tokio::sync::mpsc::channel; mod aggregated; @@ -28,6 +28,8 @@ pub use self::{ /// Begin handling websocket traffic /// +/// To negotiate sub-protocols via `Sec-WebSocket-Protocol`, use [`handle_with_protocols`]. +/// /// ```no_run /// use std::io; /// use actix_web::{middleware::Logger, web, App, HttpRequest, HttpServer, Responder}; @@ -72,7 +74,21 @@ pub fn handle( req: &HttpRequest, body: web::Payload, ) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> { - let mut response = handshake(req.head())?; + handle_with_protocols(req, body, &[]) +} + +/// Begin handling websocket traffic with optional sub-protocol negotiation. +/// +/// The first protocol offered by the client in the `Sec-WebSocket-Protocol` header that also +/// appears in `protocols` is returned in the handshake response. +/// +/// If there is no overlap, no `Sec-WebSocket-Protocol` header is set in the response. +pub fn handle_with_protocols( + req: &HttpRequest, + body: web::Payload, + protocols: &[&str], +) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> { + let mut response = handshake_with_protocols(req, protocols)?; let (tx, rx) = channel(32); Ok(( @@ -83,3 +99,123 @@ pub fn handle( MessageStream::new(body.into_inner()), )) } + +fn handshake_with_protocols( + req: &HttpRequest, + protocols: &[&str], +) -> Result { + let mut response = handshake(req.head())?; + + if let Some(protocol) = select_protocol(req, protocols) { + response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol)); + } + + Ok(response) +} + +fn select_protocol<'a>(req: &'a HttpRequest, protocols: &[&str]) -> Option<&'a str> { + for requested_protocols in req.headers().get_all(header::SEC_WEBSOCKET_PROTOCOL) { + let Ok(requested_protocols) = requested_protocols.to_str() else { + continue; + }; + + for requested_protocol in requested_protocols.split(',').map(str::trim) { + if requested_protocol.is_empty() { + continue; + } + + if protocols + .iter() + .any(|supported_protocol| *supported_protocol == requested_protocol) + { + return Some(requested_protocol); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use actix_web::{ + http::header::{self, HeaderValue}, + test::TestRequest, + HttpRequest, + }; + + use super::handshake_with_protocols; + + fn ws_request(protocols: Option<&'static str>) -> HttpRequest { + let mut req = TestRequest::default() + .insert_header((header::UPGRADE, HeaderValue::from_static("websocket"))) + .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade"))) + .insert_header(( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + )) + .insert_header(( + header::SEC_WEBSOCKET_KEY, + HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="), + )); + + if let Some(protocols) = protocols { + req = req.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocols)); + } + + req.to_http_request() + } + + #[test] + fn handshake_selects_first_supported_client_protocol() { + let req = ws_request(Some("p1,p2,p3")); + + let response = handshake_with_protocols(&req, &["p3", "p2"]) + .unwrap() + .finish(); + + assert_eq!( + response.headers().get(header::SEC_WEBSOCKET_PROTOCOL), + Some(&HeaderValue::from_static("p2")), + ); + } + + #[test] + fn handshake_omits_protocol_header_without_overlap() { + let req = ws_request(Some("p1,p2,p3")); + + let response = handshake_with_protocols(&req, &["graphql"]) + .unwrap() + .finish(); + + assert!(response + .headers() + .get(header::SEC_WEBSOCKET_PROTOCOL) + .is_none()); + } + + #[test] + fn handshake_supports_multiple_protocol_headers() { + let req = TestRequest::default() + .insert_header((header::UPGRADE, HeaderValue::from_static("websocket"))) + .insert_header((header::CONNECTION, HeaderValue::from_static("upgrade"))) + .insert_header(( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + )) + .insert_header(( + header::SEC_WEBSOCKET_KEY, + HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="), + )) + .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p1")) + .append_header((header::SEC_WEBSOCKET_PROTOCOL, "p2")) + .to_http_request(); + + let response = handshake_with_protocols(&req, &["p2"]).unwrap().finish(); + + assert_eq!( + response.headers().get(header::SEC_WEBSOCKET_PROTOCOL), + Some(&HeaderValue::from_static("p2")), + ); + } +}