Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Foundation for llm integration with Dremio"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"aiohttp>=3.11.12",
"aiohttp>=3.12.15",
"beeai-framework>=0.1.8",
"black>=25.1.0",
"click>=8.1.8",
Expand Down
94 changes: 88 additions & 6 deletions src/dremioai/api/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import asyncio

from aiohttp import ClientSession, ClientResponse, ClientResponseError
from typing import (
AnyStr,
Callable,
Optional,
Dict,
TypeAlias,
Union,
TextIO,
Awaitable,
Any,
)
from pathlib import Path
from typing import AnyStr, Callable, Optional, Dict, TypeAlias, Union, TextIO
from dremioai.log import logger
from json import loads
from pydantic import BaseModel, ValidationError
from http import HTTPStatus

from dremioai.config import settings
from dremioai.api.oauth2 import get_oauth2_tokens

DeserializationStrategy: TypeAlias = Union[Callable, BaseModel]


class RetryConfig:
def __init__(self):
if settings.instance() and settings.instance().dremio:
self.config = settings.instance().dremio.http_retry
else:
self.config = settings.HttpRetry()

@property
def max_retries(self) -> int:
"""Expose max_retries from config for convenience"""
return self.config.max_retries

def get_config_delay(self, attempt_number: int = 0) -> float:
return self.config.initial_delay * (
self.config.backoff_multiplier**attempt_number
)

def get_delay(
self,
response: ClientResponse,
attempt_number: int,
) -> float:
retry_after = response.headers.get("Retry-After")
delay = self.get_config_delay(attempt_number=attempt_number)
if retry_after is not None:
try:
delay = min(delay, int(retry_after))
except (ValueError, TypeError) as e:
logger().debug(
f"Invalid Retry-After header, using exponential backoff - {e}"
)

return min(delay, self.config.max_delay)


async def retry_middleware(
req, handler: Callable[[any], Awaitable[ClientResponse]]
) -> ClientResponse:
"""
Middleware that automatically retries requests on 429 (rate limit) errors.
Uses exponential backoff with configurable parameters from settings.
"""
retry_config = RetryConfig()
for attempt in range(retry_config.max_retries + 1):
response = await handler(req)
if response.status != HTTPStatus.TOO_MANY_REQUESTS:
break

delay = retry_config.get_delay(response, attempt)
logger(f"{__name__}.retry").warning(
f"Rate limited (429) on {req.method} {req.url.path}. "
f"Retry {attempt + 1}/{retry_config.max_retries} after {delay:.2f}s"
)
await asyncio.sleep(delay)

return response


class AsyncHttpClient:
def __init__(self, uri: AnyStr, token: AnyStr):
self.uri = uri
Expand Down Expand Up @@ -83,6 +154,18 @@ async def handle_response(
)
await self.download(response, file)

def log_request(
self, method: str, endpoint: str, params: Optional[Dict[AnyStr, Any]] = None
):
if logger().isEnabledFor(logging.DEBUG):
sanitized_headers = {
k: (v if k != "Authorization" else "Bearer <redacted>")
for k, v in self.headers.items()
}
logger().debug(
f"{method} {self.uri}{endpoint}', headers={sanitized_headers}, params={params}"
)

async def get(
self,
endpoint: AnyStr,
Expand All @@ -92,10 +175,8 @@ async def get(
file: Optional[TextIO] = None,
top_level_list: bool = False,
):
async with ClientSession() as session:
logger().info(
f"{self.uri}{endpoint}', headers={self.headers}, params={params}"
)
async with ClientSession(middlewares=(retry_middleware,)) as session:
self.log_request("GET", endpoint, params)
async with session.get(
f"{self.uri}{endpoint}",
headers=self.headers,
Expand All @@ -115,7 +196,8 @@ async def post(
file: Optional[TextIO] = None,
top_level_list: bool = False,
):
async with ClientSession() as session:
async with ClientSession(middlewares=(retry_middleware,)) as session:
self.log_request("POST", endpoint)
async with session.post(
f"{self.uri}{endpoint}", headers=self.headers, json=body, ssl=False
) as response:
Expand Down
21 changes: 21 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ class Metrics(BaseModel):
model_config = ConfigDict(validate_assignment=True)


class HttpRetry(BaseModel):
"""Configuration for HTTP retry behavior with exponential backoff"""

max_retries: Optional[int] = Field(
default=3,
description="Maximum number of retry attempts for rate-limited requests",
)
initial_delay: Optional[float] = Field(
default=1.0, description="Initial delay in seconds before first retry"
)
max_delay: Optional[float] = Field(
default=60.0, description="Maximum delay in seconds between retries"
)
backoff_multiplier: Optional[float] = Field(
default=2.0, description="Multiplier for exponential backoff"
)
model_config = ConfigDict(validate_assignment=True)


class Dremio(BaseModel):
uri: Annotated[
Union[str, HttpUrl, DremioCloudUri], AfterValidator(_resolve_dremio_uri)
Expand All @@ -160,6 +179,8 @@ class Dremio(BaseModel):
wlm: Optional[Wlm] = None
# Metrics server configuration
metrics: Optional[Metrics] = None
# HTTP retry configuration
http_retry: Optional[HttpRetry] = Field(default_factory=HttpRetry)
model_config = ConfigDict(validate_assignment=True)

@field_serializer("raw_pat")
Expand Down
Loading