@@ -2,15 +2,26 @@ mod context;
22
33use crate :: context:: CONTEXT ;
44use anyhow:: Context ;
5+ use serde:: { Deserialize , Serialize } ;
56use std:: { env, error:: Error , pin:: pin} ;
6- use tokio:: signal;
7+ use tokio:: { fs , signal} ;
78use tokio_util:: task:: TaskTracker ;
8- use twilight_gateway:: { CloseFrame , Config , Event , EventTypeFlags , Intents , Shard , StreamExt as _} ;
9+ use twilight_gateway:: {
10+ CloseFrame , Config , ConfigBuilder , Event , EventTypeFlags , Intents , Session , Shard , ShardId ,
11+ StreamExt as _,
12+ } ;
913use twilight_http:: Client ;
1014use twilight_model:: gateway:: payload:: incoming:: MessageCreate ;
1115
1216const EVENT_TYPES : EventTypeFlags = EventTypeFlags :: MESSAGE_CREATE ;
1317const INTENTS : Intents = Intents :: GUILD_MESSAGES . union ( Intents :: MESSAGE_CONTENT ) ;
18+ const RESUME_INFO_FILE : & str = "resume-info.json" ;
19+
20+ #[ derive( Debug , Deserialize , Serialize ) ]
21+ struct ResumeInfo {
22+ resume_url : Option < String > ,
23+ session : Option < Session > ,
24+ }
1425
1526#[ tokio:: main( flavor = "current_thread" ) ]
1627async fn main ( ) -> anyhow:: Result < ( ) > {
@@ -20,31 +31,73 @@ async fn main() -> anyhow::Result<()> {
2031
2132 let config = Config :: new ( token. clone ( ) , INTENTS ) ;
2233 let http = Client :: new ( token) ;
23- let shards = twilight_gateway :: create_recommended ( & http , config , | _, builder| builder . build ( ) )
34+ let info = async { Ok :: < _ , anyhow :: Error > ( http . gateway ( ) . authed ( ) . await ? . model ( ) . await ? ) }
2435 . await
25- . context ( "creating shards" ) ?;
36+ . context ( "getting info" ) ?;
37+ let resume_info = async {
38+ let contents = fs:: read ( RESUME_INFO_FILE ) . await ?;
39+ Ok :: < _ , anyhow:: Error > ( serde_json:: from_slice :: < Vec < ResumeInfo > > ( & contents) ?)
40+ }
41+ . await
42+ . ok ( ) ;
43+ let shard_ids = ( 0 ..info. shards ) . map ( |shard| ShardId :: new ( shard, info. shards ) ) ;
44+ let shards: Vec < _ > = if let Some ( mut resume_info) = resume_info
45+ && resume_info. len ( ) == info. shards as usize
46+ {
47+ tracing:: info!( "resuming previous gateway sessions" ) ;
48+ let mut resume_info_iter = resume_info. drain ( ..) ;
49+ shard_ids
50+ . map ( |shard_id| {
51+ let resume_info = resume_info_iter. next ( ) . unwrap ( ) ;
52+ let mut builder = ConfigBuilder :: from ( config. clone ( ) ) ;
53+
54+ if let Some ( resume_url) = resume_info. resume_url {
55+ builder = builder. resume_url ( resume_url) ;
56+ }
57+ if let Some ( session) = resume_info. session {
58+ builder = builder. session ( session) ;
59+ }
60+
61+ Shard :: with_config ( shard_id, builder. build ( ) )
62+ } )
63+ . collect ( )
64+ } else {
65+ shard_ids
66+ . map ( |shard_id| Shard :: with_config ( shard_id, config. clone ( ) ) )
67+ . collect ( )
68+ } ;
69+ _ = fs:: remove_file ( RESUME_INFO_FILE ) . await ;
2670 context:: initialize ( http) ;
2771
28- let tracker = TaskTracker :: new ( ) ;
29- for shard in shards {
30- tracker. spawn ( dispatcher ( shard) ) ;
72+ let tasks = shards
73+ . into_iter ( )
74+ . map ( |shard| tokio:: spawn ( dispatcher ( shard) ) )
75+ . collect :: < Vec < _ > > ( ) ;
76+ let mut resume_info = Vec :: new ( ) ;
77+ for task in tasks {
78+ resume_info. push ( task. await ?) ;
3179 }
32- tracker. close ( ) ;
33- tracker. wait ( ) . await ;
80+
81+ async {
82+ let contents = serde_json:: to_vec ( & resume_info) ?;
83+ fs:: write ( RESUME_INFO_FILE , contents) . await
84+ }
85+ . await
86+ . context ( "persisting resume info" ) ?;
3487
3588 Ok ( ( ) )
3689}
3790
3891#[ tracing:: instrument( fields( shard = %shard. id( ) , skip_all) ) ]
39- async fn dispatcher ( mut shard : Shard ) {
92+ async fn dispatcher ( mut shard : Shard ) -> ResumeInfo {
4093 let mut is_shutdown = false ;
4194 let tracker = TaskTracker :: new ( ) ;
4295 let mut shutdown_fut = pin ! ( signal:: ctrl_c( ) ) ;
4396
4497 loop {
4598 tokio:: select! {
4699 _ = & mut shutdown_fut, if !is_shutdown => {
47- shard. close( CloseFrame :: NORMAL ) ;
100+ shard. close( CloseFrame :: RESUME ) ;
48101 is_shutdown = true ;
49102 } ,
50103 Some ( item) = shard. next_event( EVENT_TYPES ) => {
@@ -74,6 +127,11 @@ async fn dispatcher(mut shard: Shard) {
74127
75128 tracker. close ( ) ;
76129 tracker. wait ( ) . await ;
130+
131+ ResumeInfo {
132+ resume_url : shard. resume_url ( ) . map ( ToOwned :: to_owned) ,
133+ session : shard. session ( ) . cloned ( ) ,
134+ }
77135}
78136
79137#[ tracing:: instrument( fields( id = %event. id) , skip_all) ]
0 commit comments