1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515#
16-
16+ import json
17+ from contextvars import ContextVar
1718from typing import (
1819 List ,
1920 Dict ,
2021 Any ,
2122 Optional ,
2223 Literal ,
23- TypeAlias ,
2424 Union ,
2525 Annotated ,
2626 ClassVar ,
3232 Awaitable ,
3333)
3434
35- from pathlib import Path
3635from dataclasses import dataclass , asdict , field
37- from enum import auto , IntFlag
36+
37+ from starlette .datastructures import URL
38+ from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
39+ from starlette .requests import Request
40+
3841from dremioai import log
3942import re
4043import functools
4144
4245import pandas as pd
43-
44- from pathlib import Path
45-
46- from dremioai .api .dremio import sql , projects , usage , engines , search
46+ from dremioai .api .dremio import sql , usage , search
4747from dremioai .config import settings
4848from dremioai .config .tools import ToolType
4949from dremioai .api .prometheus import vm
5353from io import StringIO
5454from sqlglot import parse_one
5555from sqlglot import expressions
56- from mcp .server .fastmcp .server import Context
5756from mcp .server .auth .middleware .auth_context import get_access_token
5857from mcp .server .auth .provider import AccessToken
5958
@@ -106,6 +105,31 @@ async def invoke(self):
106105 raise NotImplementedError ("Subclasses should implement this method" )
107106
108107
108+ class ProjectIdMiddleware (BaseHTTPMiddleware ):
109+ pat = re .compile (r"/mcp/([\da-z-]+)(/?.*)" )
110+ logger = log .logger ("ProjectIdMiddleware" )
111+
112+ # Context variable to store the current project ID
113+ project_id_context : ContextVar [str | None ] = ContextVar ("project_id" , default = None )
114+
115+ @classmethod
116+ def get_project_id (cls ) -> Optional [str ]:
117+ return cls .project_id_context .get ()
118+
119+ async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
120+ ProjectIdMiddleware .logger .info (
121+ f"Request { request .url .path } body = { await request .body ()!s} "
122+ )
123+ if m := ProjectIdMiddleware .pat .search (request .url .path ):
124+ ProjectIdMiddleware .project_id_context .set (m .group (1 ))
125+ else :
126+ ProjectIdMiddleware .logger .debug (
127+ f"Path { request .url .path } ({ request .url !r} ) doesn't match"
128+ )
129+
130+ return await call_next (request )
131+
132+
109133# A decorator to ensure a tool that needs to access Dremio runs with the correct token
110134# if invoked through streamable HTTP transport _with_ a valid Dremio bearer token
111135# It is a no-op if the tool is invoked through stdio transport, as MCP server ensures
@@ -114,11 +138,22 @@ def secured(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
114138
115139 @functools .wraps (fn )
116140 async def _impl (self , * args : P .args , ** kw : P .kwargs ) -> T :
141+ overrides = {}
117142 if isinstance ((token := get_access_token ()), AccessToken ):
118- return await settings .run_with (
119- fn , {"dremio.pat" : token .token }, (self ,) + args , kw
143+ overrides ["dremio.pat" ] = token .token
144+ logger .debug (
145+ f"Overriding PAT with token from request: { token .token [:4 ]} ..."
120146 )
121- return await fn (self , * args , ** kw )
147+
148+ if project_id := ProjectIdMiddleware .get_project_id ():
149+ overrides ["dremio.project_id" ] = project_id
150+ logger .debug (f"Overriding project_id with { project_id } " )
151+
152+ return (
153+ await settings .run_with (fn , overrides , (self ,) + args , kw )
154+ if overrides
155+ else await fn (self , * args , ** kw )
156+ )
122157
123158 return _impl
124159
0 commit comments