@@ -44,7 +44,9 @@ use std::fmt;
4444use std:: sync:: Arc ;
4545use std:: time:: { Duration as StdDuration , Instant } ;
4646
47- use aformat:: { aformat, CapStr } ;
47+ #[ cfg( feature = "transport_compression_zlib" ) ]
48+ use aformat:: aformat_into;
49+ use aformat:: { aformat, ArrayString , CapStr } ;
4850use tokio_tungstenite:: tungstenite:: error:: Error as TungsteniteError ;
4951use tokio_tungstenite:: tungstenite:: protocol:: frame:: CloseFrame ;
5052use tracing:: { debug, error, info, trace, warn} ;
@@ -113,6 +115,7 @@ pub struct Shard {
113115 token : SecretString ,
114116 ws_url : Arc < str > ,
115117 resume_ws_url : Option < FixedString > ,
118+ compression : TransportCompression ,
116119 pub intents : GatewayIntents ,
117120}
118121
@@ -129,7 +132,7 @@ impl Shard {
129132 /// use std::num::NonZeroU16;
130133 /// use std::sync::Arc;
131134 ///
132- /// use serenity::gateway::Shard;
135+ /// use serenity::gateway::{ Shard, TransportCompression} ;
133136 /// use serenity::model::gateway::{GatewayIntents, ShardInfo};
134137 /// use serenity::model::id::ShardId;
135138 /// use serenity::secret_string::SecretString;
@@ -147,7 +150,15 @@ impl Shard {
147150 ///
148151 /// // retrieve the gateway response, which contains the URL to connect to
149152 /// let gateway = Arc::from(http.get_gateway().await?.url);
150- /// let shard = Shard::new(gateway, token, shard_info, GatewayIntents::all(), None).await?;
153+ /// let shard = Shard::new(
154+ /// gateway,
155+ /// token,
156+ /// shard_info,
157+ /// GatewayIntents::all(),
158+ /// None,
159+ /// TransportCompression::None,
160+ /// )
161+ /// .await?;
151162 ///
152163 /// // at this point, you can create a `loop`, and receive events and match
153164 /// // their variants
@@ -165,8 +176,9 @@ impl Shard {
165176 shard_info : ShardInfo ,
166177 intents : GatewayIntents ,
167178 presence : Option < PresenceData > ,
179+ compression : TransportCompression ,
168180 ) -> Result < Shard > {
169- let client = connect ( & ws_url) . await ?;
181+ let client = connect ( & ws_url, compression ) . await ?;
170182
171183 let presence = presence. unwrap_or_default ( ) ;
172184 let last_heartbeat_sent = None ;
@@ -193,6 +205,7 @@ impl Shard {
193205 shard_info,
194206 ws_url,
195207 resume_ws_url : None ,
208+ compression,
196209 intents,
197210 } )
198211 }
@@ -748,7 +761,7 @@ impl Shard {
748761 // Hello is received.
749762 self . stage = ConnectionStage :: Connecting ;
750763 self . started = Instant :: now ( ) ;
751- let client = connect ( ws_url) . await ?;
764+ let client = connect ( ws_url, self . compression ) . await ?;
752765 self . stage = ConnectionStage :: Handshake ;
753766
754767 Ok ( client)
@@ -807,14 +820,19 @@ impl Shard {
807820 }
808821}
809822
810- async fn connect ( base_url : & str ) -> Result < WsClient > {
811- let url = Url :: parse ( & aformat ! ( "{}?v={}" , CapStr :: <64 >( base_url) , constants:: GATEWAY_VERSION ) )
812- . map_err ( |why| {
813- warn ! ( "Error building gateway URL with base `{base_url}`: {why:?}" ) ;
814- Error :: Gateway ( GatewayError :: BuildingUrl )
815- } ) ?;
816-
817- WsClient :: connect ( url) . await
823+ async fn connect ( base_url : & str , compression : TransportCompression ) -> Result < WsClient > {
824+ let url = Url :: parse ( & aformat ! (
825+ "{}?v={}{}" ,
826+ CapStr :: <64 >( base_url) ,
827+ constants:: GATEWAY_VERSION ,
828+ compression. query_param( )
829+ ) )
830+ . map_err ( |why| {
831+ warn ! ( "Error building gateway URL with base `{base_url}`: {why:?}" ) ;
832+ Error :: Gateway ( GatewayError :: BuildingUrl )
833+ } ) ?;
834+
835+ WsClient :: connect ( url, compression) . await
818836}
819837
820838#[ derive( Debug ) ]
@@ -954,3 +972,41 @@ impl PartialEq for CollectorCallback {
954972 Arc :: ptr_eq ( & self . 0 , & other. 0 )
955973 }
956974}
975+
976+ /// The transport compression method to use.
977+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
978+ #[ non_exhaustive]
979+ pub enum TransportCompression {
980+ /// No transport compression. Payload compression will be used instead.
981+ None ,
982+
983+ #[ cfg( feature = "transport_compression_zlib" ) ]
984+ /// Use zlib-stream transport compression.
985+ Zlib ,
986+
987+ #[ cfg( feature = "transport_compression_zstd" ) ]
988+ /// Use zstd-stream transport compression.
989+ Zstd ,
990+ }
991+
992+ impl TransportCompression {
993+ fn query_param ( self ) -> ArrayString < 21 > {
994+ #[ cfg_attr(
995+ not( any(
996+ feature = "transport_compression_zlib" ,
997+ feature = "transport_compression_zstd"
998+ ) ) ,
999+ expect( unused_mut)
1000+ ) ]
1001+ let mut res = ArrayString :: new ( ) ;
1002+ match self {
1003+ Self :: None => { } ,
1004+ #[ cfg( feature = "transport_compression_zlib" ) ]
1005+ Self :: Zlib => aformat_into ! ( res, "&compress=zlib-stream" ) ,
1006+ #[ cfg( feature = "transport_compression_zstd" ) ]
1007+ Self :: Zstd => aformat_into ! ( res, "&compress=zstd-stream" ) ,
1008+ }
1009+
1010+ res
1011+ }
1012+ }
0 commit comments