Skip to content

Commit 97ee4a0

Browse files
committed
estuary-cdk: allow callers to pass a should_retry() handler for 5xx errors
Some APIs return 5xx errors when too much data is requested. This commit allows connectors to tell the CDK to bubble up 5xx errors instead of transparently retrying them. The CDK's current behavior is maintained for existing connectors; if should_retry is not provided, the request is retried.
1 parent 4faba80 commit 97ee4a0

File tree

1 file changed

+55
-8
lines changed

1 file changed

+55
-8
lines changed

estuary-cdk/estuary_cdk/http.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from logging import Logger
33
from pydantic import BaseModel
4-
from typing import AsyncGenerator, Any, TypeVar, Callable, Awaitable
4+
from typing import AsyncGenerator, Any, Awaitable, TypeVar, Callable, Protocol
55
import abc
66
import aiohttp
77
import asyncio
@@ -43,6 +43,35 @@
4343
BodyGeneratorFunction = Callable[[], AsyncGenerator[bytes, None]]
4444
HeadersAndBodyGenerator = tuple[Headers, BodyGeneratorFunction]
4545

46+
class ShouldRetryProtocol(Protocol):
47+
"""
48+
ShouldRetryProtocol defines a callback function signature for custom retry logic.
49+
50+
Implementations should return True if the HTTP request should be retried, or False
51+
if the request should fail immediately with an HTTPError. This allows connectors
52+
to implement custom retry strategies for server errors (5xx status codes) based
53+
on the response status, headers, body, and current attempt number.
54+
55+
Parameters:
56+
* `status` is the HTTP status code of the response
57+
* `headers` are the HTTP response headers as a dict
58+
* `body` is the response body as bytes
59+
* `attempt` is the current attempt number (starts at 1)
60+
61+
Example usage:
62+
def custom_retry(status: int, headers: Headers, body: bytes, attempt: int) -> bool:
63+
if status == 503 and attempt > 2:
64+
return False
65+
return status >= 500
66+
"""
67+
def __call__(
68+
self,
69+
status: int,
70+
headers: Headers,
71+
body: bytes,
72+
attempt: int
73+
) -> bool: ...
74+
4675

4776
class HTTPError(RuntimeError):
4877
"""
@@ -71,6 +100,7 @@ class HTTPSession(abc.ABC):
71100
* `params` are encoded as URL parameters of the query
72101
* `json` is a JSON-encoded request body (if set, `form` cannot be)
73102
* `form` is a form URL-encoded request body (if set, `json` cannot be)
103+
* `should_retry` determines whether 5xx errors are retried or bubbled up. If not provided, all 5xx errors are retried.
74104
"""
75105

76106
async def request(
@@ -83,12 +113,13 @@ async def request(
83113
form: dict[str, Any] | None = None,
84114
_with_token: bool = True, # Unstable internal API.
85115
headers: dict[str, Any] | None = None,
116+
should_retry: ShouldRetryProtocol | None = None,
86117
) -> bytes:
87118
"""Request a url and return its body as bytes"""
88119

89120
chunks: list[bytes] = []
90121
_, body_generator = await self._request_stream(
91-
log, url, method, params, json, form, _with_token, headers
122+
log, url, method, params, json, form, _with_token, headers, should_retry,
92123
)
93124

94125
async for chunk in body_generator():
@@ -111,11 +142,12 @@ async def request_lines(
111142
form: dict[str, Any] | None = None,
112143
delim: bytes = b"\n",
113144
headers: dict[str, Any] | None = None,
145+
should_retry: ShouldRetryProtocol | None = None,
114146
) -> tuple[Headers, BodyGeneratorFunction]:
115147
"""Request a url and return its response as streaming lines, as they arrive"""
116148

117149
resp_headers, body = await self._request_stream(
118-
log, url, method, params, json, form, True, headers
150+
log, url, method, params, json, form, True, headers, should_retry,
119151
)
120152

121153
async def gen() -> AsyncGenerator[bytes, None]:
@@ -141,10 +173,11 @@ async def request_stream(
141173
form: dict[str, Any] | None = None,
142174
_with_token: bool = True, # Unstable internal API.
143175
headers: dict[str, Any] | None = None,
176+
should_retry: ShouldRetryProtocol | None = None,
144177
) -> tuple[Headers, BodyGeneratorFunction]:
145178
"""Request a url and and return the raw response as a stream of bytes"""
146179

147-
return await self._request_stream(log, url, method, params, json, form, _with_token, headers)
180+
return await self._request_stream(log, url, method, params, json, form, _with_token, headers, should_retry)
148181

149182

150183
@abc.abstractmethod
@@ -158,6 +191,7 @@ async def _request_stream(
158191
form: dict[str, Any] | None,
159192
_with_token: bool,
160193
headers: dict[str, Any] | None = None,
194+
should_retry: ShouldRetryProtocol | None = None,
161195
) -> HeadersAndBodyGenerator: ...
162196

163197
# TODO(johnny): This is an unstable API.
@@ -484,14 +518,17 @@ async def _request_stream(
484518
form: dict[str, Any] | None,
485519
_with_token: bool,
486520
headers: dict[str, Any] | None = None,
521+
should_retry: ShouldRetryProtocol | None = None,
487522
) -> HeadersAndBodyGenerator:
488523
if headers is None:
489524
headers = {}
525+
attempt = 0
490526

491527
while True:
492528
cur_delay = self.rate_limiter.delay
493529
await asyncio.sleep(cur_delay)
494530

531+
attempt += 1
495532
resp = await self._retry_on_connection_error(
496533
log, url, method,
497534
lambda: self._establish_connection_and_get_response(
@@ -516,10 +553,20 @@ async def _request_stream(
516553

517554
elif resp.status >= 500 and resp.status < 600:
518555
body = await resp.read()
519-
log.warning(
520-
"server internal error (will retry)",
521-
{"body": body.decode("utf-8")},
522-
)
556+
557+
if (
558+
should_retry is None
559+
or should_retry(resp.status, dict(resp.headers), body, attempt)
560+
):
561+
log.warning(
562+
"server internal error (will retry)",
563+
{"body": body.decode("utf-8")},
564+
)
565+
else:
566+
raise HTTPError(
567+
f"Encountered HTTP error status {resp.status}.\nURL: {url}\nResponse:\n{body.decode('utf-8')}",
568+
resp.status,
569+
)
523570
elif resp.status >= 400 and resp.status < 500:
524571
body = await resp.read()
525572
raise HTTPError(

0 commit comments

Comments
 (0)