Skip to content

Commit cc82ecc

Browse files
committed
gateway session resumption
1 parent 7c73f67 commit cc82ecc

File tree

3 files changed

+121
-15
lines changed

3 files changed

+121
-15
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ edition = "2024"
66
[dependencies]
77
anyhow = "1"
88
rustls = "0.23"
9-
tokio = { version = "1", features = ["macros", "rt", "signal"] }
9+
serde = { version = "1", features = ["derive"] }
10+
serde_json = "1"
11+
tokio = { version = "1", features = ["fs", "macros", "rt", "signal"] }
1012
tokio-util = { version = "0.7", features = ["rt"] }
1113
tracing = "0.1"
1214
tracing-subscriber = "0.3"

src/main.rs

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
mod context;
2+
mod resume;
23

34
use crate::context::CONTEXT;
4-
use anyhow::Context;
5-
use std::{env, error::Error, pin::pin};
5+
use anyhow::Context as _;
6+
use std::{env, error::Error, pin::pin, time::Duration};
67
use tokio::signal;
78
use tokio_util::task::TaskTracker;
8-
use twilight_gateway::{CloseFrame, Config, Event, EventTypeFlags, Intents, Shard, StreamExt as _};
9+
use twilight_gateway::{
10+
CloseFrame, ConfigBuilder, Event, EventTypeFlags, Intents, Shard, StreamExt as _,
11+
queue::InMemoryQueue,
12+
};
913
use twilight_http::Client;
1014
use twilight_model::gateway::payload::incoming::MessageCreate;
1115

@@ -18,13 +22,24 @@ async fn main() -> anyhow::Result<()> {
1822

1923
let token = env::var("TOKEN").context("reading `TOKEN`")?;
2024

21-
let config = Config::new(token.clone(), INTENTS);
22-
let http = Client::new(token);
23-
let shards = twilight_gateway::create_recommended(&http, config, |_, builder| builder.build())
25+
let http = Client::new(token.clone());
26+
let info = async { Ok::<_, anyhow::Error>(http.gateway().authed().await?.model().await?) }
2427
.await
25-
.context("creating shards")?;
28+
.context("getting info")?;
2629
context::initialize(http);
2730

31+
// The queue defaults are static and may be incorrect for large or newly
32+
// restarted bots.
33+
let queue = InMemoryQueue::new(
34+
info.session_start_limit.max_concurrency,
35+
info.session_start_limit.remaining,
36+
Duration::from_millis(info.session_start_limit.reset_after),
37+
info.session_start_limit.total,
38+
);
39+
let config = ConfigBuilder::new(token, INTENTS).queue(queue).build();
40+
41+
let shards = resume::restore(config, info.shards).await;
42+
2843
let tasks = shards
2944
.into_iter()
3045
.map(|shard| tokio::spawn(dispatcher(shard)))
@@ -34,29 +49,35 @@ async fn main() -> anyhow::Result<()> {
3449
tracing::info!("shutting down; press CTRL-C to abort");
3550

3651
let join_all_tasks = async {
52+
let mut resume_info = Vec::new();
3753
for task in tasks {
38-
task.await?;
54+
resume_info.push(task.await?);
3955
}
40-
Ok::<_, anyhow::Error>(())
56+
Ok::<_, anyhow::Error>(resume_info)
4157
};
42-
tokio::select! {
43-
_ = signal::ctrl_c() => {},
44-
_ = join_all_tasks => {},
58+
let resume_info = tokio::select! {
59+
_ = signal::ctrl_c() => Vec::new(),
60+
resume_info = join_all_tasks => resume_info?,
4561
};
4662

63+
// Save shard information to be restored.
64+
resume::save(&resume_info)
65+
.await
66+
.context("saving resume info")?;
67+
4768
Ok(())
4869
}
4970

5071
#[tracing::instrument(fields(shard = %shard.id(), skip_all))]
51-
async fn dispatcher(mut shard: Shard) {
72+
async fn dispatcher(mut shard: Shard) -> resume::Info {
5273
let mut is_shutdown = false;
5374
let tracker = TaskTracker::new();
5475
let mut shutdown_fut = pin!(signal::ctrl_c());
5576

5677
loop {
5778
tokio::select! {
5879
_ = &mut shutdown_fut, if !is_shutdown => {
59-
shard.close(CloseFrame::NORMAL);
80+
shard.close(CloseFrame::RESUME);
6081
is_shutdown = true;
6182
},
6283
Some(item) = shard.next_event(EVENT_TYPES) => {
@@ -86,6 +107,8 @@ async fn dispatcher(mut shard: Shard) {
86107

87108
tracker.close();
88109
tracker.wait().await;
110+
111+
resume::Info::from(&shard)
89112
}
90113

91114
#[tracing::instrument(fields(id = %event.id), skip_all)]

src/resume.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use serde::{Deserialize, Serialize};
2+
use tokio::fs;
3+
use twilight_gateway::{Config, ConfigBuilder, Session, Shard, ShardId};
4+
5+
const INFO_FILE: &str = "resume-info.json";
6+
7+
/// [`Shard`] session resumption information.
8+
#[derive(Debug, Deserialize, Serialize)]
9+
pub struct Info {
10+
resume_url: Option<String>,
11+
session: Option<Session>,
12+
}
13+
14+
impl Info {
15+
fn is_none(&self) -> bool {
16+
self.resume_url.is_none() && self.session.is_none()
17+
}
18+
}
19+
20+
impl From<&Shard> for Info {
21+
fn from(value: &Shard) -> Self {
22+
Self {
23+
resume_url: value.resume_url().map(ToOwned::to_owned),
24+
session: value.session().cloned(),
25+
}
26+
}
27+
}
28+
29+
/// Saves shard resumption information to the file system.
30+
pub async fn save(info: &[Info]) -> anyhow::Result<()> {
31+
if !info.iter().all(Info::is_none) {
32+
let contents = serde_json::to_vec(&info)?;
33+
fs::write(INFO_FILE, contents).await?;
34+
}
35+
36+
Ok(())
37+
}
38+
39+
/// Restores shard resumption information from the file system.
40+
pub async fn restore(config: Config, shards: u32) -> Vec<Shard> {
41+
let resume_info = async {
42+
let contents = fs::read(INFO_FILE).await?;
43+
Ok::<_, anyhow::Error>(serde_json::from_slice::<Vec<Info>>(&contents)?)
44+
}
45+
.await;
46+
47+
let shard_ids = (0..shards).map(|shard| ShardId::new(shard, shards));
48+
49+
// A session may only successfully be resumed if it retains its shard ID,
50+
// but Discord may have recommend a different shard count (producing
51+
// different shard IDs).
52+
let shards: Vec<_> = if let Ok(resume_info) = resume_info
53+
&& resume_info.len() == shards as usize
54+
{
55+
tracing::info!("resuming previous gateway sessions");
56+
shard_ids
57+
.zip(resume_info)
58+
.map(|(shard_id, resume_info)| {
59+
let mut builder = ConfigBuilder::from(config.clone());
60+
61+
if let Some(resume_url) = resume_info.resume_url {
62+
builder = builder.resume_url(resume_url);
63+
}
64+
if let Some(session) = resume_info.session {
65+
builder = builder.session(session);
66+
}
67+
68+
Shard::with_config(shard_id, builder.build())
69+
})
70+
.collect()
71+
} else {
72+
shard_ids
73+
.map(|shard_id| Shard::with_config(shard_id, config.clone()))
74+
.collect()
75+
};
76+
77+
// Resumed or not, the saved resume info is now stale.
78+
_ = fs::remove_file(INFO_FILE).await;
79+
80+
shards
81+
}

0 commit comments

Comments
 (0)