Skip to content

Commit 20e24f9

Browse files
author
Nekokatt
authored
Merge pull request #102 from nekokatt/bugfix/closures
Bugfix/closures
2 parents 3e9f15f + d818f30 commit 20e24f9

10 files changed

Lines changed: 491 additions & 313 deletions

File tree

hikari/impl/bot.py

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from hikari.events import lifetime_events
4848
from hikari.impl import entity_factory as entity_factory_impl
4949
from hikari.impl import event_factory as event_factory_impl
50-
from hikari.impl import rate_limits
5150
from hikari.impl import rest as rest_client_impl
5251
from hikari.impl import shard as gateway_shard_impl
5352
from hikari.impl import stateful_cache as cache_impl
@@ -201,20 +200,20 @@ class BotApp(
201200

202201
__slots__: typing.Sequence[str] = (
203202
"_cache",
204-
"_guild_chunker",
205203
"_connector_factory",
206204
"_debug",
207205
"_entity_factory",
208206
"_event_manager",
209207
"_event_factory",
210208
"_executor",
211-
"_global_ratelimit",
209+
"_guild_chunker",
212210
"_http_settings",
213211
"_initial_activity",
214212
"_initial_idle_since",
215213
"_initial_is_afk",
216214
"_initial_status",
217215
"_intents",
216+
"_has_aborted",
218217
"_large_threshold",
219218
"_max_concurrency",
220219
"_proxy_settings",
@@ -273,13 +272,13 @@ def __init__(
273272
self._entity_factory = entity_factory_impl.EntityFactoryImpl(app=self)
274273
self._event_factory = event_factory_impl.EventFactoryImpl(app=self)
275274
self._executor = executor
276-
self._global_ratelimit = rate_limits.ManualRateLimiter()
277275
self._http_settings = config.HTTPSettings() if http_settings is None else http_settings
278276
self._initial_activity = initial_activity
279277
self._initial_idle_since = initial_idle_since
280278
self._initial_is_afk = initial_is_afk
281279
self._initial_status = initial_status
282280
self._intents = intents
281+
self._has_aborted = False
283282
self._large_threshold = large_threshold
284283
self._max_concurrency = 1
285284
self._proxy_settings = config.ProxySettings() if proxy_settings is None else proxy_settings
@@ -508,7 +507,11 @@ async def start(self) -> None:
508507
self._tasks.clear()
509508
self._shard_gather_task = None
510509

511-
await self._init()
510+
try:
511+
await self._init()
512+
except Exception:
513+
await self.close()
514+
raise
512515

513516
self._request_close_event.clear()
514517

@@ -537,7 +540,15 @@ async def start(self) -> None:
537540
window[shard_id] = asyncio.create_task(shard_obj.start(), name=f"start gateway shard {shard_id}")
538541

539542
# Wait for the group to start.
540-
await asyncio.gather(*window.values())
543+
gatherer = asyncio.gather(*window.values())
544+
waiter = asyncio.create_task(self._request_close_event.wait(), name="listen for bot closure events")
545+
546+
await asyncio.wait((gatherer, waiter), return_when=asyncio.FIRST_COMPLETED)
547+
548+
if not waiter.done():
549+
waiter.cancel()
550+
else:
551+
gatherer.cancel()
541552

542553
# Store the keep-alive tasks and continue.
543554
for shard_id, start_task in window.items():
@@ -546,6 +557,7 @@ async def start(self) -> None:
546557
finally:
547558
if len(self._tasks) != len(self._shards):
548559
_LOGGER.warning("application was aborted midway through initialization, so never managed to start")
560+
await self.close()
549561
raise errors.GatewayClientClosedError("Client was aborted midway through initialization")
550562

551563
finish_time = date.monotonic()
@@ -609,32 +621,23 @@ def dispatch(self, event: base_events.Event) -> asyncio.Future[typing.Any]:
609621
return self.dispatcher.dispatch(event)
610622

611623
async def close(self) -> None:
612-
"""Request that all shards disconnect and the application shuts down.
624+
"""Immediately destroy all shards that are running and stop."""
625+
self._request_close_event.set()
613626

614-
This will close all shards that are running, and then close any
615-
REST components and connectors.
616-
"""
617-
self._guild_chunker.close()
627+
# Prevent calling this multiple times.
628+
if self._has_aborted:
629+
return
618630

619-
try:
620-
# This way if we cancel the stopping task, we still shut down properly.
621-
self._request_close_event.set()
622-
_LOGGER.info("stopping %s shard(s)", len(self._shards))
623-
624-
try:
625-
if self._shards:
626-
await self.dispatch(lifetime_events.StoppingEvent(app=self))
627-
await self._abort_shards()
628-
finally:
629-
# The starting event occurs before the bot starts, regardless of if
630-
# it had started or not, so it seems sensible stopped event has the
631-
# same semantics.
632-
self._tasks.clear()
633-
await self.dispatch(lifetime_events.StoppedEvent(app=self))
634-
finally:
635-
await self._rest.close()
636-
await self._connector_factory.close()
637-
self._global_ratelimit.close()
631+
self._has_aborted = True
632+
self._guild_chunker.close()
633+
await self.dispatch(lifetime_events.StoppingEvent(app=self))
634+
await self._abort_shards()
635+
self._tasks.clear()
636+
await self.dispatch(lifetime_events.StoppedEvent(app=self))
637+
await self._rest.close()
638+
await self._connector_factory.close()
639+
self._shard_gather_task = None
640+
self._request_close_event.clear()
638641

639642
def run(
640643
self,
@@ -744,7 +747,7 @@ def run(
744747

745748
def die() -> None:
746749
_LOGGER.info("received signal to shut down client")
747-
asyncio.ensure_future(self.close())
750+
self._request_close_event.set()
748751

749752
for signum in kill_signals:
750753
# Windows is dumb and doesn't support signals properly.
@@ -758,30 +761,33 @@ def die() -> None:
758761
finally:
759762
loop.run_until_complete(self.join())
760763
except errors.GatewayClientClosedError as ex:
761-
_LOGGER.info(str(ex))
764+
_LOGGER.info("client closed with reason: %s", ex)
762765
finally:
763766
for signum in kill_signals:
764767
# Windows is dumb and doesn't support signals properly.
765768
with contextlib.suppress(NotImplementedError):
766769
loop.remove_signal_handler(signum)
767770

768771
if finalize_loop_on_close:
769-
_LOGGER.debug("closing asyncgens for event loop %s", loop)
770-
loop.run_until_complete(loop.shutdown_asyncgens())
772+
remaining_tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]
771773

772-
remaining_tasks = asyncio.all_tasks(loop)
773774
if remaining_tasks:
774-
_LOGGER.warning("forcefully stopping %s remaining tasks", len(remaining_tasks))
775+
_LOGGER.debug("forcefully stopping %s remaining tasks", len(remaining_tasks))
776+
775777
for task in remaining_tasks:
776778
task.cancel()
779+
loop.run_until_complete(asyncio.gather(*remaining_tasks, return_exceptions=True))
777780

778-
# Don't warn that these were never retrieved.
779-
with contextlib.suppress(asyncio.InvalidStateError):
780-
task.exception()
781-
781+
for task in remaining_tasks:
782+
if not task.cancelled():
783+
exception = task.exception()
784+
if exception is not None:
785+
_LOGGER.warning("unhandled exception during shutdown", exc_info=exception)
782786
else:
783787
_LOGGER.debug("no tasks are running, congratulations on writing a tidy application")
784788

789+
_LOGGER.debug("closing asyncgens for event loop %s", loop)
790+
loop.run_until_complete(loop.shutdown_asyncgens())
785791
loop.close()
786792

787793
async def join(self) -> None:
@@ -939,9 +945,8 @@ def _max_concurrency_chunker(self) -> typing.Iterator[typing.Iterator[int]]:
939945
async def _abort_shards(self) -> None:
940946
"""Close all shards and wait for them to terminate."""
941947
for shard_id in self._shards:
942-
if self._shards[shard_id].is_alive:
943-
_LOGGER.debug("stopping shard %s", shard_id)
944-
await self._shards[shard_id].close()
948+
_LOGGER.debug("stopping shard %s", shard_id)
949+
await self._shards[shard_id].close()
945950
await asyncio.gather(*self._tasks.values(), return_exceptions=True)
946951

947952
async def _gather_shard_lifecycles(self) -> None:
@@ -950,12 +955,22 @@ async def _gather_shard_lifecycles(self) -> None:
950955
Ensure shards are requested to close before the coroutine function
951956
completes.
952957
"""
958+
_LOGGER.debug("gathering shards")
959+
gatherer = asyncio.gather(*self._tasks.values())
960+
waiter = asyncio.create_task(self._request_close_event.wait(), name="listen for bot closure events")
961+
953962
try:
954-
_LOGGER.debug("gathering shards")
955-
await asyncio.gather(*self._tasks.values())
963+
await asyncio.wait([gatherer, waiter], return_when=asyncio.FIRST_COMPLETED)
964+
965+
if not waiter.done():
966+
waiter.cancel()
956967
finally:
957968
_LOGGER.debug("gather terminated, shutting down shard(s)")
958-
await asyncio.shield(self.close())
969+
aborter = asyncio.shield(self.close())
970+
try:
971+
await gatherer
972+
finally:
973+
await aborter
959974

960975
async def _shard_management_lifecycle(self) -> None:
961976
"""Start all shards and then wait for them to finish."""

hikari/impl/buckets.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def start(self, poll_period: float = _POLL_PERIOD, expire_after: float = _EXPIRE
382382
as the rate limit has reset. Defaults to `10` seconds.
383383
"""
384384
if not self.gc_task:
385-
self.gc_task = asyncio.get_running_loop().create_task(self.gc(poll_period, expire_after))
385+
self.gc_task = asyncio.create_task(self.gc(poll_period, expire_after))
386386

387387
def close(self) -> None:
388388
"""Close the garbage collector and kill any tasks waiting on ratelimits.
@@ -396,6 +396,10 @@ def close(self) -> None:
396396
self.real_hashes_to_buckets.clear()
397397
self.routes_to_hashes.clear()
398398

399+
if self.gc_task is not None:
400+
self.gc_task.cancel()
401+
self.gc_task = None
402+
399403
# Ignore docstring not starting in an imperative mood
400404
async def gc(self, poll_period: float, expire_after: float) -> None: # noqa: D401
401405
"""The garbage collector loop.

hikari/impl/rate_limits.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def close(self) -> None:
129129

130130
if self.throttle_task is not None:
131131
self.throttle_task.cancel()
132+
self.throttle_task = None
132133

133134
failed_tasks = 0
134135
while self.queue:

hikari/impl/rest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class BasicLazyCachedTCPConnectorFactory(rest_api.ConnectorFactory):
9696

9797
def __init__(self, **kwargs: typing.Any) -> None:
9898
self.connector: typing.Optional[aiohttp.TCPConnector] = None
99+
kwargs.setdefault("force_close", True)
100+
kwargs.setdefault("enable_cleanup_closed", True)
99101
self.connector_kwargs = kwargs
100102

101103
async def close(self) -> None:
@@ -424,6 +426,7 @@ async def close(self) -> None:
424426
"""Close the HTTP client and any open HTTP connections."""
425427
if self._client_session is not None:
426428
await self._client_session.close()
429+
await self._connector_factory.close()
427430
self.global_rate_limit.close()
428431
self.buckets.close()
429432
self._closed_event.set()
@@ -444,9 +447,8 @@ def _acquire_client_session(self) -> aiohttp.ClientSession:
444447
if self._client_session is None:
445448
self._closed_event.clear()
446449
self._client_session = aiohttp.ClientSession(
447-
# Should not need a lock, since we don't technically await anything.
448450
connector=self._connector_factory.acquire(),
449-
connector_owner=self._connector_owner,
451+
connector_owner=False,
450452
version=aiohttp.HttpVersion11,
451453
timeout=aiohttp.ClientTimeout(
452454
total=self._http_settings.timeouts.total,

0 commit comments

Comments
 (0)