Skip to content

Commit d0ab19b

Browse files
Support dynamic project id (#50)
* Supporting e2e server cleanup and middleware for dynamic project id * Support dynamic_project id as a setting and project id validations * Correcting env * Correcting env
1 parent 30a432b commit d0ab19b

14 files changed

Lines changed: 293 additions & 125 deletions

File tree

.github/workflows/docker.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ jobs:
2020
tags: dremio-mcp-server:${{ github.sha }}
2121
- name: Validate docker
2222
run: |
23-
docker run -e TOOLS_MODE=FOR_DATA_PATTERNS \
24-
-e DREMIO_URI=https://fake \
25-
-e DREMIO_OAUTH_SUPPORTED=false \
23+
docker run \
24+
-e DREMIOAI_TOOLS__SERVER_MODE=FOR_DATA_PATTERNS \
25+
-e DREMIOAI_DREMIO__URI=https://fake \
26+
-e DREMIOAI_DREMIO__OAUTH_SUPPORTED=false \
2627
dremio-mcp-server:${{ github.sha }} \
2728
dremio-mcp-server tools list
2829
- name: Install uv
@@ -37,9 +38,10 @@ jobs:
3738

3839
- name: Start container
3940
run: |
40-
docker run -e TOOLS_MODE=FOR_DATA_PATTERNS \
41-
-e DREMIO_URI=https://fake \
42-
-e DREMIO_OAUTH_SUPPORTED=false \
41+
docker run \
42+
-e DREMIOAI_TOOLS__SERVER_MODE=FOR_DATA_PATTERNS \
43+
-e DREMIOAI_DREMIO__URI=https://fake \
44+
-e DREMIOAI_DREMIO__OAUTH_SUPPORTED=false \
4345
-p 6789:6789 \
4446
--name mcp --rm \
4547
--network host \

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ jobs:
2626

2727
- name: Run tests
2828
run: |
29-
make test
29+
uv run pytest tests

Makefile

Lines changed: 0 additions & 30 deletions
This file was deleted.

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ asyncio_default_fixture_loop_scope = function
88

99
# Display summary info for skipped, xfailed, xpassed tests
1010
# along with the percentage of passing tests at the end
11-
addopts = -v --showlocals
11+
addopts = -v --showlocals -x

src/dremioai/config/settings.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import uuid
17+
from uuid import UUID
1618

1719
from pydantic import (
1820
Field,
@@ -24,7 +26,7 @@
2426
AliasChoices,
2527
)
2628
from pydantic_settings import BaseSettings, SettingsConfigDict
27-
from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable
29+
from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable, Literal
2830
from dremioai.config.tools import ToolType
2931
from enum import auto, StrEnum
3032
from pathlib import Path
@@ -37,6 +39,8 @@
3739
from importlib.util import find_spec
3840
from datetime import datetime
3941

42+
ProjectId = Union[UUID, Literal["DREMIO_DYNAMIC"]]
43+
4044

4145
def _resolve_tools_settings(server_mode: Union[ToolType, int, str]) -> ToolType:
4246
if isinstance(server_mode, str):
@@ -120,7 +124,7 @@ class Dremio(BaseModel):
120124
Union[str, HttpUrl, DremioCloudUri], AfterValidator(_resolve_dremio_uri)
121125
]
122126
raw_pat: Optional[str] = Field(default=None, alias="pat")
123-
project_id: Optional[str] = None
127+
raw_project_id: Optional[ProjectId] = Field(default=None, alias="project_id")
124128
enable_search: Optional[bool] = Field(
125129
default=False,
126130
alias=AliasChoices("enable_search", "enable_experimental"),
@@ -142,11 +146,13 @@ def oauth_configured(self) -> bool:
142146
def oauth_supported(self) -> bool:
143147
return self.project_id is not None
144148

145-
# @field_validator("_pat", mode="wrap")
146-
# @classmethod
147-
# def validate_pat(cls, v: str, handler: ValidatorFunctionWrapHandler) -> str:
148-
# v = _resolve_token_file(v)
149-
# return handler(v)
149+
@property
150+
def project_id(self) -> Optional[str]:
151+
return str(self.raw_project_id) if self.raw_project_id else None
152+
153+
@project_id.setter
154+
def project_id(self, v: str):
155+
self.raw_project_id = uuid.UUID(v)
150156

151157
@property
152158
def pat(self) -> str:
@@ -229,7 +235,8 @@ class Settings(BaseSettings):
229235
beeai: Optional[BeeAI] = Field(default=None)
230236
model_config = SettingsConfigDict(
231237
env_file=".env",
232-
env_nested_delimiter="_",
238+
env_nested_delimiter="__",
239+
env_prefix="DREMIOAI_",
233240
env_extra="allow",
234241
use_enum_values=True,
235242
)

src/dremioai/servers/mcp.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from mcp.server.auth.provider import AccessToken, TokenVerifier
4747
from starlette.middleware.authentication import AuthenticationMiddleware
4848

49+
from dremioai.tools.tools import ProjectIdMiddleware
50+
4951

5052
class Transports(StrEnum):
5153
stdio = auto()
@@ -73,16 +75,23 @@ def streamable_http_app(self):
7375
app.add_middleware(
7476
AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)
7577
)
76-
log.logger("streamable_http_app").info(
77-
f"Adding auth middleware {app.user_middleware}"
78-
)
78+
if self.support_project_id_endpoints:
79+
# this means, dynamically allow endpoints
80+
# like ../mcp/{project_id}/.. and extract that project id as
81+
# context var
82+
app.add_middleware(ProjectIdMiddleware)
7983
return app
8084

85+
def __init__(self, *args, **kwargs):
86+
super().__init__(*args, **kwargs)
87+
self.support_project_id_endpoints = False
88+
8189

8290
def init(
8391
mode: Union[tools.ToolType, List[tools.ToolType]] = None,
8492
transport: Transports = Transports.stdio,
8593
port: int = None,
94+
support_project_id_endpoints: bool = False,
8695
) -> FastMCP:
8796
mcp_cls = FastMCP if transport == Transports.stdio else FastMCPServerWithAuthToken
8897
log.logger("init").info(
@@ -92,6 +101,8 @@ def init(
92101
if port is not None:
93102
opts["port"] = port
94103
mcp = mcp_cls("Dremio", **opts)
104+
if transport == Transports.streamable_http and support_project_id_endpoints:
105+
mcp.support_project_id_endpoints = support_project_id_endpoints
95106
mode = reduce(ior, mode) if mode is not None else None
96107
for tool in tools.get_tools(For=mode):
97108
tool_instance = tool()

src/dremioai/tools/tools.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
16+
import json
17+
from contextvars import ContextVar
1718
from typing import (
1819
List,
1920
Dict,
2021
Any,
2122
Optional,
2223
Literal,
23-
TypeAlias,
2424
Union,
2525
Annotated,
2626
ClassVar,
@@ -32,18 +32,18 @@
3232
Awaitable,
3333
)
3434

35-
from pathlib import Path
3635
from dataclasses import dataclass, asdict, field
37-
from enum import auto, IntFlag
36+
37+
from starlette.datastructures import URL
38+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
39+
from starlette.requests import Request
40+
3841
from dremioai import log
3942
import re
4043
import functools
4144

4245
import pandas as pd
43-
44-
from pathlib import Path
45-
46-
from dremioai.api.dremio import sql, projects, usage, engines, search
46+
from dremioai.api.dremio import sql, usage, search
4747
from dremioai.config import settings
4848
from dremioai.config.tools import ToolType
4949
from dremioai.api.prometheus import vm
@@ -53,7 +53,6 @@
5353
from io import StringIO
5454
from sqlglot import parse_one
5555
from sqlglot import expressions
56-
from mcp.server.fastmcp.server import Context
5756
from mcp.server.auth.middleware.auth_context import get_access_token
5857
from mcp.server.auth.provider import AccessToken
5958

@@ -106,6 +105,31 @@ async def invoke(self):
106105
raise NotImplementedError("Subclasses should implement this method")
107106

108107

108+
class ProjectIdMiddleware(BaseHTTPMiddleware):
109+
pat = re.compile(r"/mcp/([\da-z-]+)(/?.*)")
110+
logger = log.logger("ProjectIdMiddleware")
111+
112+
# Context variable to store the current project ID
113+
project_id_context: ContextVar[str | None] = ContextVar("project_id", default=None)
114+
115+
@classmethod
116+
def get_project_id(cls) -> Optional[str]:
117+
return cls.project_id_context.get()
118+
119+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
120+
ProjectIdMiddleware.logger.info(
121+
f"Request {request.url.path} body = {await request.body()!s}"
122+
)
123+
if m := ProjectIdMiddleware.pat.search(request.url.path):
124+
ProjectIdMiddleware.project_id_context.set(m.group(1))
125+
else:
126+
ProjectIdMiddleware.logger.debug(
127+
f"Path {request.url.path} ({request.url!r}) doesn't match"
128+
)
129+
130+
return await call_next(request)
131+
132+
109133
# A decorator to ensure a tool that needs to access Dremio runs with the correct token
110134
# if invoked through streamable HTTP transport _with_ a valid Dremio bearer token
111135
# It is a no-op if the tool is invoked through stdio transport, as MCP server ensures
@@ -114,11 +138,22 @@ def secured(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
114138

115139
@functools.wraps(fn)
116140
async def _impl(self, *args: P.args, **kw: P.kwargs) -> T:
141+
overrides = {}
117142
if isinstance((token := get_access_token()), AccessToken):
118-
return await settings.run_with(
119-
fn, {"dremio.pat": token.token}, (self,) + args, kw
143+
overrides["dremio.pat"] = token.token
144+
logger.debug(
145+
f"Overriding PAT with token from request: {token.token[:4]}..."
120146
)
121-
return await fn(self, *args, **kw)
147+
148+
if project_id := ProjectIdMiddleware.get_project_id():
149+
overrides["dremio.project_id"] = project_id
150+
logger.debug(f"Overriding project_id with {project_id}")
151+
152+
return (
153+
await settings.run_with(fn, overrides, (self,) + args, kw)
154+
if overrides
155+
else await fn(self, *args, **kw)
156+
)
122157

123158
return _impl
124159

tests/config/test_settings.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
#
1616

1717
import os
18+
import uuid
19+
20+
import pydantic
1821
import pytest
1922
import yaml
2023
from pathlib import Path
2124
from tempfile import TemporaryDirectory
2225
from unittest.mock import patch
2326

27+
from pydantic_core import ValidationError
28+
2429
from dremioai.config import settings
2530
from dremioai.config.tools import ToolType
2631

@@ -48,7 +53,7 @@ def test_configure_creates_default_config(mock_config_dir):
4853
def test_create_default_config(mock_config_dir):
4954
uri = settings.DremioCloudUri.PRODEMEA.value
5055
pat = "test-pat"
51-
project_id = "test-project"
56+
project_id = uuid.uuid4()
5257
mode = ToolType.FOR_DATA_PATTERNS
5358
settings.configure(force=True)
5459
settings._settings.set(
@@ -70,7 +75,7 @@ def test_create_default_config(mock_config_dir):
7075
assert (
7176
dremio.uri == "https://api.eu.dremio.cloud"
7277
and dremio.pat == pat
73-
and dremio.project_id == project_id
78+
and dremio.project_id == str(project_id)
7479
)
7580
tools = settings.instance().tools
7681
assert tools.server_mode == mode
@@ -89,3 +94,43 @@ def test_experimental_rename(name: str, value: bool):
8994
{name: value, "uri": "https://foo", "pat": "bar"}
9095
)
9196
assert d.enable_search == value
97+
98+
99+
@pytest.mark.parametrize(
100+
"name,project_id,error",
101+
[
102+
["valid project id", str(uuid.uuid4()), False],
103+
["no project id", None, False],
104+
["invalid project id", "asdfsa safsa", True],
105+
["invalid project id", str(uuid.uuid4())[:-1] + "a", True],
106+
["dynamic project id", "DREMIO_DYNAMIC", False],
107+
],
108+
)
109+
def test_projects(name: str, project_id: str | None, error: bool):
110+
val = {"uri": "https://foo", "project_id": project_id}
111+
if error:
112+
try:
113+
settings.Dremio.model_validate(val)
114+
assert False
115+
except:
116+
pass
117+
else:
118+
d = settings.Dremio.model_validate(val)
119+
assert d.project_id == project_id or d.project_id is None and project_id is None
120+
121+
122+
def test_env_file(mock_config_dir):
123+
try:
124+
os.environ["DREMIOAI_DREMIO__URI"] = "https://foo"
125+
os.environ["DREMIOAI_DREMIO__PAT"] = "bar"
126+
os.environ["DREMIOAI_TOOLS__SERVER_MODE"] = "FOR_DATA_PATTERNS"
127+
settings.configure(force=True)
128+
from rich import print as pp
129+
130+
pp(settings.instance().model_dump())
131+
assert settings.instance().dremio.uri == "https://foo"
132+
assert settings.instance().dremio.pat == "bar"
133+
assert settings.instance().tools.server_mode == ToolType.FOR_DATA_PATTERNS
134+
finally:
135+
os.environ.pop("DREMIOAI_DREMIO_URI", None)
136+
os.environ.pop("DREMIOAI_DREMIO_PAT", None)

0 commit comments

Comments
 (0)