diff --git a/Cargo.toml b/Cargo.toml index d35d747..4bc5b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,12 @@ [workspace] resolver = "2" members = [ - "ratchet_rs", + "ratchet_axum", "ratchet_core", "ratchet_deflate", "ratchet_ext", "ratchet_fixture", + "ratchet_rs", "ratchet_rs/autobahn/client", "ratchet_rs/autobahn/server", "ratchet_rs/autobahn/split_client", @@ -43,3 +44,10 @@ flate2 = { version = "1.0", default-features = false } anyhow = "1.0" serde_json = "1.0" tracing-subscriber = "0.3.18" +hyper = "1.4.1" +axum = "0.7.5" +axum-core = "0.4.3" +hyper-util = "0.1.0" +pin-project = "1.1.5" +async-trait = "0.1.79" +sha1 = "0.10.4" diff --git a/ratchet_axum/Cargo.toml b/ratchet_axum/Cargo.toml new file mode 100644 index 0000000..e9a4cb8 --- /dev/null +++ b/ratchet_axum/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "ratchet_axum" +description = "Axum Integration for Ratchet" +readme = "README.md" +repository = "https://github.com/swimos/ratchet/" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +license.workspace = true +categories.workspace = true + + +[dependencies] +ratchet_rs = { version = "1.0.3", path = "../ratchet_rs" } +hyper = { workspace = true } +axum-core = { workspace = true } +hyper-util = { workspace = true , features = ["tokio"]} +pin-project = { workspace = true } +async-trait = { workspace = true } +base64 = { workspace = true } +http = { workspace = true } +sha1 = { workspace = true } + +[dev-dependencies] +axum = { workspace = true } +tokio = { workspace = true, features = ["full"] } +bytes = { workspace = true } + +[[example]] +name = "axum" diff --git a/ratchet_axum/examples/axum.rs b/ratchet_axum/examples/axum.rs new file mode 100644 index 0000000..8235fca --- /dev/null +++ b/ratchet_axum/examples/axum.rs @@ -0,0 +1,40 @@ +use axum::{response::IntoResponse, routing::get, Router}; +use bytes::BytesMut; +use ratchet_axum::{UpgradeFut, WebSocketUpgrade}; +use ratchet_rs::{Message, NegotiatedExtension, NoExt, PayloadType, Role, WebSocketConfig}; + +#[tokio::main] +async fn main() { + let app = Router::new().route("/", get(ws_handler)); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +async fn handle_client(fut: UpgradeFut) { + let io = fut.await.unwrap(); + let mut websocket = ratchet_rs::WebSocket::from_upgraded( + WebSocketConfig::default(), + io, + NegotiatedExtension::from(NoExt), + BytesMut::new(), + Role::Server, + ); + let mut buf = BytesMut::new(); + + loop { + match websocket.read(&mut buf).await.unwrap() { + Message::Text => { + websocket.write(&mut buf, PayloadType::Text).await.unwrap(); + buf.clear(); + } + _ => break, + } + } +} + +async fn ws_handler(incoming_upgrade: WebSocketUpgrade) -> impl IntoResponse { + let (response, fut) = incoming_upgrade.upgrade().unwrap(); + tokio::task::spawn(async move { handle_client(fut).await }); + response +} diff --git a/ratchet_axum/src/lib.rs b/ratchet_axum/src/lib.rs new file mode 100644 index 0000000..fc28323 --- /dev/null +++ b/ratchet_axum/src/lib.rs @@ -0,0 +1,160 @@ +//todo missing docs + +#![deny( + // missing_docs, + missing_copy_implementations, + missing_debug_implementations, + trivial_numeric_casts, + unstable_features, + unused_must_use, + unused_mut, + unused_imports, + unused_import_braces +)] + +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use axum_core::body::Body; +use base64; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use http::HeaderMap; +use hyper::Response; +use hyper_util::rt::TokioIo; +use pin_project::pin_project; +use sha1::Digest; +use sha1::Sha1; + +const HEADER_CONNECTION: &str = "upgrade"; +const HEADER_UPGRADE: &str = "websocket"; +const WEBSOCKET_VERSION: &[u8] = b"13"; + +type Error = hyper::Error; + +#[derive(Debug)] +pub struct WebSocketUpgrade { + key: String, + headers: HeaderMap, + on_upgrade: hyper::upgrade::OnUpgrade, + pub permessage_deflate: bool, +} + +impl WebSocketUpgrade { + pub fn upgrade(self) -> Result<(Response
, UpgradeFut), Error> { + let mut builder = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, HEADER_CONNECTION) + .header(hyper::header::UPGRADE, HEADER_UPGRADE) + .header(hyper::header::SEC_WEBSOCKET_ACCEPT, self.key); + + if self.permessage_deflate { + builder = builder.header( + hyper::header::SEC_WEBSOCKET_EXTENSIONS, + "permessage-deflate", + ); + } + + let response = builder + .body(Body::default()) + .expect("bug: failed to build response"); + + let stream = UpgradeFut { + inner: self.on_upgrade, + headers: self.headers, + }; + + Ok((response, stream)) + } + + // pub fn upgrade_2