diff --git a/newsfragments/3642.bugfix.rst b/newsfragments/3642.bugfix.rst new file mode 100644 index 0000000000..f742e55fdc --- /dev/null +++ b/newsfragments/3642.bugfix.rst @@ -0,0 +1 @@ +Checks that ``PersistentConnectionProvider`` response cache value is a dict before attempting to access it like one. Also adds checks to ``make_batch_request`` to make sure it is in batching mode before being called and is not after. diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 4ae962247a..21715d9e3a 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -130,3 +130,21 @@ async def test_async_http_empty_batch_response(mock_async_post): # assert that even though there was an error, we have reset the batching state assert not async_w3.provider._is_batching + + +@patch( + "web3._utils.http_session_manager.HTTPSessionManager.async_make_post_request", + new_callable=AsyncMock, +) +@pytest.mark.asyncio +async def test_async_provider_is_batching_when_make_batch_request(mock_post): + def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes: + assert provider._is_batching + return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}' + + mock_post.side_effect = assert_is_batching_and_return_response + provider = AsyncHTTPProvider() + + assert not provider._is_batching + await provider.make_batch_request([("eth_blockNumber", [])]) + assert not provider._is_batching diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 49b1b2badc..c587716a47 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -132,3 +132,21 @@ def test_http_empty_batch_response(mock_post): # assert that even though there was an error, we have reset the batching state assert not w3.provider._is_batching + + +@patch( + "web3._utils.http_session_manager.HTTPSessionManager.make_post_request", + new_callable=Mock, +) +def test_sync_provider_is_batching_when_make_batch_request(mock_post): + def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes: + assert provider._is_batching + return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}' + + provider = HTTPProvider() + assert not provider._is_batching + + mock_post.side_effect = assert_is_batching_and_return_response + + provider.make_batch_request([("eth_blockNumber", [])]) + assert not provider._is_batching diff --git a/tests/core/providers/test_ipc_provider.py b/tests/core/providers/test_ipc_provider.py index 77b62370d7..7fd8474ea6 100644 --- a/tests/core/providers/test_ipc_provider.py +++ b/tests/core/providers/test_ipc_provider.py @@ -199,3 +199,18 @@ def test_ipc_provider_write_messages_end_with_new_line_delimiter(jsonrpc_ipc_pip request_data = b'{"jsonrpc": "2.0", "method": "method", "params": [], "id": 0}' provider._socket.sock.sendall.assert_called_with(request_data + b"\n") + + +def test_ipc_provider_is_batching_when_make_batch_request(jsonrpc_ipc_pipe_path): + def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes: + assert provider._is_batching + return [{"id": 0, "jsonrpc": "2.0", "result": {}}] + + provider = IPCProvider(pathlib.Path(jsonrpc_ipc_pipe_path), timeout=3) + provider._make_request = Mock() + provider._make_request.side_effect = assert_is_batching_and_return_response + + assert not provider._is_batching + + provider.make_batch_request([("eth_blockNumber", [])]) + assert not provider._is_batching diff --git a/tests/core/providers/test_websocket_provider.py b/tests/core/providers/test_websocket_provider.py index 2965addb0a..10f16d42f2 100644 --- a/tests/core/providers/test_websocket_provider.py +++ b/tests/core/providers/test_websocket_provider.py @@ -534,3 +534,35 @@ async def test_req_info_cache_size_can_be_set_and_warns_when_full(caplog): "behavior. Consider increasing the ``request_information_cache_size`` " "on the provider." ) in caplog.text + + +@pytest.mark.asyncio +async def test_raise_stray_errors_from_cache_handles_list_response_without_error(): + provider = WebSocketProvider("ws://mocked") + _mock_ws(provider) + + bad_response = [ + {"id": None, "jsonrpc": "2.0", "error": {"code": 21, "message": "oops"}} + ] + provider._request_processor._request_response_cache._data["bad_key"] = bad_response + + # assert no errors raised + provider._raise_stray_errors_from_cache() + + +@pytest.mark.asyncio +async def test_websocket_provider_is_batching_when_make_batch_request(): + def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes: + assert provider._is_batching + return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}' + + provider = WebSocketProvider("ws://mocked") + _mock_ws(provider) + provider._get_response_for_request_id = AsyncMock() + provider._get_response_for_request_id.side_effect = ( + assert_is_batching_and_return_response + ) + + assert not provider._is_batching + await provider.make_batch_request([("eth_blockNumber", [])]) + assert not provider._is_batching diff --git a/web3/_utils/batching.py b/web3/_utils/batching.py index 9155b98c66..9c574da0af 100644 --- a/web3/_utils/batching.py +++ b/web3/_utils/batching.py @@ -1,6 +1,9 @@ from copy import ( copy, ) +from functools import ( + wraps, +) from types import ( TracebackType, ) @@ -12,9 +15,11 @@ Dict, Generic, List, + Protocol, Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -33,6 +38,7 @@ Web3ValueError, ) from web3.types import ( + RPCEndpoint, TFunc, TReturn, ) @@ -55,7 +61,6 @@ JSONBaseProvider, ) from web3.types import ( # noqa: F401 - RPCEndpoint, RPCResponse, ) @@ -215,3 +220,39 @@ def sort_batch_response_by_response_ids( stacklevel=2, ) return responses + + +class SupportsBatching(Protocol): + _is_batching: bool + + +R = TypeVar("R") +T = TypeVar("T", bound=SupportsBatching) + + +def async_batching_context( + method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]] +) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]]: + @wraps(method) + async def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R: + self._is_batching = True + try: + return await method(self, requests) + finally: + self._is_batching = False + + return wrapper + + +def batching_context( + method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], R] +) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], R]: + @wraps(method) + def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R: + self._is_batching = True + try: + return method(self, requests) + finally: + self._is_batching = False + + return wrapper diff --git a/web3/providers/ipc.py b/web3/providers/ipc.py index 3a4f2b8dff..23a133843a 100644 --- a/web3/providers/ipc.py +++ b/web3/providers/ipc.py @@ -30,6 +30,7 @@ ) from .._utils.batching import ( + batching_context, sort_batch_response_by_response_ids, ) from .._utils.caching import ( @@ -201,6 +202,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: request = self.encode_rpc_request(method, params) return self._make_request(request) + @batching_context def make_batch_request( self, requests: List[Tuple[RPCEndpoint, Any]] ) -> List[RPCResponse]: diff --git a/web3/providers/legacy_websocket.py b/web3/providers/legacy_websocket.py index e2574c8da5..1f978e865f 100644 --- a/web3/providers/legacy_websocket.py +++ b/web3/providers/legacy_websocket.py @@ -27,6 +27,7 @@ ) from web3._utils.batching import ( + batching_context, sort_batch_response_by_response_ids, ) from web3._utils.caching import ( @@ -143,6 +144,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: ) return future.result() + @batching_context def make_batch_request( self, requests: List[Tuple[RPCEndpoint, Any]] ) -> List[RPCResponse]: diff --git a/web3/providers/persistent/persistent.py b/web3/providers/persistent/persistent.py index 5d7385bbb9..4f70af6deb 100644 --- a/web3/providers/persistent/persistent.py +++ b/web3/providers/persistent/persistent.py @@ -24,6 +24,7 @@ from web3._utils.batching import ( BATCH_REQUEST_ID, + async_batching_context, sort_batch_response_by_response_ids, ) from web3._utils.caching import ( @@ -237,14 +238,17 @@ async def make_request( rpc_request = await self.send_request(method, params) return await self.recv_for_request(rpc_request) + @async_batching_context async def make_batch_request( self, requests: List[Tuple[RPCEndpoint, Any]] ) -> List[RPCResponse]: request_data = self.encode_batch_rpc_request(requests) await self.socket_send(request_data) + # breakpoint() response = cast( - List[RPCResponse], await self._get_response_for_request_id(BATCH_REQUEST_ID) + List[RPCResponse], + await self._get_response_for_request_id(BATCH_REQUEST_ID), ) return response @@ -320,17 +324,16 @@ def _raise_stray_errors_from_cache(self) -> None: for ( response ) in self._request_processor._request_response_cache._data.values(): - request = ( - self._request_processor._request_information_cache.get_cache_entry( + if isinstance(response, dict): + request = self._request_processor._request_information_cache.get_cache_entry( # noqa: E501 generate_cache_key(response["id"]) ) - ) - if "error" in response and request is None: - # if we find an error response in the cache without a corresponding - # request, raise the error - validate_rpc_response_and_raise_if_error( - response, None, logger=self.logger - ) + if "error" in response and request is None: + # if we find an error response in the cache without a + # corresponding request, raise the error + validate_rpc_response_and_raise_if_error( + cast(RPCResponse, response), None, logger=self.logger + ) async def _message_listener(self) -> None: self.logger.info( diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index 85e27c33fb..6e66bae938 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -36,6 +36,7 @@ ) from ..._utils.batching import ( + async_batching_context, sort_batch_response_by_response_ids, ) from ..._utils.caching import ( @@ -166,6 +167,7 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: ) return response + @async_batching_context async def make_batch_request( self, batch_requests: List[Tuple[RPCEndpoint, Any]] ) -> Union[List[RPCResponse], RPCResponse]: diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index 4c41d46e29..d60ce8c14c 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -34,6 +34,7 @@ ) from ..._utils.batching import ( + batching_context, sort_batch_response_by_response_ids, ) from ..._utils.caching import ( @@ -174,6 +175,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: ) return response + @batching_context def make_batch_request( self, batch_requests: List[Tuple[RPCEndpoint, Any]] ) -> Union[List[RPCResponse], RPCResponse]: