Skip to content

Commit 7da8e2a

Browse files
authored
Avoid temporarily deserializing gateway messages to an untyped enum+map tree (#3114)
Avoid temporarily deserializing gateway messages to a `serde_json::Map<String, Value>` before typed deserialization to an `Event`. Previously, this meant the creation of a `serde_json::Value` tree, causing creation and immediately after the destruction of upwards of hundreds to thousands of owned strings and btreemaps for every handled gateway event.
1 parent e5f8014 commit 7da8e2a

File tree

4 files changed

+48
-47
lines changed

4 files changed

+48
-47
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ rust-version = "1.82"
2929
[dependencies]
3030
# Required dependencies
3131
bitflags = "2.4.2"
32-
serde_json = "1.0.108"
32+
serde_json = { version = "1.0.108", features = ["raw_value"] }
3333
async-trait = "0.1.74"
3434
tracing = { version = "0.1.40", features = ["log"] }
3535
serde = { version = "1.0.192", features = ["derive", "rc"] }

src/gateway/sharding/mod.rs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ use std::time::{Duration as StdDuration, Instant};
4747
#[cfg(any(feature = "transport_compression_zlib", feature = "transport_compression_zstd"))]
4848
use aformat::aformat_into;
4949
use aformat::{aformat, ArrayString, CapStr};
50-
use serde::Deserialize;
5150
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
5251
use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
5352
#[cfg(feature = "tracing_instrument")]
@@ -319,18 +318,13 @@ impl Shard {
319318
}
320319

321320
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
322-
fn handle_gateway_dispatch(
323-
&mut self,
324-
seq: u64,
325-
event: JsonMap,
326-
original_str: &str,
327-
) -> Result<Event> {
321+
fn handle_gateway_dispatch(&mut self, seq: u64, event: &[u8]) -> Result<Event> {
328322
if seq > self.seq + 1 {
329323
warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq);
330324
}
331325

332326
self.seq = seq;
333-
let event = deserialize_and_log_event(event, original_str)?;
327+
let event = deserialize_and_log_event(event)?;
334328

335329
match &event {
336330
Event::Ready(ready) => {
@@ -453,11 +447,8 @@ impl Shard {
453447
match event {
454448
Ok(GatewayEvent::Dispatch {
455449
seq,
456-
data,
457-
original_str,
458-
}) => self
459-
.handle_gateway_dispatch(seq, data, &original_str)
460-
.map(|e| Some(ShardAction::Dispatch(e))),
450+
event,
451+
}) => self.handle_gateway_dispatch(seq, &event).map(|e| Some(ShardAction::Dispatch(e))),
461452
Ok(GatewayEvent::Heartbeat) => {
462453
info!("[{:?}] Received shard heartbeat", self.shard_info);
463454

@@ -749,9 +740,8 @@ async fn connect(base_url: &str, compression: TransportCompression) -> Result<Ws
749740
WsClient::connect(url, compression).await
750741
}
751742

752-
fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event> {
753-
Event::deserialize(Value::Object(map)).map_err(|err| {
754-
let err = serde::de::Error::custom(err);
743+
fn deserialize_and_log_event(event: &[u8]) -> Result<Event> {
744+
serde_json::from_slice(event).map_err(|err| {
755745
let err_dbg = format!("{err:?}");
756746
if let Some((variant_name, _)) =
757747
err_dbg.strip_prefix(r#"Error("unknown variant `"#).and_then(|s| s.split_once('`'))
@@ -760,7 +750,9 @@ fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result<Event>
760750
} else {
761751
warn!("Err deserializing text: {err_dbg}");
762752
}
763-
debug!("Failing text: {original_str}");
753+
754+
let event_str = String::from_utf8_lossy(event);
755+
debug!("Failing event data: {event_str}");
764756
Error::Json(err)
765757
})
766758
}

src/gateway/ws.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::env::consts;
32
use std::io::Read;
43
use std::time::SystemTime;
@@ -7,7 +6,6 @@ use flate2::read::ZlibDecoder;
76
#[cfg(feature = "transport_compression_zlib")]
87
use flate2::Decompress as ZlibInflater;
98
use futures::{SinkExt, StreamExt};
10-
use small_fixed_array::FixedString;
119
use tokio::net::TcpStream;
1210
use tokio::time::{timeout, Duration};
1311
use tokio_tungstenite::tungstenite::protocol::{CloseFrame, WebSocketConfig};
@@ -254,35 +252,30 @@ impl WsClient {
254252
};
255253

256254
let json_bytes = match message {
257-
Message::Text(payload) => Cow::Owned(payload.as_bytes().to_vec()),
258-
Message::Binary(bytes) => {
259-
let Some(decompressed) = self.compression.inflate(&bytes)? else {
260-
return Ok(None);
261-
};
262-
263-
Cow::Borrowed(decompressed)
255+
Message::Text(ref payload) => payload.as_bytes(),
256+
Message::Binary(ref bytes) => match self.compression.inflate(bytes)? {
257+
Some(decompressed) => decompressed,
258+
None => return Ok(None),
264259
},
265260
Message::Close(Some(frame)) => {
266261
return Err(Error::Gateway(GatewayError::Closed(Some(frame))));
267262
},
268263
_ => return Ok(None),
269264
};
270265

271-
// TODO: Use `String::from_utf8_lossy_owned` when stable.
272-
let json_str = || String::from_utf8_lossy(&json_bytes);
273-
match serde_json::from_slice(&json_bytes) {
266+
match serde_json::from_slice(json_bytes) {
274267
Ok(mut event) => {
275268
if let GatewayEvent::Dispatch {
276-
original_str, ..
277-
} = &mut event
269+
ref mut event, ..
270+
} = event
278271
{
279-
*original_str = FixedString::from_string_trunc(json_str().into_owned());
272+
*event = json_bytes.to_vec();
280273
}
281274

282275
Ok(Some(event))
283276
},
284277
Err(err) => {
285-
debug!("Failing text: {}", json_str());
278+
debug!("Failing text: {}", String::from_utf8_lossy(json_bytes));
286279
Err(Error::Json(err))
287280
},
288281
}

src/model/event.rs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
66
use serde::de::Error as DeError;
77
use serde::Serialize;
8+
use serde_json::value::RawValue;
89
use strum::{EnumCount, IntoStaticStr, VariantNames};
910

1011
use crate::constants::Opcode;
1112
use crate::internal::utils::lending_for_each;
1213
use crate::model::prelude::*;
13-
use crate::model::utils::remove_from_map;
1414

1515
/// Requires no gateway intents.
1616
///
@@ -933,9 +933,8 @@ pub enum GatewayEvent {
933933
Dispatch {
934934
seq: u64,
935935
// Avoid deserialising straight away to handle errors and get access to `seq`.
936-
data: JsonMap,
937-
// Used for debugging, if the data cannot be deserialised.
938-
original_str: FixedString,
936+
// This must be filled in with original data by the caller after deserialisation.
937+
event: Vec<u8>,
939938
},
940939
Heartbeat,
941940
Reconnect,
@@ -948,26 +947,43 @@ pub enum GatewayEvent {
948947
// Manual impl needed to emulate integer enum tags
949948
impl<'de> Deserialize<'de> for GatewayEvent {
950949
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
951-
let mut map = JsonMap::deserialize(deserializer)?;
952-
953-
Ok(match remove_from_map(&mut map, "op")? {
950+
#[derive(Debug, Clone, Deserialize)]
951+
struct GatewayEventRaw<'a> {
952+
op: Opcode,
953+
#[serde(rename = "s")]
954+
seq: Option<u64>,
955+
#[serde(rename = "d")]
956+
data: &'a RawValue,
957+
#[serde(rename = "t")]
958+
ty: Option<&'a str>,
959+
}
960+
961+
let raw = GatewayEventRaw::deserialize(deserializer)?;
962+
963+
Ok(match raw.op {
954964
Opcode::Dispatch => {
965+
if raw.ty.is_none() {
966+
return Err(DeError::missing_field("t"));
967+
}
968+
955969
Self::Dispatch {
956-
seq: remove_from_map(&mut map, "s")?,
957-
// Filled in in recv_event
958-
original_str: FixedString::new(),
959-
data: map,
970+
seq: raw.seq.ok_or_else(|| DeError::missing_field("s"))?,
971+
event: Vec::new(),
960972
}
961973
},
962974
Opcode::Heartbeat => Self::Heartbeat,
963-
Opcode::InvalidSession => Self::InvalidateSession(remove_from_map(&mut map, "d")?),
975+
Opcode::InvalidSession => Self::InvalidateSession(
976+
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?,
977+
),
964978
Opcode::Hello => {
965979
#[derive(Deserialize)]
966980
struct HelloPayload {
967981
heartbeat_interval: u64,
968982
}
969983

970-
let inner: HelloPayload = remove_from_map(&mut map, "d")?;
984+
let inner: HelloPayload =
985+
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?;
986+
971987
Self::Hello(inner.heartbeat_interval)
972988
},
973989
Opcode::Reconnect => Self::Reconnect,

0 commit comments

Comments
 (0)