Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions estuary-cdk/estuary_cdk/http.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from dataclasses import dataclass
from logging import Logger
from aiohttp.client_exceptions import ConnectionTimeoutError
from estuary_cdk.incremental_json_processor import Remainder
from pydantic import BaseModel
from typing import AsyncGenerator, Any, TypeVar, Union, Callable
from typing import AsyncGenerator, Any, TypeVar, Callable, Awaitable
import abc
import aiohttp
import asyncio
Expand All @@ -27,7 +25,7 @@

DEFAULT_AUTHORIZATION_HEADER = "Authorization"

StreamedObject = TypeVar("StreamedObject", bound=BaseModel)
T = TypeVar("T")

class Headers(dict[str, Any]):
pass
Expand Down Expand Up @@ -79,34 +77,23 @@ async def request(
) -> bytes:
"""Request a url and return its body as bytes"""

max_attempts = 3
attempt = 1
async def _do_request() -> bytes:
chunks: list[bytes] = []
_, body_generator = await self._request_stream(
log, url, method, params, json, form, _with_token, headers
)

while True:
try:
chunks: list[bytes] = []
_, body_generator = await self._request_stream(
log, url, method, params, json, form, _with_token, headers
)
async for chunk in body_generator():
chunks.append(chunk)

async for chunk in body_generator():
chunks.append(chunk)
if len(chunks) == 0:
return b""
elif len(chunks) == 1:
return chunks[0]
else:
return b"".join(chunks)

if len(chunks) == 0:
return b""
elif len(chunks) == 1:
return chunks[0]
else:
return b"".join(chunks)
except ConnectionTimeoutError as e:
if attempt <= max_attempts:
log.warning(
f"Connection timeout error (will retry)",
{"url": url, "method": method, "attempt": attempt, "error": str(e)}
)
attempt += 1
else:
raise
return await self._retry_on_timeout(log, url, method, _do_request)

async def request_lines(
self,
Expand All @@ -121,9 +108,12 @@ async def request_lines(
) -> tuple[Headers, BodyGeneratorFunction]:
"""Request a url and return its response as streaming lines, as they arrive"""

headers, body = await self._request_stream(
log, url, method, params, json, form, True, headers
)
async def _do_request() -> tuple[Headers, BodyGeneratorFunction]:
return await self._request_stream(
log, url, method, params, json, form, True, headers
)

resp_headers, body = await self._retry_on_timeout(log, url, method, _do_request)

async def gen() -> AsyncGenerator[bytes, None]:
buffer = b""
Expand All @@ -136,7 +126,7 @@ async def gen() -> AsyncGenerator[bytes, None]:
if buffer:
yield buffer

return (headers, gen)
return (resp_headers, gen)

async def request_stream(
self,
Expand All @@ -150,8 +140,33 @@ async def request_stream(
) -> tuple[Headers, BodyGeneratorFunction]:
"""Request a url and and return the raw response as a stream of bytes"""

headers, body = await self._request_stream(log, url, method, params, json, form, True, headers)
return (headers, body)
async def _do_request() -> tuple[Headers, BodyGeneratorFunction]:
return await self._request_stream(log, url, method, params, json, form, True, headers)

return await self._retry_on_timeout(log, url, method, _do_request)

async def _retry_on_timeout(
self,
log: Logger,
url: str,
method: str,
operation: Callable[[], Awaitable[T]],
) -> T:
max_attempts = 3
attempt = 1

while True:
try:
return await operation()
except asyncio.TimeoutError as e:
if attempt <= max_attempts:
log.warning(
f"Connection timeout error (will retry)",
{"url": url, "method": method, "attempt": attempt, "error": str(e)}
)
attempt += 1
else:
raise

@abc.abstractmethod
async def _request_stream(
Expand Down
Loading