diff --git a/src/dremioai/api/dremio/ai_tools.py b/src/dremioai/api/dremio/ai_tools.py index 2b88182..9090f8f 100644 --- a/src/dremioai/api/dremio/ai_tools.py +++ b/src/dremioai/api/dremio/ai_tools.py @@ -20,6 +20,7 @@ from aiohttp import ClientResponseError from dremioai.api.transport import DremioAsyncHttpClient as AsyncHttpClient +from dremioai.config import settings from dremioai.log import logger log = logger(__name__) @@ -64,10 +65,9 @@ def is_empty(self) -> bool: async def list_tools() -> ListToolsResponse: try: client = AsyncHttpClient() - return await client.get( - "/api/v4/ai/tools", - deser=ListToolsResponse, - ) + project_id = settings.instance().dremio.project_id + endpoint = f"/v1/projects/{project_id}" if project_id else "/api/v4" + return await client.get(f"{endpoint}/ai/tools", deser=ListToolsResponse) except ClientResponseError as e: log.exception("Failed to list AI tools") return ListToolsResponse(error=f"HTTP {e.status} {e.message}") @@ -80,8 +80,10 @@ async def invoke_tool(tool_name: str, args: Dict[str, Any]) -> InvokeToolRespons safe_name = quote(tool_name, safe="") try: client = AsyncHttpClient() + project_id = settings.instance().dremio.project_id + endpoint = f"/v1/projects/{project_id}" if project_id else "/api/v4" return await client.post( - f"/api/v4/ai/tools/{safe_name}:invoke", + f"{endpoint}/ai/tools/{safe_name}:invoke", body={"args": args}, deser=InvokeToolResponse, ) diff --git a/tests/api/dremio/test_ai_tools.py b/tests/api/dremio/test_ai_tools.py index 7300a4b..d49039a 100644 --- a/tests/api/dremio/test_ai_tools.py +++ b/tests/api/dremio/test_ai_tools.py @@ -26,10 +26,19 @@ # --- list_tools tests --- + +@pytest.fixture +def no_project_id_settings(mock_settings_instance): + old_project_id = mock_settings_instance.dremio.project_id + mock_settings_instance.dremio.raw_project_id = None + yield mock_settings_instance + if old_project_id is not None: + mock_settings_instance.dremio.project_id = old_project_id + @pytest.mark.asyncio async def test_list_tools_returns_tools(mock_settings_instance): with HttpMockFramework() as mock: - mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json") + mock.load_mock_data(r"/v1/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json") result = await list_tools() assert bool(result) assert len(result.tools) == 3 @@ -42,7 +51,7 @@ async def test_list_tools_returns_tools(mock_settings_instance): @pytest.mark.asyncio async def test_list_tools_returns_input_schema(mock_settings_instance): with HttpMockFramework() as mock: - mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json") + mock.load_mock_data(r"/v1/projects/[^/]+/ai/tools$", "ai_tools/list_tools.json") result = await list_tools() run_sql = next(t for t in result.tools if t.name == "runSql") assert run_sql.input_schema["type"] == "object" @@ -52,7 +61,7 @@ async def test_list_tools_returns_input_schema(mock_settings_instance): @pytest.mark.asyncio async def test_list_tools_empty_registry(mock_settings_instance): with HttpMockFramework() as mock: - mock.add_mock_response(r"/api/v4/ai/tools$", {"tools": []}) + mock.add_mock_response(r"/v1/projects/[^/]+/ai/tools$", {"tools": []}) result = await list_tools() assert result.tools == [] assert bool(result) @@ -63,7 +72,10 @@ async def test_list_tools_empty_registry(mock_settings_instance): @pytest.mark.asyncio async def test_invoke_tool_success(mock_settings_instance): with HttpMockFramework() as mock: - mock.load_mock_data(r"/api/v4/ai/tools/runSql:invoke$", "ai_tools/invoke_result.json") + mock.load_mock_data( + r"/v1/projects/[^/]+/ai/tools/runSql:invoke$", + "ai_tools/invoke_result.json", + ) result = await invoke_tool("runSql", {"sqlText": "SELECT 1"}) assert bool(result) assert result.result["columns"] == ["id", "name"] @@ -73,7 +85,10 @@ async def test_invoke_tool_success(mock_settings_instance): @pytest.mark.asyncio async def test_invoke_tool_error_response(mock_settings_instance): with HttpMockFramework() as mock: - mock.load_mock_data(r"/api/v4/ai/tools/unknownTool:invoke$", "ai_tools/invoke_error.json") + mock.load_mock_data( + r"/v1/projects/[^/]+/ai/tools/unknownTool:invoke$", + "ai_tools/invoke_error.json", + ) result = await invoke_tool("unknownTool", {}) assert result.error is not None assert "not found" in result.error @@ -146,7 +161,11 @@ def test_invoke_tool_response_is_empty_false_when_error(): async def test_list_tools_http_error(mock_settings_instance): """list_tools should return a response with error set on HTTP 4xx/5xx.""" with HttpMockFramework() as mock: - mock.add_mock_response(r"/api/v4/ai/tools$", {"error": "Unauthorized"}, status=401) + mock.add_mock_response( + r"/v1/projects/[^/]+/ai/tools$", + {"error": "Unauthorized"}, + status=401, + ) result = await list_tools() assert not bool(result) assert result.error is not None @@ -157,7 +176,11 @@ async def test_list_tools_http_error(mock_settings_instance): async def test_invoke_tool_http_error(mock_settings_instance): """invoke_tool should return a response with error set on HTTP 500.""" with HttpMockFramework() as mock: - mock.add_mock_response(r"/api/v4/ai/tools/runSql:invoke$", {"error": "Internal Server Error"}, status=500) + mock.add_mock_response( + r"/v1/projects/[^/]+/ai/tools/runSql:invoke$", + {"error": "Internal Server Error"}, + status=500, + ) result = await invoke_tool("runSql", {"sqlText": "SELECT 1"}) assert not bool(result) assert result.error is not None @@ -169,6 +192,30 @@ async def test_invoke_tool_url_encodes_name(mock_settings_instance): """tool_name with special characters should be URL-encoded.""" with HttpMockFramework() as mock: # The encoded name "my%2Ftool" should appear in the URL - mock.add_mock_response(r"/api/v4/ai/tools/my%2Ftool:invoke$", {"result": "ok", "error": None}) + mock.add_mock_response( + r"/v1/projects/[^/]+/ai/tools/my%2Ftool:invoke$", + {"result": "ok", "error": None}, + ) result = await invoke_tool("my/tool", {}) assert result.result == "ok" + + +@pytest.mark.asyncio +async def test_list_tools_uses_api_v4_without_project_id(no_project_id_settings): + with HttpMockFramework() as mock: + mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json") + result = await list_tools() + assert bool(result) + assert len(result.tools) == 3 + + +@pytest.mark.asyncio +async def test_invoke_tool_uses_api_v4_without_project_id(no_project_id_settings): + with HttpMockFramework() as mock: + mock.load_mock_data( + r"/api/v4/ai/tools/runSql:invoke$", + "ai_tools/invoke_result.json", + ) + result = await invoke_tool("runSql", {"sqlText": "SELECT 1"}) + assert bool(result) + assert result.result["columns"] == ["id", "name"] diff --git a/tests/conftest.py b/tests/conftest.py index 3280fae..2f4626a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,24 @@ from prometheus_client import CollectorRegistry +@pytest.fixture(autouse=True) +def reset_uvicorn_logger_propagation(): + """Reset uvicorn logger state between tests. + + Uvicorn's configure_logging() sets uvicorn.access.propagate=False via its + default LOGGING_CONFIG when a server starts. This leaks into subsequent tests + that assert stdlib loggers propagate to the root handler. + """ + yield + import logging + + for name in ("uvicorn.access", "uvicorn.error", "uvicorn"): + lg = logging.getLogger(name) + lg.propagate = True + for h in lg.handlers[:]: + lg.removeHandler(h) + + @pytest.fixture(autouse=True) def reset_sse_starlette_app_status(): """