diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index ac879f35..d6e58a32 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -33,6 +33,7 @@ MovedError, NoPermissionError, ResponseError, + TimeoutError, ValkeyClusterException, ValkeyError, ) @@ -2802,6 +2803,23 @@ async def parse_response( with pytest.raises(ConnectionError): await pipe.get(key).get(key).execute(raise_on_error=True) + async def test_timeout_error_retried(self, r: ValkeyCluster) -> None: + key = "foo" + await r.set(key, "value") + execute_pipeline = ClusterNode.execute_pipeline + attempts = 0 + + async def raise_timeout_once(self, commands): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise TimeoutError("error") + return await execute_pipeline(self, commands) + + with mock.patch.object(ClusterNode, "execute_pipeline", new=raise_timeout_once): + assert await r.pipeline().get(key).execute() == [b"value"] + assert attempts == 2 + async def test_asking_error(self, r: ValkeyCluster) -> None: """Test AskError handling.""" key = "foo" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 83d6e96c..8bfb7f62 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -29,6 +29,7 @@ from valkey.crc import key_slot from valkey.exceptions import ( AskError, + AuthenticationError, ClusterDownError, ConnectionError, DataError, @@ -36,6 +37,7 @@ NoPermissionError, ResponseError, TimeoutError, + TryAgainError, ValkeyClusterException, ValkeyError, ) @@ -3213,6 +3215,74 @@ def raise_connection_error(): with pytest.raises(ConnectionError): pipe.get(key).get(key).execute(raise_on_error=True) + def test_timeout_error_get_connection_retried(self, r): + key = "foo" + r.set(key, "value") + orig_get_connection = valkey.cluster.get_connection + attempts = 0 + + def raise_timeout_once(*args, **kwargs): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise TimeoutError("error") + return orig_get_connection(*args, **kwargs) + + with patch("valkey.cluster.get_connection", side_effect=raise_timeout_once): + with patch.object( + r.nodes_manager, "initialize", wraps=r.nodes_manager.initialize + ) as initialize: + assert r.pipeline().get(key).execute() == [b"value"] + assert attempts == 2 + assert initialize.call_count == 1 + + def test_timeout_error_get_connection_raised(self, r): + key = "foo" + + with ( + patch("valkey.cluster.get_connection", side_effect=TimeoutError("error")), + patch.object( + r.nodes_manager, "initialize", wraps=r.nodes_manager.initialize + ) as initialize, + pytest.raises(TimeoutError), + ): + r.pipeline().get(key).execute() + assert initialize.call_count == r.cluster_error_retry_attempts + 1 + + def test_annotate_exception_handles_empty_args(self, r): + pipe = r.pipeline() + exception = TryAgainError() + + pipe.annotate_exception(exception, 1, ("GET", "foo")) + + assert exception.args == ( + "Command # 1 (GET foo) of pipeline caused error: TryAgainError", + ) + + def test_non_retryable_get_connection_error_releases_connections(self, r): + # in order to ensure that a pipeline will make use of connections + # from different nodes + assert r.keyslot("a") != r.keyslot("b") + + orig_get_connection = valkey.cluster.get_connection + + with patch("valkey.cluster.get_connection") as get_connection: + + def raise_non_retryable(target_node, *args, **kwargs): + if get_connection.call_count == 2: + raise AuthenticationError("mocked auth error") + return orig_get_connection(target_node, *args, **kwargs) + + get_connection.side_effect = raise_non_retryable + + with pytest.raises(AuthenticationError): + r.pipeline().get("a").get("b").execute() + + for cluster_node in r.nodes_manager.nodes_cache.values(): + connection_pool = cluster_node.valkey_connection.connection_pool + num_of_conns = len(connection_pool._available_connections) + assert num_of_conns == connection_pool._created_connections + def test_asking_error(self, r): """ Test redirection on ASK error diff --git a/valkey/cluster.py b/valkey/cluster.py index fb508637..d7714cb7 100644 --- a/valkey/cluster.py +++ b/valkey/cluster.py @@ -422,7 +422,8 @@ class AbstractValkeyCluster: list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + REINITIALIZE_ERRORS = (ConnectionError, TimeoutError, ClusterDownError) + ERRORS_ALLOW_RETRY = REINITIALIZE_ERRORS def replace_default_node(self, target_node: "ClusterNode" = None) -> None: """Replace the default cluster node. @@ -1616,7 +1617,7 @@ def initialize(self): if len(disagreements) > 5: raise ValkeyClusterException( f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' + f"slots cache: {', '.join(disagreements)}" ) fully_covered = self.check_slots_coverage(tmp_slots) @@ -1933,9 +1934,8 @@ class ClusterPipeline(ValkeyCluster): in cluster mode """ - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, + REINITIALIZE_ERRORS = AbstractValkeyCluster.REINITIALIZE_ERRORS + ERRORS_ALLOW_RETRY = REINITIALIZE_ERRORS + ( MovedError, AskError, TryAgainError, @@ -2035,10 +2035,10 @@ def annotate_exception(self, exception, number, command): Provides extra context to the exception prior to it being handled """ cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" - ) + error_message = exception.args[0] if exception.args else str(exception) + if not error_message: + error_message = exception.__class__.__name__ + msg = f"Command # {number} ({cmd}) of pipeline caused error: {error_message}" exception.args = (msg,) + exception.args[1:] def execute(self, raise_on_error=True): @@ -2111,14 +2111,14 @@ def send_cluster_commands( raise_on_error=raise_on_error, allow_redirections=allow_redirections, ) - except (ClusterDownError, ConnectionError) as e: - if retry_attempts > 0: + except Exception as e: + if retry_attempts > 0 and type(e) in self.__class__.REINITIALIZE_ERRORS: # Try again with the new cluster setup. All other errors # should be raised. retry_attempts -= 1 pass else: - raise e + raise def _send_cluster_commands( self, stack, raise_on_error=True, allow_redirections=True @@ -2176,9 +2176,11 @@ def _send_cluster_commands( valkey_node = self.get_valkey_connection(node) try: connection = get_connection(valkey_node, c.args) - except ConnectionError: + except Exception as e: for n in nodes.values(): n.connection_pool.release(n.connection) + if type(e) not in self.__class__.REINITIALIZE_ERRORS: + raise # Connection retries are being handled in the node's # Retry object. Reinitialize the node -> slot table. self.nodes_manager.initialize()