11use std:: env:: consts;
2+ use std:: io:: ErrorKind ;
23#[ cfg( feature = "client" ) ]
34use std:: io:: Read ;
45use std:: time:: SystemTime ;
@@ -17,7 +18,9 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
1718#[ cfg( feature = "client" ) ]
1819use tokio_tungstenite:: tungstenite:: Error as WsError ;
1920use tokio_tungstenite:: tungstenite:: Message ;
20- use tokio_tungstenite:: { connect_async_with_config, MaybeTlsStream , WebSocketStream } ;
21+ use tokio_tungstenite:: {
22+ client_async_tls_with_config, connect_async_with_config, MaybeTlsStream , WebSocketStream ,
23+ } ;
2124#[ cfg( feature = "client" ) ]
2225use tracing:: warn;
2326use tracing:: { debug, instrument, trace} ;
@@ -101,17 +104,77 @@ const TIMEOUT: Duration = Duration::from_millis(500);
101104const DECOMPRESSION_MULTIPLIER : usize = 3 ;
102105
103106impl WsClient {
104- pub ( crate ) async fn connect ( url : Url ) -> Result < Self > {
105- let config = WebSocketConfig {
106- max_message_size : None ,
107- max_frame_size : None ,
108- ..Default :: default ( )
107+ pub ( crate ) async fn connect ( url : Url , proxy : Option < Url > ) -> Result < Self > {
108+ let config =
109+ WebSocketConfig { max_message_size : None , max_frame_size : None , ..Default :: default ( ) } ;
110+ let ( stream, _) = match proxy {
111+ None => connect_async_with_config ( url, Some ( config) , false ) . await ?,
112+ Some ( proxy) => {
113+ let tls_stream = Self :: connect_with_proxy_async ( & url, & proxy) . await ?;
114+ tls_stream. set_nodelay ( true ) ?;
115+ client_async_tls_with_config ( url, tls_stream, Some ( config) , None ) . await ?
116+ } ,
109117 } ;
110- let ( stream, _) = connect_async_with_config ( url, Some ( config) , false ) . await ?;
111118
112119 Ok ( Self ( stream) )
113120 }
114121
122+ async fn connect_with_proxy_async (
123+ target_url : & Url ,
124+ proxy_url : & Url ,
125+ ) -> std:: result:: Result < TcpStream , std:: io:: Error > {
126+ let proxy_addr = & proxy_url[ url:: Position :: BeforeHost ..url:: Position :: AfterPort ] ;
127+ if proxy_url. scheme ( ) != "http" && proxy_url. scheme ( ) != "https" {
128+ return Err ( std:: io:: Error :: new ( ErrorKind :: Unsupported , "unknown proxy scheme" ) ) ;
129+ }
130+
131+ let host = target_url
132+ . host_str ( )
133+ . ok_or_else ( || std:: io:: Error :: new ( ErrorKind :: Unsupported , "unknown target host" ) ) ?;
134+ let port = target_url
135+ . port ( )
136+ . or_else ( || match target_url. scheme ( ) {
137+ "wss" => Some ( 443 ) ,
138+ "ws" => Some ( 80 ) ,
139+ _ => None ,
140+ } )
141+ . ok_or_else ( || std:: io:: Error :: new ( ErrorKind :: Unsupported , "unknown target scheme" ) ) ?;
142+
143+ let mut tcp_stream = TcpStream :: connect ( proxy_addr) . await ?;
144+
145+ let ( username, password) = if let Some ( pass) = proxy_url. password ( ) {
146+ let user = proxy_url. username ( ) ;
147+ ( user, pass)
148+ } else {
149+ ( "" , "" )
150+ } ;
151+
152+ if username. is_empty ( ) {
153+ // No auth: use the standard function
154+ async_http_proxy:: http_connect_tokio ( & mut tcp_stream, host, port) . await . map_err (
155+ |e| std:: io:: Error :: new ( ErrorKind :: Other , format ! ( "proxy connect failed: {e}" ) ) ,
156+ ) ?;
157+ } else {
158+ // With basic auth: use the auth variant
159+ async_http_proxy:: http_connect_tokio_with_basic_auth (
160+ & mut tcp_stream,
161+ host,
162+ port,
163+ username,
164+ password,
165+ )
166+ . await
167+ . map_err ( |e| {
168+ std:: io:: Error :: new (
169+ ErrorKind :: Other ,
170+ format ! ( "proxy connect with auth failed: {e}" ) ,
171+ )
172+ } ) ?;
173+ }
174+
175+ Ok ( tcp_stream)
176+ }
177+
115178 #[ cfg( feature = "client" ) ]
116179 pub ( crate ) async fn recv_json ( & mut self ) -> Result < Option < GatewayEvent > > {
117180 let message = match timeout ( TIMEOUT , self . 0 . next ( ) ) . await {
@@ -310,11 +373,7 @@ impl WsClient {
310373
311374 self . send_json ( & WebSocketMessage {
312375 op : Opcode :: Resume ,
313- d : WebSocketMessageData :: Resume {
314- session_id,
315- token,
316- seq,
317- } ,
376+ d : WebSocketMessageData :: Resume { session_id, token, seq } ,
318377 } )
319378 . await
320379 }
0 commit comments