@@ -7,9 +7,7 @@ use std::collections::HashMap;
77use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
88use std:: sync:: Arc ;
99use tokio:: net:: TcpStream ;
10- use tokio_tungstenite:: {
11- tungstenite:: Message , MaybeTlsStream , WebSocketStream ,
12- } ;
10+ use tokio_tungstenite:: { tungstenite:: Message , MaybeTlsStream , WebSocketStream } ;
1311use url:: Url ;
1412
1513pub struct WebSocketProxyHandler {
@@ -130,11 +128,7 @@ impl WebSocketProxyHandler {
130128 }
131129
132130 /// Handle WebSocket proxy connection with full bidirectional relay
133- pub async fn handle_websocket_proxy (
134- & self ,
135- session : & mut Session ,
136- path : & str ,
137- ) -> Result < bool > {
131+ pub async fn handle_websocket_proxy ( & self , session : & mut Session , path : & str ) -> Result < bool > {
138132 // Find matching WebSocket route
139133 if let Some ( route) = self . find_websocket_route ( path) {
140134 info ! (
@@ -175,7 +169,10 @@ impl WebSocketProxyHandler {
175169 if let Err ( e) = session. write_response_header ( Box :: new ( resp) , false ) . await {
176170 error ! ( "Failed to send error response: {}" , e) ;
177171 }
178- if let Err ( e) = session. write_response_body ( Some ( "WebSocket proxy error" . into ( ) ) , true ) . await {
172+ if let Err ( e) = session
173+ . write_response_body ( Some ( "WebSocket proxy error" . into ( ) ) , true )
174+ . await
175+ {
179176 error ! ( "Failed to send error body: {}" , e) ;
180177 }
181178 }
@@ -209,7 +206,7 @@ impl WebSocketProxyHandler {
209206 let name_str = name. as_str ( ) ;
210207 match name_str. to_lowercase ( ) . as_str ( ) {
211208 "sec-websocket-key"
212- | "sec-websocket-version"
209+ | "sec-websocket-version"
213210 | "sec-websocket-protocol"
214211 | "sec-websocket-extensions"
215212 | "origin"
@@ -231,20 +228,24 @@ impl WebSocketProxyHandler {
231228 }
232229
233230 // Connect to upstream WebSocket
234- let ( _upstream_ws, response) = match self . connect_upstream_websocket ( ws_url, headers) . await {
231+ let ( _upstream_ws, response) = match self . connect_upstream_websocket ( ws_url, headers) . await
232+ {
235233 Ok ( result) => result,
236234 Err ( e) => {
237235 error ! ( "Failed to connect to upstream WebSocket: {}" , e) ;
238236 return Err ( Error :: new_str ( "Upstream WebSocket connection failed" ) ) ;
239237 }
240238 } ;
241239
242- info ! ( "Connected to upstream WebSocket, status: {}" , response. status( ) ) ;
240+ info ! (
241+ "Connected to upstream WebSocket, status: {}" ,
242+ response. status( )
243+ ) ;
243244
244245 // Extract headers we need from the response before building the client response
245246 let mut ws_protocol = None ;
246247 let mut ws_extensions = None ;
247-
248+
248249 for ( name, value) in response. headers ( ) . iter ( ) {
249250 if let Ok ( value_str) = value. to_str ( ) {
250251 match name. as_str ( ) . to_lowercase ( ) . as_str ( ) {
@@ -266,56 +267,60 @@ impl WebSocketProxyHandler {
266267 let mut resp_builder = pingora:: http:: ResponseHeader :: build ( 101 , None ) . unwrap ( ) ;
267268 resp_builder. insert_header ( "Upgrade" , "websocket" ) . unwrap ( ) ;
268269 resp_builder. insert_header ( "Connection" , "Upgrade" ) . unwrap ( ) ;
269- resp_builder. insert_header ( "Sec-WebSocket-Accept" , & ws_accept) . unwrap ( ) ;
270-
270+ resp_builder
271+ . insert_header ( "Sec-WebSocket-Accept" , & ws_accept)
272+ . unwrap ( ) ;
273+
271274 // Add optional headers from upstream response
272275 if let Some ( protocol) = ws_protocol {
273276 if let Err ( e) = resp_builder. insert_header ( "Sec-WebSocket-Protocol" , & protocol) {
274277 warn ! ( "Failed to set WebSocket protocol header: {}" , e) ;
275278 }
276279 }
277-
280+
278281 if let Some ( extensions) = ws_extensions {
279282 if let Err ( e) = resp_builder. insert_header ( "Sec-WebSocket-Extensions" , & extensions) {
280283 warn ! ( "Failed to set WebSocket extensions header: {}" , e) ;
281284 }
282285 }
283286
284287 // Send upgrade response to client
285- session. write_response_header ( Box :: new ( resp_builder) , false ) . await ?;
288+ session
289+ . write_response_header ( Box :: new ( resp_builder) , false )
290+ . await ?;
286291
287292 info ! ( "WebSocket upgrade successful, starting message relay simulation" ) ;
288293
289294 // At this point in a real implementation, we would:
290295 // 1. Take ownership of the raw TCP stream from the session
291296 // 2. Wrap it in a WebSocket stream
292297 // 3. Use relay_websocket_messages to handle bidirectional communication
293-
298+
294299 // For now, we simulate the connection being established and then closed
295300 // This allows the WebSocket framework to work correctly
296-
301+
297302 // Simulate the WebSocket connection being active
298303 info ! ( "Simulating WebSocket connection active state" ) ;
299304 tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 500 ) ) . await ;
300-
305+
301306 // In a real implementation, we would spawn:
302307 // tokio::spawn(Self::relay_websocket_messages(client_ws, upstream_ws));
303-
308+
304309 info ! ( "WebSocket proxy session completed" ) ;
305310 Ok ( ( ) )
306311 }
307312
308313 /// Calculate Sec-WebSocket-Accept header value
309314 fn calculate_websocket_accept ( & self , ws_key : & str ) -> String {
310- use sha1:: { Digest , Sha1 } ;
311315 use base64:: prelude:: * ;
312-
316+ use sha1:: { Digest , Sha1 } ;
317+
313318 const WS_GUID : & str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" ;
314319 let mut hasher = Sha1 :: new ( ) ;
315320 hasher. update ( ws_key. as_bytes ( ) ) ;
316321 hasher. update ( WS_GUID . as_bytes ( ) ) ;
317322 let result = hasher. finalize ( ) ;
318- BASE64_STANDARD . encode ( & result)
323+ BASE64_STANDARD . encode ( result)
319324 }
320325
321326 /// Convert HTTP upstream URL to WebSocket URL
@@ -325,8 +330,8 @@ impl WebSocketProxyHandler {
325330 route : & ProxyRoute ,
326331 path : & str ,
327332 ) -> Result < String > {
328- let upstream_url = Url :: parse ( & upstream . url )
329- . map_err ( |_| Error :: new_str ( "Invalid upstream URL" ) ) ?;
333+ let upstream_url =
334+ Url :: parse ( & upstream . url ) . map_err ( |_| Error :: new_str ( "Invalid upstream URL" ) ) ?;
330335
331336 let scheme = match upstream_url. scheme ( ) {
332337 "http" => "ws" ,
@@ -373,17 +378,21 @@ impl WebSocketProxyHandler {
373378 & self ,
374379 ws_url : & str ,
375380 _headers : Vec < ( & str , & str ) > ,
376- ) -> Result < ( WebSocketStream < MaybeTlsStream < TcpStream > > , tokio_tungstenite:: tungstenite:: handshake:: client:: Response ) > {
381+ ) -> Result < (
382+ WebSocketStream < MaybeTlsStream < TcpStream > > ,
383+ tokio_tungstenite:: tungstenite:: handshake:: client:: Response ,
384+ ) > {
377385 // For now, use the simple connect_async approach
378386 // In a production environment, you'd want to handle custom headers
379387 // by building a proper request with tokio_tungstenite::client_async
380-
381- let ( ws_stream, response) = tokio_tungstenite:: connect_async ( ws_url)
382- . await
383- . map_err ( |e| {
384- error ! ( "WebSocket connection error: {}" , e) ;
385- Error :: new_str ( "WebSocket connection failed" )
386- } ) ?;
388+
389+ let ( ws_stream, response) =
390+ tokio_tungstenite:: connect_async ( ws_url)
391+ . await
392+ . map_err ( |e| {
393+ error ! ( "WebSocket connection error: {}" , e) ;
394+ Error :: new_str ( "WebSocket connection failed" )
395+ } ) ?;
387396
388397 debug ! ( "Successfully connected to upstream WebSocket" ) ;
389398 Ok ( ( ws_stream, response) )
@@ -462,7 +471,7 @@ impl WebSocketProxyHandler {
462471mod tests {
463472 use super :: * ;
464473 use crate :: config:: site:: { LoadBalancingConfig , ProxyHeadersConfig , TimeoutConfig } ;
465- use pingora:: http:: { RequestHeader , Method } ;
474+ use pingora:: http:: { Method , RequestHeader } ;
466475 use std:: collections:: HashMap ;
467476
468477 fn create_test_config ( ) -> ProxyConfig {
@@ -521,14 +530,15 @@ mod tests {
521530 #[ test]
522531 fn test_websocket_upgrade_detection ( ) {
523532 let mut req = RequestHeader :: build ( Method :: GET , b"/ws" , None ) . unwrap ( ) ;
524-
533+
525534 // Missing headers - should not be WebSocket
526535 assert ! ( !WebSocketProxyHandler :: is_websocket_upgrade_request( & req) ) ;
527536
528537 // Add WebSocket headers
529538 req. insert_header ( "Upgrade" , "websocket" ) . unwrap ( ) ;
530539 req. insert_header ( "Connection" , "Upgrade" ) . unwrap ( ) ;
531- req. insert_header ( "Sec-WebSocket-Key" , "dGhlIHNhbXBsZSBub25jZQ==" ) . unwrap ( ) ;
540+ req. insert_header ( "Sec-WebSocket-Key" , "dGhlIHNhbXBsZSBub25jZQ==" )
541+ . unwrap ( ) ;
532542
533543 // Now should be detected as WebSocket
534544 assert ! ( WebSocketProxyHandler :: is_websocket_upgrade_request( & req) ) ;
@@ -570,7 +580,9 @@ mod tests {
570580 websocket : true ,
571581 } ;
572582
573- let ws_url = handler. get_websocket_url ( upstream, route, "/ws/chat" ) . unwrap ( ) ;
583+ let ws_url = handler
584+ . get_websocket_url ( upstream, route, "/ws/chat" )
585+ . unwrap ( ) ;
574586 assert_eq ! ( ws_url, "ws://localhost:3001/chat" ) ;
575587
576588 // Test with HTTPS upstream
@@ -581,7 +593,9 @@ mod tests {
581593 max_conns : None ,
582594 } ;
583595
584- let wss_url = handler. get_websocket_url ( https_upstream, route, "/ws/chat" ) . unwrap ( ) ;
596+ let wss_url = handler
597+ . get_websocket_url ( https_upstream, route, "/ws/chat" )
598+ . unwrap ( ) ;
585599 assert_eq ! ( wss_url, "wss://localhost:3001/chat" ) ;
586600 }
587601
0 commit comments