2828import httpx
2929from authlib .common .security import generate_token
3030from authlib .integrations .httpx_client import AsyncOAuth2Client
31+ from mcp .server .auth .handlers .token import TokenErrorResponse , TokenSuccessResponse
32+ from mcp .server .auth .handlers .token import TokenHandler as _SDKTokenHandler
33+ from mcp .server .auth .json_response import PydanticJSONResponse
34+ from mcp .server .auth .middleware .client_auth import ClientAuthenticator
3135from mcp .server .auth .provider import (
3236 AccessToken ,
3337 AuthorizationCode ,
3438 AuthorizationParams ,
3539 RefreshToken ,
3640 TokenError ,
3741)
42+ from mcp .server .auth .routes import cors_middleware
3843from mcp .server .auth .settings import (
3944 ClientRegistrationOptions ,
4045 RevocationOptions ,
4146)
4247from mcp .shared .auth import OAuthClientInformationFull , OAuthToken
4348from pydantic import AnyHttpUrl , AnyUrl , SecretStr
4449from starlette .requests import Request
45- from starlette .responses import JSONResponse , RedirectResponse
50+ from starlette .responses import RedirectResponse
4651from starlette .routing import Route
4752
4853import fastmcp
@@ -122,6 +127,55 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
122127HTTP_TIMEOUT_SECONDS : Final [int ] = 30
123128
124129
130+ class TokenHandler (_SDKTokenHandler ):
131+ """TokenHandler that returns OAuth 2.1 compliant error responses.
132+
133+ The MCP SDK always returns HTTP 400 for all client authentication issues.
134+ However, OAuth 2.1 Section 5.3 and the MCP specification require that
135+ invalid or expired tokens MUST receive a HTTP 401 response.
136+
137+ This handler extends the base MCP SDK TokenHandler to transform client
138+ authentication failures into OAuth 2.1 compliant responses:
139+ - Changes 'unauthorized_client' to 'invalid_client' error code
140+ - Returns HTTP 401 status code instead of 400 for client auth failures
141+
142+ Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
143+ (Unauthorized) status code to indicate which HTTP authentication schemes
144+ are supported."
145+
146+ Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
147+ """
148+
149+ def response (self , obj : TokenSuccessResponse | TokenErrorResponse ):
150+ """Override response method to provide OAuth 2.1 compliant error handling."""
151+ # Check if this is a client authentication failure (not just unauthorized for grant type)
152+ # unauthorized_client can mean two things:
153+ # 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
154+ # 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
155+ if (
156+ isinstance (obj , TokenErrorResponse )
157+ and obj .error == "unauthorized_client"
158+ and obj .error_description
159+ and "Invalid client_id" in obj .error_description
160+ ):
161+ # Transform client auth failure to OAuth 2.1 compliant response
162+ return PydanticJSONResponse (
163+ content = TokenErrorResponse (
164+ error = "invalid_client" ,
165+ error_description = obj .error_description ,
166+ error_uri = obj .error_uri ,
167+ ),
168+ status_code = 401 ,
169+ headers = {
170+ "Cache-Control" : "no-store" ,
171+ "Pragma" : "no-cache" ,
172+ },
173+ )
174+
175+ # Otherwise use default behavior from parent class
176+ return super ().response (obj )
177+
178+
125179class OAuthProxy (OAuthProvider ):
126180 """OAuth provider that presents a DCR-compliant interface while proxying to non-DCR IDPs.
127181
@@ -852,12 +906,17 @@ def get_routes(
852906 and "POST" in route .methods
853907 ):
854908 token_route_found = True
855- # Replace with our custom token handler
909+ # Replace with our OAuth 2.1 compliant token handler
910+ token_handler = TokenHandler (
911+ provider = self , client_authenticator = ClientAuthenticator (self )
912+ )
856913 custom_routes .append (
857914 Route (
858915 path = "/token" ,
859- endpoint = self ._handle_token_request ,
860- methods = ["POST" ],
916+ endpoint = cors_middleware (
917+ token_handler .handle , ["POST" , "OPTIONS" ]
918+ ),
919+ methods = ["POST" , "OPTIONS" ],
861920 )
862921 )
863922 else :
@@ -878,81 +937,6 @@ def get_routes(
878937 )
879938 return custom_routes
880939
881- # -------------------------------------------------------------------------
882- # Custom Token Endpoint Handler
883- # -------------------------------------------------------------------------
884-
885- async def _handle_token_request (self , request : Request ) -> JSONResponse :
886- """Handle token requests with proper OAuth 2.1 error handling.
887-
888- This custom handler wraps the standard MCP SDK token handler but provides
889- OAuth 2.1 compliant error responses for client authentication failures:
890- - Returns HTTP 401 status code for client authentication failures
891- - Uses 'invalid_client' error code instead of 'unauthorized_client'
892-
893- Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401
894- (Unauthorized) status code to indicate which HTTP authentication schemes
895- are supported. If the client attempted to authenticate via the Authorization
896- request header field, the authorization server MUST respond with an HTTP 401
897- (Unauthorized) status code and include the WWW-Authenticate response header
898- field matching the authentication scheme used by the client."
899- """
900- from mcp .server .auth .handlers .token import TokenHandler
901- from mcp .server .auth .middleware .client_auth import ClientAuthenticator
902-
903- # Create the standard token handler and client authenticator
904- token_handler = TokenHandler (
905- provider = self , client_authenticator = ClientAuthenticator (self )
906- )
907-
908- # Handle the request normally
909- response = await token_handler .handle (request )
910-
911- # Check if the response is an error response for client authentication failure
912- if (
913- hasattr (response , "body" )
914- and hasattr (response , "status_code" )
915- and response .status_code == 400
916- ):
917- try :
918- import json
919-
920- # Parse the response body to check for client authentication errors
921- body_content = (
922- response .body .decode ("utf-8" )
923- if hasattr (response .body , "decode" )
924- else str (response .body )
925- )
926- error_data = json .loads (body_content )
927-
928- # Check if this is an unauthorized_client error (which means invalid client_id)
929- if error_data .get (
930- "error"
931- ) == "unauthorized_client" and "Invalid client_id" in str (
932- error_data .get ("error_description" , "" )
933- ):
934- logger .debug (
935- "Client authentication failed - client not found, returning OAuth 2.1 compliant error"
936- )
937-
938- # Return the correct OAuth 2.1 response
939- return JSONResponse (
940- content = {
941- "error" : "invalid_client" ,
942- "error_description" : error_data .get ("error_description" ),
943- },
944- status_code = 401 ,
945- headers = {
946- "Cache-Control" : "no-store" ,
947- "Pragma" : "no-cache" ,
948- },
949- )
950- except (json .JSONDecodeError , AttributeError , KeyError ):
951- # If we can't parse the response, return it as-is
952- pass
953-
954- return response
955-
956940 # -------------------------------------------------------------------------
957941 # IdP Callback Forwarding
958942 # -------------------------------------------------------------------------
0 commit comments