Skip to content

Commit f585af5

Browse files
DX-115085: MCP server hardening — auth, validation, annotations, and smoketests (#86)
Critical fixes: - Add @secured and @with_metrics decorators to GetDescriptionOfTableOrSchema so OAuth tokens are properly injected - Add input validation to GetSchemaOfTable for empty string/list inputs - Add MCP ToolAnnotations (readOnlyHint/destructiveHint) to all tools; RunSqlQuery destructiveHint is conditional on allow_dml setting Medium fixes: - Sanitize error message in DremioAsyncHttpClient to not leak URI/PAT - Wrap GetTableOrViewLineage in try/except with sanitized error response - Return clean error dict for DML rejection in RunSqlQuery - Rename RunSqlQuery parameter from 's' to 'query' for clarity - Update GetSchemaOfTable docstring with explicit format examples - Expand GetUsefulSystemTableNames from 1 to 6 entries (jobs_recent, engines, users, COLUMNS, VIEWS) Smoketest improvements: - Add 11 smoketest cases to stremable_http_cli.py covering all fixes - Add --token, --local, --check-annotations, --check-new-contract flags - Add _local_mcp_server context manager with ContextVar propagation to server thread via settings.with_overrides()
1 parent b4b76e9 commit f585af5

9 files changed

Lines changed: 440 additions & 40 deletions

File tree

src/dremioai/api/transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,5 +222,5 @@ def __init__(self):
222222
pat = dremio.pat
223223

224224
if uri is None or pat is None:
225-
raise RuntimeError(f"uri={uri} pat={pat} are required")
225+
raise RuntimeError("Dremio connection is not properly configured. Both URI and authentication token are required.")
226226
super().__init__(uri, pat)

src/dremioai/servers/mcp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mcp.server.fastmcp.resources import FunctionResource
2020
from mcp.cli.claude import get_claude_config_path
2121
from mcp.shared.auth import OAuthMetadata
22+
from mcp.types import ToolAnnotations
2223
from pydantic import AnyHttpUrl
2324
from pydantic.networks import AnyUrl
2425

@@ -145,12 +146,18 @@ def init(
145146
if transport == Transports.streamable_http and support_project_id_endpoints:
146147
mcp.support_project_id_endpoints = support_project_id_endpoints
147148
mode = reduce(ior, mode) if mode is not None else None
149+
allow_dml = settings.instance().dremio and settings.instance().dremio.allow_dml
148150
for tool in tools.get_tools(For=mode):
149151
tool_instance = tool()
152+
is_sql_tool = tool is tools.RunSqlQuery
150153
mcp.add_tool(
151154
tool_instance.invoke,
152155
name=tool.__name__,
153156
description=tool_instance.invoke.__doc__,
157+
annotations=ToolAnnotations(
158+
readOnlyHint=not (is_sql_tool and allow_dml),
159+
destructiveHint=bool(is_sql_tool and allow_dml),
160+
),
154161
)
155162

156163
for resource in tools.get_resources(For=mode):

src/dremioai/tools/tools.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,18 +349,23 @@ def ensure_query_allowed(s: str):
349349

350350
@secured
351351
@with_metrics
352-
async def invoke(self, s: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
352+
async def invoke(self, query: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
353353
"""Run a SELECT sql query on the Dremio cluster and return the results.
354354
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
355355
You are premitted to run only SELECT queries. No DML statements are allowed.
356356
357357
Args:
358-
s: sql query
358+
query: sql query
359359
"""
360-
RunSqlQuery.ensure_query_allowed(s)
361360
try:
362-
s = f"/* dremioai: submitter={self.__class__.__name__} */\n{s}"
363-
df = await sql.run_query(query=s, use_df=True)
361+
RunSqlQuery.ensure_query_allowed(query)
362+
except ValueError:
363+
return {
364+
"error": "Only SELECT queries are allowed. DML statements are not permitted.",
365+
}
366+
try:
367+
query = f"/* dremioai: submitter={self.__class__.__name__} */\n{query}"
368+
df = await sql.run_query(query=query, use_df=True)
364369
return {"result": _df_to_json_records(df)}
365370
except RuntimeError as e:
366371
return {
@@ -426,11 +431,16 @@ async def invoke(self) -> Dict[str, str]:
426431
"""Gets the names of system tables in the dremio cluster, useful for various analysis.
427432
Use Get Schema of Table tool to get the schema of the table"""
428433
return {
429-
f'information_schema."tables"': (
430-
"Information about tables in this cluster."
431-
"Be sure to filter out SYSTEM_TABLE for looking at user tables."
434+
'INFORMATION_SCHEMA."TABLES"': (
435+
"Information about tables in this cluster. "
436+
"Be sure to filter out SYSTEM_TABLE for looking at user tables. "
432437
"You must encapsulate TABLES in double quotes."
433438
),
439+
'sys.project.jobs_recent': "Recent job execution history including status, duration, user, and error details.",
440+
'sys.project.engines': "Engine configuration and status for the project.",
441+
'sys.organization.users': "Organization user information.",
442+
'INFORMATION_SCHEMA."COLUMNS"': "Column-level metadata for all tables and views.",
443+
'INFORMATION_SCHEMA."VIEWS"': "View definitions and metadata.",
434444
}
435445

436446

@@ -443,16 +453,22 @@ async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
443453
"""Gets the schema of the given table.
444454
445455
Args:
446-
table_name: string with the name of the table, including the schema. Or list of paths that make up the table
456+
table_name: The fully qualified table name. Accepts either:
457+
- A dot-separated string: '"source"."schema"."table"'
458+
- A list of path components: ["source", "schema", "table"]
447459
448460
Returns:
449461
A dictionary with information about the table. The field "fields" is a list of dictionaries
450462
that give column names and types. Optionally :"text" field and "tag" filed can provide more
451463
information about the table
452464
"""
453465
if isinstance(table_name, list):
466+
if not table_name:
467+
return {"error": "table_name must not be empty. Provide a list of path components, e.g. ['source', 'schema', 'table']."}
454468
paths = table_name
455469
else:
470+
if not table_name or not table_name.strip():
471+
return {"error": "table_name must not be empty. Provide a dot-separated name, e.g. '\"source\".\"schema\".\"table\"'."}
456472
paths = list(reader(StringIO(table_name), delimiter="."))
457473
result = await get_schema(paths[0], include_tags=True)
458474
if result and "sql" in result:
@@ -474,7 +490,14 @@ async def invoke(self, table_name: Union[str, List[str]]) -> Dict[str, Any]:
474490
Returns:
475491
A json representation with the lineage of the table or view.
476492
"""
477-
return await get_lineage(table_name)
493+
try:
494+
return await get_lineage(table_name)
495+
except Exception as e:
496+
logger.error(f"Lineage lookup failed for {table_name}: {e}")
497+
return {
498+
"error": "Unable to retrieve lineage for the specified table or view.",
499+
"message": "The lineage lookup failed. Please verify the table name and try again.",
500+
}
478501

479502

480503
class SearchTableAndViews(Tools):
@@ -626,6 +649,8 @@ async def invoke(self, promql_query: str) -> Dict[str, Any]:
626649
class GetDescriptionOfTableOrSchema(Tools):
627650
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]
628651

652+
@secured
653+
@with_metrics
629654
async def invoke(self, name: Union[List[str], str]) -> Dict[str, Any]:
630655
"""
631656
Given one or more table names or schema names, this will return the description of the table or schema, if any exists

tests/e2e/test_e2e_pat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def test_tool_pat(mock_config_dir, logging_server, logging_level, project_
4141
sf.mcp_server, token="my-token"
4242
) as session:
4343
result: CallToolResult = await session.call_tool(
44-
"RunSqlQuery", {"s": "SELECT 1"}
44+
"RunSqlQuery", {"query": "SELECT 1"}
4545
)
4646
assert (
4747
result is not None and result.structuredContent is not None

tests/e2e/test_mcp_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def test_wlm_engine_name(
6969
sf.mcp_server, token="my-token"
7070
) as session:
7171
result: CallToolResult = await session.call_tool(
72-
"RunSqlQuery", {"s": "SELECT 1"}
72+
"RunSqlQuery", {"query": "SELECT 1"}
7373
)
7474
assert (
7575
result is not None

tests/e2e/test_metrics_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def test_metrics_with_tool_invocation(
107107
sf.mcp_server, token="test-token"
108108
) as session:
109109
result: CallToolResult = await session.call_tool(
110-
"RunSqlQuery", {"s": "SELECT 1"}
110+
"RunSqlQuery", {"query": "SELECT 1"}
111111
)
112112
assert result is not None and result.structuredContent is not None
113113

0 commit comments

Comments
 (0)