1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515#
16-
16+ from mcp . server . auth . json_response import PydanticJSONResponse
1717from mcp .server .fastmcp import FastMCP
1818from mcp .server .fastmcp .prompts import Prompt
1919from mcp .server .fastmcp .resources import FunctionResource
2020from mcp .cli .claude import get_claude_config_path
21+ from mcp .shared .auth import OAuthMetadata
22+ from pydantic import AnyHttpUrl
2123from pydantic .networks import AnyUrl
24+ from starlette .requests import Request
25+ from starlette .responses import Response
26+
2227from dremioai .tools import tools
2328import os
2429from typing import List , Union , Annotated , Optional , Tuple , Dict , Any
4247from mcp .server .auth .middleware .auth_context import (
4348 AuthContextMiddleware ,
4449)
45- from mcp .server .auth .middleware .bearer_auth import BearerAuthBackend
50+ from mcp .server .auth .middleware .bearer_auth import (
51+ BearerAuthBackend ,
52+ RequireAuthMiddleware ,
53+ )
4654from mcp .server .auth .provider import AccessToken , TokenVerifier
4755from starlette .middleware .authentication import AuthenticationMiddleware
56+ from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
57+ from starlette .responses import Response as StarletteResponse
4858
4959from dremioai .tools .tools import ProjectIdMiddleware
5060
5161
62+ class RequireAuthWithWWWAuthenticateMiddleware (BaseHTTPMiddleware ):
63+ """
64+ Custom middleware that requires authentication and returns WWW-Authenticate header
65+ for unauthorized requests. This middleware should be placed AFTER AuthenticationMiddleware
66+ so that request.user is available.
67+ """
68+
69+ async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
70+ # Check if user is authenticated (request.user is available after AuthenticationMiddleware)
71+ if (
72+ not hasattr (request , "user" )
73+ or not request .user .is_authenticated
74+ and request .url .path .startswith ("/mcp" )
75+ ):
76+ # Return 401 with WWW-Authenticate header
77+ return StarletteResponse (
78+ content = "Unauthorized" ,
79+ status_code = 401 ,
80+ headers = {"WWW-Authenticate" : "Bearer" },
81+ )
82+
83+ # User is authenticated, proceed with the request
84+ return await call_next (request )
85+
86+
5287class Transports (StrEnum ):
5388 stdio = auto ()
5489 streamable_http = "streamable-http"
@@ -57,24 +92,26 @@ class Transports(StrEnum):
5792class FastMCPServerWithAuthToken (FastMCP ):
5893 class DelegatingTokenVerifier (TokenVerifier ):
5994 async def verify_token (self , token : str ) -> AccessToken | None :
60- log . logger ( "verify_token" ). info ( f"Verifying token: { token } " )
61- return (
62- AccessToken (
95+ if token :
96+ log . logger ( "verify_token" ). info ( f"Token verified: { token } " )
97+ return AccessToken (
6398 token = token , # Include the token itself
6499 client_id = "unused-client" ,
65100 scopes = ["read" ],
66101 )
67- if token
68- else None
69- )
102+ else :
103+ log . logger ( "verify_token" ). info ( f"Token not provided: { token } " )
104+ return None
70105
71106 def streamable_http_app (self ):
72107 token_verifier = FastMCPServerWithAuthToken .DelegatingTokenVerifier ()
73108 app = super ().streamable_http_app ()
109+ app .add_middleware (RequireAuthWithWWWAuthenticateMiddleware )
74110 app .add_middleware (AuthContextMiddleware )
75111 app .add_middleware (
76112 AuthenticationMiddleware , backend = BearerAuthBackend (token_verifier )
77113 )
114+ # Add middleware in reverse order (last added = first executed)
78115 if self .support_project_id_endpoints :
79116 # this means, dynamically allow endpoints
80117 # like ../mcp/{project_id}/.. and extract that project id as
@@ -97,7 +134,7 @@ def init(
97134 log .logger ("init" ).info (
98135 f"Initializing MCP server with mode={ mode } , class={ mcp_cls .__name__ } "
99136 )
100- opts = {"log_level" : "DEBUG" }
137+ opts = {"log_level" : "DEBUG" , "debug" : True }
101138 if port is not None :
102139 opts ["port" ] = port
103140 mcp = mcp_cls ("Dremio" , ** opts )
@@ -127,6 +164,24 @@ def init(
127164 mcp .add_prompt (
128165 Prompt .from_function (tools .system_prompt , "System Prompt" , "System Prompt" )
129166 )
167+
168+ @mcp .custom_route ("/.well-known/oauth-authorization-server" , methods = ["GET" ])
169+ async def authorization_server_metadata (request : Request ) -> Response :
170+ if issuer := settings .instance ().dremio .auth_issuer_uri :
171+ auth , tok = settings .instance ().dremio .auth_endpoints
172+ md = OAuthMetadata (
173+ issuer = AnyHttpUrl (issuer ),
174+ authorization_endpoint = auth ,
175+ token_endpoint = tok ,
176+ scopes_supported = ["dremio.all" , "offline_access" ],
177+ response_types_supported = ["code" ],
178+ grant_types_supported = ["authorization_code" , "refresh_token" ],
179+ code_challenge_methods_supported = ["S256" ],
180+ token_endpoint_auth_methods_supported = ["client_secret_post" ],
181+ )
182+ return PydanticJSONResponse (md )
183+ return Response (status_code = 404 )
184+
130185 return mcp
131186
132187
@@ -182,6 +237,7 @@ def main(
182237 mode = cfg .tools .server_mode ,
183238 transport = transport ,
184239 port = port ,
240+ support_project_id_endpoints = True ,
185241 )
186242 app .run (transport = transport .value )
187243
0 commit comments