Skip to content

Commit 658a661

Browse files
authored
DX-120975: Fix endpoint (#113)
* Fix endpoint * Fix test * Fixed pre-existing error * Fixed endpoint
1 parent a13a32b commit 658a661

3 files changed

Lines changed: 80 additions & 13 deletions

File tree

src/dremioai/api/dremio/ai_tools.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from aiohttp import ClientResponseError
2222
from dremioai.api.transport import DremioAsyncHttpClient as AsyncHttpClient
23+
from dremioai.config import settings
2324
from dremioai.log import logger
2425

2526
log = logger(__name__)
@@ -64,10 +65,9 @@ def is_empty(self) -> bool:
6465
async def list_tools() -> ListToolsResponse:
6566
try:
6667
client = AsyncHttpClient()
67-
return await client.get(
68-
"/api/v4/ai/tools",
69-
deser=ListToolsResponse,
70-
)
68+
project_id = settings.instance().dremio.project_id
69+
endpoint = f"/v1/projects/{project_id}" if project_id else "/api/v4"
70+
return await client.get(f"{endpoint}/ai/tools", deser=ListToolsResponse)
7171
except ClientResponseError as e:
7272
log.exception("Failed to list AI tools")
7373
return ListToolsResponse(error=f"HTTP {e.status} {e.message}")
@@ -80,8 +80,10 @@ async def invoke_tool(tool_name: str, args: Dict[str, Any]) -> InvokeToolRespons
8080
safe_name = quote(tool_name, safe="")
8181
try:
8282
client = AsyncHttpClient()
83+
project_id = settings.instance().dremio.project_id
84+
endpoint = f"/v1/projects/{project_id}" if project_id else "/api/v4"
8385
return await client.post(
84-
f"/api/v4/ai/tools/{safe_name}:invoke",
86+
f"{endpoint}/ai/tools/{safe_name}:invoke",
8587
body={"args": args},
8688
deser=InvokeToolResponse,
8789
)

tests/api/dremio/test_ai_tools.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,19 @@
2626

2727
# --- list_tools tests ---
2828

29+
30+
@pytest.fixture
31+
def no_project_id_settings(mock_settings_instance):
32+
old_project_id = mock_settings_instance.dremio.project_id
33+
mock_settings_instance.dremio.raw_project_id = None
34+
yield mock_settings_instance
35+
if old_project_id is not None:
36+
mock_settings_instance.dremio.project_id = old_project_id
37+
2938
@pytest.mark.asyncio
3039
async def test_list_tools_returns_tools(mock_settings_instance):
3140
with HttpMockFramework() as mock:
32-
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
41+
mock.load_mock_data(r"/v1/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json")
3342
result = await list_tools()
3443
assert bool(result)
3544
assert len(result.tools) == 3
@@ -42,7 +51,7 @@ async def test_list_tools_returns_tools(mock_settings_instance):
4251
@pytest.mark.asyncio
4352
async def test_list_tools_returns_input_schema(mock_settings_instance):
4453
with HttpMockFramework() as mock:
45-
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
54+
mock.load_mock_data(r"/v1/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json")
4655
result = await list_tools()
4756
run_sql = next(t for t in result.tools if t.name == "runSql")
4857
assert run_sql.input_schema["type"] == "object"
@@ -52,7 +61,7 @@ async def test_list_tools_returns_input_schema(mock_settings_instance):
5261
@pytest.mark.asyncio
5362
async def test_list_tools_empty_registry(mock_settings_instance):
5463
with HttpMockFramework() as mock:
55-
mock.add_mock_response(r"/api/v4/ai/tools$", {"tools": []})
64+
mock.add_mock_response(r"/v1/projects/[^/]+/ai/tools$", {"tools": []})
5665
result = await list_tools()
5766
assert result.tools == []
5867
assert bool(result)
@@ -63,7 +72,10 @@ async def test_list_tools_empty_registry(mock_settings_instance):
6372
@pytest.mark.asyncio
6473
async def test_invoke_tool_success(mock_settings_instance):
6574
with HttpMockFramework() as mock:
66-
mock.load_mock_data(r"/api/v4/ai/tools/runSql:invoke$", "ai_tools/invoke_result.json")
75+
mock.load_mock_data(
76+
r"/v1/projects/[^/]+/ai/tools/runSql:invoke$",
77+
"ai_tools/invoke_result.json",
78+
)
6779
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
6880
assert bool(result)
6981
assert result.result["columns"] == ["id", "name"]
@@ -73,7 +85,10 @@ async def test_invoke_tool_success(mock_settings_instance):
7385
@pytest.mark.asyncio
7486
async def test_invoke_tool_error_response(mock_settings_instance):
7587
with HttpMockFramework() as mock:
76-
mock.load_mock_data(r"/api/v4/ai/tools/unknownTool:invoke$", "ai_tools/invoke_error.json")
88+
mock.load_mock_data(
89+
r"/v1/projects/[^/]+/ai/tools/unknownTool:invoke$",
90+
"ai_tools/invoke_error.json",
91+
)
7792
result = await invoke_tool("unknownTool", {})
7893
assert result.error is not None
7994
assert "not found" in result.error
@@ -146,7 +161,11 @@ def test_invoke_tool_response_is_empty_false_when_error():
146161
async def test_list_tools_http_error(mock_settings_instance):
147162
"""list_tools should return a response with error set on HTTP 4xx/5xx."""
148163
with HttpMockFramework() as mock:
149-
mock.add_mock_response(r"/api/v4/ai/tools$", {"error": "Unauthorized"}, status=401)
164+
mock.add_mock_response(
165+
r"/v1/projects/[^/]+/ai/tools$",
166+
{"error": "Unauthorized"},
167+
status=401,
168+
)
150169
result = await list_tools()
151170
assert not bool(result)
152171
assert result.error is not None
@@ -157,7 +176,11 @@ async def test_list_tools_http_error(mock_settings_instance):
157176
async def test_invoke_tool_http_error(mock_settings_instance):
158177
"""invoke_tool should return a response with error set on HTTP 500."""
159178
with HttpMockFramework() as mock:
160-
mock.add_mock_response(r"/api/v4/ai/tools/runSql:invoke$", {"error": "Internal Server Error"}, status=500)
179+
mock.add_mock_response(
180+
r"/v1/projects/[^/]+/ai/tools/runSql:invoke$",
181+
{"error": "Internal Server Error"},
182+
status=500,
183+
)
161184
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
162185
assert not bool(result)
163186
assert result.error is not None
@@ -169,6 +192,30 @@ async def test_invoke_tool_url_encodes_name(mock_settings_instance):
169192
"""tool_name with special characters should be URL-encoded."""
170193
with HttpMockFramework() as mock:
171194
# The encoded name "my%2Ftool" should appear in the URL
172-
mock.add_mock_response(r"/api/v4/ai/tools/my%2Ftool:invoke$", {"result": "ok", "error": None})
195+
mock.add_mock_response(
196+
r"/v1/projects/[^/]+/ai/tools/my%2Ftool:invoke$",
197+
{"result": "ok", "error": None},
198+
)
173199
result = await invoke_tool("my/tool", {})
174200
assert result.result == "ok"
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_list_tools_uses_api_v4_without_project_id(no_project_id_settings):
205+
with HttpMockFramework() as mock:
206+
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
207+
result = await list_tools()
208+
assert bool(result)
209+
assert len(result.tools) == 3
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_invoke_tool_uses_api_v4_without_project_id(no_project_id_settings):
214+
with HttpMockFramework() as mock:
215+
mock.load_mock_data(
216+
r"/api/v4/ai/tools/runSql:invoke$",
217+
"ai_tools/invoke_result.json",
218+
)
219+
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
220+
assert bool(result)
221+
assert result.result["columns"] == ["id", "name"]

tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@
5050
from prometheus_client import CollectorRegistry
5151

5252

53+
@pytest.fixture(autouse=True)
54+
def reset_uvicorn_logger_propagation():
55+
"""Reset uvicorn logger state between tests.
56+
57+
Uvicorn's configure_logging() sets uvicorn.access.propagate=False via its
58+
default LOGGING_CONFIG when a server starts. This leaks into subsequent tests
59+
that assert stdlib loggers propagate to the root handler.
60+
"""
61+
yield
62+
import logging
63+
64+
for name in ("uvicorn.access", "uvicorn.error", "uvicorn"):
65+
lg = logging.getLogger(name)
66+
lg.propagate = True
67+
for h in lg.handlers[:]:
68+
lg.removeHandler(h)
69+
70+
5371
@pytest.fixture(autouse=True)
5472
def reset_sse_starlette_app_status():
5573
"""

0 commit comments

Comments
 (0)