Skip to content

Commit 1edb25f

Browse files
Auth server support (#51)
1 parent d0ab19b commit 1edb25f

4 files changed

Lines changed: 145 additions & 21 deletions

File tree

src/dremioai/config/settings.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#
1616
import uuid
1717
from uuid import UUID
18+
from urllib.parse import urlparse
1819

1920
from pydantic import (
2021
Field,
@@ -26,7 +27,18 @@
2627
AliasChoices,
2728
)
2829
from pydantic_settings import BaseSettings, SettingsConfigDict
29-
from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable, Literal
30+
from typing import (
31+
Optional,
32+
Union,
33+
Annotated,
34+
Self,
35+
List,
36+
Dict,
37+
Any,
38+
Callable,
39+
Literal,
40+
Tuple,
41+
)
3042
from dremioai.config.tools import ToolType
3143
from enum import auto, StrEnum
3244
from pathlib import Path
@@ -168,6 +180,28 @@ def pat(self, v: str):
168180
self.raw_pat = v
169181
self._pat_resolved = None
170182

183+
@property
184+
def is_cloud(self) -> bool:
185+
return self.project_id is not None
186+
187+
@property
188+
def auth_issuer_uri(self) -> Optional[str]:
189+
if self.is_cloud:
190+
uri = urlparse(self.uri)
191+
if uri.netloc.startswith("api."):
192+
uri = uri._replace(netloc=f"login.{uri.netloc[4:]}")
193+
return uri.geturl()
194+
return None
195+
196+
@property
197+
def auth_endpoints(self) -> Optional[Tuple[str, str]]:
198+
if issuer_uri := self.auth_issuer_uri:
199+
return (
200+
f"{issuer_uri}/oauth/authorize",
201+
f"{issuer_uri}/oauth/token",
202+
)
203+
return None
204+
171205

172206
class OpenAi(BaseModel):
173207
api_key: Annotated[str, AfterValidator(_resolve_token_file)] = None

src/dremioai/servers/mcp.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
16+
from mcp.server.auth.json_response import PydanticJSONResponse
1717
from mcp.server.fastmcp import FastMCP
1818
from mcp.server.fastmcp.prompts import Prompt
1919
from mcp.server.fastmcp.resources import FunctionResource
2020
from mcp.cli.claude import get_claude_config_path
21+
from mcp.shared.auth import OAuthMetadata
22+
from pydantic import AnyHttpUrl
2123
from pydantic.networks import AnyUrl
24+
from starlette.requests import Request
25+
from starlette.responses import Response
26+
2227
from dremioai.tools import tools
2328
import os
2429
from typing import List, Union, Annotated, Optional, Tuple, Dict, Any
@@ -42,13 +47,43 @@
4247
from mcp.server.auth.middleware.auth_context import (
4348
AuthContextMiddleware,
4449
)
45-
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
50+
from mcp.server.auth.middleware.bearer_auth import (
51+
BearerAuthBackend,
52+
RequireAuthMiddleware,
53+
)
4654
from mcp.server.auth.provider import AccessToken, TokenVerifier
4755
from starlette.middleware.authentication import AuthenticationMiddleware
56+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
57+
from starlette.responses import Response as StarletteResponse
4858

4959
from dremioai.tools.tools import ProjectIdMiddleware
5060

5161

62+
class RequireAuthWithWWWAuthenticateMiddleware(BaseHTTPMiddleware):
63+
"""
64+
Custom middleware that requires authentication and returns WWW-Authenticate header
65+
for unauthorized requests. This middleware should be placed AFTER AuthenticationMiddleware
66+
so that request.user is available.
67+
"""
68+
69+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
70+
# Check if user is authenticated (request.user is available after AuthenticationMiddleware)
71+
if (
72+
not hasattr(request, "user")
73+
or not request.user.is_authenticated
74+
and request.url.path.startswith("/mcp")
75+
):
76+
# Return 401 with WWW-Authenticate header
77+
return StarletteResponse(
78+
content="Unauthorized",
79+
status_code=401,
80+
headers={"WWW-Authenticate": "Bearer"},
81+
)
82+
83+
# User is authenticated, proceed with the request
84+
return await call_next(request)
85+
86+
5287
class Transports(StrEnum):
5388
stdio = auto()
5489
streamable_http = "streamable-http"
@@ -57,24 +92,26 @@ class Transports(StrEnum):
5792
class FastMCPServerWithAuthToken(FastMCP):
5893
class DelegatingTokenVerifier(TokenVerifier):
5994
async def verify_token(self, token: str) -> AccessToken | None:
60-
log.logger("verify_token").info(f"Verifying token: {token}")
61-
return (
62-
AccessToken(
95+
if token:
96+
log.logger("verify_token").info(f"Token verified: {token}")
97+
return AccessToken(
6398
token=token, # Include the token itself
6499
client_id="unused-client",
65100
scopes=["read"],
66101
)
67-
if token
68-
else None
69-
)
102+
else:
103+
log.logger("verify_token").info(f"Token not provided: {token}")
104+
return None
70105

71106
def streamable_http_app(self):
72107
token_verifier = FastMCPServerWithAuthToken.DelegatingTokenVerifier()
73108
app = super().streamable_http_app()
109+
app.add_middleware(RequireAuthWithWWWAuthenticateMiddleware)
74110
app.add_middleware(AuthContextMiddleware)
75111
app.add_middleware(
76112
AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)
77113
)
114+
# Add middleware in reverse order (last added = first executed)
78115
if self.support_project_id_endpoints:
79116
# this means, dynamically allow endpoints
80117
# like ../mcp/{project_id}/.. and extract that project id as
@@ -97,7 +134,7 @@ def init(
97134
log.logger("init").info(
98135
f"Initializing MCP server with mode={mode}, class={mcp_cls.__name__}"
99136
)
100-
opts = {"log_level": "DEBUG"}
137+
opts = {"log_level": "DEBUG", "debug": True}
101138
if port is not None:
102139
opts["port"] = port
103140
mcp = mcp_cls("Dremio", **opts)
@@ -127,6 +164,24 @@ def init(
127164
mcp.add_prompt(
128165
Prompt.from_function(tools.system_prompt, "System Prompt", "System Prompt")
129166
)
167+
168+
@mcp.custom_route("/.well-known/oauth-authorization-server", methods=["GET"])
169+
async def authorization_server_metadata(request: Request) -> Response:
170+
if issuer := settings.instance().dremio.auth_issuer_uri:
171+
auth, tok = settings.instance().dremio.auth_endpoints
172+
md = OAuthMetadata(
173+
issuer=AnyHttpUrl(issuer),
174+
authorization_endpoint=auth,
175+
token_endpoint=tok,
176+
scopes_supported=["dremio.all", "offline_access"],
177+
response_types_supported=["code"],
178+
grant_types_supported=["authorization_code", "refresh_token"],
179+
code_challenge_methods_supported=["S256"],
180+
token_endpoint_auth_methods_supported=["client_secret_post"],
181+
)
182+
return PydanticJSONResponse(md)
183+
return Response(status_code=404)
184+
130185
return mcp
131186

132187

@@ -182,6 +237,7 @@ def main(
182237
mode=cfg.tools.server_mode,
183238
transport=transport,
184239
port=port,
240+
support_project_id_endpoints=True,
185241
)
186242
app.run(transport=transport.value)
187243

tests/config/test_settings.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,16 @@ def test_experimental_rename(name: str, value: bool):
9797

9898

9999
@pytest.mark.parametrize(
100-
"name,project_id,error",
100+
"project_id,error",
101101
[
102-
["valid project id", str(uuid.uuid4()), False],
103-
["no project id", None, False],
104-
["invalid project id", "asdfsa safsa", True],
105-
["invalid project id", str(uuid.uuid4())[:-1] + "a", True],
106-
["dynamic project id", "DREMIO_DYNAMIC", False],
102+
pytest.param(str(uuid.uuid4()), False, id="valid project id"),
103+
pytest.param(None, False, id="no project id"),
104+
pytest.param("asdfsa safsa", True, id="invalid project id"),
105+
pytest.param(str(uuid.uuid4())[:-1] + "a", True, id="invalid project id"),
106+
pytest.param("DREMIO_DYNAMIC", False, id="dynamic project id"),
107107
],
108108
)
109-
def test_projects(name: str, project_id: str | None, error: bool):
109+
def test_projects(project_id: str | None, error: bool):
110110
val = {"uri": "https://foo", "project_id": project_id}
111111
if error:
112112
try:
@@ -125,12 +125,44 @@ def test_env_file(mock_config_dir):
125125
os.environ["DREMIOAI_DREMIO__PAT"] = "bar"
126126
os.environ["DREMIOAI_TOOLS__SERVER_MODE"] = "FOR_DATA_PATTERNS"
127127
settings.configure(force=True)
128-
from rich import print as pp
129-
130-
pp(settings.instance().model_dump())
131128
assert settings.instance().dremio.uri == "https://foo"
132129
assert settings.instance().dremio.pat == "bar"
133130
assert settings.instance().tools.server_mode == ToolType.FOR_DATA_PATTERNS
134131
finally:
135132
os.environ.pop("DREMIOAI_DREMIO_URI", None)
136133
os.environ.pop("DREMIOAI_DREMIO_PAT", None)
134+
135+
136+
@pytest.mark.parametrize(
137+
"uri,project_id,issuer,error",
138+
[
139+
pytest.param(
140+
uri,
141+
project_id,
142+
iss,
143+
project_id is None,
144+
id=f"{label} with {plabel}",
145+
)
146+
for uri, iss, label in (
147+
("https://foo", "https://foo", "custom-uri"),
148+
("https://api.dremio.cloud", "https://login.dremio.cloud", "prod"),
149+
(
150+
"https://api.eu.dremio.cloud",
151+
"https://login.eu.dremio.cloud",
152+
"prodemea",
153+
),
154+
("https://api.dev.dremio.site", "https://login.dev.dremio.site", "dev"),
155+
)
156+
for project_id, plabel in (
157+
(None, "no-project-id"),
158+
("DREMIO_DYNAMIC", "dynamic-project-id"),
159+
(str(uuid.uuid4()), "project-id"),
160+
)
161+
],
162+
)
163+
def test_auth_urls(uri: str, project_id: str | None, issuer: str, error: bool):
164+
d = settings.Dremio.model_validate({"uri": uri, "project_id": project_id})
165+
auth = (f"{issuer}/oauth/authorize", f"{issuer}/oauth/token") if not error else None
166+
issuer = issuer if not error else None
167+
assert d.auth_issuer_uri == issuer
168+
assert d.auth_endpoints == auth

tests/e2e/test_mcp_e2e.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
@pytest.mark.asyncio
2525
async def test_basic(mock_config_dir, logging_server, logging_level):
2626
async with http_streamable_mcp_server(logging_server, logging_level) as sf:
27-
async with http_streamable_client_server(sf.mcp_server) as session:
27+
async with http_streamable_client_server(
28+
sf.mcp_server, token="my-token"
29+
) as session:
2830
lts = await session.list_tools()
2931
tr = {t.name for t in lts.tools}
3032
assert tr == {

0 commit comments

Comments
 (0)