Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#
import uuid
from uuid import UUID
from urllib.parse import urlparse

from pydantic import (
Field,
Expand All @@ -26,7 +27,18 @@
AliasChoices,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable, Literal
from typing import (
Optional,
Union,
Annotated,
Self,
List,
Dict,
Any,
Callable,
Literal,
Tuple,
)
from dremioai.config.tools import ToolType
from enum import auto, StrEnum
from pathlib import Path
Expand Down Expand Up @@ -168,6 +180,28 @@ def pat(self, v: str):
self.raw_pat = v
self._pat_resolved = None

@property
def is_cloud(self) -> bool:
return self.project_id is not None

@property
def auth_issuer_uri(self) -> Optional[str]:
if self.is_cloud:
uri = urlparse(self.uri)
if uri.netloc.startswith("api."):
uri = uri._replace(netloc=f"login.{uri.netloc[4:]}")
return uri.geturl()
return None

@property
def auth_endpoints(self) -> Optional[Tuple[str, str]]:
if issuer_uri := self.auth_issuer_uri:
return (
f"{issuer_uri}/oauth/authorize",
f"{issuer_uri}/oauth/token",
)
return None


class OpenAi(BaseModel):
api_key: Annotated[str, AfterValidator(_resolve_token_file)] = None
Expand Down
74 changes: 65 additions & 9 deletions src/dremioai/servers/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.prompts import Prompt
from mcp.server.fastmcp.resources import FunctionResource
from mcp.cli.claude import get_claude_config_path
from mcp.shared.auth import OAuthMetadata
from pydantic import AnyHttpUrl
from pydantic.networks import AnyUrl
from starlette.requests import Request
from starlette.responses import Response

from dremioai.tools import tools
import os
from typing import List, Union, Annotated, Optional, Tuple, Dict, Any
Expand All @@ -42,13 +47,43 @@
from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
)
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.middleware.bearer_auth import (
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.auth.provider import AccessToken, TokenVerifier
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response as StarletteResponse

from dremioai.tools.tools import ProjectIdMiddleware


class RequireAuthWithWWWAuthenticateMiddleware(BaseHTTPMiddleware):
"""
Custom middleware that requires authentication and returns WWW-Authenticate header
for unauthorized requests. This middleware should be placed AFTER AuthenticationMiddleware
so that request.user is available.
"""

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
# Check if user is authenticated (request.user is available after AuthenticationMiddleware)
if (
not hasattr(request, "user")
or not request.user.is_authenticated
and request.url.path.startswith("/mcp")
):
# Return 401 with WWW-Authenticate header
return StarletteResponse(
content="Unauthorized",
status_code=401,
headers={"WWW-Authenticate": "Bearer"},
)

# User is authenticated, proceed with the request
return await call_next(request)


class Transports(StrEnum):
stdio = auto()
streamable_http = "streamable-http"
Expand All @@ -57,24 +92,26 @@ class Transports(StrEnum):
class FastMCPServerWithAuthToken(FastMCP):
class DelegatingTokenVerifier(TokenVerifier):
async def verify_token(self, token: str) -> AccessToken | None:
log.logger("verify_token").info(f"Verifying token: {token}")
return (
AccessToken(
if token:
log.logger("verify_token").info(f"Token verified: {token}")
return AccessToken(
token=token, # Include the token itself
client_id="unused-client",
scopes=["read"],
)
if token
else None
)
else:
log.logger("verify_token").info(f"Token not provided: {token}")
return None

def streamable_http_app(self):
token_verifier = FastMCPServerWithAuthToken.DelegatingTokenVerifier()
app = super().streamable_http_app()
app.add_middleware(RequireAuthWithWWWAuthenticateMiddleware)
app.add_middleware(AuthContextMiddleware)
app.add_middleware(
AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)
)
# Add middleware in reverse order (last added = first executed)
if self.support_project_id_endpoints:
# this means, dynamically allow endpoints
# like ../mcp/{project_id}/.. and extract that project id as
Expand All @@ -97,7 +134,7 @@ def init(
log.logger("init").info(
f"Initializing MCP server with mode={mode}, class={mcp_cls.__name__}"
)
opts = {"log_level": "DEBUG"}
opts = {"log_level": "DEBUG", "debug": True}
if port is not None:
opts["port"] = port
mcp = mcp_cls("Dremio", **opts)
Expand Down Expand Up @@ -127,6 +164,24 @@ def init(
mcp.add_prompt(
Prompt.from_function(tools.system_prompt, "System Prompt", "System Prompt")
)

@mcp.custom_route("/.well-known/oauth-authorization-server", methods=["GET"])
async def authorization_server_metadata(request: Request) -> Response:
if issuer := settings.instance().dremio.auth_issuer_uri:
auth, tok = settings.instance().dremio.auth_endpoints
md = OAuthMetadata(
issuer=AnyHttpUrl(issuer),
authorization_endpoint=auth,
token_endpoint=tok,
scopes_supported=["dremio.all", "offline_access"],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
token_endpoint_auth_methods_supported=["client_secret_post"],
)
return PydanticJSONResponse(md)
return Response(status_code=404)

return mcp


Expand Down Expand Up @@ -182,6 +237,7 @@ def main(
mode=cfg.tools.server_mode,
transport=transport,
port=port,
support_project_id_endpoints=True,
)
app.run(transport=transport.value)

Expand Down
52 changes: 42 additions & 10 deletions tests/config/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ def test_experimental_rename(name: str, value: bool):


@pytest.mark.parametrize(
"name,project_id,error",
"project_id,error",
[
["valid project id", str(uuid.uuid4()), False],
["no project id", None, False],
["invalid project id", "asdfsa safsa", True],
["invalid project id", str(uuid.uuid4())[:-1] + "a", True],
["dynamic project id", "DREMIO_DYNAMIC", False],
pytest.param(str(uuid.uuid4()), False, id="valid project id"),
pytest.param(None, False, id="no project id"),
pytest.param("asdfsa safsa", True, id="invalid project id"),
pytest.param(str(uuid.uuid4())[:-1] + "a", True, id="invalid project id"),
pytest.param("DREMIO_DYNAMIC", False, id="dynamic project id"),
],
)
def test_projects(name: str, project_id: str | None, error: bool):
def test_projects(project_id: str | None, error: bool):
val = {"uri": "https://foo", "project_id": project_id}
if error:
try:
Expand All @@ -125,12 +125,44 @@ def test_env_file(mock_config_dir):
os.environ["DREMIOAI_DREMIO__PAT"] = "bar"
os.environ["DREMIOAI_TOOLS__SERVER_MODE"] = "FOR_DATA_PATTERNS"
settings.configure(force=True)
from rich import print as pp

pp(settings.instance().model_dump())
assert settings.instance().dremio.uri == "https://foo"
assert settings.instance().dremio.pat == "bar"
assert settings.instance().tools.server_mode == ToolType.FOR_DATA_PATTERNS
finally:
os.environ.pop("DREMIOAI_DREMIO_URI", None)
os.environ.pop("DREMIOAI_DREMIO_PAT", None)


@pytest.mark.parametrize(
"uri,project_id,issuer,error",
[
pytest.param(
uri,
project_id,
iss,
project_id is None,
id=f"{label} with {plabel}",
)
for uri, iss, label in (
("https://foo", "https://foo", "custom-uri"),
("https://api.dremio.cloud", "https://login.dremio.cloud", "prod"),
(
"https://api.eu.dremio.cloud",
"https://login.eu.dremio.cloud",
"prodemea",
),
("https://api.dev.dremio.site", "https://login.dev.dremio.site", "dev"),
)
for project_id, plabel in (
(None, "no-project-id"),
("DREMIO_DYNAMIC", "dynamic-project-id"),
(str(uuid.uuid4()), "project-id"),
)
],
)
def test_auth_urls(uri: str, project_id: str | None, issuer: str, error: bool):
d = settings.Dremio.model_validate({"uri": uri, "project_id": project_id})
auth = (f"{issuer}/oauth/authorize", f"{issuer}/oauth/token") if not error else None
issuer = issuer if not error else None
assert d.auth_issuer_uri == issuer
assert d.auth_endpoints == auth
4 changes: 3 additions & 1 deletion tests/e2e/test_mcp_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
@pytest.mark.asyncio
async def test_basic(mock_config_dir, logging_server, logging_level):
async with http_streamable_mcp_server(logging_server, logging_level) as sf:
async with http_streamable_client_server(sf.mcp_server) as session:
async with http_streamable_client_server(
sf.mcp_server, token="my-token"
) as session:
lts = await session.list_tools()
tr = {t.name for t in lts.tools}
assert tr == {
Expand Down