Skip to content

Commit 10bfbf1

Browse files
committed
Avoid deferring gateway event deserialization (#3119)
Followup to #3114. By using an intermediate untagged enum with an `Unknown` variant, we can process sequence numbers for unknown events without cloning the json payload.
1 parent 2fa0f0d commit 10bfbf1

File tree

5 files changed

+59
-45
lines changed

5 files changed

+59
-45
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ rust-version = "1.85"
2929
[dependencies]
3030
# Required dependencies
3131
bitflags = "2.4.2"
32-
serde_json = "1.0.108"
32+
serde_json = { version = "1.0.108", features = ["raw_value"] }
3333
async-trait = "0.1.74"
3434
tracing = { version = "0.1.40", features = ["log"] }
3535
serde = { version = "1.0.192", features = ["derive", "rc"] }

src/gateway/sharding/mod.rs

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub use self::shard_runner::{ShardRunner, ShardRunnerMessage, ShardRunnerOptions
5959
use super::{ActivityData, ChunkGuildFilter, GatewayError, PresenceData, WsClient};
6060
use crate::constants::{self, CloseCode};
6161
use crate::internal::prelude::*;
62-
use crate::model::event::{Event, GatewayEvent};
62+
use crate::model::event::{DeserializedEvent, Event, GatewayEvent, UnknownEvent};
6363
use crate::model::gateway::{GatewayIntents, ShardInfo};
6464
use crate::model::id::{ApplicationId, GuildId, ShardId};
6565
use crate::model::user::OnlineStatus;
@@ -312,13 +312,24 @@ impl Shard {
312312
}
313313

314314
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
315-
fn handle_gateway_dispatch(&mut self, seq: u64, event: &[u8]) -> Result<Event> {
315+
fn handle_gateway_dispatch(&mut self, seq: u64, event: DeserializedEvent) -> Option<Event> {
316316
if seq > self.seq + 1 {
317317
warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq);
318318
}
319319

320320
self.seq = seq;
321-
let event = deserialize_and_log_event(event)?;
321+
322+
let event = match event {
323+
DeserializedEvent::Success(event) => event,
324+
DeserializedEvent::Unknown(UnknownEvent {
325+
ty,
326+
ref data,
327+
}) => {
328+
debug!("Unknown event: {ty}");
329+
debug!("Failing event data: {data:?}");
330+
return None;
331+
},
332+
};
322333

323334
match &event {
324335
Event::Ready(ready) => {
@@ -345,7 +356,7 @@ impl Shard {
345356
_ => {},
346357
}
347358

348-
Ok(event)
359+
Some(event)
349360
}
350361

351362
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
@@ -442,9 +453,9 @@ impl Shard {
442453
Ok(GatewayEvent::Dispatch {
443454
seq,
444455
event,
445-
}) => self
446-
.handle_gateway_dispatch(seq, &event)
447-
.map(|e| Some(ShardAction::Dispatch(Box::new(e)))),
456+
}) => Ok(self
457+
.handle_gateway_dispatch(seq, *event)
458+
.map(|e| ShardAction::Dispatch(Box::new(e)))),
448459
Ok(GatewayEvent::Heartbeat) => {
449460
info!("[{:?}] Received shard heartbeat", self.shard_info);
450461

@@ -736,23 +747,6 @@ async fn connect(base_url: &str, compression: TransportCompression) -> Result<Ws
736747
WsClient::connect(url, compression).await
737748
}
738749

739-
fn deserialize_and_log_event(event: &[u8]) -> Result<Event> {
740-
serde_json::from_slice(event).map_err(|err| {
741-
let err_dbg = format!("{err:?}");
742-
if let Some((variant_name, _)) =
743-
err_dbg.strip_prefix(r#"Error("unknown variant `"#).and_then(|s| s.split_once('`'))
744-
{
745-
debug!("Unknown event: {variant_name}");
746-
} else {
747-
warn!("Err deserializing text: {err_dbg}");
748-
}
749-
750-
let event_str = String::from_utf8_lossy(event);
751-
debug!("Failing event data: {event_str}");
752-
Error::Json(err)
753-
})
754-
}
755-
756750
struct ResumeMetadata {
757751
session_id: FixedString,
758752
resume_ws_url: FixedString,

src/gateway/sharding/shard_runner.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ impl ShardRunner {
404404

405405
return Err(Error::Gateway(why));
406406
},
407-
Err(Error::Json(_)) => return Ok(None),
408407
Err(why) => {
409408
error!("Shard handler recieved err: {why:?}");
410409
return Ok(None);

src/gateway/ws.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,7 @@ impl WsClient {
264264
};
265265

266266
match serde_json::from_slice(json_bytes) {
267-
Ok(mut event) => {
268-
if let GatewayEvent::Dispatch {
269-
ref mut event, ..
270-
} = event
271-
{
272-
*event = json_bytes.to_vec();
273-
}
274-
275-
Ok(Some(event))
276-
},
267+
Ok(event) => Ok(Some(event)),
277268
Err(err) => {
278269
debug!("Failing text: {}", String::from_utf8_lossy(json_bytes));
279270
Err(Error::Json(err))

src/model/event.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -932,9 +932,7 @@ pub struct MessagePollVoteRemoveEvent {
932932
pub enum GatewayEvent {
933933
Dispatch {
934934
seq: u64,
935-
// Avoid deserialising straight away to handle errors and get access to `seq`.
936-
// This must be filled in with original data by the caller after deserialisation.
937-
event: Vec<u8>,
935+
event: Box<DeserializedEvent>,
938936
},
939937
Heartbeat,
940938
Reconnect,
@@ -944,6 +942,32 @@ pub enum GatewayEvent {
944942
HeartbeatAck,
945943
}
946944

945+
#[expect(clippy::large_enum_variant)]
946+
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
947+
#[derive(Clone, Debug, Serialize)]
948+
#[non_exhaustive]
949+
#[serde(untagged)]
950+
pub enum DeserializedEvent {
951+
Success(Event),
952+
Unknown(UnknownEvent),
953+
}
954+
955+
#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
956+
#[derive(Clone, Debug, Deserialize, Serialize)]
957+
#[non_exhaustive]
958+
pub struct UnknownEvent {
959+
#[serde(rename = "t")]
960+
pub ty: String,
961+
#[serde(rename = "d")]
962+
#[cfg_attr(feature = "typesize", typesize(with = raw_value_len))]
963+
pub data: Box<RawValue>,
964+
}
965+
966+
#[cfg(feature = "typesize")]
967+
fn raw_value_len(val: &RawValue) -> usize {
968+
val.get().len()
969+
}
970+
947971
// Manual impl needed to emulate integer enum tags
948972
impl<'de> Deserialize<'de> for GatewayEvent {
949973
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
@@ -968,21 +992,27 @@ impl<'de> Deserialize<'de> for GatewayEvent {
968992

969993
Self::Dispatch {
970994
seq: raw.seq.ok_or_else(|| DeError::missing_field("s"))?,
971-
event: Vec::new(),
995+
event: {
996+
Box::new(match Event::deserialize(raw.data) {
997+
Ok(event) => DeserializedEvent::Success(event),
998+
Err(_) => DeserializedEvent::Unknown(
999+
UnknownEvent::deserialize(raw.data).map_err(DeError::custom)?,
1000+
),
1001+
})
1002+
},
9721003
}
9731004
},
9741005
Opcode::Heartbeat => Self::Heartbeat,
975-
Opcode::InvalidSession => Self::InvalidateSession(
976-
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?,
977-
),
1006+
Opcode::InvalidSession => {
1007+
Self::InvalidateSession(bool::deserialize(raw.data).map_err(DeError::custom)?)
1008+
},
9781009
Opcode::Hello => {
9791010
#[derive(Deserialize)]
9801011
struct HelloPayload {
9811012
heartbeat_interval: u64,
9821013
}
9831014

984-
let inner: HelloPayload =
985-
serde_json::from_str(raw.data.get()).map_err(DeError::custom)?;
1015+
let inner = HelloPayload::deserialize(raw.data).map_err(DeError::custom)?;
9861016

9871017
Self::Hello(inner.heartbeat_interval)
9881018
},

0 commit comments

Comments
 (0)