2626 ClassVar ,
2727 get_args ,
2828 get_type_hints ,
29+ Callable ,
30+ TypeVar ,
31+ ParamSpec ,
32+ Awaitable ,
2933)
3034
3135from pathlib import Path
3236from dataclasses import dataclass , asdict , field
3337from enum import auto , IntFlag
3438from dremioai import log
3539import re
40+ import functools
3641
3742import pandas as pd
3843
4853from io import StringIO
4954from sqlglot import parse_one
5055from sqlglot import expressions
56+ from mcp .server .fastmcp .server import Context
57+ from mcp .server .auth .middleware .auth_context import get_access_token
58+ from mcp .server .auth .provider import AccessToken
5159
5260logger = log .logger (__name__ )
5361
62+ # Type variables for the secured decorator
63+ P = ParamSpec ("P" )
64+ T = TypeVar ("T" )
65+
5466
5567@dataclass
5668class Property :
@@ -90,44 +102,25 @@ def as_dict(self) -> Dict[str, Any]:
90102
91103
92104class Tools :
93- def __init__ (self , uri = None , pat = None , project_id = None ):
94- settings .instance ().with_overrides (
95- {"dremio.uri" : uri , "dremio.pat" : pat , "dremio.project_id" : project_id }
96- )
97-
98- @property
99- def dremio_uri (self ):
100- return settings .instance ().dremio .uri
101-
102- @property
103- def pat (self ):
104- return settings .instance ().dremio .pat
105-
106- @property
107- def project_id (self ):
108- return settings .instance ().dremio .project_id
109-
110105 async def invoke (self ):
111106 raise NotImplementedError ("Subclasses should implement this method" )
112107
113- def get_parameters (self ):
114- return Parameters ()
115108
116- # support for LangChain tools as compatiblity
117- def as_tool (self ):
118- return Tool (
119- function = Function (
120- name = self .__class__ .__name__ ,
121- description = self .invoke .__doc__ ,
122- parameters = self .get_parameters (),
123- )
124- )
109+ # A decorator to ensure a tool that needs to access Dremio runs with the correct token
110+ # if invoked through streamable HTTP transport _with_ a valid Dremio bearer token
111+ # It is a no-op if the tool is invoked through stdio transport, as MCP server ensures
112+ # proper PAT is used for all requests.
113+ def secured (fn : Callable [P , Awaitable [T ]]) -> Callable [P , Awaitable [T ]]:
125114
115+ @functools .wraps (fn )
116+ async def _impl (self , * args : P .args , ** kw : P .kwargs ) -> T :
117+ if isinstance ((token := get_access_token ()), AccessToken ):
118+ return await settings .run_with (
119+ fn , {"dremio.pat" : token .token }, (self ,) + args , kw
120+ )
121+ return await fn (self , * args , ** kw )
126122
127- JobType : TypeAlias = Union [
128- List [Literal ["UI" , "ACCELERATION" , "INTERNAL" , "EXTERNAL" ]], str
129- ]
130- StatusType : TypeAlias = Union [List [Literal ["COMPLETED" , "CANCELED" , "FAILED" ]], str ]
123+ return _impl
131124
132125
133126def _get_class_var_hints (tool : Tools , name : str ) -> bool :
@@ -164,6 +157,7 @@ class GetFailedJobDetails(Tools):
164157 def group_by (self , df , by ):
165158 return df .groupby (by ).size ().reset_index (name = "count" ).to_dict (orient = "records" )
166159
160+ @secured
167161 async def invoke (self ) -> Dict [str , Any ]:
168162 """Get the stats and details of failed or canceled jobs executed in the Dremio cluster in the past 7 days
169163 along with a split by job type
@@ -273,6 +267,7 @@ def ensure_query_allowed(s: str):
273267 "The query contains a DML statement. Only select queries are allowed"
274268 )
275269
270+ @secured
276271 async def invoke (self , s : str ) -> Dict [str , List [Dict [Any , Any ]]]:
277272 """Run a SELECT sql query on the Dremio cluster and return the results.
278273 Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
@@ -292,19 +287,12 @@ async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
292287 "message" : "The query failed. Please check the syntax and try again" ,
293288 }
294289
295- def get_parameters (self ):
296- return Parameters (
297- properties = {
298- "sql" : Property (type = "string" , description = "The sql query to run" )
299- },
300- required = ["sql" ],
301- )
302-
303290
304291class BuildUsageReport (Tools ):
305292 For : ClassVar [Annotated [ToolType , ToolType .FOR_SELF ]]
306293 project_id_required : ClassVar [Annotated [bool , True ]]
307294
295+ @secured
308296 async def invoke (
309297 self , by : Optional [Literal ["PROJECT" , "ENGINE" ]] = "ENGINE"
310298 ) -> Dict [str , Any ]:
@@ -367,6 +355,7 @@ async def invoke(self) -> Dict[str, str]:
367355class GetSchemaOfTable (Tools ):
368356 For : ClassVar [Annotated [ToolType , ToolType .FOR_SELF | ToolType .FOR_DATA_PATTERNS ]]
369357
358+ @secured
370359 async def invoke (self , table_name : Union [str | List [str ]]) -> Dict [str , Any ]:
371360 """Gets the schema of the given table.
372361
@@ -391,6 +380,7 @@ async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
391380class GetTableOrViewLineage (Tools ):
392381 For : ClassVar [Annotated [ToolType , ToolType .FOR_SELF | ToolType .FOR_DATA_PATTERNS ]]
393382
383+ @secured
394384 async def invoke (self , table_name : Union [str , List [str ]]) -> Dict [str , Any ]:
395385 """Finds the lineage of a table or view in the Dremio cluster
396386
@@ -411,6 +401,7 @@ class SearchTableAndViews(Tools):
411401 ]
412402 ]
413403
404+ @secured
414405 async def invoke (self , query : str ) -> Dict [str , Any ]:
415406 """Runs a semantic search on the Dremio cluster to find tables and views that match the query.
416407
0 commit comments