Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/dremioai/api/dremio/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from pydantic import BaseModel, Field
from typing import List, Dict, Union, Optional, Any
from dataclasses import dataclass

from enum import auto
from datetime import datetime
Expand Down Expand Up @@ -251,3 +252,88 @@ async def run_query(
deser=QuerySubmission,
)
return await get_results(project_id, qs, use_df=use_df, client=client)


@dataclass
class QueryResult:
df: pd.DataFrame
total_rows: int
returned_rows: int

@property
def truncated(self) -> bool:
return self.returned_rows < self.total_rows


async def run_query_capped(
query: Union[Query, str], max_rows: int = 500
) -> QueryResult:
"""Submit a query and fetch at most *max_rows* rows (0 = unlimited).

Returns a ``QueryResult`` with the DataFrame, total row count from the
job, the number of rows actually fetched, and a ``truncated`` flag.
"""
client = AsyncHttpClient()
if not isinstance(query, Query):
engine_name = (
settings.instance().dremio.wlm.engine_name
if settings.instance().dremio.wlm is not None
else None
)
query = Query(sql=query, engineName=engine_name)

project_id = settings.instance().dremio.project_id
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
qs: QuerySubmission = await client.post(
f"{endpoint}/sql",
body=query.model_dump(by_alias=True, exclude_none=True),
deser=QuerySubmission,
)

delay = settings.instance().dremio.api.polling_interval
job: Job = await client.get(f"{endpoint}/job/{qs.id}", deser=Job)
while not job.done:
await asyncio.sleep(delay)
job = await client.get(f"{endpoint}/job/{qs.id}", deser=Job)

if not job.succeeded:
emsg = (
job.error_message
if job.error_message
else (
job.cancellation_reason
if job.job_state == JobState.CANCELED
else "Unknown error"
)
)
raise RuntimeError(f"Job {qs.id} failed: {emsg}")

total_rows = job.row_count or 0
if total_rows == 0:
return QueryResult(df=pd.DataFrame(), total_rows=0, returned_rows=0)

fetch_rows = total_rows if max_rows == 0 else min(total_rows, max_rows)
page_size = min(500, fetch_rows)

results = await run_in_parallel(
[
_fetch_results(None, None, project_id, qs.id, off, page_size)
for off in range(0, fetch_rows, page_size)
]
)
jr = JobResultsWrapper(results)

all_rows = list(itertools.chain.from_iterable(jr_page.rows for jr_page in jr))
# The last page may return more rows than needed; trim to fetch_rows.
all_rows = all_rows[:fetch_rows]

if all_rows and jr[0].result_schema:
columns = [rs.name for rs in jr[0].result_schema]
df = pd.DataFrame(data=all_rows, columns=columns)
for rs in jr[0].result_schema:
if rs.type.name == "TIMESTAMP":
df[rs.name] = pd.to_datetime(df[rs.name])
else:
df = pd.DataFrame(data=all_rows)

return QueryResult(df=df, total_rows=total_rows, returned_rows=len(df))
8 changes: 8 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,14 @@ class Dremio(FlagAwareModel):
description="How long (seconds) to cache JWKS keys before refetching. Default: 3600 (1 hour).",
)
wlm: Optional[Wlm] = None
max_result_rows: Optional[int] = Field(
default=500,
description="Maximum number of rows returned by RunSqlQuery. Use 0 for unlimited.",
)
max_result_bytes: Optional[int] = Field(
default=204_800,
description="Maximum UTF-8 byte size of RunSqlQuery results. Enforced after row cap. Use 0 for unlimited.",
)
api: Optional[ApiSettings] = Field(default_factory=ApiSettings)
metrics: Optional[Metrics] = None

Expand Down
52 changes: 48 additions & 4 deletions src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from starlette.requests import Request

from dremioai import log
import json
import re
import functools

Expand Down Expand Up @@ -351,12 +352,19 @@ def ensure_query_allowed(s: str):

@secured
@with_metrics
async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]], str, bool, int]]:
"""Run a SQL query on the Dremio cluster and return the results.
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes.
DML statements (INSERT, UPDATE, DELETE, etc.) may or may not be permitted depending on project configuration.
If a DML query is not allowed, this will return an error.

Results are capped at max_result_rows rows (default 500) and max_result_bytes bytes
(default 200 KB). When truncated, the response includes 'truncated', 'total_rows',
'returned_rows', and 'truncation_reason' fields. To reduce result size, add LIMIT or
GROUP BY to your query. Configure limits via dremio.max_result_rows and
dremio.max_result_bytes settings (env vars DREMIOAI_DREMIO__MAX_RESULT_ROWS,
DREMIOAI_DREMIO__MAX_RESULT_BYTES). Set either to 0 for unlimited.

Args:
query: sql query
"""
Expand All @@ -367,9 +375,45 @@ async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str
"error": "Only SELECT queries are allowed. DML statements are not permitted.",
}
try:
query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}"
df = await sql.run_query(query=query, use_df=True)
return {"result": _df_to_json_records(df)}
tagged_query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}"
dremio_settings = settings.instance().dremio
max_rows = dremio_settings.get("max_result_rows")
if max_rows is None:
max_rows = 500
max_bytes = dremio_settings.get("max_result_bytes")
if max_bytes is None:
max_bytes = 204_800

qr = await sql.run_query_capped(query=tagged_query, max_rows=max_rows)
records = _df_to_json_records(qr.df)

truncation_reason = None
if qr.truncated:
truncation_reason = "row_limit"

# Enforce byte cap
if max_bytes > 0 and records:
kept = []
running_bytes = 0
for rec in records:
rec_bytes = len(json.dumps(rec).encode("utf-8"))
if running_bytes + rec_bytes > max_bytes:
truncation_reason = "byte_limit"
break
kept.append(rec)
running_bytes += rec_bytes
if truncation_reason == "byte_limit":
records = kept

if truncation_reason:
return {
"result": records,
"truncated": True,
"total_rows": qr.total_rows,
"returned_rows": len(records),
"truncation_reason": truncation_reason,
}
return {"result": records}
except RuntimeError as e:
return {
"error": str(e),
Expand Down
153 changes: 152 additions & 1 deletion tests/api/dremio/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
#

import pytest
from dremioai.api.dremio.sql import Job
from unittest.mock import AsyncMock, patch, MagicMock
import pandas as pd
from dremioai.api.dremio.sql import (
Job,
JobResults,
QueryResult,
QuerySubmission,
run_query_capped,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -52,3 +60,146 @@
)
def test_basic_job(js: str):
j = Job.model_validate_json(js)


# -- helpers for run_query_capped tests ----------------------------------------

def _make_completed_job(row_count: int) -> Job:
return Job.model_validate(
{"jobState": "COMPLETED", "rowCount": row_count, "queryType": "REST"}
)


def _make_job_results(rows, schema_names=None):
if schema_names is None:
schema_names = list(rows[0].keys()) if rows else []
return JobResults.model_validate(
{
"rowCount": len(rows),
"schema": [{"name": n, "type": {"name": "VARCHAR"}} for n in schema_names],
"rows": rows,
}
)


def _mock_settings(project_id=None, engine_name=None, polling_interval=0):
s = MagicMock()
s.dremio.project_id = project_id
s.dremio.wlm = None
s.dremio.api.polling_interval = polling_interval
return s


@pytest.mark.asyncio
async def test_run_query_capped_under_limit():
"""Rows below max_rows => truncated=False, all rows returned."""
rows = [{"a": str(i)} for i in range(3)]
job = _make_completed_job(3)
jr = _make_job_results(rows)

with (
patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient,
patch("dremioai.api.dremio.sql.settings") as mock_settings_mod,
):
mock_settings_mod.instance.return_value = _mock_settings()
client = MockClient.return_value
client.post = AsyncMock(return_value=QuerySubmission(id="j1"))
client.get = AsyncMock(return_value=job)
with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)):
qr = await run_query_capped("SELECT 1", max_rows=10)

assert qr.truncated is False
assert qr.returned_rows == 3
assert qr.total_rows == 3
assert len(qr.df) == 3


@pytest.mark.asyncio
async def test_run_query_capped_row_limit_hit():
"""Row limit fires => truncated=True."""
job = _make_completed_job(100)
rows = [{"a": str(i)} for i in range(10)]
jr = _make_job_results(rows)

with (
patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient,
patch("dremioai.api.dremio.sql.settings") as mock_settings_mod,
):
mock_settings_mod.instance.return_value = _mock_settings()
client = MockClient.return_value
client.post = AsyncMock(return_value=QuerySubmission(id="j1"))
client.get = AsyncMock(return_value=job)
with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)):
qr = await run_query_capped("SELECT 1", max_rows=10)

assert qr.truncated is True
assert qr.returned_rows == 10
assert qr.total_rows == 100


@pytest.mark.asyncio
async def test_run_query_capped_unlimited():
"""max_rows=0 fetches all rows."""
job = _make_completed_job(5)
rows = [{"a": str(i)} for i in range(5)]
jr = _make_job_results(rows)

with (
patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient,
patch("dremioai.api.dremio.sql.settings") as mock_settings_mod,
):
mock_settings_mod.instance.return_value = _mock_settings()
client = MockClient.return_value
client.post = AsyncMock(return_value=QuerySubmission(id="j1"))
client.get = AsyncMock(return_value=job)
with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)):
qr = await run_query_capped("SELECT 1", max_rows=0)

assert qr.truncated is False
assert qr.returned_rows == 5
assert qr.total_rows == 5


@pytest.mark.asyncio
async def test_run_query_capped_empty_result():
"""row_count=0 => empty DataFrame, truncated=False."""
job = _make_completed_job(0)

with (
patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient,
patch("dremioai.api.dremio.sql.settings") as mock_settings_mod,
):
mock_settings_mod.instance.return_value = _mock_settings()
client = MockClient.return_value
client.post = AsyncMock(return_value=QuerySubmission(id="j1"))
client.get = AsyncMock(return_value=job)

qr = await run_query_capped("SELECT 1", max_rows=10)

assert qr.truncated is False
assert qr.returned_rows == 0
assert qr.total_rows == 0
assert qr.df.empty


@pytest.mark.asyncio
async def test_run_query_capped_exact_boundary():
"""row_count == max_rows => truncated=False."""
job = _make_completed_job(5)
rows = [{"a": str(i)} for i in range(5)]
jr = _make_job_results(rows)

with (
patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient,
patch("dremioai.api.dremio.sql.settings") as mock_settings_mod,
):
mock_settings_mod.instance.return_value = _mock_settings()
client = MockClient.return_value
client.post = AsyncMock(return_value=QuerySubmission(id="j1"))
client.get = AsyncMock(return_value=job)
with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)):
qr = await run_query_capped("SELECT 1", max_rows=5)

assert qr.truncated is False
assert qr.returned_rows == 5
assert qr.total_rows == 5
2 changes: 2 additions & 0 deletions tests/config/golden_flag_keys.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ flag_keys:
- dremio.extract_org_id_from_jwt
- dremio.jwks_cache_lifespan
- dremio.jwks_uri
- dremio.max_result_bytes
- dremio.max_result_rows
- dremio.metrics.enabled
- dremio.metrics.port
- dremio.wlm.engine_name
Expand Down
5 changes: 3 additions & 2 deletions tests/test_simple_fastmcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ async def test_simple_tool_invocation(self):

# Test RunSqlQuery tool with proper mocking
with patch(
"dremioai.api.dremio.sql.run_query", new_callable=AsyncMock
"dremioai.api.dremio.sql.run_query_capped", new_callable=AsyncMock
) as mock_run_query:
from dremioai.api.dremio.sql import QueryResult
mock_df = pd.DataFrame([{"test_column": 1}])
mock_run_query.return_value = mock_df
mock_run_query.return_value = QueryResult(df=mock_df, total_rows=1, returned_rows=1)

# Call the tool
result = await fastmcp_server.call_tool(
Expand Down
6 changes: 4 additions & 2 deletions tests/tools/test_output_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from dremioai.config import settings
from dremioai.tools.tools import GetUsefulSystemTableNames, GetSchemaOfTable, RunSqlQuery
from dremioai.api.dremio.sql import QueryResult


async def mock_mcp_validate_tool_output(tool, *args, **kwargs):
Expand Down Expand Up @@ -82,8 +83,9 @@ async def test_run_sql_query_json_safe_output():
]
)

with patch("dremioai.tools.tools.sql.run_query", new_callable=AsyncMock) as mock_run_query:
mock_run_query.return_value = df
qr = QueryResult(df=df, total_rows=len(df), returned_rows=len(df))
with patch("dremioai.tools.tools.sql.run_query_capped", new_callable=AsyncMock) as mock_run_query_capped:
mock_run_query_capped.return_value = qr
token = settings._settings.set(
settings.Settings.model_validate({"dremio": {"uri": "https://test"}})
)
Expand Down
Loading
Loading