Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions safe_eth/eth/clients/etherscan_client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin

import aiohttp
import requests

from safe_eth.eth import EthereumNetwork
Expand All @@ -12,6 +11,7 @@
EtherscanClient,
EtherscanRateLimitError,
)
from safe_eth.eth.clients.rate_limiter import get_client_rate_limited


class EtherscanClientV2(EtherscanClient):
Expand Down Expand Up @@ -97,19 +97,15 @@ def __init__(
max_requests: int = int(os.environ.get("ETHERSCAN_CLIENT_MAX_REQUESTS", 100)),
):
super().__init__(network, api_key, request_timeout)
self.async_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit_per_host=max_requests)
)
self.client = get_client_rate_limited(self.base_api_url, 5) # 5 per second

async def _async_do_request(
self, url: str
) -> Optional[Union[Dict[str, Any], List[Any], str]]:
"""
Async version of _do_request
"""
async with self.async_session.get(
url, timeout=self.request_timeout
) as response:
async with await self.client.get(url, timeout=self.request_timeout) as response:
if response.ok:
response_json = await response.json()
result = response_json["result"]
Expand Down
74 changes: 74 additions & 0 deletions safe_eth/eth/clients/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
from functools import cache
from logging import getLogger

import aiohttp

logger = getLogger(__name__)


class RateLimiter:
"""
Class to limit the number of requests per second
"""

def __init__(self, client, rate):
self.client = client
self.rate = rate
self.available_conns = rate # Initialize available conns
self._waiters = [] # List of tasks that are waiting for a connection
self.loop = asyncio.get_event_loop()
self._schedule_next_release_connections() # Schedule first release connections

async def get(self, *args, **kwargs):
await self._wait_for_available_conn()
return self.client.get(*args, **kwargs)

async def post(self, *args, **kwargs):
await self._wait_for_available_conn()
return self.client.post(*args, **kwargs)

def _wakeup_waiters(self):
"""
Unblock tasks waiting for connections
"""
while self.available_conns > 0 and self._waiters:
future = self._waiters.pop(0)
future.set_result(None) # Release await
self.available_conns -= 1

async def _wait_for_available_conn(self):
if self.available_conns < 1:
future = asyncio.Future()
self._waiters.append(future)
await future
else:
self.available_conns -= 1

def _release_available_conns(self):
"""
Release new connections
"""
self.available_conns += self.rate - self.available_conns
self._wakeup_waiters()
self._schedule_next_release_connections()

def _schedule_next_release_connections(self):
"""
Schedule next release connections
"""
self.loop.call_later(1, self._release_available_conns)


@cache
def get_client_rate_limited(host: str, rate: int) -> "RateLimiter":
"""
Get a rate limited client by host
Host parameter is just being used to store in cache different instance by host

:param host:
:param rate: number of requests allowed per second
"""
logger.info(f"Initializing rate limiter for {host} by {rate}/s")
async_session = aiohttp.ClientSession()
return RateLimiter(async_session, rate)
27 changes: 27 additions & 0 deletions safe_eth/eth/tests/clients/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
import unittest

from safe_eth.eth.clients.rate_limiter import get_client_rate_limited


class TestRateLimit(unittest.IsolatedAsyncioTestCase):
async def _make_request(self, rate_limited):
"""Helper to fetch the status code of a request."""
try:
async with await rate_limited.get(
"https://safe-transaction-sepolia.safe.global/api/v1/about", timeout=5
) as response:
return response
except asyncio.TimeoutError:
return None

async def test_rate_limiter(self):
rate_limited = get_client_rate_limited(
"https://safe-transaction-sepolia.safe.global/", 5
)
tasks = [self._make_request(rate_limited) for _ in range(20)]
responses = await asyncio.gather(*tasks)
self.assertEqual(len(responses), 20)
for response in responses:
self.assertEqual(response.status, 200) # Check the status code
self.assertTrue(response.ok)
Loading