@@ -27,7 +27,7 @@ use http::{header::HeaderMap, Uri};
2727use hyper_rustls:: HttpsConnector ;
2828use hyper_util:: {
2929 client:: legacy:: { connect:: HttpConnector , Client } ,
30- rt:: { TokioExecutor , TokioIo } ,
30+ rt:: TokioExecutor ,
3131} ;
3232use linkup:: {
3333 allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore , NameKind ,
@@ -40,10 +40,12 @@ use std::{
4040} ;
4141use std:: { path:: Path , sync:: Arc } ;
4242use tokio:: { net:: UdpSocket , signal} ;
43+ use tokio_tungstenite:: tungstenite:: client:: IntoClientRequest ;
4344use tower:: ServiceBuilder ;
4445use tower_http:: trace:: { DefaultOnRequest , DefaultOnResponse , TraceLayer } ;
4546
4647pub mod certificates;
48+ mod ws;
4749
4850type HttpsClient = Client < HttpsConnector < HttpConnector > , Body > ;
4951
@@ -182,6 +184,7 @@ pub async fn start_dns_server(linkup_session_name: String, domains: Vec<String>)
182184async fn linkup_request_handler (
183185 Extension ( store) : Extension < MemoryStringStore > ,
184186 Extension ( client) : Extension < HttpsClient > ,
187+ ws : ws:: ExtractOptionalWebSocketUpgrade ,
185188 req : Request ,
186189) -> Response {
187190 let sessions = SessionAllocator :: new ( & store) ;
@@ -224,15 +227,58 @@ async fn linkup_request_handler(
224227
225228 let extra_headers = get_additional_headers ( & url, & headers, & session_name, & target_service) ;
226229
227- if req
228- . headers ( )
229- . get ( "upgrade" )
230- . map ( |v| v == "websocket" )
231- . unwrap_or ( false )
232- {
233- handle_ws_req ( req, target_service, extra_headers, client) . await
234- } else {
235- handle_http_req ( req, target_service, extra_headers, client) . await
230+ match ws. 0 {
231+ Some ( downstream_upgrade) => {
232+ let mut url = target_service. url ;
233+ if url. starts_with ( "http://" ) {
234+ url = url. replace ( "http://" , "ws://" ) ;
235+ } else if url. starts_with ( "https://" ) {
236+ url = url. replace ( "https://" , "wss://" ) ;
237+ }
238+
239+ let uri = url. parse :: < Uri > ( ) . unwrap ( ) ;
240+ let mut upstream_request = uri. into_client_request ( ) . unwrap ( ) ;
241+
242+ let extra_http_headers: HeaderMap = extra_headers. into ( ) ;
243+ for ( key, value) in extra_http_headers. iter ( ) {
244+ upstream_request. headers_mut ( ) . insert ( key, value. clone ( ) ) ;
245+ }
246+
247+ let ( upstream_ws_stream, upstream_response) =
248+ match tokio_tungstenite:: connect_async ( upstream_request) . await {
249+ Ok ( connection) => connection,
250+ Err ( error) => match error {
251+ tokio_tungstenite:: tungstenite:: Error :: Http ( response) => {
252+ let ( parts, body) = response. into_parts ( ) ;
253+ let body = body. unwrap_or_default ( ) ;
254+
255+ return Response :: from_parts ( parts, Body :: from ( body) ) ;
256+ }
257+ error => {
258+ return Response :: builder ( )
259+ . status ( StatusCode :: BAD_GATEWAY )
260+ . body ( Body :: from ( error. to_string ( ) ) )
261+ . unwrap ( )
262+ }
263+ } ,
264+ } ;
265+
266+ let mut upstream_upgrade_response =
267+ downstream_upgrade. on_upgrade ( ws:: context_handle_socket ( upstream_ws_stream) ) ;
268+
269+ let websocket_upgrade_response_headers = upstream_upgrade_response. headers_mut ( ) ;
270+ for upstream_header in upstream_response. headers ( ) {
271+ if !websocket_upgrade_response_headers. contains_key ( upstream_header. 0 ) {
272+ websocket_upgrade_response_headers
273+ . append ( upstream_header. 0 , upstream_header. 1 . clone ( ) ) ;
274+ }
275+ }
276+
277+ websocket_upgrade_response_headers. extend ( allow_all_cors ( ) ) ;
278+
279+ upstream_upgrade_response
280+ }
281+ None => handle_http_req ( req, target_service, extra_headers, client) . await ,
236282 }
237283}
238284
@@ -272,119 +318,6 @@ async fn handle_http_req(
272318 resp. into_response ( )
273319}
274320
275- async fn handle_ws_req (
276- req : Request ,
277- target_service : TargetService ,
278- extra_headers : linkup:: HeaderMap ,
279- client : HttpsClient ,
280- ) -> Response {
281- let extra_http_headers: HeaderMap = extra_headers. into ( ) ;
282-
283- let target_ws_req_result = Request :: builder ( )
284- . uri ( target_service. url )
285- . method ( req. method ( ) . clone ( ) )
286- . body ( Body :: empty ( ) ) ;
287-
288- let mut target_ws_req = match target_ws_req_result {
289- Ok ( request) => request,
290- Err ( e) => {
291- return ApiError :: new (
292- format ! ( "Failed to build request: {}" , e) ,
293- StatusCode :: INTERNAL_SERVER_ERROR ,
294- )
295- . into_response ( ) ;
296- }
297- } ;
298-
299- target_ws_req. headers_mut ( ) . extend ( req. headers ( ) . clone ( ) ) ;
300- target_ws_req. headers_mut ( ) . extend ( extra_http_headers) ;
301- target_ws_req. headers_mut ( ) . remove ( http:: header:: HOST ) ;
302-
303- // Send the modified request to the target service.
304- let target_ws_resp = match client. request ( target_ws_req) . await {
305- Ok ( resp) => resp,
306- Err ( e) => {
307- return ApiError :: new (
308- format ! ( "Failed to proxy request: {}" , e) ,
309- StatusCode :: BAD_GATEWAY ,
310- )
311- . into_response ( )
312- }
313- } ;
314-
315- let status = target_ws_resp. status ( ) ;
316- if status != 101 {
317- return ApiError :: new (
318- format ! (
319- "Failed to proxy request: expected 101 Switching Protocols, got {}" ,
320- status
321- ) ,
322- StatusCode :: BAD_GATEWAY ,
323- )
324- . into_response ( ) ;
325- }
326-
327- let target_ws_resp_headers = target_ws_resp. headers ( ) . clone ( ) ;
328-
329- let upgraded_target = match hyper:: upgrade:: on ( target_ws_resp) . await {
330- Ok ( upgraded) => upgraded,
331- Err ( e) => {
332- return ApiError :: new (
333- format ! ( "Failed to upgrade connection: {}" , e) ,
334- StatusCode :: BAD_GATEWAY ,
335- )
336- . into_response ( )
337- }
338- } ;
339-
340- tokio:: spawn ( async move {
341- // We won't get passed this until the 101 response returns to the client
342- let upgraded_incoming = match hyper:: upgrade:: on ( req) . await {
343- Ok ( upgraded) => upgraded,
344- Err ( e) => {
345- println ! ( "Failed to upgrade incoming connection: {}" , e) ;
346- return ;
347- }
348- } ;
349-
350- let mut incoming_stream = TokioIo :: new ( upgraded_incoming) ;
351- let mut target_stream = TokioIo :: new ( upgraded_target) ;
352-
353- let res = tokio:: io:: copy_bidirectional ( & mut incoming_stream, & mut target_stream) . await ;
354-
355- match res {
356- Ok ( ( incoming_to_target, target_to_incoming) ) => {
357- println ! (
358- "Copied {} bytes from incoming to target and {} bytes from target to incoming" ,
359- incoming_to_target, target_to_incoming
360- ) ;
361- }
362- Err ( e) => {
363- eprintln ! ( "Error copying between incoming and target: {}" , e) ;
364- }
365- }
366- } ) ;
367-
368- let mut resp_builder = Response :: builder ( ) . status ( 101 ) ;
369- let resp_headers_result = resp_builder. headers_mut ( ) ;
370- if let Some ( resp_headers) = resp_headers_result {
371- for ( header, value) in target_ws_resp_headers {
372- if let Some ( header_name) = header {
373- resp_headers. append ( header_name, value) ;
374- }
375- }
376- }
377-
378- match resp_builder. body ( Body :: empty ( ) ) {
379- Ok ( response) => response,
380- Err ( e) => ApiError :: new (
381- format ! ( "Failed to build response: {}" , e) ,
382- StatusCode :: INTERNAL_SERVER_ERROR ,
383- )
384- . into_response ( ) ,
385- }
386- }
387-
388321async fn linkup_config_handler (
389322 Extension ( store) : Extension < MemoryStringStore > ,
390323 Json ( update_req) : Json < UpdateSessionRequest > ,
0 commit comments