Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions src/dremioai/servers/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,57 @@
from yaml import dump, add_representer
import sys

from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
get_access_token,
)
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken, TokenVerifier
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware


class Transports(StrEnum):
stdio = auto()
streamable_http = "streamable-http"


class FastMCPServerWithAuthToken(FastMCP):
class DelegatingTokenVerifier(TokenVerifier):
async def verify_token(self, token: str) -> AccessToken | None:
log.logger("verify_token").info(f"Verifying token: {token}")
return (
AccessToken(
token=token, # Include the token itself
client_id="unused-client",
scopes=["read"],
)
if token
else None
)

def streamable_http_app(self):
token_verifier = FastMCPServerWithAuthToken.DelegatingTokenVerifier()
app = super().streamable_http_app()
app.add_middleware(AuthContextMiddleware)
app.add_middleware(
AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)
)
log.logger("streamable_http_app").info(
f"Adding auth middleware {app.user_middleware}"
)
return app


def init(
uri: str = None,
pat: str = None,
project_id: str = None,
mode: Union[tools.ToolType, List[tools.ToolType]] = None,
transport: Transports = Transports.stdio,
) -> FastMCP:
mcp = FastMCP("Dremio", level="DEBUG")
mcp_cls = FastMCP if transport == Transports.stdio else FastMCPServerWithAuthToken
log.logger("init").info(
f"Initializing MCP server with mode={mode}, class={mcp_cls.__name__}"
)
mcp = mcp_cls("Dremio", level="DEBUG")
mode = reduce(ior, mode) if mode is not None else None
for tool in tools.get_tools(For=mode):
tool_instance = tool()
Expand Down Expand Up @@ -74,10 +117,6 @@ def init(


app = None
# if __name__ != "__main__":
# if mode := os.environ.get("MODE"):
# mode = [tools.ToolType[m.upper()] for m in ",".split(mode)]
# app = init(mode=mode)


def _mode() -> List[str]:
Expand All @@ -89,53 +128,26 @@ def _mode() -> List[str]:

@ty.command(name="run", help="Run the DremioAI MCP server")
def main(
dremio_uri: Annotated[Optional[str], Option(help="Dremio URI")] = None,
dremio_pat: Annotated[Optional[str], Option(help="Dremio PAT")] = None,
dremio_project_id: Annotated[
Optional[str], Option(help="Dremio Project Id")
] = None,
config_file: Annotated[
Optional[Path],
Option("-c", "--cfg", help="The config yaml for various options"),
] = None,
mode: Annotated[
Optional[List[str]],
Option("-m", "--mode", help="MCP server mode", click_type=Choice(_mode())),
] = None,
list_tools: Annotated[
bool, Option(help="List available tools for this mode and exit")
] = False,
log_to_file: Annotated[Optional[bool], Option(help="Log to file")] = True,
enable_json_logging: Annotated[
Optional[bool], Option(help="Enable JSON logs")
] = False,
enable_streaming_http: Annotated[
Optional[bool], Option(help="Run MCP as streaming HTTP")
] = False,
):
log.configure(enable_json_logging=enable_json_logging, to_file=log_to_file)
log.set_level("DEBUG")
if enable_streaming_http:
transport = Transports.streamable_http
else:
transport = Transports.stdio

if mode is not None:
mode = [tools.ToolType[m.upper()] for m in mode]

cfg = (
settings.configure(config_file)
.get()
.with_overrides(
{
"dremio.uri": dremio_uri,
"dremio.pat": dremio_pat,
"dremio.project_id": dremio_project_id,
"tools.server_mode": mode,
}
)
)
if list_tools:
log.logger().info(f"Starting Dremio tools with {cfg}")
mode = reduce(ior, mode) if mode is not None else None
log.logger().info(f"Listing available tools for mode={mode}")
for tool in tools.get_tools(For=mode):
print(tool.__name__)
return

cfg = settings.configure(config_file).get()
dremio = settings.instance().dremio
if (
dremio.oauth_supported
Expand All @@ -146,12 +158,10 @@ def main(
oauth.update_settings()

app = init(
uri=cfg.dremio.uri,
pat=cfg.dremio.pat,
project_id=cfg.dremio.project_id,
mode=cfg.tools.server_mode,
transport=transport,
)
app.run()
app.run(transport=transport.value)


tc = Typer(
Expand Down
71 changes: 31 additions & 40 deletions src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@
ClassVar,
get_args,
get_type_hints,
Callable,
TypeVar,
ParamSpec,
Awaitable,
)

from pathlib import Path
from dataclasses import dataclass, asdict, field
from enum import auto, IntFlag
from dremioai import log
import re
import functools

import pandas as pd

Expand All @@ -48,9 +53,16 @@
from io import StringIO
from sqlglot import parse_one
from sqlglot import expressions
from mcp.server.fastmcp.server import Context
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.provider import AccessToken

logger = log.logger(__name__)

# Type variables for the secured decorator
P = ParamSpec("P")
T = TypeVar("T")


@dataclass
class Property:
Expand Down Expand Up @@ -90,44 +102,25 @@ def as_dict(self) -> Dict[str, Any]:


class Tools:
def __init__(self, uri=None, pat=None, project_id=None):
settings.instance().with_overrides(
{"dremio.uri": uri, "dremio.pat": pat, "dremio.project_id": project_id}
)

@property
def dremio_uri(self):
return settings.instance().dremio.uri

@property
def pat(self):
return settings.instance().dremio.pat

@property
def project_id(self):
return settings.instance().dremio.project_id

async def invoke(self):
raise NotImplementedError("Subclasses should implement this method")

def get_parameters(self):
return Parameters()

# support for LangChain tools as compatiblity
def as_tool(self):
return Tool(
function=Function(
name=self.__class__.__name__,
description=self.invoke.__doc__,
parameters=self.get_parameters(),
)
)
# A decorator to ensure a tool that needs to access Dremio runs with the correct token
# if invoked through streamable HTTP transport _with_ a valid Dremio bearer token
# It is a no-op if the tool is invoked through stdio transport, as MCP server ensures
# proper PAT is used for all requests.
def secured(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:

@functools.wraps(fn)
async def _impl(self, *args: P.args, **kw: P.kwargs) -> T:
if isinstance((token := get_access_token()), AccessToken):
return await settings.run_with(
fn, {"dremio.pat": token.token}, (self,) + args, kw
)
return await fn(self, *args, **kw)

JobType: TypeAlias = Union[
List[Literal["UI", "ACCELERATION", "INTERNAL", "EXTERNAL"]], str
]
StatusType: TypeAlias = Union[List[Literal["COMPLETED", "CANCELED", "FAILED"]], str]
return _impl


def _get_class_var_hints(tool: Tools, name: str) -> bool:
Expand Down Expand Up @@ -164,6 +157,7 @@ class GetFailedJobDetails(Tools):
def group_by(self, df, by):
return df.groupby(by).size().reset_index(name="count").to_dict(orient="records")

@secured
async def invoke(self) -> Dict[str, Any]:
"""Get the stats and details of failed or canceled jobs executed in the Dremio cluster in the past 7 days
along with a split by job type
Expand Down Expand Up @@ -273,6 +267,7 @@ def ensure_query_allowed(s: str):
"The query contains a DML statement. Only select queries are allowed"
)

@secured
async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
"""Run a SELECT sql query on the Dremio cluster and return the results.
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
Expand All @@ -292,19 +287,12 @@ async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
"message": "The query failed. Please check the syntax and try again",
}

def get_parameters(self):
return Parameters(
properties={
"sql": Property(type="string", description="The sql query to run")
},
required=["sql"],
)


class BuildUsageReport(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF]]
project_id_required: ClassVar[Annotated[bool, True]]

@secured
async def invoke(
self, by: Optional[Literal["PROJECT", "ENGINE"]] = "ENGINE"
) -> Dict[str, Any]:
Expand Down Expand Up @@ -367,6 +355,7 @@ async def invoke(self) -> Dict[str, str]:
class GetSchemaOfTable(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]

@secured
async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
"""Gets the schema of the given table.

Expand All @@ -391,6 +380,7 @@ async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
class GetTableOrViewLineage(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]

@secured
async def invoke(self, table_name: Union[str, List[str]]) -> Dict[str, Any]:
"""Finds the lineage of a table or view in the Dremio cluster

Expand All @@ -411,6 +401,7 @@ class SearchTableAndViews(Tools):
]
]

@secured
async def invoke(self, query: str) -> Dict[str, Any]:
"""Runs a semantic search on the Dremio cluster to find tables and views that match the query.

Expand Down
47 changes: 47 additions & 0 deletions tests/stremable_http_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
MCP HTTP Streamable Client Example using Python SDK

This example demonstrates how to create an MCP client that connects to a server
using the Streamable HTTP transport protocol.
"""

import asyncio

from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from typer import Typer, Option
from rich import print as pp
from typing import Annotated, Optional

app = Typer(
no_args_is_help=True,
name="mcp-client",
help="Run simple mcp client",
context_settings=dict(help_option_names=["-h", "--help"]),
)


# Example usage and demonstration
async def cli(url, token):
async with streamablehttp_client(
url=url, headers={"Authorization": f"Bearer {token}"}
) as (read_stream, write_stream, gid):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
for t in await session.list_tools():
pp(t)
pp(await session.call_tool("RunSqlQuery", {"s": "SELECT 1"}))


@app.command()
def main(
token: Annotated[Optional[str], Option(help="The authorization token to use")],
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
):
asyncio.run(cli(url, token))


if __name__ == "__main__":
app()
Loading