Skip to content

Commit d7c6051

Browse files
authored
Refactor OAuth 2.1 error handling with TokenHandler subclass (#1948)
1 parent 1c323a8 commit d7c6051

2 files changed

Lines changed: 130 additions & 79 deletions

File tree

src/fastmcp/server/auth/oauth_proxy.py

Lines changed: 63 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,26 @@
2828
import httpx
2929
from authlib.common.security import generate_token
3030
from authlib.integrations.httpx_client import AsyncOAuth2Client
31+
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
32+
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
33+
from mcp.server.auth.json_response import PydanticJSONResponse
34+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
3135
from mcp.server.auth.provider import (
3236
AccessToken,
3337
AuthorizationCode,
3438
AuthorizationParams,
3539
RefreshToken,
3640
TokenError,
3741
)
42+
from mcp.server.auth.routes import cors_middleware
3843
from mcp.server.auth.settings import (
3944
ClientRegistrationOptions,
4045
RevocationOptions,
4146
)
4247
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
4348
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
4449
from starlette.requests import Request
45-
from starlette.responses import JSONResponse, RedirectResponse
50+
from starlette.responses import RedirectResponse
4651
from starlette.routing import Route
4752

4853
import fastmcp
@@ -122,6 +127,55 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
122127
HTTP_TIMEOUT_SECONDS: Final[int] = 30
123128

124129

130+
class TokenHandler(_SDKTokenHandler):
131+
"""TokenHandler that returns OAuth 2.1 compliant error responses.
132+
133+
The MCP SDK always returns HTTP 400 for all client authentication issues.
134+
However, OAuth 2.1 Section 5.3 and the MCP specification require that
135+
invalid or expired tokens MUST receive a HTTP 401 response.
136+
137+
This handler extends the base MCP SDK TokenHandler to transform client
138+
authentication failures into OAuth 2.1 compliant responses:
139+
- Changes 'unauthorized_client' to 'invalid_client' error code
140+
- Returns HTTP 401 status code instead of 400 for client auth failures
141+
142+
Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
143+
(Unauthorized) status code to indicate which HTTP authentication schemes
144+
are supported."
145+
146+
Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
147+
"""
148+
149+
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
150+
"""Override response method to provide OAuth 2.1 compliant error handling."""
151+
# Check if this is a client authentication failure (not just unauthorized for grant type)
152+
# unauthorized_client can mean two things:
153+
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
154+
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
155+
if (
156+
isinstance(obj, TokenErrorResponse)
157+
and obj.error == "unauthorized_client"
158+
and obj.error_description
159+
and "Invalid client_id" in obj.error_description
160+
):
161+
# Transform client auth failure to OAuth 2.1 compliant response
162+
return PydanticJSONResponse(
163+
content=TokenErrorResponse(
164+
error="invalid_client",
165+
error_description=obj.error_description,
166+
error_uri=obj.error_uri,
167+
),
168+
status_code=401,
169+
headers={
170+
"Cache-Control": "no-store",
171+
"Pragma": "no-cache",
172+
},
173+
)
174+
175+
# Otherwise use default behavior from parent class
176+
return super().response(obj)
177+
178+
125179
class OAuthProxy(OAuthProvider):
126180
"""OAuth provider that presents a DCR-compliant interface while proxying to non-DCR IDPs.
127181
@@ -852,12 +906,17 @@ def get_routes(
852906
and "POST" in route.methods
853907
):
854908
token_route_found = True
855-
# Replace with our custom token handler
909+
# Replace with our OAuth 2.1 compliant token handler
910+
token_handler = TokenHandler(
911+
provider=self, client_authenticator=ClientAuthenticator(self)
912+
)
856913
custom_routes.append(
857914
Route(
858915
path="/token",
859-
endpoint=self._handle_token_request,
860-
methods=["POST"],
916+
endpoint=cors_middleware(
917+
token_handler.handle, ["POST", "OPTIONS"]
918+
),
919+
methods=["POST", "OPTIONS"],
861920
)
862921
)
863922
else:
@@ -878,81 +937,6 @@ def get_routes(
878937
)
879938
return custom_routes
880939

881-
# -------------------------------------------------------------------------
882-
# Custom Token Endpoint Handler
883-
# -------------------------------------------------------------------------
884-
885-
async def _handle_token_request(self, request: Request) -> JSONResponse:
886-
"""Handle token requests with proper OAuth 2.1 error handling.
887-
888-
This custom handler wraps the standard MCP SDK token handler but provides
889-
OAuth 2.1 compliant error responses for client authentication failures:
890-
- Returns HTTP 401 status code for client authentication failures
891-
- Uses 'invalid_client' error code instead of 'unauthorized_client'
892-
893-
Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401
894-
(Unauthorized) status code to indicate which HTTP authentication schemes
895-
are supported. If the client attempted to authenticate via the Authorization
896-
request header field, the authorization server MUST respond with an HTTP 401
897-
(Unauthorized) status code and include the WWW-Authenticate response header
898-
field matching the authentication scheme used by the client."
899-
"""
900-
from mcp.server.auth.handlers.token import TokenHandler
901-
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
902-
903-
# Create the standard token handler and client authenticator
904-
token_handler = TokenHandler(
905-
provider=self, client_authenticator=ClientAuthenticator(self)
906-
)
907-
908-
# Handle the request normally
909-
response = await token_handler.handle(request)
910-
911-
# Check if the response is an error response for client authentication failure
912-
if (
913-
hasattr(response, "body")
914-
and hasattr(response, "status_code")
915-
and response.status_code == 400
916-
):
917-
try:
918-
import json
919-
920-
# Parse the response body to check for client authentication errors
921-
body_content = (
922-
response.body.decode("utf-8")
923-
if hasattr(response.body, "decode")
924-
else str(response.body)
925-
)
926-
error_data = json.loads(body_content)
927-
928-
# Check if this is an unauthorized_client error (which means invalid client_id)
929-
if error_data.get(
930-
"error"
931-
) == "unauthorized_client" and "Invalid client_id" in str(
932-
error_data.get("error_description", "")
933-
):
934-
logger.debug(
935-
"Client authentication failed - client not found, returning OAuth 2.1 compliant error"
936-
)
937-
938-
# Return the correct OAuth 2.1 response
939-
return JSONResponse(
940-
content={
941-
"error": "invalid_client",
942-
"error_description": error_data.get("error_description"),
943-
},
944-
status_code=401,
945-
headers={
946-
"Cache-Control": "no-store",
947-
"Pragma": "no-cache",
948-
},
949-
)
950-
except (json.JSONDecodeError, AttributeError, KeyError):
951-
# If we can't parse the response, return it as-is
952-
pass
953-
954-
return response
955-
956940
# -------------------------------------------------------------------------
957941
# IdP Callback Forwarding
958942
# -------------------------------------------------------------------------

tests/server/auth/test_oauth_proxy.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,3 +1029,70 @@ async def test_token_endpoint_invalid_client_error(self, jwt_verifier):
10291029
# Verify proper cache headers are set
10301030
assert response.headers.get("Cache-Control") == "no-store"
10311031
assert response.headers.get("Pragma") == "no-cache"
1032+
1033+
1034+
class TestTokenHandlerErrorTransformation:
1035+
"""Tests for TokenHandler's OAuth 2.1 compliant error transformation."""
1036+
1037+
def test_transforms_client_auth_failure_to_invalid_client_401(self):
1038+
"""Test that client authentication failures return invalid_client with 401."""
1039+
from mcp.server.auth.handlers.token import TokenErrorResponse
1040+
1041+
from fastmcp.server.auth.oauth_proxy import TokenHandler
1042+
1043+
handler = TokenHandler(provider=Mock(), client_authenticator=Mock())
1044+
1045+
# Simulate error from ClientAuthenticator.authenticate() failure
1046+
error_response = TokenErrorResponse(
1047+
error="unauthorized_client",
1048+
error_description="Invalid client_id 'test-client-id'",
1049+
)
1050+
1051+
response = handler.response(error_response)
1052+
1053+
# Should transform to OAuth 2.1 compliant response
1054+
assert response.status_code == 401
1055+
assert b'"error":"invalid_client"' in response.body
1056+
assert (
1057+
b'"error_description":"Invalid client_id \'test-client-id\'"'
1058+
in response.body
1059+
)
1060+
1061+
def test_does_not_transform_grant_type_unauthorized_to_invalid_client(self):
1062+
"""Test that grant type authorization errors stay as unauthorized_client with 400."""
1063+
from mcp.server.auth.handlers.token import TokenErrorResponse
1064+
1065+
from fastmcp.server.auth.oauth_proxy import TokenHandler
1066+
1067+
handler = TokenHandler(provider=Mock(), client_authenticator=Mock())
1068+
1069+
# Simulate error from grant_type not in client_info.grant_types
1070+
error_response = TokenErrorResponse(
1071+
error="unauthorized_client",
1072+
error_description="Client not authorized for this grant type",
1073+
)
1074+
1075+
response = handler.response(error_response)
1076+
1077+
# Should NOT transform - keep as 400 unauthorized_client
1078+
assert response.status_code == 400
1079+
assert b'"error":"unauthorized_client"' in response.body
1080+
1081+
def test_does_not_transform_other_errors(self):
1082+
"""Test that other error types pass through unchanged."""
1083+
from mcp.server.auth.handlers.token import TokenErrorResponse
1084+
1085+
from fastmcp.server.auth.oauth_proxy import TokenHandler
1086+
1087+
handler = TokenHandler(provider=Mock(), client_authenticator=Mock())
1088+
1089+
error_response = TokenErrorResponse(
1090+
error="invalid_grant",
1091+
error_description="Authorization code has expired",
1092+
)
1093+
1094+
response = handler.response(error_response)
1095+
1096+
# Should pass through unchanged
1097+
assert response.status_code == 400
1098+
assert b'"error":"invalid_grant"' in response.body

0 commit comments

Comments
 (0)