|
9 | 9 | import logging
|
10 | 10 | import ssl
|
11 | 11 | import time
|
| 12 | +import warnings |
12 | 13 | from unittest.mock import AsyncMock
|
13 | 14 | from hashlib import blake2b
|
14 | 15 | from typing import (
|
@@ -531,15 +532,31 @@ def __init__(
|
531 | 532 | self._exit_task = None
|
532 | 533 | self._open_subscriptions = 0
|
533 | 534 | self._options = options if options else {}
|
534 |
| - self.last_received = time.time() |
| 535 | + try: |
| 536 | + now = asyncio.get_running_loop().time() |
| 537 | + except RuntimeError: |
| 538 | + warnings.warn( |
| 539 | + "You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. " |
| 540 | + "Verify this is intended." |
| 541 | + ) |
| 542 | + now = asyncio.new_event_loop().time() |
| 543 | + self.last_received = now |
| 544 | + self.last_sent = now |
535 | 545 |
|
536 | 546 | async def __aenter__(self):
|
537 | 547 | async with self._lock:
|
538 | 548 | self._in_use += 1
|
539 | 549 | await self.connect()
|
540 | 550 | return self
|
541 | 551 |
|
| 552 | + @staticmethod |
| 553 | + async def loop_time() -> float: |
| 554 | + return asyncio.get_running_loop().time() |
| 555 | + |
542 | 556 | async def connect(self, force=False):
|
| 557 | + now = await self.loop_time() |
| 558 | + self.last_received = now |
| 559 | + self.last_sent = now |
543 | 560 | if self._exit_task:
|
544 | 561 | self._exit_task.cancel()
|
545 | 562 | if not self._initialized or force:
|
@@ -595,7 +612,7 @@ async def _recv(self) -> None:
|
595 | 612 | try:
|
596 | 613 | # TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
|
597 | 614 | response = json.loads(await self.ws.recv(decode=False))
|
598 |
| - self.last_received = time.time() |
| 615 | + self.last_received = await self.loop_time() |
599 | 616 | async with self._lock:
|
600 | 617 | # note that these 'subscriptions' are all waiting sent messages which have not received
|
601 | 618 | # responses, and thus are not the same as RPC 'subscriptions', which are unique
|
@@ -631,12 +648,12 @@ async def send(self, payload: dict) -> int:
|
631 | 648 | Returns:
|
632 | 649 | id: the internal ID of the request (incremented int)
|
633 | 650 | """
|
634 |
| - # async with self._lock: |
635 | 651 | original_id = get_next_id()
|
636 | 652 | # self._open_subscriptions += 1
|
637 | 653 | await self.max_subscriptions.acquire()
|
638 | 654 | try:
|
639 | 655 | await self.ws.send(json.dumps({**payload, **{"id": original_id}}))
|
| 656 | + self.last_sent = await self.loop_time() |
640 | 657 | return original_id
|
641 | 658 | except (ConnectionClosed, ssl.SSLError, EOFError):
|
642 | 659 | async with self._lock:
|
@@ -2126,7 +2143,11 @@ async def _make_rpc_request(
|
2126 | 2143 |
|
2127 | 2144 | if request_manager.is_complete:
|
2128 | 2145 | break
|
2129 |
| - if time.time() - self.ws.last_received >= self.retry_timeout: |
| 2146 | + if ( |
| 2147 | + (current_time := await self.ws.loop_time()) - self.ws.last_received |
| 2148 | + >= self.retry_timeout |
| 2149 | + and current_time - self.ws.last_sent >= self.retry_timeout |
| 2150 | + ): |
2130 | 2151 | if attempt >= self.max_retries:
|
2131 | 2152 | logger.warning(
|
2132 | 2153 | f"Timed out waiting for RPC requests {attempt} times. Exiting."
|
|
0 commit comments