diff --git a/Cargo.toml b/Cargo.toml index d35d747..4bc5b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,12 @@ [workspace] resolver = "2" members = [ - "ratchet_rs", + "ratchet_axum", "ratchet_core", "ratchet_deflate", "ratchet_ext", "ratchet_fixture", + "ratchet_rs", "ratchet_rs/autobahn/client", "ratchet_rs/autobahn/server", "ratchet_rs/autobahn/split_client", @@ -43,3 +44,10 @@ flate2 = { version = "1.0", default-features = false } anyhow = "1.0" serde_json = "1.0" tracing-subscriber = "0.3.18" +hyper = "1.4.1" +axum = "0.7.5" +axum-core = "0.4.3" +hyper-util = "0.1.0" +pin-project = "1.1.5" +async-trait = "0.1.79" +sha1 = "0.10.4" diff --git a/ratchet_axum/Cargo.toml b/ratchet_axum/Cargo.toml new file mode 100644 index 0000000..e9a4cb8 --- /dev/null +++ b/ratchet_axum/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "ratchet_axum" +description = "Axum Integration for Ratchet" +readme = "README.md" +repository = "https://github.com/swimos/ratchet/" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +license.workspace = true +categories.workspace = true + + +[dependencies] +ratchet_rs = { version = "1.0.3", path = "../ratchet_rs" } +hyper = { workspace = true } +axum-core = { workspace = true } +hyper-util = { workspace = true , features = ["tokio"]} +pin-project = { workspace = true } +async-trait = { workspace = true } +base64 = { workspace = true } +http = { workspace = true } +sha1 = { workspace = true } + +[dev-dependencies] +axum = { workspace = true } +tokio = { workspace = true, features = ["full"] } +bytes = { workspace = true } + +[[example]] +name = "axum" diff --git a/ratchet_axum/examples/axum.rs b/ratchet_axum/examples/axum.rs new file mode 100644 index 0000000..8235fca --- /dev/null +++ b/ratchet_axum/examples/axum.rs @@ -0,0 +1,40 @@ +use axum::{response::IntoResponse, routing::get, Router}; +use bytes::BytesMut; +use ratchet_axum::{UpgradeFut, WebSocketUpgrade}; +use ratchet_rs::{Message, NegotiatedExtension, NoExt, PayloadType, Role, WebSocketConfig}; + +#[tokio::main] +async fn main() { + let app = Router::new().route("/", get(ws_handler)); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +async fn handle_client(fut: UpgradeFut) { + let io = fut.await.unwrap(); + let mut websocket = ratchet_rs::WebSocket::from_upgraded( + WebSocketConfig::default(), + io, + NegotiatedExtension::from(NoExt), + BytesMut::new(), + Role::Server, + ); + let mut buf = BytesMut::new(); + + loop { + match websocket.read(&mut buf).await.unwrap() { + Message::Text => { + websocket.write(&mut buf, PayloadType::Text).await.unwrap(); + buf.clear(); + } + _ => break, + } + } +} + +async fn ws_handler(incoming_upgrade: WebSocketUpgrade) -> impl IntoResponse { + let (response, fut) = incoming_upgrade.upgrade().unwrap(); + tokio::task::spawn(async move { handle_client(fut).await }); + response +} diff --git a/ratchet_axum/src/lib.rs b/ratchet_axum/src/lib.rs new file mode 100644 index 0000000..fc28323 --- /dev/null +++ b/ratchet_axum/src/lib.rs @@ -0,0 +1,160 @@ +//todo missing docs + +#![deny( + // missing_docs, + missing_copy_implementations, + missing_debug_implementations, + trivial_numeric_casts, + unstable_features, + unused_must_use, + unused_mut, + unused_imports, + unused_import_braces +)] + +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use axum_core::body::Body; +use base64; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use http::HeaderMap; +use hyper::Response; +use hyper_util::rt::TokioIo; +use pin_project::pin_project; +use sha1::Digest; +use sha1::Sha1; + +const HEADER_CONNECTION: &str = "upgrade"; +const HEADER_UPGRADE: &str = "websocket"; +const WEBSOCKET_VERSION: &[u8] = b"13"; + +type Error = hyper::Error; + +#[derive(Debug)] +pub struct WebSocketUpgrade { + key: String, + headers: HeaderMap, + on_upgrade: hyper::upgrade::OnUpgrade, + pub permessage_deflate: bool, +} + +impl WebSocketUpgrade { + pub fn upgrade(self) -> Result<(Response, UpgradeFut), Error> { + let mut builder = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, HEADER_CONNECTION) + .header(hyper::header::UPGRADE, HEADER_UPGRADE) + .header(hyper::header::SEC_WEBSOCKET_ACCEPT, self.key); + + if self.permessage_deflate { + builder = builder.header( + hyper::header::SEC_WEBSOCKET_EXTENSIONS, + "permessage-deflate", + ); + } + + let response = builder + .body(Body::default()) + .expect("bug: failed to build response"); + + let stream = UpgradeFut { + inner: self.on_upgrade, + headers: self.headers, + }; + + Ok((response, stream)) + } + + // pub fn upgrade_2(self, f: F) -> Response + // where + // F: FnOnce(UpgradedServer, E>) -> Fut, + // Fut: Future, + // E: Extension, + // { + // + // + // } +} + +#[async_trait::async_trait] +impl axum_core::extract::FromRequestParts for WebSocketUpgrade +where + S: Sync, +{ + type Rejection = hyper::StatusCode; + + async fn from_request_parts( + parts: &mut http::request::Parts, + _state: &S, + ) -> Result { + let key = parts + .headers + .get(http::header::SEC_WEBSOCKET_KEY) + .ok_or(hyper::StatusCode::BAD_REQUEST)?; + + if parts + .headers + .get(http::header::SEC_WEBSOCKET_VERSION) + .map(|v| v.as_bytes()) + != Some(WEBSOCKET_VERSION) + { + return Err(hyper::StatusCode::BAD_REQUEST); + } + + let permessage_deflate = parts + .headers + .get(http::header::SEC_WEBSOCKET_EXTENSIONS) + .map(|val| { + val.to_str() + .unwrap_or_default() + .to_lowercase() + .contains("permessage-deflate") + }) + .unwrap_or(false); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(hyper::StatusCode::BAD_REQUEST)?; + + Ok(Self { + on_upgrade, + key: sec_websocket_protocol(key.as_bytes()), + headers: parts.headers.clone(), + permessage_deflate, + }) + } +} + +/// A future that resolves to a websocket stream when the associated HTTP upgrade completes. +#[pin_project] +#[derive(Debug)] +pub struct UpgradeFut { + #[pin] + inner: hyper::upgrade::OnUpgrade, + pub headers: HeaderMap, +} + +impl std::future::Future for UpgradeFut { + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + let upgraded = match this.inner.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => x, + }; + Poll::Ready(upgraded.map(|u| TokioIo::new(u)).map_err(|e| e.into())) + } +} + +fn sec_websocket_protocol(key: &[u8]) -> String { + let mut sha1 = Sha1::new(); + sha1.update(key); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); // magic string + let result = sha1.finalize(); + STANDARD.encode(&result[..]) +} diff --git a/ratchet_rs/Cargo.toml b/ratchet_rs/Cargo.toml index dd198e3..38fe70b 100644 --- a/ratchet_rs/Cargo.toml +++ b/ratchet_rs/Cargo.toml @@ -17,6 +17,8 @@ split = ["ratchet_core/split"] fixture = ["ratchet_core/fixture"] [dependencies] +axum = { workspace = true, optional = true } +hyper = { workspace = true, features = ["http1", "server", "client"], optional = true } ratchet_core = { version = "1.0.3", path = "../ratchet_core" } ratchet_ext = { version = "1.0.3", path = "../ratchet_ext" } ratchet_deflate = { version = "1.0.3", path = "../ratchet_deflate", optional = true } @@ -49,4 +51,4 @@ name = "client" required-features = ["split"] [[example]] -name = "server" \ No newline at end of file +name = "server"