4242from mcp .shared .auth import OAuthClientInformationFull , OAuthToken
4343from pydantic import AnyHttpUrl , AnyUrl , SecretStr
4444from starlette .requests import Request
45- from starlette .responses import RedirectResponse
45+ from starlette .responses import JSONResponse , RedirectResponse
4646from starlette .routing import Route
4747
4848import fastmcp
@@ -844,16 +844,25 @@ def get_routes(
844844 f"Route { i } : { route } - path: { getattr (route , 'path' , 'N/A' )} , methods: { getattr (route , 'methods' , 'N/A' )} "
845845 )
846846
847- # Keep all standard OAuth routes unchanged - our DCR-compliant flow handles everything
848- custom_routes .append (route )
849-
847+ # Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
850848 if (
851849 isinstance (route , Route )
852850 and route .path == "/token"
853851 and route .methods is not None
854852 and "POST" in route .methods
855853 ):
856854 token_route_found = True
855+ # Replace with our custom token handler
856+ custom_routes .append (
857+ Route (
858+ path = "/token" ,
859+ endpoint = self ._handle_token_request ,
860+ methods = ["POST" ],
861+ )
862+ )
863+ else :
864+ # Keep all other standard OAuth routes unchanged
865+ custom_routes .append (route )
857866
858867 # Add OAuth callback endpoint for forwarding to client callbacks
859868 custom_routes .append (
@@ -869,6 +878,81 @@ def get_routes(
869878 )
870879 return custom_routes
871880
881+ # -------------------------------------------------------------------------
882+ # Custom Token Endpoint Handler
883+ # -------------------------------------------------------------------------
884+
885+ async def _handle_token_request (self , request : Request ) -> JSONResponse :
886+ """Handle token requests with proper OAuth 2.1 error handling.
887+
888+ This custom handler wraps the standard MCP SDK token handler but provides
889+ OAuth 2.1 compliant error responses for client authentication failures:
890+ - Returns HTTP 401 status code for client authentication failures
891+ - Uses 'invalid_client' error code instead of 'unauthorized_client'
892+
893+ Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401
894+ (Unauthorized) status code to indicate which HTTP authentication schemes
895+ are supported. If the client attempted to authenticate via the Authorization
896+ request header field, the authorization server MUST respond with an HTTP 401
897+ (Unauthorized) status code and include the WWW-Authenticate response header
898+ field matching the authentication scheme used by the client."
899+ """
900+ from mcp .server .auth .handlers .token import TokenHandler
901+ from mcp .server .auth .middleware .client_auth import ClientAuthenticator
902+
903+ # Create the standard token handler and client authenticator
904+ token_handler = TokenHandler (
905+ provider = self , client_authenticator = ClientAuthenticator (self )
906+ )
907+
908+ # Handle the request normally
909+ response = await token_handler .handle (request )
910+
911+ # Check if the response is an error response for client authentication failure
912+ if (
913+ hasattr (response , "body" )
914+ and hasattr (response , "status_code" )
915+ and response .status_code == 400
916+ ):
917+ try :
918+ import json
919+
920+ # Parse the response body to check for client authentication errors
921+ body_content = (
922+ response .body .decode ("utf-8" )
923+ if hasattr (response .body , "decode" )
924+ else str (response .body )
925+ )
926+ error_data = json .loads (body_content )
927+
928+ # Check if this is an unauthorized_client error (which means invalid client_id)
929+ if error_data .get (
930+ "error"
931+ ) == "unauthorized_client" and "Invalid client_id" in str (
932+ error_data .get ("error_description" , "" )
933+ ):
934+ logger .debug (
935+ "Client authentication failed - client not found, returning OAuth 2.1 compliant error"
936+ )
937+
938+ # Return the correct OAuth 2.1 response
939+ return JSONResponse (
940+ content = {
941+ "error" : "invalid_client" ,
942+ "error_description" : error_data .get ("error_description" ),
943+ },
944+ status_code = 401 ,
945+ headers = {
946+ "Cache-Control" : "no-store" ,
947+ "Pragma" : "no-cache" ,
948+ },
949+ )
950+ except (json .JSONDecodeError , AttributeError , KeyError ):
951+ # If we can't parse the response, return it as-is
952+ pass
953+
954+ return response
955+
872956 # -------------------------------------------------------------------------
873957 # IdP Callback Forwarding
874958 # -------------------------------------------------------------------------
0 commit comments