Skip to content

Commit b20151f

Browse files
fgardtGnomedDev
andauthored
Transport compression support (#3036)
Co-authored-by: GnomedDev <[email protected]>
1 parent 2e12663 commit b20151f

File tree

8 files changed

+309
-43
lines changed

8 files changed

+309
-43
lines changed

Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ bytes = "1.5.0"
4343
fxhash = { version = "0.2.1", optional = true }
4444
chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true }
4545
flate2 = { version = "1.0.28", optional = true }
46+
zstd-safe = { version = "7.2.1", optional = true }
4647
reqwest = { version = "0.12.2", default-features = false, features = ["multipart", "stream", "json"], optional = true }
4748
tokio-tungstenite = { version = "0.24.0", features = ["url"], optional = true }
4849
percent-encoding = { version = "2.3.0", optional = true }
@@ -70,6 +71,8 @@ default_no_backend = [
7071
"cache",
7172
"chrono",
7273
"framework",
74+
"transport_compression_zlib",
75+
"transport_compression_zstd",
7376
]
7477

7578
# Enables builder structs to configure Discord HTTP requests. Without this feature, you have to
@@ -93,6 +96,10 @@ http = ["dashmap", "mime_guess", "percent-encoding"]
9396
# TODO: remove dependeny on utils feature
9497
model = ["builder", "http", "utils"]
9598
voice_model = ["serenity-voice-model"]
99+
# Enables zlib-stream transport compression of incoming gateway events.
100+
transport_compression_zlib = ["flate2", "gateway"]
101+
# Enables zstd-stream transport compression of incoming gateway events.
102+
transport_compression_zstd = ["zstd-safe", "gateway"]
96103
# Enables support for Discord API functionality that's not stable yet, as well as serenity APIs that
97104
# are allowed to change even in semver non-breaking updates.
98105
unstable = []

src/gateway/client/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use tracing::{debug, warn};
4242

4343
pub use self::context::Context;
4444
pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler};
45+
use super::TransportCompression;
4546
#[cfg(feature = "cache")]
4647
use crate::cache::Cache;
4748
#[cfg(feature = "cache")]
@@ -82,6 +83,7 @@ pub struct ClientBuilder {
8283
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
8384
presence: PresenceData,
8485
wait_time_between_shard_start: Duration,
86+
compression: TransportCompression,
8587
}
8688

8789
impl ClientBuilder {
@@ -116,6 +118,7 @@ impl ClientBuilder {
116118
raw_event_handler: None,
117119
presence: PresenceData::default(),
118120
wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START,
121+
compression: TransportCompression::None,
119122
}
120123
}
121124

@@ -176,6 +179,12 @@ impl ClientBuilder {
176179
self
177180
}
178181

182+
/// Sets the compression method to be used when receiving data from the gateway.
183+
pub fn compression(mut self, compression: TransportCompression) -> Self {
184+
self.compression = compression;
185+
self
186+
}
187+
179188
/// Sets the voice gateway handler to be used. It will receive voice events sent over the
180189
/// gateway and then consider - based on its settings - whether to dispatch a command.
181190
#[cfg(feature = "voice")]
@@ -342,6 +351,7 @@ impl IntoFuture for ClientBuilder {
342351
presence: Some(presence),
343352
max_concurrency,
344353
wait_time_between_shard_start: self.wait_time_between_shard_start,
354+
compression: self.compression,
345355
});
346356

347357
let client = Client {

src/gateway/error.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use tokio_tungstenite::tungstenite::protocol::CloseFrame;
77
///
88
/// Note that - from a user standpoint - there should be no situation in which you manually handle
99
/// these.
10-
#[derive(Clone, Debug)]
10+
#[derive(Debug)]
1111
#[non_exhaustive]
1212
pub enum Error {
1313
/// There was an error building a URL.
@@ -50,6 +50,17 @@ pub enum Error {
5050
/// If an connection has been established but privileged gateway intents were provided without
5151
/// enabling them prior.
5252
DisallowedGatewayIntents,
53+
#[cfg(feature = "transport_compression_zlib")]
54+
/// A decompression error from the `flate2` crate.
55+
DecompressZlib(flate2::DecompressError),
56+
#[cfg(feature = "transport_compression_zstd")]
57+
/// A decompression error from zstd.
58+
DecompressZstd(usize),
59+
/// When zstd decompression fails due to corrupted data.
60+
#[cfg(feature = "transport_compression_zstd")]
61+
DecompressZstdCorrupted,
62+
/// When decompressed gateway data is not valid UTF-8.
63+
DecompressUtf8(std::string::FromUtf8Error),
5364
}
5465

5566
impl fmt::Display for Error {
@@ -70,6 +81,15 @@ impl fmt::Display for Error {
7081
Self::DisallowedGatewayIntents => {
7182
f.write_str("Disallowed gateway intents were provided")
7283
},
84+
#[cfg(feature = "transport_compression_zlib")]
85+
Self::DecompressZlib(inner) => fmt::Display::fmt(&inner, f),
86+
#[cfg(feature = "transport_compression_zstd")]
87+
Self::DecompressZstd(code) => write!(f, "Zstd decompression error: {code}"),
88+
#[cfg(feature = "transport_compression_zstd")]
89+
Self::DecompressZstdCorrupted => {
90+
f.write_str("Zstd decompression error: corrupted data")
91+
},
92+
Self::DecompressUtf8(inner) => fmt::Display::fmt(&inner, f),
7393
}
7494
}
7595
}

src/gateway/sharding/mod.rs

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ use std::fmt;
4444
use std::sync::Arc;
4545
use std::time::{Duration as StdDuration, Instant};
4646

47-
use aformat::{aformat, CapStr};
47+
#[cfg(feature = "transport_compression_zlib")]
48+
use aformat::aformat_into;
49+
use aformat::{aformat, ArrayString, CapStr};
4850
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
4951
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
5052
use tracing::{debug, error, info, trace, warn};
@@ -113,6 +115,7 @@ pub struct Shard {
113115
token: SecretString,
114116
ws_url: Arc<str>,
115117
resume_ws_url: Option<FixedString>,
118+
compression: TransportCompression,
116119
pub intents: GatewayIntents,
117120
}
118121

@@ -129,7 +132,7 @@ impl Shard {
129132
/// use std::num::NonZeroU16;
130133
/// use std::sync::Arc;
131134
///
132-
/// use serenity::gateway::Shard;
135+
/// use serenity::gateway::{Shard, TransportCompression};
133136
/// use serenity::model::gateway::{GatewayIntents, ShardInfo};
134137
/// use serenity::model::id::ShardId;
135138
/// use serenity::secret_string::SecretString;
@@ -147,7 +150,15 @@ impl Shard {
147150
///
148151
/// // retrieve the gateway response, which contains the URL to connect to
149152
/// let gateway = Arc::from(http.get_gateway().await?.url);
150-
/// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?;
153+
/// let shard = Shard::new(
154+
/// gateway,
155+
/// token,
156+
/// shard_info,
157+
/// GatewayIntents::all(),
158+
/// None,
159+
/// TransportCompression::None,
160+
/// )
161+
/// .await?;
151162
///
152163
/// // at this point, you can create a `loop`, and receive events and match
153164
/// // their variants
@@ -165,8 +176,9 @@ impl Shard {
165176
shard_info: ShardInfo,
166177
intents: GatewayIntents,
167178
presence: Option<PresenceData>,
179+
compression: TransportCompression,
168180
) -> Result<Shard> {
169-
let client = connect(&ws_url).await?;
181+
let client = connect(&ws_url, compression).await?;
170182

171183
let presence = presence.unwrap_or_default();
172184
let last_heartbeat_sent = None;
@@ -193,6 +205,7 @@ impl Shard {
193205
shard_info,
194206
ws_url,
195207
resume_ws_url: None,
208+
compression,
196209
intents,
197210
})
198211
}
@@ -748,7 +761,7 @@ impl Shard {
748761
// Hello is received.
749762
self.stage = ConnectionStage::Connecting;
750763
self.started = Instant::now();
751-
let client = connect(ws_url).await?;
764+
let client = connect(ws_url, self.compression).await?;
752765
self.stage = ConnectionStage::Handshake;
753766

754767
Ok(client)
@@ -807,14 +820,19 @@ impl Shard {
807820
}
808821
}
809822

810-
async fn connect(base_url: &str) -> Result<WsClient> {
811-
let url = Url::parse(&aformat!("{}?v={}", CapStr::<64>(base_url), constants::GATEWAY_VERSION))
812-
.map_err(|why| {
813-
warn!("Error building gateway URL with base `{base_url}`: {why:?}");
814-
Error::Gateway(GatewayError::BuildingUrl)
815-
})?;
816-
817-
WsClient::connect(url).await
823+
async fn connect(base_url: &str, compression: TransportCompression) -> Result<WsClient> {
824+
let url = Url::parse(&aformat!(
825+
"{}?v={}{}",
826+
CapStr::<64>(base_url),
827+
constants::GATEWAY_VERSION,
828+
compression.query_param()
829+
))
830+
.map_err(|why| {
831+
warn!("Error building gateway URL with base `{base_url}`: {why:?}");
832+
Error::Gateway(GatewayError::BuildingUrl)
833+
})?;
834+
835+
WsClient::connect(url, compression).await
818836
}
819837

820838
#[derive(Debug)]
@@ -954,3 +972,41 @@ impl PartialEq for CollectorCallback {
954972
Arc::ptr_eq(&self.0, &other.0)
955973
}
956974
}
975+
976+
/// The transport compression method to use.
977+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
978+
#[non_exhaustive]
979+
pub enum TransportCompression {
980+
/// No transport compression. Payload compression will be used instead.
981+
None,
982+
983+
#[cfg(feature = "transport_compression_zlib")]
984+
/// Use zlib-stream transport compression.
985+
Zlib,
986+
987+
#[cfg(feature = "transport_compression_zstd")]
988+
/// Use zstd-stream transport compression.
989+
Zstd,
990+
}
991+
992+
impl TransportCompression {
993+
fn query_param(self) -> ArrayString<21> {
994+
#[cfg_attr(
995+
not(any(
996+
feature = "transport_compression_zlib",
997+
feature = "transport_compression_zstd"
998+
)),
999+
expect(unused_mut)
1000+
)]
1001+
let mut res = ArrayString::new();
1002+
match self {
1003+
Self::None => {},
1004+
#[cfg(feature = "transport_compression_zlib")]
1005+
Self::Zlib => aformat_into!(res, "&compress=zlib-stream"),
1006+
#[cfg(feature = "transport_compression_zstd")]
1007+
Self::Zstd => aformat_into!(res, "&compress=zstd-stream"),
1008+
}
1009+
1010+
res
1011+
}
1012+
}

src/gateway/sharding/shard_manager.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@ use tokio::sync::Mutex;
1111
use tokio::time::timeout;
1212
use tracing::{info, warn};
1313

14-
use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo};
14+
use super::{
15+
ShardId,
16+
ShardQueue,
17+
ShardQueuer,
18+
ShardQueuerMessage,
19+
ShardRunnerInfo,
20+
TransportCompression,
21+
};
1522
#[cfg(feature = "cache")]
1623
use crate::cache::Cache;
1724
#[cfg(feature = "framework")]
@@ -53,7 +60,12 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5);
5360
/// use std::sync::{Arc, OnceLock};
5461
///
5562
/// use serenity::gateway::client::EventHandler;
56-
/// use serenity::gateway::{ShardManager, ShardManagerOptions, DEFAULT_WAIT_BETWEEN_SHARD_START};
63+
/// use serenity::gateway::{
64+
/// ShardManager,
65+
/// ShardManagerOptions,
66+
/// TransportCompression,
67+
/// DEFAULT_WAIT_BETWEEN_SHARD_START,
68+
/// };
5769
/// use serenity::http::Http;
5870
/// use serenity::model::gateway::GatewayIntents;
5971
/// use serenity::prelude::*;
@@ -88,6 +100,7 @@ pub const DEFAULT_WAIT_BETWEEN_SHARD_START: Duration = Duration::from_secs(5);
88100
/// presence: None,
89101
/// max_concurrency,
90102
/// wait_time_between_shard_start: DEFAULT_WAIT_BETWEEN_SHARD_START,
103+
/// compression: TransportCompression::None,
91104
/// });
92105
/// # Ok(())
93106
/// # }
@@ -144,6 +157,7 @@ impl ShardManager {
144157
#[cfg(feature = "voice")]
145158
voice_manager: opt.voice_manager,
146159
ws_url: opt.ws_url,
160+
compression: opt.compression,
147161
shard_total: opt.shard_total,
148162
#[cfg(feature = "cache")]
149163
cache: opt.cache,
@@ -379,4 +393,5 @@ pub struct ShardManagerOptions {
379393
pub max_concurrency: NonZeroU16,
380394
/// Number of seconds to wait between starting each shard/set of shards start
381395
pub wait_time_between_shard_start: Duration,
396+
pub compression: TransportCompression,
382397
}

src/gateway/sharding/shard_queuer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use super::{
1717
ShardRunner,
1818
ShardRunnerInfo,
1919
ShardRunnerOptions,
20+
TransportCompression,
2021
};
2122
#[cfg(feature = "cache")]
2223
use crate::cache::Cache;
@@ -64,6 +65,8 @@ pub struct ShardQueuer {
6465
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
6566
/// A copy of the URL to use to connect to the gateway.
6667
pub ws_url: Arc<str>,
68+
/// The compression method to use for the WebSocket connection.
69+
pub compression: TransportCompression,
6770
/// The total amount of shards to start.
6871
pub shard_total: NonZeroU16,
6972
/// Number of seconds to wait between each start
@@ -216,6 +219,7 @@ impl ShardQueuer {
216219
shard_info,
217220
self.intents,
218221
self.presence.clone(),
222+
self.compression,
219223
)
220224
.await?;
221225

src/gateway/sharding/shard_runner.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,16 @@ impl ShardRunner {
451451
)) => {
452452
error!("Shard handler received fatal err: {why:?}");
453453

454-
self.manager.return_with_value(Err(why.clone())).await;
454+
let why_clone = match why {
455+
GatewayError::InvalidAuthentication => GatewayError::InvalidAuthentication,
456+
GatewayError::InvalidGatewayIntents => GatewayError::InvalidGatewayIntents,
457+
GatewayError::DisallowedGatewayIntents => {
458+
GatewayError::DisallowedGatewayIntents
459+
},
460+
_ => unreachable!(),
461+
};
462+
463+
self.manager.return_with_value(Err(why_clone)).await;
455464
return Err(Error::Gateway(why));
456465
},
457466
Err(Error::Json(_)) => return Ok((None, None, true)),

0 commit comments

Comments
 (0)