Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions actix-ws/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

## Unreleased

- feat: Add `handle_with_protocols()` for `Sec-WebSocket-Protocol` negotiation [#479]
- feat: Add optional typed message codecs with serde_json support.
- feat: Implement `Sink<Message>` for `Session`
- fix: Ignore empty continuation chunks [#660]
- fix: Truncate oversized control-frame payloads to avoid emitting invalid frames [#508]
- fix: Fix continuation overflow handling

[#479]: https://github.com/actix/actix-extras/issues/479
[#660]: https://github.com/actix/actix-extras/pull/660
[#508]: https://github.com/actix/actix-extras/issues/508

Expand Down
16 changes: 16 additions & 0 deletions actix-ws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ See `examples/json.rs` and run it with:
cargo run -p actix-ws --features serde-json --example json
```

## WebSocket Sub-Protocols

Use `handle_with_protocols` when your server supports one or more
`Sec-WebSocket-Protocol` values.

```rust
let (response, session, msg_stream) = actix_ws::handle_with_protocols(
&req,
body,
&["graphql-transport-ws", "graphql-ws"],
)?;
```

When there is an overlap, the first protocol offered by the client that the server supports is
returned in the handshake response.

## Resources

- [API Documentation](https://docs.rs/actix-ws)
Expand Down
142 changes: 139 additions & 3 deletions actix-ws/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! WebSockets for Actix Web, without actors.
//!
//! For usage, see documentation on [`handle()`].
//! For usage, see documentation on [`handle()`] and [`handle_with_protocols()`].

#![warn(missing_docs)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
Expand All @@ -12,7 +12,7 @@ use actix_http::{
body::{BodyStream, MessageBody},
ws::handshake,
};
use actix_web::{web, HttpRequest, HttpResponse};
use actix_web::{http::header, web, HttpRequest, HttpResponse};
use tokio::sync::mpsc::channel;

mod aggregated;
Expand All @@ -28,6 +28,8 @@ pub use self::{

/// Begin handling websocket traffic
///
/// To negotiate sub-protocols via `Sec-WebSocket-Protocol`, use [`handle_with_protocols`].
///
/// ```no_run
/// use std::io;
/// use actix_web::{middleware::Logger, web, App, HttpRequest, HttpServer, Responder};
Expand Down Expand Up @@ -72,7 +74,21 @@ pub fn handle(
req: &HttpRequest,
body: web::Payload,
) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
let mut response = handshake(req.head())?;
handle_with_protocols(req, body, &[])
}

/// Begin handling websocket traffic with optional sub-protocol negotiation.
///
/// The first protocol offered by the client in the `Sec-WebSocket-Protocol` header that also
/// appears in `protocols` is returned in the handshake response.
///
/// If there is no overlap, no `Sec-WebSocket-Protocol` header is set in the response.
pub fn handle_with_protocols(
req: &HttpRequest,
body: web::Payload,
protocols: &[&str],
) -> Result<(HttpResponse, Session, MessageStream), actix_web::Error> {
let mut response = handshake_with_protocols(req, protocols)?;
let (tx, rx) = channel(32);

Ok((
Expand All @@ -83,3 +99,123 @@ pub fn handle(
MessageStream::new(body.into_inner()),
))
}

fn handshake_with_protocols(
req: &HttpRequest,
protocols: &[&str],
) -> Result<actix_http::ResponseBuilder, actix_http::ws::HandshakeError> {
let mut response = handshake(req.head())?;

if let Some(protocol) = select_protocol(req, protocols) {
response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
}

Ok(response)
}

fn select_protocol<'a>(req: &'a HttpRequest, protocols: &[&str]) -> Option<&'a str> {
for requested_protocols in req.headers().get_all(header::SEC_WEBSOCKET_PROTOCOL) {
let Ok(requested_protocols) = requested_protocols.to_str() else {
continue;
};

for requested_protocol in requested_protocols.split(',').map(str::trim) {
if requested_protocol.is_empty() {
continue;
}

if protocols
.iter()
.any(|supported_protocol| *supported_protocol == requested_protocol)
{
return Some(requested_protocol);
}
}
}

None
}

#[cfg(test)]
mod tests {
use actix_web::{
http::header::{self, HeaderValue},
test::TestRequest,
HttpRequest,
};

use super::handshake_with_protocols;

fn ws_request(protocols: Option<&'static str>) -> HttpRequest {
let mut req = TestRequest::default()
.insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
.insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
.insert_header((
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
))
.insert_header((
header::SEC_WEBSOCKET_KEY,
HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
));

if let Some(protocols) = protocols {
req = req.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocols));
}

req.to_http_request()
}

#[test]
fn handshake_selects_first_supported_client_protocol() {
let req = ws_request(Some("p1,p2,p3"));

let response = handshake_with_protocols(&req, &["p3", "p2"])
.unwrap()
.finish();

assert_eq!(
response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
Some(&HeaderValue::from_static("p2")),
);
}

#[test]
fn handshake_omits_protocol_header_without_overlap() {
let req = ws_request(Some("p1,p2,p3"));

let response = handshake_with_protocols(&req, &["graphql"])
.unwrap()
.finish();

assert!(response
.headers()
.get(header::SEC_WEBSOCKET_PROTOCOL)
.is_none());
}

#[test]
fn handshake_supports_multiple_protocol_headers() {
let req = TestRequest::default()
.insert_header((header::UPGRADE, HeaderValue::from_static("websocket")))
.insert_header((header::CONNECTION, HeaderValue::from_static("upgrade")))
.insert_header((
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
))
.insert_header((
header::SEC_WEBSOCKET_KEY,
HeaderValue::from_static("x3JJHMbDL1EzLkh9GBhXDw=="),
))
.append_header((header::SEC_WEBSOCKET_PROTOCOL, "p1"))
.append_header((header::SEC_WEBSOCKET_PROTOCOL, "p2"))
.to_http_request();

let response = handshake_with_protocols(&req, &["p2"]).unwrap().finish();

assert_eq!(
response.headers().get(header::SEC_WEBSOCKET_PROTOCOL),
Some(&HeaderValue::from_static("p2")),
);
}
}
Loading