Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/dremioai/api/dremio/ai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from aiohttp import ClientResponseError
from dremioai.api.transport import DremioAsyncHttpClient as AsyncHttpClient
from dremioai.config import settings
from dremioai.log import logger

log = logger(__name__)
Expand Down Expand Up @@ -64,10 +65,9 @@ def is_empty(self) -> bool:
async def list_tools() -> ListToolsResponse:
try:
client = AsyncHttpClient()
return await client.get(
"/api/v4/ai/tools",
deser=ListToolsResponse,
)
project_id = settings.instance().dremio.project_id
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
Comment thread
AfonsoP1nto marked this conversation as resolved.
Outdated
Comment thread
AfonsoP1nto marked this conversation as resolved.
Outdated
return await client.get(f"{endpoint}/ai/tools", deser=ListToolsResponse)
except ClientResponseError as e:
log.exception("Failed to list AI tools")
return ListToolsResponse(error=f"HTTP {e.status} {e.message}")
Expand All @@ -80,8 +80,10 @@ async def invoke_tool(tool_name: str, args: Dict[str, Any]) -> InvokeToolRespons
safe_name = quote(tool_name, safe="")
try:
client = AsyncHttpClient()
project_id = settings.instance().dremio.project_id
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
Comment thread
AfonsoP1nto marked this conversation as resolved.
Outdated
Comment thread
AfonsoP1nto marked this conversation as resolved.
Outdated
return await client.post(
f"/api/v4/ai/tools/{safe_name}:invoke",
f"{endpoint}/ai/tools/{safe_name}:invoke",
body={"args": args},
deser=InvokeToolResponse,
)
Expand Down
63 changes: 55 additions & 8 deletions tests/api/dremio/test_ai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,19 @@

# --- list_tools tests ---


@pytest.fixture
def no_project_id_settings(mock_settings_instance):
old_project_id = mock_settings_instance.dremio.project_id
mock_settings_instance.dremio.raw_project_id = None
yield mock_settings_instance
if old_project_id is not None:
mock_settings_instance.dremio.project_id = old_project_id

@pytest.mark.asyncio
async def test_list_tools_returns_tools(mock_settings_instance):
with HttpMockFramework() as mock:
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
mock.load_mock_data(r"/v0/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json")
result = await list_tools()
assert bool(result)
assert len(result.tools) == 3
Expand All @@ -42,7 +51,7 @@ async def test_list_tools_returns_tools(mock_settings_instance):
@pytest.mark.asyncio
async def test_list_tools_returns_input_schema(mock_settings_instance):
with HttpMockFramework() as mock:
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
mock.load_mock_data(r"/v0/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json")
Comment thread
AfonsoP1nto marked this conversation as resolved.
Outdated
result = await list_tools()
run_sql = next(t for t in result.tools if t.name == "runSql")
assert run_sql.input_schema["type"] == "object"
Expand All @@ -52,7 +61,7 @@ async def test_list_tools_returns_input_schema(mock_settings_instance):
@pytest.mark.asyncio
async def test_list_tools_empty_registry(mock_settings_instance):
with HttpMockFramework() as mock:
mock.add_mock_response(r"/api/v4/ai/tools$", {"tools": []})
mock.add_mock_response(r"/v0/projects/[^/]+/ai/tools$", {"tools": []})
result = await list_tools()
assert result.tools == []
assert bool(result)
Expand All @@ -63,7 +72,10 @@ async def test_list_tools_empty_registry(mock_settings_instance):
@pytest.mark.asyncio
async def test_invoke_tool_success(mock_settings_instance):
with HttpMockFramework() as mock:
mock.load_mock_data(r"/api/v4/ai/tools/runSql:invoke$", "ai_tools/invoke_result.json")
mock.load_mock_data(
r"/v0/projects/[^/]+/ai/tools/runSql:invoke$",
"ai_tools/invoke_result.json",
)
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
assert bool(result)
assert result.result["columns"] == ["id", "name"]
Expand All @@ -73,7 +85,10 @@ async def test_invoke_tool_success(mock_settings_instance):
@pytest.mark.asyncio
async def test_invoke_tool_error_response(mock_settings_instance):
with HttpMockFramework() as mock:
mock.load_mock_data(r"/api/v4/ai/tools/unknownTool:invoke$", "ai_tools/invoke_error.json")
mock.load_mock_data(
r"/v0/projects/[^/]+/ai/tools/unknownTool:invoke$",
"ai_tools/invoke_error.json",
)
result = await invoke_tool("unknownTool", {})
assert result.error is not None
assert "not found" in result.error
Expand Down Expand Up @@ -146,7 +161,11 @@ def test_invoke_tool_response_is_empty_false_when_error():
async def test_list_tools_http_error(mock_settings_instance):
"""list_tools should return a response with error set on HTTP 4xx/5xx."""
with HttpMockFramework() as mock:
mock.add_mock_response(r"/api/v4/ai/tools$", {"error": "Unauthorized"}, status=401)
mock.add_mock_response(
r"/v0/projects/[^/]+/ai/tools$",
{"error": "Unauthorized"},
status=401,
)
result = await list_tools()
assert not bool(result)
assert result.error is not None
Expand All @@ -157,7 +176,11 @@ async def test_list_tools_http_error(mock_settings_instance):
async def test_invoke_tool_http_error(mock_settings_instance):
"""invoke_tool should return a response with error set on HTTP 500."""
with HttpMockFramework() as mock:
mock.add_mock_response(r"/api/v4/ai/tools/runSql:invoke$", {"error": "Internal Server Error"}, status=500)
mock.add_mock_response(
r"/v0/projects/[^/]+/ai/tools/runSql:invoke$",
{"error": "Internal Server Error"},
status=500,
)
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
assert not bool(result)
assert result.error is not None
Expand All @@ -169,6 +192,30 @@ async def test_invoke_tool_url_encodes_name(mock_settings_instance):
"""tool_name with special characters should be URL-encoded."""
with HttpMockFramework() as mock:
# The encoded name "my%2Ftool" should appear in the URL
mock.add_mock_response(r"/api/v4/ai/tools/my%2Ftool:invoke$", {"result": "ok", "error": None})
mock.add_mock_response(
r"/v0/projects/[^/]+/ai/tools/my%2Ftool:invoke$",
{"result": "ok", "error": None},
)
result = await invoke_tool("my/tool", {})
assert result.result == "ok"


@pytest.mark.asyncio
async def test_list_tools_uses_api_v3_without_project_id(no_project_id_settings):
with HttpMockFramework() as mock:
mock.load_mock_data(r"/api/v3/ai/tools$", "ai_tools/list_tools.json")
result = await list_tools()
assert bool(result)
assert len(result.tools) == 3


@pytest.mark.asyncio
async def test_invoke_tool_uses_api_v3_without_project_id(no_project_id_settings):
with HttpMockFramework() as mock:
mock.load_mock_data(
r"/api/v3/ai/tools/runSql:invoke$",
"ai_tools/invoke_result.json",
)
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
assert bool(result)
assert result.result["columns"] == ["id", "name"]
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@
from prometheus_client import CollectorRegistry


@pytest.fixture(autouse=True)
def reset_uvicorn_logger_propagation():
"""Reset uvicorn logger state between tests.

Uvicorn's configure_logging() sets uvicorn.access.propagate=False via its
default LOGGING_CONFIG when a server starts. This leaks into subsequent tests
that assert stdlib loggers propagate to the root handler.
"""
yield
import logging

for name in ("uvicorn.access", "uvicorn.error", "uvicorn"):
lg = logging.getLogger(name)
lg.propagate = True
for h in lg.handlers[:]:
lg.removeHandler(h)


@pytest.fixture(autouse=True)
def reset_sse_starlette_app_status():
"""
Expand Down
Loading