55//! ```
66//! use axum::{
77//! extract::ws::{WebSocketUpgrade, WebSocket},
8- //! routing::get ,
8+ //! routing::any ,
99//! response::{IntoResponse, Response},
1010//! Router,
1111//! };
1212//!
13- //! let app = Router::new().route("/ws", get (handler));
13+ //! let app = Router::new().route("/ws", any (handler));
1414//!
1515//! async fn handler(ws: WebSocketUpgrade) -> Response {
1616//! ws.on_upgrade(handle_socket)
4040//! use axum::{
4141//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
4242//! response::Response,
43- //! routing::get ,
43+ //! routing::any ,
4444//! Router,
4545//! };
4646//!
5858//! }
5959//!
6060//! let app = Router::new()
61- //! .route("/ws", get (handler))
61+ //! .route("/ws", any (handler))
6262//! .with_state(AppState { /* ... */ });
6363//! # let _: Router = app;
6464//! ```
@@ -101,7 +101,7 @@ use futures_util::{
101101use http:: {
102102 header:: { self , HeaderMap , HeaderName , HeaderValue } ,
103103 request:: Parts ,
104- Method , StatusCode ,
104+ Method , StatusCode , Version ,
105105} ;
106106use hyper_util:: rt:: TokioIo ;
107107use sha1:: { Digest , Sha1 } ;
@@ -121,17 +121,20 @@ use tokio_tungstenite::{
121121
122122/// Extractor for establishing WebSocket connections.
123123///
124- /// Note: This extractor requires the request method to be `GET` so it should
125- /// always be used with [`get`](crate::routing::get). Requests with other methods will be
126- /// rejected .
124+ /// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
125+ /// in later versions, `CONNECT` is used instead.
126+ /// To support both, it should be used with [`any`](crate::routing::any) .
127127///
128128/// See the [module docs](self) for an example.
129+ ///
130+ /// [`MethodFilter`]: crate::routing::MethodFilter
129131#[ cfg_attr( docsrs, doc( cfg( feature = "ws" ) ) ) ]
130132pub struct WebSocketUpgrade < F = DefaultOnFailedUpgrade > {
131133 config : WebSocketConfig ,
132134 /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
133135 protocol : Option < HeaderValue > ,
134- sec_websocket_key : HeaderValue ,
136+ /// `None` if HTTP/2+ WebSockets are used.
137+ sec_websocket_key : Option < HeaderValue > ,
135138 on_upgrade : hyper:: upgrade:: OnUpgrade ,
136139 on_failed_upgrade : F ,
137140 sec_websocket_protocol : Option < HeaderValue > ,
@@ -212,12 +215,12 @@ impl<F> WebSocketUpgrade<F> {
212215 /// ```
213216 /// use axum::{
214217 /// extract::ws::{WebSocketUpgrade, WebSocket},
215- /// routing::get ,
218+ /// routing::any ,
216219 /// response::{IntoResponse, Response},
217220 /// Router,
218221 /// };
219222 ///
220- /// let app = Router::new().route("/ws", get (handler));
223+ /// let app = Router::new().route("/ws", any (handler));
221224 ///
222225 /// async fn handler(ws: WebSocketUpgrade) -> Response {
223226 /// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@@ -329,25 +332,34 @@ impl<F> WebSocketUpgrade<F> {
329332 callback ( socket) . await ;
330333 } ) ;
331334
332- #[ allow( clippy:: declare_interior_mutable_const) ]
333- const UPGRADE : HeaderValue = HeaderValue :: from_static ( "upgrade" ) ;
334- #[ allow( clippy:: declare_interior_mutable_const) ]
335- const WEBSOCKET : HeaderValue = HeaderValue :: from_static ( "websocket" ) ;
336-
337- let mut builder = Response :: builder ( )
338- . status ( StatusCode :: SWITCHING_PROTOCOLS )
339- . header ( header:: CONNECTION , UPGRADE )
340- . header ( header:: UPGRADE , WEBSOCKET )
341- . header (
342- header:: SEC_WEBSOCKET_ACCEPT ,
343- sign ( self . sec_websocket_key . as_bytes ( ) ) ,
344- ) ;
345-
346- if let Some ( protocol) = self . protocol {
347- builder = builder. header ( header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
348- }
335+ if let Some ( sec_websocket_key) = & self . sec_websocket_key {
336+ // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
337+
338+ #[ allow( clippy:: declare_interior_mutable_const) ]
339+ const UPGRADE : HeaderValue = HeaderValue :: from_static ( "upgrade" ) ;
340+ #[ allow( clippy:: declare_interior_mutable_const) ]
341+ const WEBSOCKET : HeaderValue = HeaderValue :: from_static ( "websocket" ) ;
342+
343+ let mut builder = Response :: builder ( )
344+ . status ( StatusCode :: SWITCHING_PROTOCOLS )
345+ . header ( header:: CONNECTION , UPGRADE )
346+ . header ( header:: UPGRADE , WEBSOCKET )
347+ . header (
348+ header:: SEC_WEBSOCKET_ACCEPT ,
349+ sign ( sec_websocket_key. as_bytes ( ) ) ,
350+ ) ;
351+
352+ if let Some ( protocol) = self . protocol {
353+ builder = builder. header ( header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
354+ }
349355
350- builder. body ( Body :: empty ( ) ) . unwrap ( )
356+ builder. body ( Body :: empty ( ) ) . unwrap ( )
357+ } else {
358+ // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
359+ // with a 2XX with an empty body:
360+ // <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
361+ Response :: new ( Body :: empty ( ) )
362+ }
351363 }
352364}
353365
@@ -387,28 +399,49 @@ where
387399 type Rejection = WebSocketUpgradeRejection ;
388400
389401 async fn from_request_parts ( parts : & mut Parts , _state : & S ) -> Result < Self , Self :: Rejection > {
390- if parts. method != Method :: GET {
391- return Err ( MethodNotGet . into ( ) ) ;
392- }
402+ let sec_websocket_key = if parts. version <= Version :: HTTP_11 {
403+ if parts. method != Method :: GET {
404+ return Err ( MethodNotGet . into ( ) ) ;
405+ }
393406
394- if !header_contains ( & parts. headers , header:: CONNECTION , "upgrade" ) {
395- return Err ( InvalidConnectionHeader . into ( ) ) ;
396- }
407+ if !header_contains ( & parts. headers , header:: CONNECTION , "upgrade" ) {
408+ return Err ( InvalidConnectionHeader . into ( ) ) ;
409+ }
397410
398- if !header_eq ( & parts. headers , header:: UPGRADE , "websocket" ) {
399- return Err ( InvalidUpgradeHeader . into ( ) ) ;
400- }
411+ if !header_eq ( & parts. headers , header:: UPGRADE , "websocket" ) {
412+ return Err ( InvalidUpgradeHeader . into ( ) ) ;
413+ }
414+
415+ Some (
416+ parts
417+ . headers
418+ . get ( header:: SEC_WEBSOCKET_KEY )
419+ . ok_or ( WebSocketKeyHeaderMissing ) ?
420+ . clone ( ) ,
421+ )
422+ } else {
423+ if parts. method != Method :: CONNECT {
424+ return Err ( MethodNotConnect . into ( ) ) ;
425+ }
426+
427+ // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
428+ // with.
429+ #[ cfg( feature = "http2" ) ]
430+ if parts
431+ . extensions
432+ . get :: < hyper:: ext:: Protocol > ( )
433+ . map_or ( true , |p| p. as_str ( ) != "websocket" )
434+ {
435+ return Err ( InvalidProtocolPseudoheader . into ( ) ) ;
436+ }
437+
438+ None
439+ } ;
401440
402441 if !header_eq ( & parts. headers , header:: SEC_WEBSOCKET_VERSION , "13" ) {
403442 return Err ( InvalidWebSocketVersionHeader . into ( ) ) ;
404443 }
405444
406- let sec_websocket_key = parts
407- . headers
408- . get ( header:: SEC_WEBSOCKET_KEY )
409- . ok_or ( WebSocketKeyHeaderMissing ) ?
410- . clone ( ) ;
411-
412445 let on_upgrade = parts
413446 . extensions
414447 . remove :: < hyper:: upgrade:: OnUpgrade > ( )
@@ -706,6 +739,13 @@ pub mod rejection {
706739 pub struct MethodNotGet ;
707740 }
708741
742+ define_rejection ! {
743+ #[ status = METHOD_NOT_ALLOWED ]
744+ #[ body = "Request method must be `CONNECT`" ]
745+ /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
746+ pub struct MethodNotConnect ;
747+ }
748+
709749 define_rejection ! {
710750 #[ status = BAD_REQUEST ]
711751 #[ body = "Connection header did not include 'upgrade'" ]
@@ -720,6 +760,13 @@ pub mod rejection {
720760 pub struct InvalidUpgradeHeader ;
721761 }
722762
763+ define_rejection ! {
764+ #[ status = BAD_REQUEST ]
765+ #[ body = "`:protocol` pseudo-header did not include 'websocket'" ]
766+ /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
767+ pub struct InvalidProtocolPseudoheader ;
768+ }
769+
723770 define_rejection ! {
724771 #[ status = BAD_REQUEST ]
725772 #[ body = "`Sec-WebSocket-Version` header did not include '13'" ]
@@ -755,8 +802,10 @@ pub mod rejection {
755802 /// extractor can fail.
756803 pub enum WebSocketUpgradeRejection {
757804 MethodNotGet ,
805+ MethodNotConnect ,
758806 InvalidConnectionHeader ,
759807 InvalidUpgradeHeader ,
808+ InvalidProtocolPseudoheader ,
760809 InvalidWebSocketVersionHeader ,
761810 WebSocketKeyHeaderMissing ,
762811 ConnectionNotUpgradable ,
@@ -838,14 +887,18 @@ mod tests {
838887 use std:: future:: ready;
839888
840889 use super :: * ;
841- use crate :: { routing:: get , test_helpers:: spawn_service, Router } ;
890+ use crate :: { routing:: any , test_helpers:: spawn_service, Router } ;
842891 use http:: { Request , Version } ;
892+ use http_body_util:: BodyExt as _;
893+ use hyper_util:: rt:: TokioExecutor ;
894+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
895+ use tokio:: net:: TcpStream ;
843896 use tokio_tungstenite:: tungstenite;
844897 use tower:: ServiceExt ;
845898
846899 #[ crate :: test]
847900 async fn rejects_http_1_0_requests ( ) {
848- let svc = get ( |ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
901+ let svc = any ( |ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
849902 let rejection = ws. unwrap_err ( ) ;
850903 assert ! ( matches!(
851904 rejection,
@@ -874,7 +927,7 @@ mod tests {
874927 async fn handler ( ws : WebSocketUpgrade ) -> Response {
875928 ws. on_upgrade ( |_| async { } )
876929 }
877- let _: Router = Router :: new ( ) . route ( "/" , get ( handler) ) ;
930+ let _: Router = Router :: new ( ) . route ( "/" , any ( handler) ) ;
878931 }
879932
880933 #[ allow( dead_code) ]
@@ -883,16 +936,61 @@ mod tests {
883936 ws. on_failed_upgrade ( |_error : Error | println ! ( "oops!" ) )
884937 . on_upgrade ( |_| async { } )
885938 }
886- let _: Router = Router :: new ( ) . route ( "/" , get ( handler) ) ;
939+ let _: Router = Router :: new ( ) . route ( "/" , any ( handler) ) ;
887940 }
888941
889942 #[ crate :: test]
890943 async fn integration_test ( ) {
891- let app = Router :: new ( ) . route (
892- "/echo" ,
893- get ( |ws : WebSocketUpgrade | ready ( ws. on_upgrade ( handle_socket) ) ) ,
894- ) ;
944+ let addr = spawn_service ( echo_app ( ) ) ;
945+ let ( socket, _response) = tokio_tungstenite:: connect_async ( format ! ( "ws://{addr}/echo" ) )
946+ . await
947+ . unwrap ( ) ;
948+ test_echo_app ( socket) . await ;
949+ }
950+
951+ #[ crate :: test]
952+ #[ cfg( feature = "http2" ) ]
953+ async fn http2 ( ) {
954+ let addr = spawn_service ( echo_app ( ) ) ;
955+ let io = TokioIo :: new ( TcpStream :: connect ( addr) . await . unwrap ( ) ) ;
956+ let ( mut send_request, conn) =
957+ hyper:: client:: conn:: http2:: Builder :: new ( TokioExecutor :: new ( ) )
958+ . handshake ( io)
959+ . await
960+ . unwrap ( ) ;
961+
962+ // Wait a little for the SETTINGS frame to go through…
963+ for _ in 0 ..10 {
964+ tokio:: task:: yield_now ( ) . await ;
965+ }
966+ assert ! ( conn. is_extended_connect_protocol_enabled( ) ) ;
967+ tokio:: spawn ( async {
968+ conn. await . unwrap ( ) ;
969+ } ) ;
895970
971+ let req = Request :: builder ( )
972+ . method ( Method :: CONNECT )
973+ . extension ( hyper:: ext:: Protocol :: from_static ( "websocket" ) )
974+ . uri ( "/echo" )
975+ . header ( "sec-websocket-version" , "13" )
976+ . header ( "Host" , "server.example.com" )
977+ . body ( Body :: empty ( ) )
978+ . unwrap ( ) ;
979+
980+ let response = send_request. send_request ( req) . await . unwrap ( ) ;
981+ let status = response. status ( ) ;
982+ if status != 200 {
983+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
984+ let body = std:: str:: from_utf8 ( & body) . unwrap ( ) ;
985+ panic ! ( "response status was {}: {body}" , status) ;
986+ }
987+ let upgraded = hyper:: upgrade:: on ( response) . await . unwrap ( ) ;
988+ let upgraded = TokioIo :: new ( upgraded) ;
989+ let socket = WebSocketStream :: from_raw_socket ( upgraded, protocol:: Role :: Client , None ) . await ;
990+ test_echo_app ( socket) . await ;
991+ }
992+
993+ fn echo_app ( ) -> Router {
896994 async fn handle_socket ( mut socket : WebSocket ) {
897995 while let Some ( Ok ( msg) ) = socket. recv ( ) . await {
898996 match msg {
@@ -908,11 +1006,13 @@ mod tests {
9081006 }
9091007 }
9101008
911- let addr = spawn_service ( app) ;
912- let ( mut socket, _response) = tokio_tungstenite:: connect_async ( format ! ( "ws://{addr}/echo" ) )
913- . await
914- . unwrap ( ) ;
1009+ Router :: new ( ) . route (
1010+ "/echo" ,
1011+ any ( |ws : WebSocketUpgrade | ready ( ws. on_upgrade ( handle_socket) ) ) ,
1012+ )
1013+ }
9151014
1015+ async fn test_echo_app < S : AsyncRead + AsyncWrite + Unpin > ( mut socket : WebSocketStream < S > ) {
9161016 let input = tungstenite:: Message :: Text ( "foobar" . to_owned ( ) ) ;
9171017 socket. send ( input. clone ( ) ) . await . unwrap ( ) ;
9181018 let output = socket. next ( ) . await . unwrap ( ) . unwrap ( ) ;
0 commit comments