diff --git a/src/dremioai/api/dremio/sql.py b/src/dremioai/api/dremio/sql.py index 02f5793..d677293 100644 --- a/src/dremioai/api/dremio/sql.py +++ b/src/dremioai/api/dremio/sql.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from typing import List, Dict, Union, Optional, Any +from dataclasses import dataclass from enum import auto from datetime import datetime @@ -251,3 +252,88 @@ async def run_query( deser=QuerySubmission, ) return await get_results(project_id, qs, use_df=use_df, client=client) + + +@dataclass +class QueryResult: + df: pd.DataFrame + total_rows: int + returned_rows: int + + @property + def truncated(self) -> bool: + return self.returned_rows < self.total_rows + + +async def run_query_capped( + query: Union[Query, str], max_rows: int = 500 +) -> QueryResult: + """Submit a query and fetch at most *max_rows* rows (0 = unlimited). + + Returns a ``QueryResult`` with the DataFrame, total row count from the + job, the number of rows actually fetched, and a ``truncated`` flag. + """ + client = AsyncHttpClient() + if not isinstance(query, Query): + engine_name = ( + settings.instance().dremio.wlm.engine_name + if settings.instance().dremio.wlm is not None + else None + ) + query = Query(sql=query, engineName=engine_name) + + project_id = settings.instance().dremio.project_id + endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3" + qs: QuerySubmission = await client.post( + f"{endpoint}/sql", + body=query.model_dump(by_alias=True, exclude_none=True), + deser=QuerySubmission, + ) + + delay = settings.instance().dremio.api.polling_interval + job: Job = await client.get(f"{endpoint}/job/{qs.id}", deser=Job) + while not job.done: + await asyncio.sleep(delay) + job = await client.get(f"{endpoint}/job/{qs.id}", deser=Job) + + if not job.succeeded: + emsg = ( + job.error_message + if job.error_message + else ( + job.cancellation_reason + if job.job_state == JobState.CANCELED + else "Unknown error" + ) + ) + raise RuntimeError(f"Job {qs.id} failed: {emsg}") + + total_rows = job.row_count or 0 + if total_rows == 0: + return QueryResult(df=pd.DataFrame(), total_rows=0, returned_rows=0) + + fetch_rows = total_rows if max_rows == 0 else min(total_rows, max_rows) + page_size = min(500, fetch_rows) + + results = await run_in_parallel( + [ + _fetch_results(None, None, project_id, qs.id, off, page_size) + for off in range(0, fetch_rows, page_size) + ] + ) + jr = JobResultsWrapper(results) + + all_rows = list(itertools.chain.from_iterable(jr_page.rows for jr_page in jr)) + # The last page may return more rows than needed; trim to fetch_rows. + all_rows = all_rows[:fetch_rows] + + if all_rows and jr[0].result_schema: + columns = [rs.name for rs in jr[0].result_schema] + df = pd.DataFrame(data=all_rows, columns=columns) + for rs in jr[0].result_schema: + if rs.type.name == "TIMESTAMP": + df[rs.name] = pd.to_datetime(df[rs.name]) + else: + df = pd.DataFrame(data=all_rows) + + return QueryResult(df=df, total_rows=total_rows, returned_rows=len(df)) diff --git a/src/dremioai/config/settings.py b/src/dremioai/config/settings.py index 6f230fb..36e2839 100644 --- a/src/dremioai/config/settings.py +++ b/src/dremioai/config/settings.py @@ -265,6 +265,14 @@ class Dremio(FlagAwareModel): description="How long (seconds) to cache JWKS keys before refetching. Default: 3600 (1 hour).", ) wlm: Optional[Wlm] = None + max_result_rows: Optional[int] = Field( + default=500, + description="Maximum number of rows returned by RunSqlQuery. Use 0 for unlimited.", + ) + max_result_bytes: Optional[int] = Field( + default=204_800, + description="Maximum UTF-8 byte size of RunSqlQuery results. Enforced after row cap. Use 0 for unlimited.", + ) api: Optional[ApiSettings] = Field(default_factory=ApiSettings) metrics: Optional[Metrics] = None diff --git a/src/dremioai/tools/tools.py b/src/dremioai/tools/tools.py index 95cf107..0b2761a 100644 --- a/src/dremioai/tools/tools.py +++ b/src/dremioai/tools/tools.py @@ -39,6 +39,7 @@ from starlette.requests import Request from dremioai import log +import json import re import functools @@ -351,12 +352,19 @@ def ensure_query_allowed(s: str): @secured @with_metrics - async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]: + async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]], str, bool, int]]: """Run a SQL query on the Dremio cluster and return the results. Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes. DML statements (INSERT, UPDATE, DELETE, etc.) may or may not be permitted depending on project configuration. If a DML query is not allowed, this will return an error. + Results are capped at max_result_rows rows (default 500) and max_result_bytes bytes + (default 200 KB). When truncated, the response includes 'truncated', 'total_rows', + 'returned_rows', and 'truncation_reason' fields. To reduce result size, add LIMIT or + GROUP BY to your query. Configure limits via dremio.max_result_rows and + dremio.max_result_bytes settings (env vars DREMIOAI_DREMIO__MAX_RESULT_ROWS, + DREMIOAI_DREMIO__MAX_RESULT_BYTES). Set either to 0 for unlimited. + Args: query: sql query """ @@ -367,9 +375,45 @@ async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str "error": "Only SELECT queries are allowed. DML statements are not permitted.", } try: - query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}" - df = await sql.run_query(query=query, use_df=True) - return {"result": _df_to_json_records(df)} + tagged_query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}" + dremio_settings = settings.instance().dremio + max_rows = dremio_settings.get("max_result_rows") + if max_rows is None: + max_rows = 500 + max_bytes = dremio_settings.get("max_result_bytes") + if max_bytes is None: + max_bytes = 204_800 + + qr = await sql.run_query_capped(query=tagged_query, max_rows=max_rows) + records = _df_to_json_records(qr.df) + + truncation_reason = None + if qr.truncated: + truncation_reason = "row_limit" + + # Enforce byte cap + if max_bytes > 0 and records: + kept = [] + running_bytes = 0 + for rec in records: + rec_bytes = len(json.dumps(rec).encode("utf-8")) + if running_bytes + rec_bytes > max_bytes: + truncation_reason = "byte_limit" + break + kept.append(rec) + running_bytes += rec_bytes + if truncation_reason == "byte_limit": + records = kept + + if truncation_reason: + return { + "result": records, + "truncated": True, + "total_rows": qr.total_rows, + "returned_rows": len(records), + "truncation_reason": truncation_reason, + } + return {"result": records} except RuntimeError as e: return { "error": str(e), diff --git a/tests/api/dremio/test_sql.py b/tests/api/dremio/test_sql.py index c0faf21..5a34384 100644 --- a/tests/api/dremio/test_sql.py +++ b/tests/api/dremio/test_sql.py @@ -15,7 +15,15 @@ # import pytest -from dremioai.api.dremio.sql import Job +from unittest.mock import AsyncMock, patch, MagicMock +import pandas as pd +from dremioai.api.dremio.sql import ( + Job, + JobResults, + QueryResult, + QuerySubmission, + run_query_capped, +) @pytest.mark.parametrize( @@ -52,3 +60,146 @@ ) def test_basic_job(js: str): j = Job.model_validate_json(js) + + +# -- helpers for run_query_capped tests ---------------------------------------- + +def _make_completed_job(row_count: int) -> Job: + return Job.model_validate( + {"jobState": "COMPLETED", "rowCount": row_count, "queryType": "REST"} + ) + + +def _make_job_results(rows, schema_names=None): + if schema_names is None: + schema_names = list(rows[0].keys()) if rows else [] + return JobResults.model_validate( + { + "rowCount": len(rows), + "schema": [{"name": n, "type": {"name": "VARCHAR"}} for n in schema_names], + "rows": rows, + } + ) + + +def _mock_settings(project_id=None, engine_name=None, polling_interval=0): + s = MagicMock() + s.dremio.project_id = project_id + s.dremio.wlm = None + s.dremio.api.polling_interval = polling_interval + return s + + +@pytest.mark.asyncio +async def test_run_query_capped_under_limit(): + """Rows below max_rows => truncated=False, all rows returned.""" + rows = [{"a": str(i)} for i in range(3)] + job = _make_completed_job(3) + jr = _make_job_results(rows) + + with ( + patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient, + patch("dremioai.api.dremio.sql.settings") as mock_settings_mod, + ): + mock_settings_mod.instance.return_value = _mock_settings() + client = MockClient.return_value + client.post = AsyncMock(return_value=QuerySubmission(id="j1")) + client.get = AsyncMock(return_value=job) + with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)): + qr = await run_query_capped("SELECT 1", max_rows=10) + + assert qr.truncated is False + assert qr.returned_rows == 3 + assert qr.total_rows == 3 + assert len(qr.df) == 3 + + +@pytest.mark.asyncio +async def test_run_query_capped_row_limit_hit(): + """Row limit fires => truncated=True.""" + job = _make_completed_job(100) + rows = [{"a": str(i)} for i in range(10)] + jr = _make_job_results(rows) + + with ( + patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient, + patch("dremioai.api.dremio.sql.settings") as mock_settings_mod, + ): + mock_settings_mod.instance.return_value = _mock_settings() + client = MockClient.return_value + client.post = AsyncMock(return_value=QuerySubmission(id="j1")) + client.get = AsyncMock(return_value=job) + with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)): + qr = await run_query_capped("SELECT 1", max_rows=10) + + assert qr.truncated is True + assert qr.returned_rows == 10 + assert qr.total_rows == 100 + + +@pytest.mark.asyncio +async def test_run_query_capped_unlimited(): + """max_rows=0 fetches all rows.""" + job = _make_completed_job(5) + rows = [{"a": str(i)} for i in range(5)] + jr = _make_job_results(rows) + + with ( + patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient, + patch("dremioai.api.dremio.sql.settings") as mock_settings_mod, + ): + mock_settings_mod.instance.return_value = _mock_settings() + client = MockClient.return_value + client.post = AsyncMock(return_value=QuerySubmission(id="j1")) + client.get = AsyncMock(return_value=job) + with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)): + qr = await run_query_capped("SELECT 1", max_rows=0) + + assert qr.truncated is False + assert qr.returned_rows == 5 + assert qr.total_rows == 5 + + +@pytest.mark.asyncio +async def test_run_query_capped_empty_result(): + """row_count=0 => empty DataFrame, truncated=False.""" + job = _make_completed_job(0) + + with ( + patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient, + patch("dremioai.api.dremio.sql.settings") as mock_settings_mod, + ): + mock_settings_mod.instance.return_value = _mock_settings() + client = MockClient.return_value + client.post = AsyncMock(return_value=QuerySubmission(id="j1")) + client.get = AsyncMock(return_value=job) + + qr = await run_query_capped("SELECT 1", max_rows=10) + + assert qr.truncated is False + assert qr.returned_rows == 0 + assert qr.total_rows == 0 + assert qr.df.empty + + +@pytest.mark.asyncio +async def test_run_query_capped_exact_boundary(): + """row_count == max_rows => truncated=False.""" + job = _make_completed_job(5) + rows = [{"a": str(i)} for i in range(5)] + jr = _make_job_results(rows) + + with ( + patch("dremioai.api.dremio.sql.AsyncHttpClient") as MockClient, + patch("dremioai.api.dremio.sql.settings") as mock_settings_mod, + ): + mock_settings_mod.instance.return_value = _mock_settings() + client = MockClient.return_value + client.post = AsyncMock(return_value=QuerySubmission(id="j1")) + client.get = AsyncMock(return_value=job) + with patch("dremioai.api.dremio.sql._fetch_results", AsyncMock(return_value=jr)): + qr = await run_query_capped("SELECT 1", max_rows=5) + + assert qr.truncated is False + assert qr.returned_rows == 5 + assert qr.total_rows == 5 diff --git a/tests/config/golden_flag_keys.yaml b/tests/config/golden_flag_keys.yaml index 65df8a7..62a4013 100644 --- a/tests/config/golden_flag_keys.yaml +++ b/tests/config/golden_flag_keys.yaml @@ -10,6 +10,8 @@ flag_keys: - dremio.extract_org_id_from_jwt - dremio.jwks_cache_lifespan - dremio.jwks_uri +- dremio.max_result_bytes +- dremio.max_result_rows - dremio.metrics.enabled - dremio.metrics.port - dremio.wlm.engine_name diff --git a/tests/test_simple_fastmcp_server.py b/tests/test_simple_fastmcp_server.py index 5ad4ad3..9b91fa6 100644 --- a/tests/test_simple_fastmcp_server.py +++ b/tests/test_simple_fastmcp_server.py @@ -91,10 +91,11 @@ async def test_simple_tool_invocation(self): # Test RunSqlQuery tool with proper mocking with patch( - "dremioai.api.dremio.sql.run_query", new_callable=AsyncMock + "dremioai.api.dremio.sql.run_query_capped", new_callable=AsyncMock ) as mock_run_query: + from dremioai.api.dremio.sql import QueryResult mock_df = pd.DataFrame([{"test_column": 1}]) - mock_run_query.return_value = mock_df + mock_run_query.return_value = QueryResult(df=mock_df, total_rows=1, returned_rows=1) # Call the tool result = await fastmcp_server.call_tool( diff --git a/tests/tools/test_output_validation.py b/tests/tools/test_output_validation.py index 609d316..ac4f1ef 100644 --- a/tests/tools/test_output_validation.py +++ b/tests/tools/test_output_validation.py @@ -24,6 +24,7 @@ from mcp.server.fastmcp.utilities.func_metadata import func_metadata from dremioai.config import settings from dremioai.tools.tools import GetUsefulSystemTableNames, GetSchemaOfTable, RunSqlQuery +from dremioai.api.dremio.sql import QueryResult async def mock_mcp_validate_tool_output(tool, *args, **kwargs): @@ -82,8 +83,9 @@ async def test_run_sql_query_json_safe_output(): ] ) - with patch("dremioai.tools.tools.sql.run_query", new_callable=AsyncMock) as mock_run_query: - mock_run_query.return_value = df + qr = QueryResult(df=df, total_rows=len(df), returned_rows=len(df)) + with patch("dremioai.tools.tools.sql.run_query_capped", new_callable=AsyncMock) as mock_run_query_capped: + mock_run_query_capped.return_value = qr token = settings._settings.set( settings.Settings.model_validate({"dremio": {"uri": "https://test"}}) ) diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 6d0597d..ff1da24 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -14,9 +14,13 @@ # limitations under the License. # +import json import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import pandas as pd from dremioai.tools import tools from dremioai.config import settings +from dremioai.api.dremio.sql import QueryResult from typing import Dict, Union from contextlib import contextmanager @@ -479,3 +483,121 @@ def normalize(expected: MCPTool, actual: MCPTool): pp("Actual:") pp(actual) assert False + + +# -- RunSqlQuery.invoke() tests ------------------------------------------------ + +def _mock_dremio_settings(max_rows=500, max_bytes=204_800, allow_dml=False): + s = MagicMock() + s.dremio.get.side_effect = lambda key: { + "max_result_rows": max_rows, + "max_result_bytes": max_bytes, + "allow_dml": allow_dml, + }.get(key) + s.dremio.project_id = None + return s + + +@pytest.mark.asyncio +async def test_invoke_no_truncation(): + """When results fit within both limits, response has only 'result' key.""" + df = pd.DataFrame({"col": ["a", "b"]}) + qr = QueryResult(df=df, total_rows=2, returned_rows=2) + + with ( + patch("dremioai.tools.tools.settings") as mock_s, + patch("dremioai.tools.tools.sql") as mock_sql, + ): + mock_s.instance.return_value = _mock_dremio_settings() + mock_sql.run_query_capped = AsyncMock(return_value=qr) + + instance = tools.RunSqlQuery() + result = await instance.invoke.__wrapped__(instance, "SELECT 1") + + assert "result" in result + assert "truncated" not in result + assert len(result["result"]) == 2 + + +@pytest.mark.asyncio +async def test_invoke_row_limit_hit(): + """Row limit triggers truncation metadata with reason 'row_limit'.""" + df = pd.DataFrame({"col": [str(i) for i in range(10)]}) + qr = QueryResult(df=df, total_rows=100, returned_rows=10) + + with ( + patch("dremioai.tools.tools.settings") as mock_s, + patch("dremioai.tools.tools.sql") as mock_sql, + ): + mock_s.instance.return_value = _mock_dremio_settings(max_rows=10) + mock_sql.run_query_capped = AsyncMock(return_value=qr) + + instance = tools.RunSqlQuery() + result = await instance.invoke.__wrapped__(instance, "SELECT 1") + + assert result["truncated"] is True + assert result["truncation_reason"] == "row_limit" + assert result["total_rows"] == 100 + assert result["returned_rows"] == 10 + + +@pytest.mark.asyncio +async def test_invoke_byte_limit_hit(): + """Byte limit fires mid-result, reason is 'byte_limit'.""" + # Each record is ~20+ bytes as JSON; set a very small byte budget + df = pd.DataFrame({"col": ["x" * 50 for _ in range(10)]}) + qr = QueryResult(df=df, total_rows=10, returned_rows=10) + + with ( + patch("dremioai.tools.tools.settings") as mock_s, + patch("dremioai.tools.tools.sql") as mock_sql, + ): + mock_s.instance.return_value = _mock_dremio_settings(max_rows=500, max_bytes=100) + mock_sql.run_query_capped = AsyncMock(return_value=qr) + + instance = tools.RunSqlQuery() + result = await instance.invoke.__wrapped__(instance, "SELECT 1") + + assert result["truncated"] is True + assert result["truncation_reason"] == "byte_limit" + assert len(result["result"]) < 10 + + +@pytest.mark.asyncio +async def test_invoke_both_limits_zero(): + """max_rows=0 and max_bytes=0 => no truncation, all rows returned.""" + df = pd.DataFrame({"col": [str(i) for i in range(20)]}) + qr = QueryResult(df=df, total_rows=20, returned_rows=20) + + with ( + patch("dremioai.tools.tools.settings") as mock_s, + patch("dremioai.tools.tools.sql") as mock_sql, + ): + mock_s.instance.return_value = _mock_dremio_settings(max_rows=0, max_bytes=0) + mock_sql.run_query_capped = AsyncMock(return_value=qr) + + instance = tools.RunSqlQuery() + result = await instance.invoke.__wrapped__(instance, "SELECT 1") + + assert "truncated" not in result + assert len(result["result"]) == 20 + + +@pytest.mark.asyncio +async def test_invoke_settings_override(): + """Verify correct max_rows is passed to run_query_capped from settings.""" + df = pd.DataFrame({"col": ["a"]}) + qr = QueryResult(df=df, total_rows=1, returned_rows=1) + + with ( + patch("dremioai.tools.tools.settings") as mock_s, + patch("dremioai.tools.tools.sql") as mock_sql, + ): + mock_s.instance.return_value = _mock_dremio_settings(max_rows=42) + mock_sql.run_query_capped = AsyncMock(return_value=qr) + + instance = tools.RunSqlQuery() + await instance.invoke.__wrapped__(instance, "SELECT 1") + + call_kwargs = mock_sql.run_query_capped.call_args + assert call_kwargs.kwargs.get("max_rows") == 42