Skip to content

Commit 6eace32

Browse files
committed
add proxy support for websocket connect
1 parent 709dbcf commit 6eace32

File tree

6 files changed

+118
-18
lines changed

6 files changed

+118
-18
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ flate2 = { version = "1.0.28", optional = true }
5151
reqwest = { version = ">=0.11.22", default-features = false, features = ["multipart", "stream"], optional = true }
5252
static_assertions = { version = "1.1.0", optional = true }
5353
tokio-tungstenite = { version = "0.21.0", optional = true }
54+
async-http-proxy = { version = "1.2", optional = true, features = ["runtime-tokio", "basic-auth"] }
5455
typemap_rev = { version = "0.3.0", optional = true }
5556
bytes = { version = "1.5.0", optional = true }
5657
percent-encoding = { version = "2.3.0", optional = true }
@@ -143,13 +144,15 @@ absolute_ratelimits = []
143144
rustls_backend = [
144145
"reqwest/rustls-tls",
145146
"tokio-tungstenite/rustls-tls-webpki-roots",
147+
"async-http-proxy",
146148
"bytes",
147149
]
148150

149151
# - Native TLS Backends
150152
native_tls_backend = [
151153
"reqwest/native-tls",
152154
"tokio-tungstenite/native-tls",
155+
"async-http-proxy",
153156
"bytes",
154157
]
155158

src/client/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use futures::StreamExt as _;
3535
use tokio::sync::{Mutex, RwLock};
3636
use tracing::{debug, error, info, instrument};
3737
use typemap_rev::{TypeMap, TypeMapKey};
38-
38+
use url::Url;
3939
pub use self::context::Context;
4040
pub use self::error::Error as ClientError;
4141
#[cfg(feature = "gateway")]
@@ -76,6 +76,7 @@ pub struct ClientBuilder {
7676
event_handlers: Vec<Arc<dyn EventHandler>>,
7777
raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
7878
presence: PresenceData,
79+
ws_proxy: Option<String>,
7980
}
8081

8182
#[cfg(feature = "gateway")]
@@ -94,6 +95,7 @@ impl ClientBuilder {
9495
event_handlers: vec![],
9596
raw_event_handlers: vec![],
9697
presence: PresenceData::default(),
98+
ws_proxy: None,
9799
}
98100
}
99101

@@ -157,6 +159,17 @@ impl ClientBuilder {
157159
&self.data
158160
}
159161

162+
/// Sets http proxy for the websocket connection.
163+
pub fn ws_proxy<T: Into<String>>(mut self, proxy: T) -> Self {
164+
self.ws_proxy = Some(proxy.into());
165+
self
166+
}
167+
168+
/// Gets the websocket proxy. See [`Self::ws_proxy`] for more info.
169+
pub fn get_ws_proxy(&self) -> Option<&str> {
170+
self.ws_proxy.as_deref()
171+
}
172+
160173
/// Insert a single `value` into the internal [`TypeMap`] that will be available in
161174
/// [`Context::data`]. This method can be called multiple times in order to populate the
162175
/// [`TypeMap`] with `value`s.
@@ -339,6 +352,7 @@ impl IntoFuture for ClientBuilder {
339352
let raw_event_handlers = self.raw_event_handlers;
340353
let intents = self.intents;
341354
let presence = self.presence;
355+
let ws_proxy = self.ws_proxy;
342356

343357
let mut http = self.http;
344358

@@ -369,6 +383,8 @@ impl IntoFuture for ClientBuilder {
369383
},
370384
}));
371385

386+
let ws_proxy = Arc::new(Mutex::new(ws_proxy));
387+
372388
#[cfg(feature = "framework")]
373389
let framework_cell = Arc::new(OnceLock::new());
374390
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
@@ -383,6 +399,7 @@ impl IntoFuture for ClientBuilder {
383399
#[cfg(feature = "voice")]
384400
voice_manager: voice_manager.clone(),
385401
ws_url: Arc::clone(&ws_url),
402+
ws_proxy: Arc::clone(&ws_proxy),
386403
#[cfg(feature = "cache")]
387404
cache: Arc::clone(&cache),
388405
http: Arc::clone(&http),

src/gateway/bridge/shard_manager.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ impl ShardManager {
155155
#[cfg(feature = "voice")]
156156
voice_manager: opt.voice_manager,
157157
ws_url: opt.ws_url,
158+
ws_proxy: opt.ws_proxy,
158159
#[cfg(feature = "cache")]
159160
cache: opt.cache,
160161
http: opt.http,
@@ -396,6 +397,7 @@ pub struct ShardManagerOptions {
396397
#[cfg(feature = "voice")]
397398
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
398399
pub ws_url: Arc<Mutex<String>>,
400+
pub ws_proxy: Arc<Mutex<Option<String>>>,
399401
#[cfg(feature = "cache")]
400402
pub cache: Arc<Cache>,
401403
pub http: Arc<Http>,

src/gateway/bridge/shard_queuer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ pub struct ShardQueuer {
7373
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
7474
/// A copy of the URL to use to connect to the gateway.
7575
pub ws_url: Arc<Mutex<String>>,
76+
pub ws_proxy: Arc<Mutex<Option<String>>>,
7677
#[cfg(feature = "cache")]
7778
pub cache: Arc<Cache>,
7879
pub http: Arc<Http>,
@@ -168,6 +169,7 @@ impl ShardQueuer {
168169

169170
let mut shard = Shard::new(
170171
Arc::clone(&self.ws_url),
172+
Arc::clone(&self.ws_proxy),
171173
self.http.token(),
172174
shard_info,
173175
self.intents,

src/gateway/shard.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pub struct Shard {
7575
pub started: Instant,
7676
pub token: String,
7777
ws_url: Arc<Mutex<String>>,
78+
ws_proxy: Arc<Mutex<Option<String>>>,
7879
pub intents: GatewayIntents,
7980
}
8081

@@ -121,13 +122,15 @@ impl Shard {
121122
/// TLS error.
122123
pub async fn new(
123124
ws_url: Arc<Mutex<String>>,
125+
ws_proxy: Arc<Mutex<Option<String>>>,
124126
token: &str,
125127
info: ShardInfo,
126128
intents: GatewayIntents,
127129
presence: Option<PresenceData>,
128130
) -> Result<Shard> {
129131
let url = ws_url.lock().await.clone();
130-
let client = connect(&url).await?;
132+
let proxy = ws_proxy.lock().await.clone();
133+
let client = connect(&url, &proxy).await?;
131134

132135
let presence = presence.unwrap_or_default();
133136
let last_heartbeat_sent = None;
@@ -153,6 +156,7 @@ impl Shard {
153156
session_id,
154157
info,
155158
ws_url,
159+
ws_proxy,
156160
intents,
157161
})
158162
}
@@ -687,7 +691,8 @@ impl Shard {
687691
self.stage = ConnectionStage::Connecting;
688692
self.started = Instant::now();
689693
let url = &self.ws_url.lock().await.clone();
690-
let client = connect(url).await?;
694+
let proxy = &self.ws_proxy.lock().await.clone();
695+
let client = connect(url, proxy).await?;
691696
self.stage = ConnectionStage::Handshake;
692697

693698
Ok(client)
@@ -744,13 +749,25 @@ impl Shard {
744749
}
745750
}
746751

747-
async fn connect(base_url: &str) -> Result<WsClient> {
748-
let url =
752+
async fn connect(base_url: &str, proxy_url: &Option<String>) -> Result<WsClient> {
753+
let ws_url =
749754
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
750755
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);
751756

752757
Error::Gateway(GatewayError::BuildingUrl)
753758
})?;
754759

755-
WsClient::connect(url).await
760+
let parsed_proxy = match proxy_url {
761+
Some(proxy) => {
762+
let parsed_proxy = Url::parse(&proxy).map_err(|why| {
763+
warn!("Error building proxy URL with base `{}`: {:?}", proxy, why);
764+
765+
Error::Gateway(GatewayError::BuildingUrl)
766+
})?;
767+
Some(parsed_proxy)
768+
},
769+
None => None,
770+
};
771+
772+
WsClient::connect(ws_url, parsed_proxy).await
756773
}

src/gateway/ws.rs

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::env::consts;
2+
use std::io::ErrorKind;
23
#[cfg(feature = "client")]
34
use std::io::Read;
45
use std::time::SystemTime;
@@ -17,7 +18,9 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
1718
#[cfg(feature = "client")]
1819
use tokio_tungstenite::tungstenite::Error as WsError;
1920
use tokio_tungstenite::tungstenite::Message;
20-
use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
21+
use tokio_tungstenite::{
22+
client_async_tls_with_config, connect_async_with_config, MaybeTlsStream, WebSocketStream,
23+
};
2124
#[cfg(feature = "client")]
2225
use tracing::warn;
2326
use tracing::{debug, instrument, trace};
@@ -101,17 +104,77 @@ const TIMEOUT: Duration = Duration::from_millis(500);
101104
const DECOMPRESSION_MULTIPLIER: usize = 3;
102105

103106
impl WsClient {
104-
pub(crate) async fn connect(url: Url) -> Result<Self> {
105-
let config = WebSocketConfig {
106-
max_message_size: None,
107-
max_frame_size: None,
108-
..Default::default()
107+
pub(crate) async fn connect(url: Url, proxy: Option<Url>) -> Result<Self> {
108+
let config =
109+
WebSocketConfig { max_message_size: None, max_frame_size: None, ..Default::default() };
110+
let (stream, _) = match proxy {
111+
None => connect_async_with_config(url, Some(config), false).await?,
112+
Some(proxy) => {
113+
let tls_stream = Self::connect_with_proxy_async(&url, &proxy).await?;
114+
tls_stream.set_nodelay(true)?;
115+
client_async_tls_with_config(url, tls_stream, Some(config), None).await?
116+
},
109117
};
110-
let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
111118

112119
Ok(Self(stream))
113120
}
114121

122+
async fn connect_with_proxy_async(
123+
target_url: &Url,
124+
proxy_url: &Url,
125+
) -> std::result::Result<TcpStream, std::io::Error> {
126+
let proxy_addr = &proxy_url[url::Position::BeforeHost..url::Position::AfterPort];
127+
if proxy_url.scheme() != "http" && proxy_url.scheme() != "https" {
128+
return Err(std::io::Error::new(ErrorKind::Unsupported, "unknown proxy scheme"));
129+
}
130+
131+
let host = target_url
132+
.host_str()
133+
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target host"))?;
134+
let port = target_url
135+
.port()
136+
.or_else(|| match target_url.scheme() {
137+
"wss" => Some(443),
138+
"ws" => Some(80),
139+
_ => None,
140+
})
141+
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?;
142+
143+
let mut tcp_stream = TcpStream::connect(proxy_addr).await?;
144+
145+
let (username, password) = if let Some(pass) = proxy_url.password() {
146+
let user = proxy_url.username();
147+
(user, pass)
148+
} else {
149+
("", "")
150+
};
151+
152+
if username.is_empty() {
153+
// No auth: use the standard function
154+
async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port).await.map_err(
155+
|e| std::io::Error::new(ErrorKind::Other, format!("proxy connect failed: {e}")),
156+
)?;
157+
} else {
158+
// With basic auth: use the auth variant
159+
async_http_proxy::http_connect_tokio_with_basic_auth(
160+
&mut tcp_stream,
161+
host,
162+
port,
163+
username,
164+
password,
165+
)
166+
.await
167+
.map_err(|e| {
168+
std::io::Error::new(
169+
ErrorKind::Other,
170+
format!("proxy connect with auth failed: {e}"),
171+
)
172+
})?;
173+
}
174+
175+
Ok(tcp_stream)
176+
}
177+
115178
#[cfg(feature = "client")]
116179
pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
117180
let message = match timeout(TIMEOUT, self.0.next()).await {
@@ -310,11 +373,7 @@ impl WsClient {
310373

311374
self.send_json(&WebSocketMessage {
312375
op: Opcode::Resume,
313-
d: WebSocketMessageData::Resume {
314-
session_id,
315-
token,
316-
seq,
317-
},
376+
d: WebSocketMessageData::Resume { session_id, token, seq },
318377
})
319378
.await
320379
}

0 commit comments

Comments
 (0)