Skip to content

Commit 1c323a8

Browse files
authored
feat: Follow OAuth 2.1 spec requirements on auth failures (#1923)
Co-authored-by: Tomas <>
1 parent 66221ed commit 1c323a8

2 files changed

Lines changed: 144 additions & 4 deletions

File tree

src/fastmcp/server/auth/oauth_proxy.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
4343
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
4444
from starlette.requests import Request
45-
from starlette.responses import RedirectResponse
45+
from starlette.responses import JSONResponse, RedirectResponse
4646
from starlette.routing import Route
4747

4848
import fastmcp
@@ -844,16 +844,25 @@ def get_routes(
844844
f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
845845
)
846846

847-
# Keep all standard OAuth routes unchanged - our DCR-compliant flow handles everything
848-
custom_routes.append(route)
849-
847+
# Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
850848
if (
851849
isinstance(route, Route)
852850
and route.path == "/token"
853851
and route.methods is not None
854852
and "POST" in route.methods
855853
):
856854
token_route_found = True
855+
# Replace with our custom token handler
856+
custom_routes.append(
857+
Route(
858+
path="/token",
859+
endpoint=self._handle_token_request,
860+
methods=["POST"],
861+
)
862+
)
863+
else:
864+
# Keep all other standard OAuth routes unchanged
865+
custom_routes.append(route)
857866

858867
# Add OAuth callback endpoint for forwarding to client callbacks
859868
custom_routes.append(
@@ -869,6 +878,81 @@ def get_routes(
869878
)
870879
return custom_routes
871880

881+
# -------------------------------------------------------------------------
882+
# Custom Token Endpoint Handler
883+
# -------------------------------------------------------------------------
884+
885+
async def _handle_token_request(self, request: Request) -> JSONResponse:
886+
"""Handle token requests with proper OAuth 2.1 error handling.
887+
888+
This custom handler wraps the standard MCP SDK token handler but provides
889+
OAuth 2.1 compliant error responses for client authentication failures:
890+
- Returns HTTP 401 status code for client authentication failures
891+
- Uses 'invalid_client' error code instead of 'unauthorized_client'
892+
893+
Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401
894+
(Unauthorized) status code to indicate which HTTP authentication schemes
895+
are supported. If the client attempted to authenticate via the Authorization
896+
request header field, the authorization server MUST respond with an HTTP 401
897+
(Unauthorized) status code and include the WWW-Authenticate response header
898+
field matching the authentication scheme used by the client."
899+
"""
900+
from mcp.server.auth.handlers.token import TokenHandler
901+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
902+
903+
# Create the standard token handler and client authenticator
904+
token_handler = TokenHandler(
905+
provider=self, client_authenticator=ClientAuthenticator(self)
906+
)
907+
908+
# Handle the request normally
909+
response = await token_handler.handle(request)
910+
911+
# Check if the response is an error response for client authentication failure
912+
if (
913+
hasattr(response, "body")
914+
and hasattr(response, "status_code")
915+
and response.status_code == 400
916+
):
917+
try:
918+
import json
919+
920+
# Parse the response body to check for client authentication errors
921+
body_content = (
922+
response.body.decode("utf-8")
923+
if hasattr(response.body, "decode")
924+
else str(response.body)
925+
)
926+
error_data = json.loads(body_content)
927+
928+
# Check if this is an unauthorized_client error (which means invalid client_id)
929+
if error_data.get(
930+
"error"
931+
) == "unauthorized_client" and "Invalid client_id" in str(
932+
error_data.get("error_description", "")
933+
):
934+
logger.debug(
935+
"Client authentication failed - client not found, returning OAuth 2.1 compliant error"
936+
)
937+
938+
# Return the correct OAuth 2.1 response
939+
return JSONResponse(
940+
content={
941+
"error": "invalid_client",
942+
"error_description": error_data.get("error_description"),
943+
},
944+
status_code=401,
945+
headers={
946+
"Cache-Control": "no-store",
947+
"Pragma": "no-cache",
948+
},
949+
)
950+
except (json.JSONDecodeError, AttributeError, KeyError):
951+
# If we can't parse the response, return it as-is
952+
pass
953+
954+
return response
955+
872956
# -------------------------------------------------------------------------
873957
# IdP Callback Forwarding
874958
# -------------------------------------------------------------------------

tests/server/auth/test_oauth_proxy.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,59 @@ async def test_multiple_extra_params(self, jwt_verifier):
973973
assert query_params["audience"][0] == "https://api.example.com"
974974
assert query_params["prompt"][0] == "consent"
975975
assert query_params["max_age"][0] == "3600"
976+
977+
@pytest.mark.asyncio
978+
async def test_token_endpoint_invalid_client_error(self, jwt_verifier):
979+
"""Test that invalid client_id returns OAuth 2.1 compliant error response.
980+
981+
When a client ID is not found during token exchange, the proxy should:
982+
1. Return HTTP 401 status code
983+
2. Use 'invalid_client' error code instead of 'unauthorized_client'
984+
985+
This aligns with OAuth 2.1 spec and enables Claude's automatic client re-registration.
986+
"""
987+
from starlette.applications import Starlette
988+
from starlette.testclient import TestClient
989+
990+
proxy = OAuthProxy(
991+
upstream_authorization_endpoint="https://oauth.example.com/authorize",
992+
upstream_token_endpoint="https://oauth.example.com/token",
993+
upstream_client_id="upstream-client",
994+
upstream_client_secret="upstream-secret",
995+
token_verifier=jwt_verifier,
996+
base_url="https://proxy.example.com",
997+
)
998+
999+
# Create a test app with OAuth routes
1000+
app = Starlette(routes=proxy.get_routes())
1001+
1002+
# Test the token endpoint with an invalid (non-existent) client_id
1003+
with TestClient(app) as client:
1004+
response = client.post(
1005+
"/token",
1006+
data={
1007+
"grant_type": "authorization_code",
1008+
"code": "test-auth-code",
1009+
"client_id": "non-existent-client-id",
1010+
"code_verifier": "test-code-verifier",
1011+
"redirect_uri": "http://localhost:12345/callback",
1012+
},
1013+
headers={
1014+
"Content-Type": "application/x-www-form-urlencoded",
1015+
},
1016+
)
1017+
1018+
# Verify OAuth 2.1 compliant error response
1019+
assert response.status_code == 401, (
1020+
f"Expected 401 but got {response.status_code}"
1021+
)
1022+
1023+
error_data = response.json()
1024+
assert error_data["error"] == "invalid_client", (
1025+
f"Expected 'invalid_client' but got '{error_data.get('error')}'"
1026+
)
1027+
assert "Invalid client_id" in error_data["error_description"]
1028+
1029+
# Verify proper cache headers are set
1030+
assert response.headers.get("Cache-Control") == "no-store"
1031+
assert response.headers.get("Pragma") == "no-cache"

0 commit comments

Comments
 (0)