11mod context;
2+ mod resume;
23
34use crate :: context:: CONTEXT ;
4- use anyhow:: Context ;
5- use std:: { env, error:: Error , pin:: pin} ;
5+ use anyhow:: Context as _ ;
6+ use std:: { env, error:: Error , pin:: pin, time :: Duration } ;
67use tokio:: signal;
78use tokio_util:: task:: TaskTracker ;
8- use twilight_gateway:: { CloseFrame , Config , Event , EventTypeFlags , Intents , Shard , StreamExt as _} ;
9+ use twilight_gateway:: {
10+ CloseFrame , ConfigBuilder , Event , EventTypeFlags , Intents , Shard , StreamExt as _,
11+ queue:: InMemoryQueue ,
12+ } ;
913use twilight_http:: Client ;
1014use twilight_model:: gateway:: payload:: incoming:: MessageCreate ;
1115
@@ -18,13 +22,24 @@ async fn main() -> anyhow::Result<()> {
1822
1923 let token = env:: var ( "TOKEN" ) . context ( "reading `TOKEN`" ) ?;
2024
21- let config = Config :: new ( token. clone ( ) , INTENTS ) ;
22- let http = Client :: new ( token) ;
23- let shards = twilight_gateway:: create_recommended ( & http, config, |_, builder| builder. build ( ) )
25+ let http = Client :: new ( token. clone ( ) ) ;
26+ let info = async { Ok :: < _ , anyhow:: Error > ( http. gateway ( ) . authed ( ) . await ?. model ( ) . await ?) }
2427 . await
25- . context ( "creating shards " ) ?;
28+ . context ( "getting info " ) ?;
2629 context:: initialize ( http) ;
2730
31+ // The queue defaults are static and may be incorrect for large or newly
32+ // restarted bots.
33+ let queue = InMemoryQueue :: new (
34+ info. session_start_limit . max_concurrency ,
35+ info. session_start_limit . remaining ,
36+ Duration :: from_millis ( info. session_start_limit . reset_after ) ,
37+ info. session_start_limit . total ,
38+ ) ;
39+ let config = ConfigBuilder :: new ( token, INTENTS ) . queue ( queue) . build ( ) ;
40+
41+ let shards = resume:: restore ( config, info. shards ) . await ;
42+
2843 let tasks = shards
2944 . into_iter ( )
3045 . map ( |shard| tokio:: spawn ( dispatcher ( shard) ) )
@@ -34,29 +49,35 @@ async fn main() -> anyhow::Result<()> {
3449 tracing:: info!( "shutting down; press CTRL-C to abort" ) ;
3550
3651 let join_all_tasks = async {
52+ let mut resume_info = Vec :: new ( ) ;
3753 for task in tasks {
38- task. await ?;
54+ resume_info . push ( task. await ?) ;
3955 }
40- Ok :: < _ , anyhow:: Error > ( ( ) )
56+ Ok :: < _ , anyhow:: Error > ( resume_info )
4157 } ;
42- tokio:: select! {
43- _ = signal:: ctrl_c( ) => { } ,
44- _ = join_all_tasks => { } ,
58+ let resume_info = tokio:: select! {
59+ _ = signal:: ctrl_c( ) => Vec :: new ( ) ,
60+ resume_info = join_all_tasks => resume_info? ,
4561 } ;
4662
63+ // Save shard information to be restored.
64+ resume:: save ( & resume_info)
65+ . await
66+ . context ( "saving resume info" ) ?;
67+
4768 Ok ( ( ) )
4869}
4970
5071#[ tracing:: instrument( fields( shard = %shard. id( ) , skip_all) ) ]
51- async fn dispatcher ( mut shard : Shard ) {
72+ async fn dispatcher ( mut shard : Shard ) -> resume :: Info {
5273 let mut is_shutdown = false ;
5374 let tracker = TaskTracker :: new ( ) ;
5475 let mut shutdown_fut = pin ! ( signal:: ctrl_c( ) ) ;
5576
5677 loop {
5778 tokio:: select! {
5879 _ = & mut shutdown_fut, if !is_shutdown => {
59- shard. close( CloseFrame :: NORMAL ) ;
80+ shard. close( CloseFrame :: RESUME ) ;
6081 is_shutdown = true ;
6182 } ,
6283 Some ( item) = shard. next_event( EVENT_TYPES ) => {
@@ -86,6 +107,8 @@ async fn dispatcher(mut shard: Shard) {
86107
87108 tracker. close ( ) ;
88109 tracker. wait ( ) . await ;
110+
111+ resume:: Info :: from ( & shard)
89112}
90113
91114#[ tracing:: instrument( fields( id = %event. id) , skip_all) ]
0 commit comments