Skip to content

Commit 6822e28

Browse files
Recognize bearer auth tokens for streamable http (#43)
* Accepting bearer auth header * unneeded tests
1 parent 61d8b38 commit 6822e28

5 files changed

Lines changed: 173 additions & 572 deletions

File tree

src/dremioai/servers/mcp.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,57 @@
3838
from yaml import dump, add_representer
3939
import sys
4040

41+
from mcp.server.auth.middleware.auth_context import (
42+
AuthContextMiddleware,
43+
get_access_token,
44+
)
45+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
46+
from mcp.server.auth.provider import AccessToken, TokenVerifier
47+
from starlette.middleware import Middleware
48+
from starlette.middleware.authentication import AuthenticationMiddleware
49+
50+
51+
class Transports(StrEnum):
52+
stdio = auto()
53+
streamable_http = "streamable-http"
54+
55+
56+
class FastMCPServerWithAuthToken(FastMCP):
57+
class DelegatingTokenVerifier(TokenVerifier):
58+
async def verify_token(self, token: str) -> AccessToken | None:
59+
log.logger("verify_token").info(f"Verifying token: {token}")
60+
return (
61+
AccessToken(
62+
token=token, # Include the token itself
63+
client_id="unused-client",
64+
scopes=["read"],
65+
)
66+
if token
67+
else None
68+
)
69+
70+
def streamable_http_app(self):
71+
token_verifier = FastMCPServerWithAuthToken.DelegatingTokenVerifier()
72+
app = super().streamable_http_app()
73+
app.add_middleware(AuthContextMiddleware)
74+
app.add_middleware(
75+
AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)
76+
)
77+
log.logger("streamable_http_app").info(
78+
f"Adding auth middleware {app.user_middleware}"
79+
)
80+
return app
81+
4182

4283
def init(
43-
uri: str = None,
44-
pat: str = None,
45-
project_id: str = None,
4684
mode: Union[tools.ToolType, List[tools.ToolType]] = None,
85+
transport: Transports = Transports.stdio,
4786
) -> FastMCP:
48-
mcp = FastMCP("Dremio", level="DEBUG")
87+
mcp_cls = FastMCP if transport == Transports.stdio else FastMCPServerWithAuthToken
88+
log.logger("init").info(
89+
f"Initializing MCP server with mode={mode}, class={mcp_cls.__name__}"
90+
)
91+
mcp = mcp_cls("Dremio", level="DEBUG")
4992
mode = reduce(ior, mode) if mode is not None else None
5093
for tool in tools.get_tools(For=mode):
5194
tool_instance = tool()
@@ -74,10 +117,6 @@ def init(
74117

75118

76119
app = None
77-
# if __name__ != "__main__":
78-
# if mode := os.environ.get("MODE"):
79-
# mode = [tools.ToolType[m.upper()] for m in ",".split(mode)]
80-
# app = init(mode=mode)
81120

82121

83122
def _mode() -> List[str]:
@@ -89,53 +128,26 @@ def _mode() -> List[str]:
89128

90129
@ty.command(name="run", help="Run the DremioAI MCP server")
91130
def main(
92-
dremio_uri: Annotated[Optional[str], Option(help="Dremio URI")] = None,
93-
dremio_pat: Annotated[Optional[str], Option(help="Dremio PAT")] = None,
94-
dremio_project_id: Annotated[
95-
Optional[str], Option(help="Dremio Project Id")
96-
] = None,
97131
config_file: Annotated[
98132
Optional[Path],
99133
Option("-c", "--cfg", help="The config yaml for various options"),
100134
] = None,
101-
mode: Annotated[
102-
Optional[List[str]],
103-
Option("-m", "--mode", help="MCP server mode", click_type=Choice(_mode())),
104-
] = None,
105-
list_tools: Annotated[
106-
bool, Option(help="List available tools for this mode and exit")
107-
] = False,
108135
log_to_file: Annotated[Optional[bool], Option(help="Log to file")] = True,
109136
enable_json_logging: Annotated[
110137
Optional[bool], Option(help="Enable JSON logs")
111138
] = False,
139+
enable_streaming_http: Annotated[
140+
Optional[bool], Option(help="Run MCP as streaming HTTP")
141+
] = False,
112142
):
113143
log.configure(enable_json_logging=enable_json_logging, to_file=log_to_file)
114144
log.set_level("DEBUG")
145+
if enable_streaming_http:
146+
transport = Transports.streamable_http
147+
else:
148+
transport = Transports.stdio
115149

116-
if mode is not None:
117-
mode = [tools.ToolType[m.upper()] for m in mode]
118-
119-
cfg = (
120-
settings.configure(config_file)
121-
.get()
122-
.with_overrides(
123-
{
124-
"dremio.uri": dremio_uri,
125-
"dremio.pat": dremio_pat,
126-
"dremio.project_id": dremio_project_id,
127-
"tools.server_mode": mode,
128-
}
129-
)
130-
)
131-
if list_tools:
132-
log.logger().info(f"Starting Dremio tools with {cfg}")
133-
mode = reduce(ior, mode) if mode is not None else None
134-
log.logger().info(f"Listing available tools for mode={mode}")
135-
for tool in tools.get_tools(For=mode):
136-
print(tool.__name__)
137-
return
138-
150+
cfg = settings.configure(config_file).get()
139151
dremio = settings.instance().dremio
140152
if (
141153
dremio.oauth_supported
@@ -146,12 +158,10 @@ def main(
146158
oauth.update_settings()
147159

148160
app = init(
149-
uri=cfg.dremio.uri,
150-
pat=cfg.dremio.pat,
151-
project_id=cfg.dremio.project_id,
152161
mode=cfg.tools.server_mode,
162+
transport=transport,
153163
)
154-
app.run()
164+
app.run(transport=transport.value)
155165

156166

157167
tc = Typer(

src/dremioai/tools/tools.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@
2626
ClassVar,
2727
get_args,
2828
get_type_hints,
29+
Callable,
30+
TypeVar,
31+
ParamSpec,
32+
Awaitable,
2933
)
3034

3135
from pathlib import Path
3236
from dataclasses import dataclass, asdict, field
3337
from enum import auto, IntFlag
3438
from dremioai import log
3539
import re
40+
import functools
3641

3742
import pandas as pd
3843

@@ -48,9 +53,16 @@
4853
from io import StringIO
4954
from sqlglot import parse_one
5055
from sqlglot import expressions
56+
from mcp.server.fastmcp.server import Context
57+
from mcp.server.auth.middleware.auth_context import get_access_token
58+
from mcp.server.auth.provider import AccessToken
5159

5260
logger = log.logger(__name__)
5361

62+
# Type variables for the secured decorator
63+
P = ParamSpec("P")
64+
T = TypeVar("T")
65+
5466

5567
@dataclass
5668
class Property:
@@ -90,44 +102,25 @@ def as_dict(self) -> Dict[str, Any]:
90102

91103

92104
class Tools:
93-
def __init__(self, uri=None, pat=None, project_id=None):
94-
settings.instance().with_overrides(
95-
{"dremio.uri": uri, "dremio.pat": pat, "dremio.project_id": project_id}
96-
)
97-
98-
@property
99-
def dremio_uri(self):
100-
return settings.instance().dremio.uri
101-
102-
@property
103-
def pat(self):
104-
return settings.instance().dremio.pat
105-
106-
@property
107-
def project_id(self):
108-
return settings.instance().dremio.project_id
109-
110105
async def invoke(self):
111106
raise NotImplementedError("Subclasses should implement this method")
112107

113-
def get_parameters(self):
114-
return Parameters()
115108

116-
# support for LangChain tools as compatiblity
117-
def as_tool(self):
118-
return Tool(
119-
function=Function(
120-
name=self.__class__.__name__,
121-
description=self.invoke.__doc__,
122-
parameters=self.get_parameters(),
123-
)
124-
)
109+
# A decorator to ensure a tool that needs to access Dremio runs with the correct token
110+
# if invoked through streamable HTTP transport _with_ a valid Dremio bearer token
111+
# It is a no-op if the tool is invoked through stdio transport, as MCP server ensures
112+
# proper PAT is used for all requests.
113+
def secured(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
125114

115+
@functools.wraps(fn)
116+
async def _impl(self, *args: P.args, **kw: P.kwargs) -> T:
117+
if isinstance((token := get_access_token()), AccessToken):
118+
return await settings.run_with(
119+
fn, {"dremio.pat": token.token}, (self,) + args, kw
120+
)
121+
return await fn(self, *args, **kw)
126122

127-
JobType: TypeAlias = Union[
128-
List[Literal["UI", "ACCELERATION", "INTERNAL", "EXTERNAL"]], str
129-
]
130-
StatusType: TypeAlias = Union[List[Literal["COMPLETED", "CANCELED", "FAILED"]], str]
123+
return _impl
131124

132125

133126
def _get_class_var_hints(tool: Tools, name: str) -> bool:
@@ -164,6 +157,7 @@ class GetFailedJobDetails(Tools):
164157
def group_by(self, df, by):
165158
return df.groupby(by).size().reset_index(name="count").to_dict(orient="records")
166159

160+
@secured
167161
async def invoke(self) -> Dict[str, Any]:
168162
"""Get the stats and details of failed or canceled jobs executed in the Dremio cluster in the past 7 days
169163
along with a split by job type
@@ -273,6 +267,7 @@ def ensure_query_allowed(s: str):
273267
"The query contains a DML statement. Only select queries are allowed"
274268
)
275269

270+
@secured
276271
async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
277272
"""Run a SELECT sql query on the Dremio cluster and return the results.
278273
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
@@ -292,19 +287,12 @@ async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
292287
"message": "The query failed. Please check the syntax and try again",
293288
}
294289

295-
def get_parameters(self):
296-
return Parameters(
297-
properties={
298-
"sql": Property(type="string", description="The sql query to run")
299-
},
300-
required=["sql"],
301-
)
302-
303290

304291
class BuildUsageReport(Tools):
305292
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF]]
306293
project_id_required: ClassVar[Annotated[bool, True]]
307294

295+
@secured
308296
async def invoke(
309297
self, by: Optional[Literal["PROJECT", "ENGINE"]] = "ENGINE"
310298
) -> Dict[str, Any]:
@@ -367,6 +355,7 @@ async def invoke(self) -> Dict[str, str]:
367355
class GetSchemaOfTable(Tools):
368356
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]
369357

358+
@secured
370359
async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
371360
"""Gets the schema of the given table.
372361
@@ -391,6 +380,7 @@ async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
391380
class GetTableOrViewLineage(Tools):
392381
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]
393382

383+
@secured
394384
async def invoke(self, table_name: Union[str, List[str]]) -> Dict[str, Any]:
395385
"""Finds the lineage of a table or view in the Dremio cluster
396386
@@ -411,6 +401,7 @@ class SearchTableAndViews(Tools):
411401
]
412402
]
413403

404+
@secured
414405
async def invoke(self, query: str) -> Dict[str, Any]:
415406
"""Runs a semantic search on the Dremio cluster to find tables and views that match the query.
416407

tests/stremable_http_cli.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
MCP HTTP Streamable Client Example using Python SDK
3+
4+
This example demonstrates how to create an MCP client that connects to a server
5+
using the Streamable HTTP transport protocol.
6+
"""
7+
8+
import asyncio
9+
10+
from mcp import ClientSession, types
11+
from mcp.client.streamable_http import streamablehttp_client
12+
from typer import Typer, Option
13+
from rich import print as pp
14+
from typing import Annotated, Optional
15+
16+
app = Typer(
17+
no_args_is_help=True,
18+
name="mcp-client",
19+
help="Run simple mcp client",
20+
context_settings=dict(help_option_names=["-h", "--help"]),
21+
)
22+
23+
24+
# Example usage and demonstration
25+
async def cli(url, token):
26+
async with streamablehttp_client(
27+
url=url, headers={"Authorization": f"Bearer {token}"}
28+
) as (read_stream, write_stream, gid):
29+
async with ClientSession(read_stream, write_stream) as session:
30+
await session.initialize()
31+
for t in await session.list_tools():
32+
pp(t)
33+
pp(await session.call_tool("RunSqlQuery", {"s": "SELECT 1"}))
34+
35+
36+
@app.command()
37+
def main(
38+
token: Annotated[Optional[str], Option(help="The authorization token to use")],
39+
url: Annotated[
40+
Optional[str], Option(help="The URL of the MCP server")
41+
] = "http://127.0.0.1:8000/mcp",
42+
):
43+
asyncio.run(cli(url, token))
44+
45+
46+
if __name__ == "__main__":
47+
app()

0 commit comments

Comments
 (0)