Skip to content

Commit 241c660

Browse files
authored
DX-118683: Access tools from Dremio repository (#99)
* First commit * Fixed tests * Fixing tests * Fixing tests * Fixed minor errors * Fix small problems * Wired tools to MCP * Finished review * Fixed feedback * Addressed feedback * Fix test failure * Fix test * Addressed feedback * Cleanup * Addressed feedback * Added flag * Addressed feedback
1 parent 74dc0de commit 241c660

12 files changed

Lines changed: 529 additions & 14 deletions

File tree

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#
2+
# Copyright (C) 2017-2025 Dremio Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from pydantic import BaseModel, Field, ConfigDict
18+
from typing import Dict, List, Optional, Any
19+
from urllib.parse import quote
20+
21+
from aiohttp import ClientResponseError
22+
from dremioai.api.transport import DremioAsyncHttpClient as AsyncHttpClient
23+
from dremioai.log import logger
24+
25+
log = logger(__name__)
26+
27+
28+
class AiTool(BaseModel):
29+
30+
name: str
31+
description: Optional[str] = None
32+
input_schema: Dict[str, Any] = Field(
33+
default_factory=lambda: {"type": "object"}, alias="inputSchema"
34+
)
35+
model_config = ConfigDict(extra="allow", populate_by_name=True)
36+
37+
38+
class ListToolsResponse(BaseModel):
39+
tools: List[AiTool] = Field(default_factory=list)
40+
error: Optional[str] = None
41+
42+
def __bool__(self):
43+
return self.error is None
44+
45+
46+
class InvokeToolResponse(BaseModel):
47+
48+
result: Optional[Any] = None
49+
error: Optional[str] = None
50+
51+
def __bool__(self):
52+
return self.error is None
53+
54+
@property
55+
def is_empty(self) -> bool:
56+
"""True when the response carries neither a result nor an error.
57+
58+
This can happen when Dremio returns a 200 with an empty body for a
59+
void tool. Callers may choose to treat this as a successful no-op.
60+
"""
61+
return self.result is None and self.error is None
62+
63+
64+
async def list_tools() -> ListToolsResponse:
65+
try:
66+
client = AsyncHttpClient()
67+
return await client.get(
68+
"/api/v4/ai/tools",
69+
deser=ListToolsResponse,
70+
)
71+
except ClientResponseError as e:
72+
log.exception("Failed to list AI tools")
73+
return ListToolsResponse(error=f"HTTP {e.status} {e.message}")
74+
except Exception:
75+
log.exception("Failed to list AI tools")
76+
return ListToolsResponse(error="Unexpected error listing AI tools")
77+
78+
79+
async def invoke_tool(tool_name: str, args: Dict[str, Any]) -> InvokeToolResponse:
80+
safe_name = quote(tool_name, safe="")
81+
try:
82+
client = AsyncHttpClient()
83+
return await client.post(
84+
f"/api/v4/ai/tools/{safe_name}:invoke",
85+
body={"args": args},
86+
deser=InvokeToolResponse,
87+
)
88+
except ClientResponseError as e:
89+
log.exception("Failed to invoke AI tool '%s'", tool_name)
90+
return InvokeToolResponse(error=f"HTTP {e.status} {e.message}")
91+
except Exception:
92+
log.exception("Failed to invoke AI tool '%s'", tool_name)
93+
return InvokeToolResponse(error=f"Unexpected error invoking tool '{tool_name}'")

src/dremioai/config/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ class Dremio(FlagAwareModel):
274274
wlm: Optional[Wlm] = None
275275
api: Optional[ApiSettings] = Field(default_factory=ApiSettings)
276276
metrics: Optional[Metrics] = None
277+
enable_remote_tools: Optional[bool] = Field(
278+
default=False,
279+
description="Enable dynamic registration of remote tools from Dremio's Java-side tool registry",
280+
)
277281

278282
@field_serializer("raw_pat")
279283
def serialize_pat(self, pat: str):

src/dremioai/servers/mcp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from click import Choice
3535
from mcp.cli.claude import get_claude_config_path
3636
from mcp.server.auth.json_response import PydanticJSONResponse
37-
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
37+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
3838
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
3939
from mcp.server.auth.provider import AccessToken, TokenVerifier
4040
from mcp.server.fastmcp import FastMCP
@@ -59,6 +59,7 @@
5959
from dremioai.config import settings
6060
from dremioai.config.feature_flags import FeatureFlagManager
6161
from dremioai.metrics.registry import get_metrics_app
62+
from dremioai.metrics.tool_metrics import invocation_counter, invocation_duration
6263
from dremioai.servers.jwks_verifier import JWKSVerifier, TokenExpiredError
6364
from dremioai.tools import tools
6465
from dremioai.tools.tools import ProjectIdMiddleware

src/dremioai/tools/tools.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@
4444

4545
import pandas as pd
4646
import numpy as np
47-
from dremioai.api.dremio import sql, usage, search
47+
from dremioai.api.dremio import sql, usage, search, ai_tools
4848
from dremioai.config import settings
4949
from dremioai.config.tools import ToolType
5050
from dremioai.api.prometheus import vm
5151
from dremioai.api.dremio.catalog import get_schema, get_lineage, get_descriptions
5252
from dremioai.api.util import run_in_parallel
5353
from csv import reader
5454
from io import StringIO
55+
import json
5556
from sqlglot import parse_one
5657
from sqlglot import expressions
5758
from mcp.server.auth.middleware.auth_context import get_access_token
@@ -537,6 +538,47 @@ async def invoke(self, query: str) -> Dict[str, Any]:
537538
return {"results": res.to_dict(orient="records")}
538539

539540

541+
class DiscoverDynamicTools(Tools):
542+
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]
543+
544+
@secured
545+
@with_metrics
546+
async def invoke(self) -> str:
547+
"""Discover additional tools available from the Dremio server.
548+
Call this tool to get a list of dynamically available tools with their
549+
names, descriptions, and input schemas."""
550+
if not settings.instance().dremio.get("enable_remote_tools"):
551+
return "Remote tools are not enabled."
552+
result = await ai_tools.list_tools()
553+
return result.model_dump_json()
554+
555+
556+
class CallDynamicTool(Tools):
557+
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]
558+
559+
@secured
560+
@with_metrics
561+
async def invoke(self, tool_name: str, tool_arguments: Union[str, dict]) -> str:
562+
"""Invoke a dynamically discovered tool on the Dremio server.
563+
564+
Args:
565+
tool_name: The name of the tool to invoke, as returned by DiscoverDynamicTools.
566+
tool_arguments: The arguments to pass to the tool, either as a JSON string or a dict.
567+
"""
568+
if not settings.instance().dremio.get("enable_remote_tools"):
569+
return "Remote tools are not enabled."
570+
if isinstance(tool_arguments, str):
571+
try:
572+
args = json.loads(tool_arguments)
573+
except json.JSONDecodeError as exc:
574+
return f"Invalid JSON in tool_arguments: {exc}"
575+
else:
576+
args = tool_arguments
577+
578+
result = await ai_tools.invoke_tool(tool_name, args)
579+
return result.model_dump_json(exclude_none=True)
580+
581+
540582
def _subclasses(cls):
541583
for sub in cls.__subclasses__():
542584
yield from _subclasses(sub)

tests/api/dremio/test_ai_tools.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#
2+
# Copyright (C) 2017-2025 Dremio Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import pytest
18+
from dremioai.api.dremio.ai_tools import (
19+
AiTool,
20+
InvokeToolResponse,
21+
list_tools,
22+
invoke_tool,
23+
)
24+
from mocks.http_mock import HttpMockFramework
25+
26+
27+
# --- list_tools tests ---
28+
29+
@pytest.mark.asyncio
30+
async def test_list_tools_returns_tools(mock_settings_instance):
31+
with HttpMockFramework() as mock:
32+
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
33+
result = await list_tools()
34+
assert bool(result)
35+
assert len(result.tools) == 3
36+
names = [t.name for t in result.tools]
37+
assert "runSql" in names
38+
assert "getTableOrViewSchema" in names
39+
assert "listEngines" in names
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_list_tools_returns_input_schema(mock_settings_instance):
44+
with HttpMockFramework() as mock:
45+
mock.load_mock_data(r"/api/v4/ai/tools$", "ai_tools/list_tools.json")
46+
result = await list_tools()
47+
run_sql = next(t for t in result.tools if t.name == "runSql")
48+
assert run_sql.input_schema["type"] == "object"
49+
assert "sqlText" in run_sql.input_schema["properties"]
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_list_tools_empty_registry(mock_settings_instance):
54+
with HttpMockFramework() as mock:
55+
mock.add_mock_response(r"/api/v4/ai/tools$", {"tools": []})
56+
result = await list_tools()
57+
assert result.tools == []
58+
assert bool(result)
59+
60+
61+
# --- invoke_tool tests ---
62+
63+
@pytest.mark.asyncio
64+
async def test_invoke_tool_success(mock_settings_instance):
65+
with HttpMockFramework() as mock:
66+
mock.load_mock_data(r"/api/v4/ai/tools/runSql:invoke$", "ai_tools/invoke_result.json")
67+
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
68+
assert bool(result)
69+
assert result.result["columns"] == ["id", "name"]
70+
assert result.error is None
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_invoke_tool_error_response(mock_settings_instance):
75+
with HttpMockFramework() as mock:
76+
mock.load_mock_data(r"/api/v4/ai/tools/unknownTool:invoke$", "ai_tools/invoke_error.json")
77+
result = await invoke_tool("unknownTool", {})
78+
assert result.error is not None
79+
assert "not found" in result.error
80+
assert result.result is None
81+
82+
83+
# --- Pydantic model unit tests (no HTTP) ---
84+
85+
def test_ai_tool_model_validation():
86+
raw = {
87+
"name": "runSql",
88+
"description": "runSql",
89+
"inputSchema": {
90+
"type": "object",
91+
"properties": {"sqlText": {"type": "string"}},
92+
"required": ["sqlText"],
93+
},
94+
}
95+
tool = AiTool.model_validate(raw)
96+
assert tool.name == "runSql"
97+
assert tool.input_schema["type"] == "object"
98+
assert tool.input_schema["required"] == ["sqlText"]
99+
100+
101+
def test_ai_tool_model_minimal_schema():
102+
"""Tools with an empty inputSchema (e.g. listEngines) should deserialize cleanly."""
103+
raw = {"name": "listEngines", "description": "listEngines", "inputSchema": {"type": "object"}}
104+
tool = AiTool.model_validate(raw)
105+
assert tool.name == "listEngines"
106+
assert tool.input_schema == {"type": "object"}
107+
108+
109+
def test_invoke_tool_response_succeeded():
110+
resp = InvokeToolResponse.model_validate({"result": {"sql": "SELECT 1"}})
111+
assert bool(resp) is True
112+
assert resp.result == {"sql": "SELECT 1"}
113+
assert resp.error is None
114+
115+
116+
def test_invoke_tool_response_failed():
117+
resp = InvokeToolResponse.model_validate({"error": "Tool not found"})
118+
assert bool(resp) is False
119+
assert resp.result is None
120+
assert resp.error == "Tool not found"
121+
122+
123+
def test_invoke_tool_response_empty():
124+
resp = InvokeToolResponse.model_validate({})
125+
assert bool(resp) is True
126+
assert resp.result is None
127+
assert resp.error is None
128+
assert resp.is_empty is True
129+
130+
131+
def test_invoke_tool_response_is_empty_false_when_result():
132+
"""is_empty should be False when a result is present."""
133+
resp = InvokeToolResponse.model_validate({"result": {"sql": "SELECT 1"}})
134+
assert resp.is_empty is False
135+
136+
137+
def test_invoke_tool_response_is_empty_false_when_error():
138+
"""is_empty should be False when an error is present."""
139+
resp = InvokeToolResponse.model_validate({"error": "Tool not found"})
140+
assert resp.is_empty is False
141+
142+
143+
# --- HTTP error scenario tests ---
144+
145+
@pytest.mark.asyncio
146+
async def test_list_tools_http_error(mock_settings_instance):
147+
"""list_tools should return a response with error set on HTTP 4xx/5xx."""
148+
with HttpMockFramework() as mock:
149+
mock.add_mock_response(r"/api/v4/ai/tools$", {"error": "Unauthorized"}, status=401)
150+
result = await list_tools()
151+
assert not bool(result)
152+
assert result.error is not None
153+
assert "401" in result.error
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_invoke_tool_http_error(mock_settings_instance):
158+
"""invoke_tool should return a response with error set on HTTP 500."""
159+
with HttpMockFramework() as mock:
160+
mock.add_mock_response(r"/api/v4/ai/tools/runSql:invoke$", {"error": "Internal Server Error"}, status=500)
161+
result = await invoke_tool("runSql", {"sqlText": "SELECT 1"})
162+
assert not bool(result)
163+
assert result.error is not None
164+
assert "500" in result.error
165+
166+
167+
@pytest.mark.asyncio
168+
async def test_invoke_tool_url_encodes_name(mock_settings_instance):
169+
"""tool_name with special characters should be URL-encoded."""
170+
with HttpMockFramework() as mock:
171+
# The encoded name "my%2Ftool" should appear in the URL
172+
mock.add_mock_response(r"/api/v4/ai/tools/my%2Ftool:invoke$", {"result": "ok", "error": None})
173+
result = await invoke_tool("my/tool", {})
174+
assert result.result == "ok"

tests/config/golden_flag_keys.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ flag_keys:
66
- dremio.api.http_retry.max_retries
77
- dremio.api.polling_interval
88
- dremio.auth_issuer_uri_override
9+
- dremio.enable_remote_tools
910
- dremio.enable_search
1011
- dremio.extract_org_id_from_jwt
1112
- dremio.jwks_cache_lifespan

0 commit comments

Comments
 (0)