Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions sdks/python/hatchet_sdk/clients/rest/tenacity_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
from __future__ import annotations

import re
from collections.abc import Callable
from typing import ParamSpec, TypeVar
from typing import TYPE_CHECKING, ParamSpec, TypeVar

import grpc
import tenacity

from hatchet_sdk.clients.rest.exceptions import NotFoundException, ServiceException
from hatchet_sdk.config import TenacityConfig
from hatchet_sdk.clients.rest.exceptions import (
NotFoundException,
RestTransportError,
ServiceException,
)
from hatchet_sdk.logger import logger

if TYPE_CHECKING:
from hatchet_sdk.config import TenacityConfig

P = ParamSpec("P")
R = TypeVar("R")

# Pattern to extract HTTP method from exception reason
_METHOD_PATTERN = re.compile(r"\bmethod=(\w+)\b", re.IGNORECASE)


def tenacity_retry(func: Callable[P, R], config: TenacityConfig) -> Callable[P, R]:
if config.max_attempts <= 0:
return func

def should_retry(ex: BaseException) -> bool:
return tenacity_should_retry(ex, config)

return tenacity.retry(
reraise=True,
wait=tenacity.wait_exponential_jitter(),
stop=tenacity.stop_after_attempt(config.max_attempts),
before_sleep=tenacity_alert_retry,
retry=tenacity.retry_if_exception(tenacity_should_retry),
retry=tenacity.retry_if_exception(should_retry),
)(func)


Expand All @@ -33,10 +48,23 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None:
)


def tenacity_should_retry(ex: BaseException) -> bool:
def tenacity_should_retry(
ex: BaseException, config: TenacityConfig | None = None
) -> bool:
"""Determine if an exception should trigger a retry.

Args:
ex: The exception to evaluate.
config: Optional tenacity config for transport error settings.

Returns:
True if the exception should be retried, False otherwise.
"""
# HTTP errors: ServiceException (5xx) and NotFoundException (404) are retried
if isinstance(ex, ServiceException | NotFoundException):
return True

# gRPC errors: retry most, except specific permanent failure codes
if isinstance(ex, grpc.aio.AioRpcError | grpc.RpcError):
return ex.code() not in [
grpc.StatusCode.UNIMPLEMENTED,
Expand All @@ -47,4 +75,24 @@ def tenacity_should_retry(ex: BaseException) -> bool:
grpc.StatusCode.PERMISSION_DENIED,
]

# REST transport errors: opt-in retry for configured HTTP methods
if isinstance(ex, RestTransportError):
if config is not None and config.retry_transport_errors:
method = _extract_method_from_reason(ex.reason)
if method is not None:
allowed_methods = {m.upper() for m in config.retry_transport_methods}
return method.upper() in allowed_methods
return False

return False


def _extract_method_from_reason(reason: str | None) -> str | None:
"""Extract HTTP method from exception reason string.

The reason string contains 'method=GET' or similar from rest.py exception handling.
"""
if not reason:
return None
match = _METHOD_PATTERN.search(reason)
return match.group(1) if match else None
9 changes: 9 additions & 0 deletions sdks/python/hatchet_sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ class TenacityConfig(BaseSettings):

max_attempts: int = 5

retry_transport_errors: bool = Field(
default=False,
description="Enable retries for REST transport errors (timeout, connection, TLS). Default: off.",
)
retry_transport_methods: list[str] = Field(
default_factory=lambda: ["GET", "DELETE"],
description="HTTP methods to retry on transport errors when retry_transport_errors is enabled; excludes POST/PUT/PATCH by default due to idempotency concerns.",
)


DEFAULT_HOST_PORT = "localhost:7070"

Expand Down
179 changes: 179 additions & 0 deletions sdks/python/tests/test_tenacity_transport_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Unit tests for tenacity transport error retry behavior.

Tests verify:
1. Default behavior: RestTransportError is NOT retried (even for GET)
2. Opt-in behavior: RestTransportError retried for configured methods only
3. Existing HTTP error retry behavior unchanged
4. Method extraction from exception reason strings
"""

import pytest

from hatchet_sdk.clients.rest.exceptions import (
NotFoundException,
RestConnectionError,
RestProtocolError,
RestTimeoutError,
RestTLSError,
RestTransportError,
ServiceException,
)
from hatchet_sdk.clients.rest.tenacity_utils import (
_extract_method_from_reason,
tenacity_should_retry,
)
from hatchet_sdk.config import TenacityConfig

# --- Default behavior tests (transport errors NOT retried) ---


@pytest.mark.parametrize(
"exc_class",
[RestTransportError, RestTimeoutError],
ids=["base-class", "subclass"],
)
def test_default__transport_errors_not_retried(exc_class: type) -> None:
"""By default, RestTransportError and subclasses should not be retried."""
exc = exc_class(status=0, reason="method=GET\nurl=http://test")
config = TenacityConfig()
assert tenacity_should_retry(exc, config) is False


# --- Opt-in behavior tests (transport errors retried for allowed methods) ---


@pytest.mark.parametrize(
"method",
["GET", "DELETE"],
ids=["get", "delete"],
)
def test_optin__idempotent_methods_retried(method: str) -> None:
"""When enabled, GET and DELETE requests with transport errors should be retried."""
exc = RestTimeoutError(status=0, reason=f"method={method}\nurl=http://test")
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is True


@pytest.mark.parametrize(
"method",
["POST", "PUT", "PATCH"],
ids=["post", "put", "patch"],
)
def test_optin__non_idempotent_methods_not_retried(method: str) -> None:
"""Non-idempotent requests should not be retried even when transport retry is enabled."""
exc = RestTimeoutError(status=0, reason=f"method={method}\nurl=http://test")
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is False


def test_optin__custom_methods_list() -> None:
"""Custom retry_transport_methods should be honored."""
exc = RestTimeoutError(status=0, reason="method=POST\nurl=http://test")
config = TenacityConfig(
retry_transport_errors=True,
retry_transport_methods=["POST"],
)
assert tenacity_should_retry(exc, config) is True


def test_optin__custom_methods_excludes_default() -> None:
"""Custom retry_transport_methods can exclude default methods like GET."""
exc = RestTimeoutError(status=0, reason="method=GET\nurl=http://test")
config = TenacityConfig(
retry_transport_errors=True,
retry_transport_methods=["DELETE"],
)
assert tenacity_should_retry(exc, config) is False


# --- Regression tests: existing HTTP error retry behavior unchanged ---


@pytest.mark.parametrize(
("exc", "desc"),
[
(ServiceException(status=500, reason="Internal Server Error"), "5xx"),
(NotFoundException(status=404, reason="Not Found"), "404"),
],
ids=["service-exception", "not-found"],
)
def test_regression__http_errors_still_retried(exc: Exception, desc: str) -> None:
"""ServiceException (5xx) and NotFoundException (404) should still be retried."""
config = TenacityConfig()
assert tenacity_should_retry(exc, config) is True


def test_regression__backward_compat_no_config() -> None:
"""ServiceException should be retried even without config (backward compat)."""
exc = ServiceException(status=500, reason="Internal Server Error")
assert tenacity_should_retry(exc) is True


# --- Unit tests for _extract_method_from_reason ---


@pytest.mark.parametrize(
("reason", "expected"),
[
("method=GET\nurl=http://test", "GET"),
("method=POST\nurl=http://test", "POST"),
("method=delete\nurl=http://test", "delete"),
("prefix method=PUT suffix", "PUT"),
("some error without method", None),
("method=\nurl=http://test", None),
("", None),
(None, None),
],
ids=[
"get-uppercase",
"post-uppercase",
"lowercase-preserved",
"embedded-in-text",
"no-method-field",
"empty-method-value",
"empty-string",
"none",
],
)
def test_extract_method__parses_reason(
reason: str | None, expected: str | None
) -> None:
"""_extract_method_from_reason should correctly parse HTTP method from reason."""
assert _extract_method_from_reason(reason) == expected


# --- Edge cases for retry behavior ---


@pytest.mark.parametrize(
"reason",
["some error without method", "", None],
ids=["no-method-field", "empty-string", "none"],
)
def test_edge__unparseable_reason_not_retried(reason: str | None) -> None:
"""If method cannot be extracted from reason, should not retry."""
exc = RestTimeoutError(status=0, reason=reason)
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is False


def test_edge__case_insensitive_method_matching() -> None:
"""Method matching should be case-insensitive."""
exc = RestTimeoutError(status=0, reason="method=get\nurl=http://test")
config = TenacityConfig(retry_transport_errors=True)
assert tenacity_should_retry(exc, config) is True


# --- Config defaults tests ---


def test_config__default_retry_transport_errors_is_false() -> None:
"""retry_transport_errors should default to False."""
config = TenacityConfig()
assert config.retry_transport_errors is False


def test_config__default_retry_transport_methods() -> None:
"""retry_transport_methods should default to GET and DELETE."""
config = TenacityConfig()
assert set(config.retry_transport_methods) == {"GET", "DELETE"}
Loading