11use axum:: extract:: ws:: { Message , WebSocket } ;
2- use axum:: extract:: { FromRequest , State , WebSocketUpgrade } ;
2+ use axum:: extract:: { ConnectInfo , FromRequest , State , WebSocketUpgrade } ;
33use axum:: response:: { IntoResponse , Response } ;
44use axum:: routing:: get;
5+ use axum:: serve:: ListenerExt ;
56use futures_util:: { SinkExt , StreamExt } ;
67use std:: collections:: HashMap ;
8+ use std:: net:: SocketAddr ;
79use std:: sync:: { Arc , LazyLock } ;
810use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt } ;
911#[ cfg( unix) ]
@@ -63,36 +65,6 @@ async fn connect_ipc(path: &str) -> Result<IpcStream, String> {
6365 }
6466}
6567
66- /// Wraps TcpListener to set TCP_NODELAY on every accepted connection,
67- /// disabling Nagle's algorithm for low-latency frame delivery.
68- struct NoDelayListener ( tokio:: net:: TcpListener ) ;
69-
70- impl axum:: serve:: Listener for NoDelayListener {
71- type Io = tokio:: net:: TcpStream ;
72- type Addr = std:: net:: SocketAddr ;
73-
74- async fn accept ( & mut self ) -> ( Self :: Io , Self :: Addr ) {
75- {
76- loop {
77- match self . 0 . accept ( ) . await {
78- Ok ( ( stream, addr) ) => {
79- let _ = stream. set_nodelay ( true ) ;
80- return ( stream, addr) ;
81- }
82- Err ( e) => {
83- eprintln ! ( "accept error: {e}" ) ;
84- tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) . await ;
85- }
86- }
87- }
88- }
89- }
90-
91- fn local_addr ( & self ) -> std:: io:: Result < std:: net:: SocketAddr > {
92- self . 0 . local_addr ( )
93- }
94- }
95-
9668const INDEX_HTML_BR : & [ u8 ] = include_bytes ! ( "../../../js/ui/dist/index.html.br" ) ;
9769
9870static INDEX_ETAG : LazyLock < String > = LazyLock :: new ( || blit_webserver:: html_etag ( INDEX_HTML_BR ) ) ;
@@ -126,6 +98,8 @@ struct Config {
12698 /// Broadcast notification triggered on SIGINT/SIGTERM so active
12799 /// WebSocket/WebTransport handlers can send `S2C_QUIT` before exit.
128100 shutdown : Arc < tokio:: sync:: Notify > ,
101+ /// Shared auth throttle for config and gateway transports.
102+ auth_throttle : blit_webserver:: config:: AuthThrottle ,
129103}
130104
131105impl Config {
@@ -506,6 +480,7 @@ pub async fn run() {
506480 hub_url,
507481 webrtc_enabled,
508482 shutdown : shutdown. clone ( ) ,
483+ auth_throttle : blit_webserver:: config:: AuthThrottle :: new ( ) ,
509484 } ) ;
510485
511486 // --- Reconcile destinations whenever blit.remotes changes ---
@@ -551,7 +526,9 @@ pub async fn run() {
551526 eprintln ! ( "blit gateway: cannot bind to {addr}: {e}" ) ;
552527 std:: process:: exit ( 1 ) ;
553528 } ) ;
554- let listener = NoDelayListener ( tcp) ;
529+ let listener = tcp. tap_io ( |stream| {
530+ let _ = stream. set_nodelay ( true ) ;
531+ } ) ;
555532 eprintln ! (
556533 "listening on {addr} (WebSocket{}){}" ,
557534 if quic_enabled { " + WebTransport" } else { "" } ,
@@ -564,7 +541,11 @@ pub async fn run() {
564541
565542 blit_sd_notify:: notify_ready ( false ) ;
566543
567- let graceful = axum:: serve ( listener, app) . with_graceful_shutdown ( async move {
544+ let graceful = axum:: serve (
545+ listener,
546+ app. into_make_service_with_connect_info :: < SocketAddr > ( ) ,
547+ )
548+ . with_graceful_shutdown ( async move {
568549 #[ cfg( unix) ]
569550 {
570551 use tokio:: signal:: unix:: { SignalKind , signal} ;
@@ -697,6 +678,11 @@ fn mux_error(ch: u16, msg: &str) -> Vec<u8> {
697678}
698679
699680async fn root_handler ( State ( state) : State < AppState > , request : axum:: extract:: Request ) -> Response {
681+ let auth_peer = request
682+ . extensions ( )
683+ . get :: < ConnectInfo < SocketAddr > > ( )
684+ . map ( |ConnectInfo ( addr) | addr. ip ( ) . to_string ( ) )
685+ . unwrap_or_else ( || "unknown" . to_string ( ) ) ;
700686 let path = request. uri ( ) . path ( ) . to_string ( ) ;
701687
702688 if let Some ( resp) = blit_webserver:: try_font_route ( & path, state. cors_origin . as_deref ( ) ) {
@@ -726,6 +712,10 @@ async fn root_handler(State(state): State<AppState>, request: axum::extract::Req
726712 Some ( & state. remotes ) ,
727713 transform,
728714 & extra_init,
715+ blit_webserver:: config:: AuthContext {
716+ throttle : & state. auth_throttle ,
717+ peer : & auth_peer,
718+ } ,
729719 )
730720 . await ;
731721 } ) ,
@@ -735,15 +725,15 @@ async fn root_handler(State(state): State<AppState>, request: axum::extract::Req
735725 match WebSocketUpgrade :: from_request ( request, & state) . await {
736726 Ok ( ws) => ws
737727 . max_message_size ( MAX_FRAME_SIZE + 2 ) // +2 for channel ID prefix
738- . on_upgrade ( move |socket| handle_mux_ws ( socket, state) ) ,
728+ . on_upgrade ( move |socket| handle_mux_ws ( socket, state, auth_peer ) ) ,
739729 Err ( e) => e. into_response ( ) ,
740730 }
741731 } else if is_ws {
742732 let dest_name = resolve_destination_name ( & path) ;
743733 match WebSocketUpgrade :: from_request ( request, & state) . await {
744734 Ok ( ws) => ws
745735 . max_message_size ( MAX_FRAME_SIZE )
746- . on_upgrade ( move |socket| handle_ws ( socket, state, dest_name) ) ,
736+ . on_upgrade ( move |socket| handle_ws ( socket, state, dest_name, auth_peer ) ) ,
747737 Err ( e) => e. into_response ( ) ,
748738 }
749739 } else {
@@ -768,35 +758,21 @@ fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
768758 std:: hint:: black_box ( diff) == 0
769759}
770760
771- async fn handle_ws ( mut ws : WebSocket , state : AppState , dest_name : Option < String > ) {
772- let authed = match tokio:: time:: timeout ( std:: time:: Duration :: from_secs ( 30 ) , async {
773- loop {
774- match ws. recv ( ) . await {
775- Some ( Ok ( Message :: Text ( pass) ) ) => {
776- if constant_time_eq ( pass. trim ( ) . as_bytes ( ) , state. passphrase . as_bytes ( ) ) {
777- break true ;
778- } else {
779- let _ = ws. send ( Message :: Text ( "auth" . into ( ) ) ) . await ;
780- let _ = ws. close ( ) . await ;
781- break false ;
782- }
783- }
784- Some ( Ok ( Message :: Ping ( d) ) ) => {
785- let _ = ws. send ( Message :: Pong ( d) ) . await ;
786- }
787- _ => break false ,
788- }
789- }
790- } )
761+ async fn handle_ws (
762+ mut ws : WebSocket ,
763+ state : AppState ,
764+ dest_name : Option < String > ,
765+ auth_peer : String ,
766+ ) {
767+ if !blit_webserver:: config:: authenticate_text_ws (
768+ & mut ws,
769+ & state. passphrase ,
770+ & state. auth_throttle ,
771+ & auth_peer,
772+ None ,
773+ )
791774 . await
792775 {
793- Ok ( result) => result,
794- Err ( _) => {
795- let _ = ws. close ( ) . await ;
796- false
797- }
798- } ;
799- if !authed {
800776 return ;
801777 }
802778
@@ -906,36 +882,17 @@ impl MuxChannelState {
906882 }
907883}
908884
909- async fn handle_mux_ws ( mut ws : WebSocket , state : AppState ) {
885+ async fn handle_mux_ws ( mut ws : WebSocket , state : AppState , auth_peer : String ) {
910886 // --- Authentication (identical to handle_ws) ---
911- let authed = match tokio:: time:: timeout ( std:: time:: Duration :: from_secs ( 30 ) , async {
912- loop {
913- match ws. recv ( ) . await {
914- Some ( Ok ( Message :: Text ( pass) ) ) => {
915- if constant_time_eq ( pass. trim ( ) . as_bytes ( ) , state. passphrase . as_bytes ( ) ) {
916- break true ;
917- } else {
918- let _ = ws. send ( Message :: Text ( "auth" . into ( ) ) ) . await ;
919- let _ = ws. close ( ) . await ;
920- break false ;
921- }
922- }
923- Some ( Ok ( Message :: Ping ( d) ) ) => {
924- let _ = ws. send ( Message :: Pong ( d) ) . await ;
925- }
926- _ => break false ,
927- }
928- }
929- } )
887+ if !blit_webserver:: config:: authenticate_text_ws (
888+ & mut ws,
889+ & state. passphrase ,
890+ & state. auth_throttle ,
891+ & auth_peer,
892+ None ,
893+ )
930894 . await
931895 {
932- Ok ( result) => result,
933- Err ( _) => {
934- let _ = ws. close ( ) . await ;
935- false
936- }
937- } ;
938- if !authed {
939896 return ;
940897 }
941898
@@ -1410,6 +1367,7 @@ async fn wt_authenticate(
14101367 send : & mut wt:: SendStream ,
14111368 recv : & mut wt:: RecvStream ,
14121369 passphrase : & str ,
1370+ guard : blit_webserver:: config:: AuthAttemptGuard ,
14131371) -> Result < ( ) , Box < dyn std:: error:: Error + Send + Sync > > {
14141372 let auth_result = tokio:: time:: timeout ( std:: time:: Duration :: from_secs ( 30 ) , async {
14151373 let mut len_buf = [ 0u8 ; 2 ] ;
@@ -1418,7 +1376,7 @@ async fn wt_authenticate(
14181376 . map_err ( |e| format ! ( "auth read len: {e}" ) ) ?;
14191377 let pass_len = u16:: from_le_bytes ( len_buf) as usize ;
14201378 if pass_len > 4096 {
1421- return Err :: < ( ) , String > ( "passphrase too long" . into ( ) ) ;
1379+ return Err :: < bool , String > ( "passphrase too long" . into ( ) ) ;
14221380 }
14231381 let mut pass_buf = vec ! [ 0u8 ; pass_len] ;
14241382 recv. read_exact ( & mut pass_buf)
@@ -1428,16 +1386,26 @@ async fn wt_authenticate(
14281386
14291387 if !constant_time_eq ( pass. trim ( ) . as_bytes ( ) , passphrase. as_bytes ( ) ) {
14301388 send. write_all ( & [ 0 ] ) . await . ok ( ) ;
1431- return Err ( "authentication failed" . into ( ) ) ;
1389+ return Ok ( false ) ;
14321390 }
1433- Ok ( ( ) )
1391+ Ok ( true )
14341392 } )
14351393 . await ;
14361394
14371395 match auth_result {
1438- Ok ( Ok ( ( ) ) ) => { }
1439- Ok ( Err ( e) ) => return Err ( e. into ( ) ) ,
1440- Err ( _) => return Err ( "authentication timed out" . into ( ) ) ,
1396+ Ok ( Ok ( true ) ) => guard. record_success ( ) ,
1397+ Ok ( Ok ( false ) ) => {
1398+ guard. record_failure ( ) ;
1399+ return Err ( "authentication failed" . into ( ) ) ;
1400+ }
1401+ Ok ( Err ( e) ) => {
1402+ guard. record_failure ( ) ;
1403+ return Err ( e. into ( ) ) ;
1404+ }
1405+ Err ( _) => {
1406+ guard. record_failure ( ) ;
1407+ return Err ( "authentication timed out" . into ( ) ) ;
1408+ }
14411409 }
14421410 send. write_all ( & [ 1 ] )
14431411 . await
@@ -1450,13 +1418,32 @@ async fn handle_webtransport_session(
14501418 state : AppState ,
14511419) -> Result < ( ) , Box < dyn std:: error:: Error + Send + Sync > > {
14521420 let path = request. url . path ( ) . to_string ( ) ;
1421+ let auth_peer = request. conn ( ) . remote_address ( ) . ip ( ) . to_string ( ) ;
14531422 let is_mux = is_mux_path ( & path) ;
14541423 let dest_name = resolve_destination_name ( & path) ;
1424+ let Some ( auth_guard) = state. auth_throttle . begin ( auth_peer. clone ( ) ) else {
1425+ request
1426+ . reject ( axum:: http:: StatusCode :: TOO_MANY_REQUESTS )
1427+ . await ?;
1428+ return Ok ( ( ) ) ;
1429+ } ;
14551430 let session = request. ok ( ) . await ?;
14561431
1457- let ( mut send, mut recv) = session. accept_bi ( ) . await ?;
1432+ let ( mut send, mut recv) =
1433+ match tokio:: time:: timeout ( std:: time:: Duration :: from_secs ( 30 ) , session. accept_bi ( ) ) . await {
1434+ Ok ( Ok ( streams) ) => streams,
1435+ Ok ( Err ( e) ) => {
1436+ auth_guard. record_failure ( ) ;
1437+ return Err ( e. into ( ) ) ;
1438+ }
1439+ Err ( _) => {
1440+ auth_guard. record_failure ( ) ;
1441+ session. close ( 1 , b"authentication timed out" ) ;
1442+ return Err ( "authentication timed out" . into ( ) ) ;
1443+ }
1444+ } ;
14581445
1459- wt_authenticate ( & mut send, & mut recv, & state. passphrase ) . await ?;
1446+ wt_authenticate ( & mut send, & mut recv, & state. passphrase , auth_guard ) . await ?;
14601447
14611448 if is_mux {
14621449 return handle_mux_wt ( send, recv, state) . await ;
@@ -1716,6 +1703,7 @@ mod tests {
17161703 hub_url : blit_webrtc_forwarder:: normalize_hub ( blit_webrtc_forwarder:: DEFAULT_HUB_URL ) ,
17171704 webrtc_enabled : false ,
17181705 shutdown : Arc :: new ( tokio:: sync:: Notify :: new ( ) ) ,
1706+ auth_throttle : blit_webserver:: config:: AuthThrottle :: new ( ) ,
17191707 } )
17201708 }
17211709
0 commit comments