Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dremioai/api/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,5 @@ def __init__(self):
pat = dremio.pat

if uri is None or pat is None:
raise RuntimeError(f"uri={uri} pat={pat} are required")
raise RuntimeError("Dremio connection is not properly configured. Both URI and authentication token are required.")
super().__init__(uri, pat)
7 changes: 7 additions & 0 deletions src/dremioai/servers/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mcp.server.fastmcp.resources import FunctionResource
from mcp.cli.claude import get_claude_config_path
from mcp.shared.auth import OAuthMetadata
from mcp.types import ToolAnnotations
from pydantic import AnyHttpUrl
from pydantic.networks import AnyUrl

Expand Down Expand Up @@ -145,12 +146,18 @@ def init(
if transport == Transports.streamable_http and support_project_id_endpoints:
mcp.support_project_id_endpoints = support_project_id_endpoints
mode = reduce(ior, mode) if mode is not None else None
allow_dml = settings.instance().dremio and settings.instance().dremio.allow_dml
for tool in tools.get_tools(For=mode):
tool_instance = tool()
is_sql_tool = tool is tools.RunSqlQuery
mcp.add_tool(
tool_instance.invoke,
name=tool.__name__,
description=tool_instance.invoke.__doc__,
annotations=ToolAnnotations(
readOnlyHint=not (is_sql_tool and allow_dml),
destructiveHint=bool(is_sql_tool and allow_dml),
),
)

for resource in tools.get_resources(For=mode):
Expand Down
45 changes: 35 additions & 10 deletions src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,23 @@ def ensure_query_allowed(s: str):

@secured
@with_metrics
async def invoke(self, s: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
"""Run a SELECT sql query on the Dremio cluster and return the results.
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
You are premitted to run only SELECT queries. No DML statements are allowed.

Args:
s: sql query
query: sql query
"""
RunSqlQuery.ensure_query_allowed(s)
try:
s = f"/* dremioai: submitter={self.__class__.__name__} */\n{s}"
df = await sql.run_query(query=s, use_df=True)
RunSqlQuery.ensure_query_allowed(query)
except ValueError:
return {
"error": "Only SELECT queries are allowed. DML statements are not permitted.",
}
try:
query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}"
df = await sql.run_query(query=query, use_df=True)
return {"result": _df_to_json_records(df)}
except RuntimeError as e:
return {
Expand Down Expand Up @@ -426,11 +431,16 @@ async def invoke(self) -> Dict[str, str]:
"""Gets the names of system tables in the dremio cluster, useful for various analysis.
Use Get Schema of Table tool to get the schema of the table"""
return {
f'information_schema."tables"': (
"Information about tables in this cluster."
"Be sure to filter out SYSTEM_TABLE for looking at user tables."
'INFORMATION_SCHEMA."TABLES"': (
"Information about tables in this cluster. "
"Be sure to filter out SYSTEM_TABLE for looking at user tables. "
"You must encapsulate TABLES in double quotes."
),
'sys.project.jobs_recent': "Recent job execution history including status, duration, user, and error details.",
'sys.project.engines': "Engine configuration and status for the project.",
'sys.organization.users': "Organization user information.",
'INFORMATION_SCHEMA."COLUMNS"': "Column-level metadata for all tables and views.",
'INFORMATION_SCHEMA."VIEWS"': "View definitions and metadata.",
}


Expand All @@ -443,16 +453,22 @@ async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
"""Gets the schema of the given table.

Args:
table_name: string with the name of the table, including the schema. Or list of paths that make up the table
table_name: The fully qualified table name. Accepts either:
- A dot-separated string: '"source"."schema"."table"'
- A list of path components: ["source", "schema", "table"]

Returns:
A dictionary with information about the table. The field "fields" is a list of dictionaries
that give column names and types. Optionally :"text" field and "tag" filed can provide more
information about the table
"""
if isinstance(table_name, list):
if not table_name:
return {"error": "table_name must not be empty. Provide a list of path components, e.g. ['source', 'schema', 'table']."}
paths = table_name
else:
if not table_name or not table_name.strip():
return {"error": "table_name must not be empty. Provide a dot-separated name, e.g. '\"source\".\"schema\".\"table\"'."}
paths = list(reader(StringIO(table_name), delimiter="."))
result = await get_schema(paths[0], include_tags=True)
if result and "sql" in result:
Expand All @@ -474,7 +490,14 @@ async def invoke(self, table_name: Union[str, List[str]]) -> Dict[str, Any]:
Returns:
A json representation with the lineage of the table or view.
"""
return await get_lineage(table_name)
try:
return await get_lineage(table_name)
except Exception as e:
logger.error(f"Lineage lookup failed for {table_name}: {e}")
return {
"error": "Unable to retrieve lineage for the specified table or view.",
"message": "The lineage lookup failed. Please verify the table name and try again.",
}


class SearchTableAndViews(Tools):
Expand Down Expand Up @@ -626,6 +649,8 @@ async def invoke(self, promql_query: str) -> Dict[str, Any]:
class GetDescriptionOfTableOrSchema(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]

@secured
@with_metrics
async def invoke(self, name: Union[List[str], str]) -> Dict[str, Any]:
"""
Given one or more table names or schema names, this will return the description of the table or schema, if any exists
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_e2e_pat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_tool_pat(mock_config_dir, logging_server, logging_level, project_
sf.mcp_server, token="my-token"
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
"RunSqlQuery", {"query": "SELECT 1"}
)
assert (
result is not None and result.structuredContent is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_mcp_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def test_wlm_engine_name(
sf.mcp_server, token="my-token"
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
"RunSqlQuery", {"query": "SELECT 1"}
)
assert (
result is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_metrics_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def test_metrics_with_tool_invocation(
sf.mcp_server, token="test-token"
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
"RunSqlQuery", {"query": "SELECT 1"}
)
assert result is not None and result.structuredContent is not None

Expand Down
Loading