Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ flate2 = { version = "1.0.28", optional = true }
reqwest = { version = ">=0.11.22", default-features = false, features = ["multipart", "stream"], optional = true }
static_assertions = { version = "1.1.0", optional = true }
tokio-tungstenite = { version = "0.21.0", optional = true }
async-http-proxy = { version = "1.2", optional = true, features = ["runtime-tokio", "basic-auth"] }
typemap_rev = { version = "0.3.0", optional = true }
bytes = { version = "1.5.0", optional = true }
percent-encoding = { version = "2.3.0", optional = true }
Expand Down Expand Up @@ -143,13 +144,15 @@ absolute_ratelimits = []
rustls_backend = [
"reqwest/rustls-tls",
"tokio-tungstenite/rustls-tls-webpki-roots",
"async-http-proxy",
"bytes",
]

# - Native TLS Backends
native_tls_backend = [
"reqwest/native-tls",
"tokio-tungstenite/native-tls",
"async-http-proxy",
"bytes",
]

Expand Down
19 changes: 18 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use futures::StreamExt as _;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, info, instrument};
use typemap_rev::{TypeMap, TypeMapKey};

use url::Url;
pub use self::context::Context;
pub use self::error::Error as ClientError;
#[cfg(feature = "gateway")]
Expand Down Expand Up @@ -76,6 +76,7 @@ pub struct ClientBuilder {
event_handlers: Vec<Arc<dyn EventHandler>>,
raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
presence: PresenceData,
ws_proxy: Option<String>,
}

#[cfg(feature = "gateway")]
Expand All @@ -94,6 +95,7 @@ impl ClientBuilder {
event_handlers: vec![],
raw_event_handlers: vec![],
presence: PresenceData::default(),
ws_proxy: None,
}
}

Expand Down Expand Up @@ -157,6 +159,17 @@ impl ClientBuilder {
&self.data
}

/// Sets http proxy for the websocket connection.
pub fn ws_proxy<T: Into<String>>(mut self, proxy: T) -> Self {
self.ws_proxy = Some(proxy.into());
self
}

/// Gets the websocket proxy. See [`Self::ws_proxy`] for more info.
pub fn get_ws_proxy(&self) -> Option<&str> {
self.ws_proxy.as_deref()
}

/// Insert a single `value` into the internal [`TypeMap`] that will be available in
/// [`Context::data`]. This method can be called multiple times in order to populate the
/// [`TypeMap`] with `value`s.
Expand Down Expand Up @@ -339,6 +352,7 @@ impl IntoFuture for ClientBuilder {
let raw_event_handlers = self.raw_event_handlers;
let intents = self.intents;
let presence = self.presence;
let ws_proxy = self.ws_proxy;

let mut http = self.http;

Expand Down Expand Up @@ -369,6 +383,8 @@ impl IntoFuture for ClientBuilder {
},
}));

let ws_proxy = Arc::new(Mutex::new(ws_proxy));

#[cfg(feature = "framework")]
let framework_cell = Arc::new(OnceLock::new());
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
Expand All @@ -383,6 +399,7 @@ impl IntoFuture for ClientBuilder {
#[cfg(feature = "voice")]
voice_manager: voice_manager.clone(),
ws_url: Arc::clone(&ws_url),
ws_proxy: Arc::clone(&ws_proxy),
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
http: Arc::clone(&http),
Expand Down
2 changes: 2 additions & 0 deletions src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ impl ShardManager {
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
ws_proxy: opt.ws_proxy,
#[cfg(feature = "cache")]
cache: opt.cache,
http: opt.http,
Expand Down Expand Up @@ -396,6 +397,7 @@ pub struct ShardManagerOptions {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
pub ws_url: Arc<Mutex<String>>,
pub ws_proxy: Arc<Mutex<Option<String>>>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub http: Arc<Http>,
Expand Down
2 changes: 2 additions & 0 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct ShardQueuer {
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// A copy of the URL to use to connect to the gateway.
pub ws_url: Arc<Mutex<String>>,
pub ws_proxy: Arc<Mutex<Option<String>>>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub http: Arc<Http>,
Expand Down Expand Up @@ -168,6 +169,7 @@ impl ShardQueuer {

let mut shard = Shard::new(
Arc::clone(&self.ws_url),
Arc::clone(&self.ws_proxy),
self.http.token(),
shard_info,
self.intents,
Expand Down
27 changes: 22 additions & 5 deletions src/gateway/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub struct Shard {
pub started: Instant,
pub token: String,
ws_url: Arc<Mutex<String>>,
ws_proxy: Arc<Mutex<Option<String>>>,
pub intents: GatewayIntents,
}

Expand Down Expand Up @@ -121,13 +122,15 @@ impl Shard {
/// TLS error.
pub async fn new(
ws_url: Arc<Mutex<String>>,
ws_proxy: Arc<Mutex<Option<String>>>,
token: &str,
info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
) -> Result<Shard> {
let url = ws_url.lock().await.clone();
let client = connect(&url).await?;
let proxy = ws_proxy.lock().await.clone();
let client = connect(&url, &proxy).await?;

let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
Expand All @@ -153,6 +156,7 @@ impl Shard {
session_id,
info,
ws_url,
ws_proxy,
intents,
})
}
Expand Down Expand Up @@ -687,7 +691,8 @@ impl Shard {
self.stage = ConnectionStage::Connecting;
self.started = Instant::now();
let url = &self.ws_url.lock().await.clone();
let client = connect(url).await?;
let proxy = &self.ws_proxy.lock().await.clone();
let client = connect(url, proxy).await?;
self.stage = ConnectionStage::Handshake;

Ok(client)
Expand Down Expand Up @@ -744,13 +749,25 @@ impl Shard {
}
}

async fn connect(base_url: &str) -> Result<WsClient> {
let url =
async fn connect(base_url: &str, proxy_url: &Option<String>) -> Result<WsClient> {
let ws_url =
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;

WsClient::connect(url).await
let parsed_proxy = match proxy_url {
Some(proxy) => {
let parsed_proxy = Url::parse(&proxy).map_err(|why| {
warn!("Error building proxy URL with base `{}`: {:?}", proxy, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;
Some(parsed_proxy)
},
None => None,
};

WsClient::connect(ws_url, parsed_proxy).await
}
83 changes: 71 additions & 12 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::env::consts;
use std::io::ErrorKind;
#[cfg(feature = "client")]
use std::io::Read;
use std::time::SystemTime;
Expand All @@ -17,7 +18,9 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
#[cfg(feature = "client")]
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::{
client_async_tls_with_config, connect_async_with_config, MaybeTlsStream, WebSocketStream,
};
#[cfg(feature = "client")]
use tracing::warn;
use tracing::{debug, instrument, trace};
Expand Down Expand Up @@ -101,17 +104,77 @@ const TIMEOUT: Duration = Duration::from_millis(500);
const DECOMPRESSION_MULTIPLIER: usize = 3;

impl WsClient {
pub(crate) async fn connect(url: Url) -> Result<Self> {
let config = WebSocketConfig {
max_message_size: None,
max_frame_size: None,
..Default::default()
pub(crate) async fn connect(url: Url, proxy: Option<Url>) -> Result<Self> {
let config =
WebSocketConfig { max_message_size: None, max_frame_size: None, ..Default::default() };
let (stream, _) = match proxy {
None => connect_async_with_config(url, Some(config), false).await?,
Some(proxy) => {
let tls_stream = Self::connect_with_proxy_async(&url, &proxy).await?;
tls_stream.set_nodelay(true)?;
client_async_tls_with_config(url, tls_stream, Some(config), None).await?
},
};
let (stream, _) = connect_async_with_config(url, Some(config), false).await?;

Ok(Self(stream))
}

async fn connect_with_proxy_async(
target_url: &Url,
proxy_url: &Url,
) -> std::result::Result<TcpStream, std::io::Error> {
let proxy_addr = &proxy_url[url::Position::BeforeHost..url::Position::AfterPort];
if proxy_url.scheme() != "http" && proxy_url.scheme() != "https" {
return Err(std::io::Error::new(ErrorKind::Unsupported, "unknown proxy scheme"));
}

let host = target_url
.host_str()
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target host"))?;
let port = target_url
.port()
.or_else(|| match target_url.scheme() {
"wss" => Some(443),
"ws" => Some(80),
_ => None,
})
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?;

let mut tcp_stream = TcpStream::connect(proxy_addr).await?;

let (username, password) = if let Some(pass) = proxy_url.password() {
let user = proxy_url.username();
(user, pass)
} else {
("", "")
};

if username.is_empty() {
// No auth: use the standard function
async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port).await.map_err(
|e| std::io::Error::new(ErrorKind::Other, format!("proxy connect failed: {e}")),
)?;
} else {
// With basic auth: use the auth variant
async_http_proxy::http_connect_tokio_with_basic_auth(
&mut tcp_stream,
host,
port,
username,
password,
)
.await
.map_err(|e| {
std::io::Error::new(
ErrorKind::Other,
format!("proxy connect with auth failed: {e}"),
)
})?;
}

Ok(tcp_stream)
}

#[cfg(feature = "client")]
pub(crate) async fn recv_json(&mut self) -> Result<Option<GatewayEvent>> {
let message = match timeout(TIMEOUT, self.0.next()).await {
Expand Down Expand Up @@ -310,11 +373,7 @@ impl WsClient {

self.send_json(&WebSocketMessage {
op: Opcode::Resume,
d: WebSocketMessageData::Resume {
session_id,
token,
seq,
},
d: WebSocketMessageData::Resume { session_id, token, seq },
})
.await
}
Expand Down
Loading