4747from hikari .events import lifetime_events
4848from hikari .impl import entity_factory as entity_factory_impl
4949from hikari .impl import event_factory as event_factory_impl
50- from hikari .impl import rate_limits
5150from hikari .impl import rest as rest_client_impl
5251from hikari .impl import shard as gateway_shard_impl
5352from hikari .impl import stateful_cache as cache_impl
@@ -201,20 +200,20 @@ class BotApp(
201200
202201 __slots__ : typing .Sequence [str ] = (
203202 "_cache" ,
204- "_guild_chunker" ,
205203 "_connector_factory" ,
206204 "_debug" ,
207205 "_entity_factory" ,
208206 "_event_manager" ,
209207 "_event_factory" ,
210208 "_executor" ,
211- "_global_ratelimit " ,
209+ "_guild_chunker " ,
212210 "_http_settings" ,
213211 "_initial_activity" ,
214212 "_initial_idle_since" ,
215213 "_initial_is_afk" ,
216214 "_initial_status" ,
217215 "_intents" ,
216+ "_has_aborted" ,
218217 "_large_threshold" ,
219218 "_max_concurrency" ,
220219 "_proxy_settings" ,
@@ -273,13 +272,13 @@ def __init__(
273272 self ._entity_factory = entity_factory_impl .EntityFactoryImpl (app = self )
274273 self ._event_factory = event_factory_impl .EventFactoryImpl (app = self )
275274 self ._executor = executor
276- self ._global_ratelimit = rate_limits .ManualRateLimiter ()
277275 self ._http_settings = config .HTTPSettings () if http_settings is None else http_settings
278276 self ._initial_activity = initial_activity
279277 self ._initial_idle_since = initial_idle_since
280278 self ._initial_is_afk = initial_is_afk
281279 self ._initial_status = initial_status
282280 self ._intents = intents
281+ self ._has_aborted = False
283282 self ._large_threshold = large_threshold
284283 self ._max_concurrency = 1
285284 self ._proxy_settings = config .ProxySettings () if proxy_settings is None else proxy_settings
@@ -508,7 +507,11 @@ async def start(self) -> None:
508507 self ._tasks .clear ()
509508 self ._shard_gather_task = None
510509
511- await self ._init ()
510+ try :
511+ await self ._init ()
512+ except Exception :
513+ await self .close ()
514+ raise
512515
513516 self ._request_close_event .clear ()
514517
@@ -537,7 +540,15 @@ async def start(self) -> None:
537540 window [shard_id ] = asyncio .create_task (shard_obj .start (), name = f"start gateway shard { shard_id } " )
538541
539542 # Wait for the group to start.
540- await asyncio .gather (* window .values ())
543+ gatherer = asyncio .gather (* window .values ())
544+ waiter = asyncio .create_task (self ._request_close_event .wait (), name = "listen for bot closure events" )
545+
546+ await asyncio .wait ((gatherer , waiter ), return_when = asyncio .FIRST_COMPLETED )
547+
548+ if not waiter .done ():
549+ waiter .cancel ()
550+ else :
551+ gatherer .cancel ()
541552
542553 # Store the keep-alive tasks and continue.
543554 for shard_id , start_task in window .items ():
@@ -546,6 +557,7 @@ async def start(self) -> None:
546557 finally :
547558 if len (self ._tasks ) != len (self ._shards ):
548559 _LOGGER .warning ("application was aborted midway through initialization, so never managed to start" )
560+ await self .close ()
549561 raise errors .GatewayClientClosedError ("Client was aborted midway through initialization" )
550562
551563 finish_time = date .monotonic ()
@@ -609,32 +621,23 @@ def dispatch(self, event: base_events.Event) -> asyncio.Future[typing.Any]:
609621 return self .dispatcher .dispatch (event )
610622
611623 async def close (self ) -> None :
612- """Request that all shards disconnect and the application shuts down.
624+ """Immediately destroy all shards that are running and stop."""
625+ self ._request_close_event .set ()
613626
614- This will close all shards that are running, and then close any
615- REST components and connectors.
616- """
617- self ._guild_chunker .close ()
627+ # Prevent calling this multiple times.
628+ if self ._has_aborted :
629+ return
618630
619- try :
620- # This way if we cancel the stopping task, we still shut down properly.
621- self ._request_close_event .set ()
622- _LOGGER .info ("stopping %s shard(s)" , len (self ._shards ))
623-
624- try :
625- if self ._shards :
626- await self .dispatch (lifetime_events .StoppingEvent (app = self ))
627- await self ._abort_shards ()
628- finally :
629- # The starting event occurs before the bot starts, regardless of if
630- # it had started or not, so it seems sensible stopped event has the
631- # same semantics.
632- self ._tasks .clear ()
633- await self .dispatch (lifetime_events .StoppedEvent (app = self ))
634- finally :
635- await self ._rest .close ()
636- await self ._connector_factory .close ()
637- self ._global_ratelimit .close ()
631+ self ._has_aborted = True
632+ self ._guild_chunker .close ()
633+ await self .dispatch (lifetime_events .StoppingEvent (app = self ))
634+ await self ._abort_shards ()
635+ self ._tasks .clear ()
636+ await self .dispatch (lifetime_events .StoppedEvent (app = self ))
637+ await self ._rest .close ()
638+ await self ._connector_factory .close ()
639+ self ._shard_gather_task = None
640+ self ._request_close_event .clear ()
638641
639642 def run (
640643 self ,
@@ -744,7 +747,7 @@ def run(
744747
745748 def die () -> None :
746749 _LOGGER .info ("received signal to shut down client" )
747- asyncio . ensure_future ( self .close () )
750+ self ._request_close_event . set ( )
748751
749752 for signum in kill_signals :
750753 # Windows is dumb and doesn't support signals properly.
@@ -758,30 +761,33 @@ def die() -> None:
758761 finally :
759762 loop .run_until_complete (self .join ())
760763 except errors .GatewayClientClosedError as ex :
761- _LOGGER .info (str ( ex ) )
764+ _LOGGER .info ("client closed with reason: %s" , ex )
762765 finally :
763766 for signum in kill_signals :
764767 # Windows is dumb and doesn't support signals properly.
765768 with contextlib .suppress (NotImplementedError ):
766769 loop .remove_signal_handler (signum )
767770
768771 if finalize_loop_on_close :
769- _LOGGER .debug ("closing asyncgens for event loop %s" , loop )
770- loop .run_until_complete (loop .shutdown_asyncgens ())
772+ remaining_tasks = [t for t in asyncio .all_tasks (loop ) if not t .done ()]
771773
772- remaining_tasks = asyncio .all_tasks (loop )
773774 if remaining_tasks :
774- _LOGGER .warning ("forcefully stopping %s remaining tasks" , len (remaining_tasks ))
775+ _LOGGER .debug ("forcefully stopping %s remaining tasks" , len (remaining_tasks ))
776+
775777 for task in remaining_tasks :
776778 task .cancel ()
779+ loop .run_until_complete (asyncio .gather (* remaining_tasks , return_exceptions = True ))
777780
778- # Don't warn that these were never retrieved.
779- with contextlib .suppress (asyncio .InvalidStateError ):
780- task .exception ()
781-
781+ for task in remaining_tasks :
782+ if not task .cancelled ():
783+ exception = task .exception ()
784+ if exception is not None :
785+ _LOGGER .warning ("unhandled exception during shutdown" , exc_info = exception )
782786 else :
783787 _LOGGER .debug ("no tasks are running, congratulations on writing a tidy application" )
784788
789+ _LOGGER .debug ("closing asyncgens for event loop %s" , loop )
790+ loop .run_until_complete (loop .shutdown_asyncgens ())
785791 loop .close ()
786792
787793 async def join (self ) -> None :
@@ -939,9 +945,8 @@ def _max_concurrency_chunker(self) -> typing.Iterator[typing.Iterator[int]]:
939945 async def _abort_shards (self ) -> None :
940946 """Close all shards and wait for them to terminate."""
941947 for shard_id in self ._shards :
942- if self ._shards [shard_id ].is_alive :
943- _LOGGER .debug ("stopping shard %s" , shard_id )
944- await self ._shards [shard_id ].close ()
948+ _LOGGER .debug ("stopping shard %s" , shard_id )
949+ await self ._shards [shard_id ].close ()
945950 await asyncio .gather (* self ._tasks .values (), return_exceptions = True )
946951
947952 async def _gather_shard_lifecycles (self ) -> None :
@@ -950,12 +955,22 @@ async def _gather_shard_lifecycles(self) -> None:
950955 Ensure shards are requested to close before the coroutine function
951956 completes.
952957 """
958+ _LOGGER .debug ("gathering shards" )
959+ gatherer = asyncio .gather (* self ._tasks .values ())
960+ waiter = asyncio .create_task (self ._request_close_event .wait (), name = "listen for bot closure events" )
961+
953962 try :
954- _LOGGER .debug ("gathering shards" )
955- await asyncio .gather (* self ._tasks .values ())
963+ await asyncio .wait ([gatherer , waiter ], return_when = asyncio .FIRST_COMPLETED )
964+
965+ if not waiter .done ():
966+ waiter .cancel ()
956967 finally :
957968 _LOGGER .debug ("gather terminated, shutting down shard(s)" )
958- await asyncio .shield (self .close ())
969+ aborter = asyncio .shield (self .close ())
970+ try :
971+ await gatherer
972+ finally :
973+ await aborter
959974
960975 async def _shard_management_lifecycle (self ) -> None :
961976 """Start all shards and then wait for them to finish."""
0 commit comments