Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3642.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Checks that ``PersistentConnectionProvider`` response cache value is a dict before attempting to access it like one. Also adds checks to ``make_batch_request`` to make sure it is in batching mode before being called and is not after.
18 changes: 18 additions & 0 deletions tests/core/providers/test_async_http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,21 @@ async def test_async_http_empty_batch_response(mock_async_post):

# assert that even though there was an error, we have reset the batching state
assert not async_w3.provider._is_batching


@patch(
"web3._utils.http_session_manager.HTTPSessionManager.async_make_post_request",
new_callable=AsyncMock,
)
@pytest.mark.asyncio
async def test_async_provider_is_batching_when_make_batch_request(mock_post):
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
assert provider._is_batching
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'

mock_post.side_effect = assert_is_batching_and_return_response
provider = AsyncHTTPProvider()

assert not provider._is_batching
await provider.make_batch_request([("eth_blockNumber", [])])
assert not provider._is_batching
18 changes: 18 additions & 0 deletions tests/core/providers/test_http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,21 @@ def test_http_empty_batch_response(mock_post):

# assert that even though there was an error, we have reset the batching state
assert not w3.provider._is_batching


@patch(
"web3._utils.http_session_manager.HTTPSessionManager.make_post_request",
new_callable=Mock,
)
def test_sync_provider_is_batching_when_make_batch_request(mock_post):
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
assert provider._is_batching
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'

provider = HTTPProvider()
assert not provider._is_batching

mock_post.side_effect = assert_is_batching_and_return_response

provider.make_batch_request([("eth_blockNumber", [])])
assert not provider._is_batching
15 changes: 15 additions & 0 deletions tests/core/providers/test_ipc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,18 @@ def test_ipc_provider_write_messages_end_with_new_line_delimiter(jsonrpc_ipc_pip

request_data = b'{"jsonrpc": "2.0", "method": "method", "params": [], "id": 0}'
provider._socket.sock.sendall.assert_called_with(request_data + b"\n")


def test_ipc_provider_is_batching_when_make_batch_request(jsonrpc_ipc_pipe_path):
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
assert provider._is_batching
return [{"id": 0, "jsonrpc": "2.0", "result": {}}]

provider = IPCProvider(pathlib.Path(jsonrpc_ipc_pipe_path), timeout=3)
provider._make_request = Mock()
provider._make_request.side_effect = assert_is_batching_and_return_response

assert not provider._is_batching

provider.make_batch_request([("eth_blockNumber", [])])
assert not provider._is_batching
32 changes: 32 additions & 0 deletions tests/core/providers/test_websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,35 @@ async def test_req_info_cache_size_can_be_set_and_warns_when_full(caplog):
"behavior. Consider increasing the ``request_information_cache_size`` "
"on the provider."
) in caplog.text


@pytest.mark.asyncio
async def test_raise_stray_errors_from_cache_handles_list_response_without_error():
provider = WebSocketProvider("ws://mocked")
_mock_ws(provider)

bad_response = [
{"id": None, "jsonrpc": "2.0", "error": {"code": 21, "message": "oops"}}
]
provider._request_processor._request_response_cache._data["bad_key"] = bad_response

# assert no errors raised
provider._raise_stray_errors_from_cache()


@pytest.mark.asyncio
async def test_websocket_provider_is_batching_when_make_batch_request():
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
assert provider._is_batching
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'

provider = WebSocketProvider("ws://mocked")
_mock_ws(provider)
provider._get_response_for_request_id = AsyncMock()
provider._get_response_for_request_id.side_effect = (
assert_is_batching_and_return_response
)

assert not provider._is_batching
await provider.make_batch_request([("eth_blockNumber", [])])
assert not provider._is_batching
43 changes: 42 additions & 1 deletion web3/_utils/batching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from copy import (
copy,
)
from functools import (
wraps,
)
from types import (
TracebackType,
)
Expand All @@ -12,9 +15,11 @@
Dict,
Generic,
List,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
Expand All @@ -33,6 +38,7 @@
Web3ValueError,
)
from web3.types import (
RPCEndpoint,
TFunc,
TReturn,
)
Expand All @@ -55,7 +61,6 @@
JSONBaseProvider,
)
from web3.types import ( # noqa: F401
RPCEndpoint,
RPCResponse,
)

Expand Down Expand Up @@ -215,3 +220,39 @@ def sort_batch_response_by_response_ids(
stacklevel=2,
)
return responses


class SupportsBatching(Protocol):
_is_batching: bool


R = TypeVar("R")
T = TypeVar("T", bound=SupportsBatching)


def async_batching_context(
method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]]
) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]]:
@wraps(method)
async def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R:
self._is_batching = True
try:
return await method(self, requests)
finally:
self._is_batching = False

return wrapper


def batching_context(
method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], R]
) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], R]:
@wraps(method)
def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R:
self._is_batching = True
try:
return method(self, requests)
finally:
self._is_batching = False

return wrapper
2 changes: 2 additions & 0 deletions web3/providers/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

from .._utils.batching import (
batching_context,
sort_batch_response_by_response_ids,
)
from .._utils.caching import (
Expand Down Expand Up @@ -201,6 +202,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
request = self.encode_rpc_request(method, params)
return self._make_request(request)

@batching_context
def make_batch_request(
self, requests: List[Tuple[RPCEndpoint, Any]]
) -> List[RPCResponse]:
Expand Down
2 changes: 2 additions & 0 deletions web3/providers/legacy_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from web3._utils.batching import (
batching_context,
sort_batch_response_by_response_ids,
)
from web3._utils.caching import (
Expand Down Expand Up @@ -143,6 +144,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
)
return future.result()

@batching_context
def make_batch_request(
self, requests: List[Tuple[RPCEndpoint, Any]]
) -> List[RPCResponse]:
Expand Down
23 changes: 13 additions & 10 deletions web3/providers/persistent/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from web3._utils.batching import (
BATCH_REQUEST_ID,
async_batching_context,
sort_batch_response_by_response_ids,
)
from web3._utils.caching import (
Expand Down Expand Up @@ -237,14 +238,17 @@ async def make_request(
rpc_request = await self.send_request(method, params)
return await self.recv_for_request(rpc_request)

@async_batching_context
async def make_batch_request(
self, requests: List[Tuple[RPCEndpoint, Any]]
) -> List[RPCResponse]:
request_data = self.encode_batch_rpc_request(requests)
await self.socket_send(request_data)

# breakpoint()
response = cast(
List[RPCResponse], await self._get_response_for_request_id(BATCH_REQUEST_ID)
List[RPCResponse],
await self._get_response_for_request_id(BATCH_REQUEST_ID),
)
return response

Expand Down Expand Up @@ -320,17 +324,16 @@ def _raise_stray_errors_from_cache(self) -> None:
for (
response
) in self._request_processor._request_response_cache._data.values():
request = (
self._request_processor._request_information_cache.get_cache_entry(
if isinstance(response, dict):
request = self._request_processor._request_information_cache.get_cache_entry( # noqa: E501
generate_cache_key(response["id"])
)
)
if "error" in response and request is None:
# if we find an error response in the cache without a corresponding
# request, raise the error
validate_rpc_response_and_raise_if_error(
response, None, logger=self.logger
)
if "error" in response and request is None:
# if we find an error response in the cache without a
# corresponding request, raise the error
validate_rpc_response_and_raise_if_error(
cast(RPCResponse, response), None, logger=self.logger
)

async def _message_listener(self) -> None:
self.logger.info(
Expand Down
2 changes: 2 additions & 0 deletions web3/providers/rpc/async_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)

from ..._utils.batching import (
async_batching_context,
sort_batch_response_by_response_ids,
)
from ..._utils.caching import (
Expand Down Expand Up @@ -166,6 +167,7 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
)
return response

@async_batching_context
async def make_batch_request(
self, batch_requests: List[Tuple[RPCEndpoint, Any]]
) -> Union[List[RPCResponse], RPCResponse]:
Expand Down
2 changes: 2 additions & 0 deletions web3/providers/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

from ..._utils.batching import (
batching_context,
sort_batch_response_by_response_ids,
)
from ..._utils.caching import (
Expand Down Expand Up @@ -174,6 +175,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
)
return response

@batching_context
def make_batch_request(
self, batch_requests: List[Tuple[RPCEndpoint, Any]]
) -> Union[List[RPCResponse], RPCResponse]:
Expand Down