Skip to content

Commit c57c51d

Browse files
committed
gateway session resumption
1 parent b0e5721 commit c57c51d

File tree

2 files changed

+72
-12
lines changed

2 files changed

+72
-12
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ edition = "2024"
66
[dependencies]
77
anyhow = "1"
88
rustls = "0.23"
9-
tokio = { version = "1", features = ["macros", "rt", "signal"] }
9+
serde = { version = "1", features = ["derive"] }
10+
serde_json = "1"
11+
tokio = { version = "1", features = ["fs", "macros", "rt", "signal"] }
1012
tokio-util = { version = "0.7", features = ["rt"] }
1113
tracing = "0.1"
1214
tracing-subscriber = "0.3"

src/main.rs

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,26 @@ mod context;
22

33
use crate::context::CONTEXT;
44
use anyhow::Context;
5+
use serde::{Deserialize, Serialize};
56
use std::{env, error::Error, pin::pin};
6-
use tokio::signal;
7+
use tokio::{fs, signal};
78
use tokio_util::task::TaskTracker;
8-
use twilight_gateway::{CloseFrame, Config, Event, EventTypeFlags, Intents, Shard, StreamExt as _};
9+
use twilight_gateway::{
10+
CloseFrame, Config, ConfigBuilder, Event, EventTypeFlags, Intents, Session, Shard, ShardId,
11+
StreamExt as _,
12+
};
913
use twilight_http::Client;
1014
use twilight_model::gateway::payload::incoming::MessageCreate;
1115

1216
const EVENT_TYPES: EventTypeFlags = EventTypeFlags::MESSAGE_CREATE;
1317
const INTENTS: Intents = Intents::GUILD_MESSAGES.union(Intents::MESSAGE_CONTENT);
18+
const RESUME_INFO_FILE: &str = "resume-info.json";
19+
20+
#[derive(Debug, Deserialize, Serialize)]
21+
struct ResumeInfo {
22+
resume_url: Option<String>,
23+
session: Option<Session>,
24+
}
1425

1526
#[tokio::main(flavor = "current_thread")]
1627
async fn main() -> anyhow::Result<()> {
@@ -20,31 +31,73 @@ async fn main() -> anyhow::Result<()> {
2031

2132
let config = Config::new(token.clone(), INTENTS);
2233
let http = Client::new(token);
23-
let shards = twilight_gateway::create_recommended(&http, config, |_, builder| builder.build())
34+
let info = async { Ok::<_, anyhow::Error>(http.gateway().authed().await?.model().await?) }
2435
.await
25-
.context("creating shards")?;
36+
.context("getting info")?;
37+
let resume_info = async {
38+
let contents = fs::read(RESUME_INFO_FILE).await?;
39+
Ok::<_, anyhow::Error>(serde_json::from_slice::<Vec<ResumeInfo>>(&contents)?)
40+
}
41+
.await
42+
.ok();
43+
let shard_ids = (0..info.shards).map(|shard| ShardId::new(shard, info.shards));
44+
let shards: Vec<_> = if let Some(mut resume_info) = resume_info
45+
&& resume_info.len() == info.shards as usize
46+
{
47+
tracing::info!("resuming previous gateway sessions");
48+
let mut resume_info_iter = resume_info.drain(..);
49+
shard_ids
50+
.map(|shard_id| {
51+
let resume_info = resume_info_iter.next().unwrap();
52+
let mut builder = ConfigBuilder::from(config.clone());
53+
54+
if let Some(resume_url) = resume_info.resume_url {
55+
builder = builder.resume_url(resume_url);
56+
}
57+
if let Some(session) = resume_info.session {
58+
builder = builder.session(session);
59+
}
60+
61+
Shard::with_config(shard_id, builder.build())
62+
})
63+
.collect()
64+
} else {
65+
shard_ids
66+
.map(|shard_id| Shard::with_config(shard_id, config.clone()))
67+
.collect()
68+
};
69+
_ = fs::remove_file(RESUME_INFO_FILE).await;
2670
context::initialize(http);
2771

28-
let tracker = TaskTracker::new();
29-
for shard in shards {
30-
tracker.spawn(dispatcher(shard));
72+
let tasks = shards
73+
.into_iter()
74+
.map(|shard| tokio::spawn(dispatcher(shard)))
75+
.collect::<Vec<_>>();
76+
let mut resume_info = Vec::new();
77+
for task in tasks {
78+
resume_info.push(task.await?);
3179
}
32-
tracker.close();
33-
tracker.wait().await;
80+
81+
async {
82+
let contents = serde_json::to_vec(&resume_info)?;
83+
fs::write(RESUME_INFO_FILE, contents).await
84+
}
85+
.await
86+
.context("persisting resume info")?;
3487

3588
Ok(())
3689
}
3790

3891
#[tracing::instrument(fields(shard = %shard.id(), skip_all))]
39-
async fn dispatcher(mut shard: Shard) {
92+
async fn dispatcher(mut shard: Shard) -> ResumeInfo {
4093
let mut is_shutdown = false;
4194
let tracker = TaskTracker::new();
4295
let mut shutdown_fut = pin!(signal::ctrl_c());
4396

4497
loop {
4598
tokio::select! {
4699
_ = &mut shutdown_fut, if !is_shutdown => {
47-
shard.close(CloseFrame::NORMAL);
100+
shard.close(CloseFrame::RESUME);
48101
is_shutdown = true;
49102
},
50103
Some(item) = shard.next_event(EVENT_TYPES) => {
@@ -74,6 +127,11 @@ async fn dispatcher(mut shard: Shard) {
74127

75128
tracker.close();
76129
tracker.wait().await;
130+
131+
ResumeInfo {
132+
resume_url: shard.resume_url().map(ToOwned::to_owned),
133+
session: shard.session().cloned(),
134+
}
77135
}
78136

79137
#[tracing::instrument(fields(id = %event.id), skip_all)]

0 commit comments

Comments
 (0)