Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
)

from dataclasses import dataclass, asdict, field
from datetime import datetime
from decimal import Decimal

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
Expand All @@ -41,6 +43,7 @@
import functools

import pandas as pd
import numpy as np
from dremioai.api.dremio import sql, usage, search
from dremioai.config import settings
from dremioai.config.tools import ToolType
Expand Down Expand Up @@ -104,6 +107,31 @@ async def invoke(self):
raise NotImplementedError("Subclasses should implement this method")


def _json_safe_value(value: Any) -> Any:
if value is None or value is pd.NA or value is pd.NaT:
return None
if isinstance(value, (pd.Timestamp, datetime)):
return value.isoformat()
if isinstance(value, (pd.Timedelta,)):
return str(value)
if isinstance(value, Decimal):
return str(value)
if isinstance(value, np.generic):
return value.item()
return value


def _df_to_json_records(df: pd.DataFrame) -> List[Dict[str, Any]]:
if df.empty:
return []
df = df.where(pd.notnull(df), None)
records = df.to_dict(orient="records")
return [
{key: _json_safe_value(value) for key, value in row.items()}
for row in records
]


class ProjectIdMiddleware(BaseHTTPMiddleware):
pat = re.compile(r"/mcp/([\da-z-]+)(/?.*)")
logger = log.logger("ProjectIdMiddleware")
Expand Down Expand Up @@ -333,7 +361,7 @@ async def invoke(self, s: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
try:
s = f"/* dremioai: submitter={self.__class__.__name__} */\n{s}"
df = await sql.run_query(query=s, use_df=True)
return {"result": df.to_dict(orient="records")}
return {"result": _df_to_json_records(df)}
except RuntimeError as e:
return {
"error": str(e),
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def http_streamable_mcp_server(
logging_level: str,
project_id: str = None,
wlm_engine: str = None,
) -> AsyncGenerator[StreamableMcpServerFixture]:
) -> AsyncGenerator[StreamableMcpServerFixture, None]:
old = settings.instance()
sf = None
try:
Expand Down Expand Up @@ -279,7 +279,7 @@ async def http_streamable_mcp_server(
@contextlib.asynccontextmanager
async def http_streamable_client_server(
sf: ServerFixture, token=None
) -> AsyncGenerator[ClientSession]:
) -> AsyncGenerator[ClientSession, None]:
headers = {"Authorization": f"Bearer {token}"} if token is not None else None
async with streamablehttp_client(url=sf.url, headers=headers) as (
read_stream,
Expand Down
41 changes: 39 additions & 2 deletions tests/tools/test_output_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
# limitations under the License.
#

import json
from decimal import Decimal

import numpy as np
import pandas as pd
import pytest
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from dremioai.tools.tools import GetUsefulSystemTableNames, GetSchemaOfTable
from dremioai.config import settings
from dremioai.tools.tools import GetUsefulSystemTableNames, GetSchemaOfTable, RunSqlQuery


async def mock_mcp_validate_tool_output(tool, *args, **kwargs):
Expand Down Expand Up @@ -59,3 +65,34 @@ async def test_get_schema_of_table_validation():

with patch("dremioai.tools.tools.get_schema", return_value=mock_schema_result):
await mock_mcp_validate_tool_output(tool, "sys.jobs")


@pytest.mark.asyncio
async def test_run_sql_query_json_safe_output():
tool = RunSqlQuery()
df = pd.DataFrame(
[
{
"ts": pd.Timestamp("2024-01-02T03:04:05"),
"latency_ms": np.int64(150),
"ratio": np.float64(0.75),
"amount": Decimal("10.25"),
"maybe_null": pd.NA,
}
]
)

with patch("dremioai.tools.tools.sql.run_query", new_callable=AsyncMock) as mock_run_query:
mock_run_query.return_value = df
token = settings._settings.set(
settings.Settings.model_validate({"dremio": {"uri": "https://test"}})
)
try:
result = await tool.invoke("SELECT 1")
finally:
settings._settings.reset(token)

assert isinstance(result, dict)
assert "result" in result
payload = json.dumps(result)
assert "2024-01-02T03:04:05" in payload