diff --git a/Directory.Packages.props b/Directory.Packages.props index db97ab1c..2e377c4f 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -19,6 +19,7 @@ + @@ -26,6 +27,8 @@ + + diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index e4fd42fe..5ed8ba0d 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -12,6 +12,8 @@ + + @@ -33,6 +35,7 @@ + diff --git a/samples/ProtectedMCPClient/Program.cs b/samples/ProtectedMCPClient/Program.cs new file mode 100644 index 00000000..516227b3 --- /dev/null +++ b/samples/ProtectedMCPClient/Program.cs @@ -0,0 +1,148 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Diagnostics; +using System.Net; +using System.Text; +using System.Web; + +var serverUrl = "http://localhost:7071/"; + +Console.WriteLine("Protected MCP Client"); +Console.WriteLine($"Connecting to weather server at {serverUrl}..."); +Console.WriteLine(); + +// We can customize a shared HttpClient with a custom handler if desired +var sharedHandler = new SocketsHttpHandler +{ + PooledConnectionLifetime = TimeSpan.FromMinutes(2), + PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1) +}; +var httpClient = new HttpClient(sharedHandler); + +var consoleLoggerFactory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +var transport = new SseClientTransport(new() +{ + Endpoint = new Uri(serverUrl), + Name = "Secure Weather Client", + OAuth = new() + { + ClientName = "ProtectedMcpClient", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + } +}, httpClient, consoleLoggerFactory); + +var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); + +var tools = await client.ListToolsAsync(); +if (tools.Count == 0) +{ + Console.WriteLine("No tools available on the server."); + return; +} + +Console.WriteLine($"Found {tools.Count} tools on the server."); +Console.WriteLine(); + +if (tools.Any(t => t.Name == "get_alerts")) +{ + Console.WriteLine("Calling get_alerts tool..."); + + var result = await client.CallToolAsync( + "get_alerts", + new Dictionary { { "state", "WA" } } + ); + + Console.WriteLine("Result: " + ((TextContentBlock)result.Content[0]).Text); + Console.WriteLine(); +} + +/// Handles the OAuth authorization URL by starting a local HTTP server and opening a browser. +/// This implementation demonstrates how SDK consumers can provide their own authorization flow. +/// +/// The authorization URL to open in the browser. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// The authorization code extracted from the callback, or null if the operation failed. +static async Task HandleAuthorizationUrlAsync(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) +{ + Console.WriteLine("Starting OAuth authorization flow..."); + Console.WriteLine($"Opening browser to: {authorizationUrl}"); + + var listenerPrefix = redirectUri.GetLeftPart(UriPartial.Authority); + if (!listenerPrefix.EndsWith("/")) listenerPrefix += "/"; + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + try + { + listener.Start(); + Console.WriteLine($"Listening for OAuth callback on: {listenerPrefix}"); + + OpenBrowser(authorizationUrl); + + var context = await listener.GetContextAsync(); + var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty); + var code = query["code"]; + var error = query["error"]; + + string responseHtml = "

Authentication complete

You can close this window now.

"; + byte[] buffer = Encoding.UTF8.GetBytes(responseHtml); + context.Response.ContentLength64 = buffer.Length; + context.Response.ContentType = "text/html"; + context.Response.OutputStream.Write(buffer, 0, buffer.Length); + context.Response.Close(); + + if (!string.IsNullOrEmpty(error)) + { + Console.WriteLine($"Auth error: {error}"); + return null; + } + + if (string.IsNullOrEmpty(code)) + { + Console.WriteLine("No authorization code received"); + return null; + } + + Console.WriteLine("Authorization code received successfully."); + return code; + } + catch (Exception ex) + { + Console.WriteLine($"Error getting auth code: {ex.Message}"); + return null; + } + finally + { + if (listener.IsListening) listener.Stop(); + } +} + +/// +/// Opens the specified URL in the default browser. +/// +/// The URL to open. +static void OpenBrowser(Uri url) +{ + try + { + var psi = new ProcessStartInfo + { + FileName = url.ToString(), + UseShellExecute = true + }; + Process.Start(psi); + } + catch (Exception ex) + { + Console.WriteLine($"Error opening browser. {ex.Message}"); + Console.WriteLine($"Please manually open this URL: {url}"); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPClient/ProtectedMCPClient.csproj b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj new file mode 100644 index 00000000..d1d47637 --- /dev/null +++ b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj @@ -0,0 +1,18 @@ + + + + Exe + net9.0 + enable + enable + + + + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPClient/README.md b/samples/ProtectedMCPClient/README.md new file mode 100644 index 00000000..977331a0 --- /dev/null +++ b/samples/ProtectedMCPClient/README.md @@ -0,0 +1,93 @@ +# Protected MCP Client Sample + +This sample demonstrates how to create an MCP client that connects to a protected MCP server using OAuth 2.0 authentication. The client implements a custom OAuth authorization flow with browser-based authentication. + +## Overview + +The Protected MCP Client sample shows how to: +- Connect to an OAuth-protected MCP server +- Handle OAuth 2.0 authorization code flow +- Use custom authorization redirect handling +- Call protected MCP tools with authentication + +## Prerequisites + +- .NET 9.0 or later +- A running TestOAuthServer (for OAuth authentication) +- A running ProtectedMCPServer (for MCP services) + +## Setup and Running + +### Step 1: Start the Test OAuth Server + +First, you need to start the TestOAuthServer which provides OAuth authentication: + +```bash +cd tests\ModelContextProtocol.TestOAuthServer +dotnet run --framework net9.0 +``` + +The OAuth server will start at `https://localhost:7029` + +### Step 2: Start the Protected MCP Server + +Next, start the ProtectedMCPServer which provides the weather tools: + +```bash +cd samples\ProtectedMCPServer +dotnet run +``` + +The protected server will start at `http://localhost:7071` + +### Step 3: Run the Protected MCP Client + +Finally, run this client: + +```bash +cd samples\ProtectedMCPClient +dotnet run +``` + +## What Happens + +1. The client attempts to connect to the protected MCP server at `http://localhost:7071` +2. The server responds with OAuth metadata indicating authentication is required +3. The client initiates OAuth 2.0 authorization code flow: + - Opens a browser to the authorization URL at the OAuth server + - Starts a local HTTP listener on `http://localhost:1179/callback` to receive the authorization code + - Exchanges the authorization code for an access token +4. The client uses the access token to authenticate with the MCP server +5. The client lists available tools and calls the `GetAlerts` tool for Washington state + +## OAuth Configuration + +The client is configured with: +- **Client ID**: `demo-client` +- **Client Secret**: `demo-secret` +- **Redirect URI**: `http://localhost:1179/callback` +- **OAuth Server**: `https://localhost:7029` +- **Protected Resource**: `http://localhost:7071` + +## Available Tools + +Once authenticated, the client can access weather tools including: +- **GetAlerts**: Get weather alerts for a US state +- **GetForecast**: Get weather forecast for a location (latitude/longitude) + +## Troubleshooting + +- Ensure the ASP.NET Core dev certificate is trusted. + ``` + dotnet dev-certs https --clean + dotnet dev-certs https --trust + ``` +- Ensure all three services are running in the correct order +- Check that ports 7029, 7071, and 1179 are available +- If the browser doesn't open automatically, copy the authorization URL from the console and open it manually +- Make sure to allow the OAuth server's self-signed certificate in your browser + +## Key Files + +- `Program.cs`: Main client application with OAuth flow implementation +- `ProtectedMCPClient.csproj`: Project file with dependencies \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Program.cs b/samples/ProtectedMCPServer/Program.cs new file mode 100644 index 00000000..ef70fe73 --- /dev/null +++ b/samples/ProtectedMCPServer/Program.cs @@ -0,0 +1,93 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ProtectedMCPServer.Tools; +using System.Net.Http.Headers; +using System.Security.Claims; + +var builder = WebApplication.CreateBuilder(args); + +var serverUrl = "http://localhost:7071/"; +var inMemoryOAuthServerUrl = "https://localhost:7029"; + +builder.Services.AddAuthentication(options => +{ + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; +}) +.AddJwtBearer(options => +{ + // Configure to validate tokens from our in-memory OAuth server + options.Authority = inMemoryOAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = serverUrl, // Validate that the audience matches the resource metadata as suggested in RFC 8707 + ValidIssuer = inMemoryOAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + + options.Events = new JwtBearerEvents + { + OnTokenValidated = context => + { + var name = context.Principal?.Identity?.Name ?? "unknown"; + var email = context.Principal?.FindFirstValue("preferred_username") ?? "unknown"; + Console.WriteLine($"Token validated for: {name} ({email})"); + return Task.CompletedTask; + }, + OnAuthenticationFailed = context => + { + Console.WriteLine($"Authentication failed: {context.Exception.Message}"); + return Task.CompletedTask; + }, + OnChallenge = context => + { + Console.WriteLine($"Challenging client to authenticate with Entra ID"); + return Task.CompletedTask; + } + }; +}) +.AddMcp(options => +{ + options.ResourceMetadata = new() + { + Resource = new Uri(serverUrl), + ResourceDocumentation = new Uri("https://docs.example.com/api/weather"), + AuthorizationServers = { new Uri(inMemoryOAuthServerUrl) }, + ScopesSupported = ["mcp:tools"], + }; +}); + +builder.Services.AddAuthorization(); + +builder.Services.AddHttpContextAccessor(); +builder.Services.AddMcpServer() + .WithTools() + .WithHttpTransport(); + +// Configure HttpClientFactory for weather.gov API +builder.Services.AddHttpClient("WeatherApi", client => +{ + client.BaseAddress = new Uri("https://api.weather.gov"); + client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); +}); + +var app = builder.Build(); + +app.UseAuthentication(); +app.UseAuthorization(); + +// Use the default MCP policy name that we've configured +app.MapMcp().RequireAuthorization(); + +Console.WriteLine($"Starting MCP server with authorization at {serverUrl}"); +Console.WriteLine($"Using in-memory OAuth server at {inMemoryOAuthServerUrl}"); +Console.WriteLine($"Protected Resource Metadata URL: {serverUrl}.well-known/oauth-protected-resource"); +Console.WriteLine("Press Ctrl+C to stop the server"); + +app.Run(serverUrl); diff --git a/samples/ProtectedMCPServer/Properties/launchSettings.json b/samples/ProtectedMCPServer/Properties/launchSettings.json new file mode 100644 index 00000000..31b04db8 --- /dev/null +++ b/samples/ProtectedMCPServer/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "ProtectedMCPServer": { + "commandName": "Project", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "http://localhost:7071" + } + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/ProtectedMCPServer.csproj b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj new file mode 100644 index 00000000..b4c35c77 --- /dev/null +++ b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + 783daef3-9c45-408d-a1d3-7caf44724f39 + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPServer/README.md b/samples/ProtectedMCPServer/README.md new file mode 100644 index 00000000..f0ac708a --- /dev/null +++ b/samples/ProtectedMCPServer/README.md @@ -0,0 +1,125 @@ +# Protected MCP Server Sample + +This sample demonstrates how to create an MCP server that requires OAuth 2.0 authentication to access its tools and resources. The server provides weather-related tools protected by JWT bearer token authentication. + +## Overview + +The Protected MCP Server sample shows how to: +- Create an MCP server with OAuth 2.0 protection +- Configure JWT bearer token authentication +- Implement protected MCP tools and resources +- Integrate with ASP.NET Core authentication and authorization +- Provide OAuth resource metadata for client discovery + +## Prerequisites + +- .NET 9.0 or later +- A running TestOAuthServer (for OAuth authentication) + +## Setup and Running + +### Step 1: Start the Test OAuth Server + +First, you need to start the TestOAuthServer which issues access tokens: + +```bash +cd tests\ModelContextProtocol.TestOAuthServer +dotnet run --framework net9.0 +``` + +The OAuth server will start at `https://localhost:7029` + +### Step 2: Start the Protected MCP Server + +Run this protected server: + +```bash +cd samples\ProtectedMCPServer +dotnet run +``` + +The protected server will start at `http://localhost:7071` + +### Step 3: Test with Protected MCP Client + +You can test the server using the ProtectedMCPClient sample: + +```bash +cd samples\ProtectedMCPClient +dotnet run +``` + +## What the Server Provides + +### Protected Resources + +- **MCP Endpoint**: `http://localhost:7071/` (requires authentication) +- **OAuth Resource Metadata**: `http://localhost:7071/.well-known/oauth-protected-resource` + +### Available Tools + +The server provides weather-related tools that require authentication: + +1. **GetAlerts**: Get weather alerts for a US state + - Parameter: `state` (string) - 2-letter US state abbreviation + - Example: `GetAlerts` with `state: "WA"` + +2. **GetForecast**: Get weather forecast for a location + - Parameters: + - `latitude` (double) - Latitude coordinate + - `longitude` (double) - Longitude coordinate + - Example: `GetForecast` with `latitude: 47.6062, longitude: -122.3321` + +### Authentication Configuration + +The server is configured to: +- Accept JWT bearer tokens from the OAuth server at `https://localhost:7029` +- Validate token audience as `demo-client` +- Require tokens to have appropriate scopes (`mcp:tools`) +- Provide OAuth resource metadata for client discovery + +## Architecture + +The server uses: +- **ASP.NET Core** for hosting and HTTP handling +- **JWT Bearer Authentication** for token validation +- **MCP Authentication Extensions** for OAuth resource metadata +- **HttpClient** for calling the weather.gov API +- **Authorization** to protect MCP endpoints + +## Configuration Details + +- **Server URL**: `http://localhost:7071` +- **OAuth Server**: `https://localhost:7029` +- **Demo Client ID**: `demo-client` + +## Testing Without Client + +You can test the server directly using HTTP tools: + +1. Get an access token from the OAuth server +2. Include the token in the `Authorization: Bearer ` header +3. Make requests to the MCP endpoints + +## External Dependencies + +The weather tools use the National Weather Service API at `api.weather.gov` to fetch real weather data. + +## Troubleshooting + +- Ensure the ASP.NET Core dev certificate is trusted. + ``` + dotnet dev-certs https --clean + dotnet dev-certs https --trust + ``` +- Ensure the TestOAuthServer is running first +- Check that port 7071 is available +- Verify the OAuth server is accessible at `https://localhost:7029` +- Check console output for authentication events and errors + +## Key Files + +- `Program.cs`: Server setup with authentication and MCP configuration +- `Tools/WeatherTools.cs`: Weather tool implementations +- `Tools/HttpClientExt.cs`: HTTP client extensions +- `Properties/launchSettings.json`: Development launch configuration \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/HttpClientExt.cs b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs new file mode 100644 index 00000000..f7b2b549 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs @@ -0,0 +1,13 @@ +using System.Text.Json; + +namespace ModelContextProtocol; + +internal static class HttpClientExt +{ + public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) + { + using var response = await client.GetAsync(requestUri); + response.EnsureSuccessStatusCode(); + return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/WeatherTools.cs b/samples/ProtectedMCPServer/Tools/WeatherTools.cs new file mode 100644 index 00000000..7c8c0851 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/WeatherTools.cs @@ -0,0 +1,67 @@ +using ModelContextProtocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; + +namespace ProtectedMCPServer.Tools; + +[McpServerToolType] +public sealed class WeatherTools +{ + private readonly IHttpClientFactory _httpClientFactory; + + public WeatherTools(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; + } + + [McpServerTool, Description("Get weather alerts for a US state.")] + public async Task GetAlerts( + [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); + var jsonElement = jsonDocument.RootElement; + var alerts = jsonElement.GetProperty("features").EnumerateArray(); + + if (!alerts.Any()) + { + return "No active alerts for this state."; + } + + return string.Join("\n--\n", alerts.Select(alert => + { + JsonElement properties = alert.GetProperty("properties"); + return $""" + Event: {properties.GetProperty("event").GetString()} + Area: {properties.GetProperty("areaDesc").GetString()} + Severity: {properties.GetProperty("severity").GetString()} + Description: {properties.GetProperty("description").GetString()} + Instruction: {properties.GetProperty("instruction").GetString()} + """; + })); + } + + [McpServerTool, Description("Get weather forecast for a location.")] + public async Task GetForecast( + [Description("Latitude of the location.")] double latitude, + [Description("Longitude of the location.")] double longitude) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); + using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); + var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); + var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); + + return string.Join("\n---\n", periods.Select(period => $""" + {period.GetProperty("name").GetString()} + Temperature: {period.GetProperty("temperature").GetInt32()}°F + Wind: {period.GetProperty("windSpeed").GetString()} {period.GetProperty("windDirection").GetString()} + Forecast: {period.GetProperty("detailedForecast").GetString()} + """)); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs new file mode 100644 index 00000000..4c720c65 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Default values used by MCP authentication. +/// +public static class McpAuthenticationDefaults +{ + /// + /// The default value used for authentication scheme name. + /// + public const string AuthenticationScheme = "McpAuth"; + + /// + /// The default value used for authentication scheme display name. + /// + public const string DisplayName = "MCP Authentication"; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs new file mode 100644 index 00000000..0d430225 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Represents the authentication events for Model Context Protocol. +/// +public class McpAuthenticationEvents +{ + /// + /// Gets or sets the function that is invoked when resource metadata is requested. + /// + /// + /// This function is called when a resource metadata request is made to the protected resource metadata endpoint. + /// The implementer should set the property + /// to provide the appropriate metadata for the current request. + /// + public Func OnResourceMetadataRequest { get; set; } = context => Task.CompletedTask; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs new file mode 100644 index 00000000..f103357c --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs @@ -0,0 +1,47 @@ +using Microsoft.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Authentication; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for adding MCP authorization support to ASP.NET Core applications. +/// +public static class McpAuthenticationExtensions +{ + /// + /// Adds MCP authorization support to the application. + /// + /// The authentication builder. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + Action? configureOptions = null) + { + return AddMcp( + builder, + McpAuthenticationDefaults.AuthenticationScheme, + McpAuthenticationDefaults.DisplayName, + configureOptions); + } + + /// + /// Adds MCP authorization support to the application with a custom scheme name. + /// + /// The authentication builder. + /// The authentication scheme name to use. + /// The display name for the authentication scheme. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + string authenticationScheme, + string displayName, + Action? configureOptions = null) + { + return builder.AddScheme( + authenticationScheme, + displayName, + configureOptions); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs new file mode 100644 index 00000000..942db1b6 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs @@ -0,0 +1,157 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Authentication; +using System.Text.Encodings.Web; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Authentication handler for MCP protocol that adds resource metadata to challenge responses +/// and handles resource metadata endpoint requests. +/// +public class McpAuthenticationHandler : AuthenticationHandler, IAuthenticationRequestHandler +{ + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory logger, + UrlEncoder encoder) + : base(options, logger, encoder) + { + } + + /// + public async Task HandleRequestAsync() + { + // Check if the request is for the resource metadata endpoint + string requestPath = Request.Path.Value ?? string.Empty; + + string expectedMetadataPath = Options.ResourceMetadataUri?.ToString() ?? string.Empty; + if (Options.ResourceMetadataUri != null && !Options.ResourceMetadataUri.IsAbsoluteUri) + { + // For relative URIs, it's just the path component. + expectedMetadataPath = Options.ResourceMetadataUri.OriginalString; + } + + // If the path doesn't match, let the request continue through the pipeline + if (!string.Equals(requestPath, expectedMetadataPath, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var cancellationToken = Request.HttpContext.RequestAborted; + await HandleResourceMetadataRequestAsync(cancellationToken); + return true; + } + + /// + /// Gets the base URL from the current request, including scheme, host, and path base. + /// + private string GetBaseUrl() => $"{Request.Scheme}://{Request.Host}{Request.PathBase}"; + + /// + /// Gets the absolute URI for the resource metadata endpoint. + /// + private string GetAbsoluteResourceMetadataUri() + { + var resourceMetadataUri = Options.ResourceMetadataUri; + + string currentPath = resourceMetadataUri?.ToString() ?? string.Empty; + + if (resourceMetadataUri != null && resourceMetadataUri.IsAbsoluteUri) + { + return currentPath; + } + + // For relative URIs, combine with the base URL + string baseUrl = GetBaseUrl(); + string relativePath = resourceMetadataUri?.OriginalString.TrimStart('/') ?? string.Empty; + + if (!Uri.TryCreate($"{baseUrl.TrimEnd('/')}/{relativePath}", UriKind.Absolute, out var absoluteUri)) + { + throw new InvalidOperationException($"Could not create absolute URI for resource metadata. Base URL: {baseUrl}, Relative Path: {relativePath}"); + } + + return absoluteUri.ToString(); + } + + /// + /// Handles the resource metadata request. + /// + /// A token to cancel the operation. + private async Task HandleResourceMetadataRequestAsync(CancellationToken cancellationToken = default) + { + var resourceMetadata = Options.ResourceMetadata; + + if (Options.Events.OnResourceMetadataRequest is not null) + { + var context = new ResourceMetadataRequestContext(Request.HttpContext, Scheme, Options) + { + ResourceMetadata = CloneResourceMetadata(resourceMetadata), + }; + + await Options.Events.OnResourceMetadataRequest(context); + } + + + if (resourceMetadata == null) + { + throw new InvalidOperationException("ResourceMetadata has not been configured. Please set McpAuthenticationOptions.ResourceMetadata."); + } + + await Results.Json(resourceMetadata, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))).ExecuteAsync(Context); + } + + /// + // If no forwarding is configured, this handler doesn't perform authentication + protected override async Task HandleAuthenticateAsync() => AuthenticateResult.NoResult(); + + /// + protected override Task HandleChallengeAsync(AuthenticationProperties properties) + { + // Get the absolute URI for the resource metadata + string rawPrmDocumentUri = GetAbsoluteResourceMetadataUri(); + + properties ??= new AuthenticationProperties(); + + // Store the resource_metadata in properties in case other handlers need it + properties.Items["resource_metadata"] = rawPrmDocumentUri; + + // Add the WWW-Authenticate header with Bearer scheme and resource metadata + string headerValue = $"Bearer realm=\"{Scheme.Name}\", resource_metadata=\"{rawPrmDocumentUri}\""; + Response.Headers.Append("WWW-Authenticate", headerValue); + + return base.HandleChallengeAsync(properties); + } + + internal static ProtectedResourceMetadata? CloneResourceMetadata(ProtectedResourceMetadata? resourceMetadata) + { + if (resourceMetadata is null) + { + return null; + } + + return new ProtectedResourceMetadata + { + Resource = resourceMetadata.Resource, + AuthorizationServers = [.. resourceMetadata.AuthorizationServers], + BearerMethodsSupported = [.. resourceMetadata.BearerMethodsSupported], + ScopesSupported = [.. resourceMetadata.ScopesSupported], + JwksUri = resourceMetadata.JwksUri, + ResourceSigningAlgValuesSupported = resourceMetadata.ResourceSigningAlgValuesSupported is not null ? [.. resourceMetadata.ResourceSigningAlgValuesSupported] : null, + ResourceName = resourceMetadata.ResourceName, + ResourceDocumentation = resourceMetadata.ResourceDocumentation, + ResourcePolicyUri = resourceMetadata.ResourcePolicyUri, + ResourceTosUri = resourceMetadata.ResourceTosUri, + TlsClientCertificateBoundAccessTokens = resourceMetadata.TlsClientCertificateBoundAccessTokens, + AuthorizationDetailsTypesSupported = resourceMetadata.AuthorizationDetailsTypesSupported is not null ? [.. resourceMetadata.AuthorizationDetailsTypesSupported] : null, + DpopSigningAlgValuesSupported = resourceMetadata.DpopSigningAlgValuesSupported is not null ? [.. resourceMetadata.DpopSigningAlgValuesSupported] : null, + DpopBoundAccessTokensRequired = resourceMetadata.DpopBoundAccessTokensRequired + }; + } + +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs new file mode 100644 index 00000000..ecb6c6c8 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs @@ -0,0 +1,49 @@ +using Microsoft.AspNetCore.Authentication; +using ModelContextProtocol.Authentication; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Options for the MCP authentication handler. +/// +public class McpAuthenticationOptions : AuthenticationSchemeOptions +{ + private static readonly Uri DefaultResourceMetadataUri = new("/.well-known/oauth-protected-resource", UriKind.Relative); + + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationOptions() + { + // "Bearer" is JwtBearerDefaults.AuthenticationScheme, but we don't have a reference to the JwtBearer package here. + ForwardAuthenticate = "Bearer"; + ResourceMetadataUri = DefaultResourceMetadataUri; + Events = new McpAuthenticationEvents(); + } + + /// + /// Gets or sets the events used to handle authentication events. + /// + public new McpAuthenticationEvents Events + { + get { return (McpAuthenticationEvents)base.Events!; } + set { base.Events = value; } + } + + /// + /// The URI to the resource metadata document. + /// + /// + /// This URI will be included in the WWW-Authenticate header when a 401 response is returned. + /// + public Uri ResourceMetadataUri { get; set; } + + /// + /// Gets or sets the protected resource metadata. + /// + /// + /// This contains the OAuth metadata for the protected resource, including authorization servers, + /// supported scopes, and other information needed for clients to authenticate. + /// + public ProtectedResourceMetadata? ResourceMetadata { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs b/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs new file mode 100644 index 00000000..0d064123 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs @@ -0,0 +1,30 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using ModelContextProtocol.Authentication; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Context for resource metadata request events. +/// +public class ResourceMetadataRequestContext : HandleRequestContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The HTTP context. + /// The authentication scheme. + /// The authentication options. + public ResourceMetadataRequestContext( + HttpContext context, + AuthenticationScheme scheme, + McpAuthenticationOptions options) + : base(context, scheme, options) + { + } + + /// + /// Gets or sets the protected resource metadata for the current request. + /// + public ProtectedResourceMetadata? ResourceMetadata { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs new file mode 100644 index 00000000..1cc08189 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Net.Http.Headers; + +namespace ModelContextProtocol.Authentication; + +/// +/// A delegating handler that adds authentication tokens to requests and handles 401 responses. +/// +internal sealed class AuthenticatingMcpHttpClient(HttpClient httpClient, ClientOAuthProvider credentialProvider) : McpHttpClient(httpClient) +{ + // Select first supported scheme as the default + private string _currentScheme = credentialProvider.SupportedSchemes.FirstOrDefault() ?? + throw new ArgumentException("Authorization provider must support at least one authentication scheme.", nameof(credentialProvider)); + + /// + /// Sends an HTTP request with authentication handling. + /// + internal override async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + if (request.Headers.Authorization == null) + { + await AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false); + } + + var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false); + + if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) + { + return await HandleUnauthorizedResponseAsync(request, message, response, cancellationToken).ConfigureAwait(false); + } + + return response; + } + + /// + /// Handles a 401 Unauthorized response by attempting to authenticate and retry the request. + /// + private async Task HandleUnauthorizedResponseAsync( + HttpRequestMessage originalRequest, + JsonRpcMessage? originalJsonRpcMessage, + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Gather the schemes the server wants us to use from WWW-Authenticate headers + var serverSchemes = ExtractServerSupportedSchemes(response); + + if (!serverSchemes.Contains(_currentScheme)) + { + // Find the first server scheme that's in our supported set + var bestSchemeMatch = serverSchemes.Intersect(credentialProvider.SupportedSchemes, StringComparer.OrdinalIgnoreCase).FirstOrDefault(); + + if (bestSchemeMatch is not null) + { + _currentScheme = bestSchemeMatch; + } + else if (serverSchemes.Count > 0) + { + // If no match was found, either throw an exception or use default + throw new McpException( + $"The server does not support any of the provided authentication schemes." + + $"Server supports: [{string.Join(", ", serverSchemes)}], " + + $"Provider supports: [{string.Join(", ", credentialProvider.SupportedSchemes)}]."); + } + } + + // Try to handle the 401 response with the selected scheme + await credentialProvider.HandleUnauthorizedResponseAsync(_currentScheme, response, cancellationToken).ConfigureAwait(false); + + using var retryRequest = new HttpRequestMessage(originalRequest.Method, originalRequest.RequestUri); + + // Copy headers except Authorization which we'll set separately + foreach (var header in originalRequest.Headers) + { + if (!header.Key.Equals("Authorization", StringComparison.OrdinalIgnoreCase)) + { + retryRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } + + await AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false); + return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false); + } + + /// + /// Extracts the authentication schemes that the server supports from the WWW-Authenticate headers. + /// + private static HashSet ExtractServerSupportedSchemes(HttpResponseMessage response) + { + var serverSchemes = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var header in response.Headers.WwwAuthenticate) + { + serverSchemes.Add(header.Scheme); + } + + return serverSchemes; + } + + /// + /// Adds an authorization header to the request. + /// + private async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, string scheme, CancellationToken cancellationToken) + { + if (request.RequestUri is null) + { + return; + } + + var token = await credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false); + if (string.IsNullOrEmpty(token)) + { + return; + } + + request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs new file mode 100644 index 00000000..d3c33231 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs @@ -0,0 +1,28 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a method that handles the OAuth authorization URL and returns the authorization code. +/// +/// The authorization URL that the user needs to visit. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// A task that represents the asynchronous operation. The task result contains the authorization code if successful, or null if the operation failed or was cancelled. +/// +/// +/// This delegate provides SDK consumers with full control over how the OAuth authorization flow is handled. +/// Implementers can choose to: +/// +/// +/// Start a local HTTP server and open a browser (default behavior) +/// Display the authorization URL to the user for manual handling +/// Integrate with a custom UI or authentication flow +/// Use a different redirect mechanism altogether +/// +/// +/// The implementation should handle user interaction to visit the authorization URL and extract +/// the authorization code from the callback. The authorization code is typically provided as +/// a query parameter in the redirect URI callback. +/// +/// +public delegate Task AuthorizationRedirectDelegate(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..e94fce7a --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs @@ -0,0 +1,69 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the metadata about an OAuth authorization server. +/// +internal sealed class AuthorizationServerMetadata +{ + /// + /// The authorization endpoint URI. + /// + [JsonPropertyName("authorization_endpoint")] + public Uri AuthorizationEndpoint { get; set; } = null!; + + /// + /// The token endpoint URI. + /// + [JsonPropertyName("token_endpoint")] + public Uri TokenEndpoint { get; set; } = null!; + + /// + /// The registration endpoint URI. + /// + [JsonPropertyName("registration_endpoint")] + public Uri? RegistrationEndpoint { get; set; } + + /// + /// The revocation endpoint URI. + /// + [JsonPropertyName("revocation_endpoint")] + public Uri? RevocationEndpoint { get; set; } + + /// + /// The response types supported by the authorization server. + /// + [JsonPropertyName("response_types_supported")] + public List? ResponseTypesSupported { get; set; } + + /// + /// The grant types supported by the authorization server. + /// + [JsonPropertyName("grant_types_supported")] + public List? GrantTypesSupported { get; set; } + + /// + /// The token endpoint authentication methods supported by the authorization server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public List? TokenEndpointAuthMethodsSupported { get; set; } + + /// + /// The code challenge methods supported by the authorization server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public List? CodeChallengeMethodsSupported { get; set; } + + /// + /// The issuer URI of the authorization server. + /// + [JsonPropertyName("issuer")] + public Uri? Issuer { get; set; } + + /// + /// The scopes supported by the authorization server. + /// + [JsonPropertyName("scopes_supported")] + public List? ScopesSupported { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs new file mode 100644 index 00000000..686316f5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -0,0 +1,99 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Provides configuration options for the . +/// +public sealed class ClientOAuthOptions +{ + /// + /// Gets or sets the OAuth redirect URI. + /// + public required Uri RedirectUri { get; set; } + + /// + /// Gets or sets the OAuth client ID. If not provided, the client will attempt to register dynamically. + /// + public string? ClientId { get; set; } + + /// + /// Gets or sets the OAuth client secret. + /// + /// + /// This is optional for public clients or when using PKCE without client authentication. + /// + public string? ClientSecret { get; set; } + + /// + /// Gets or sets the OAuth scopes to request. + /// + /// + /// + /// When specified, these scopes will be used instead of the scopes advertised by the protected resource. + /// If not specified, the provider will use the scopes from the protected resource metadata. + /// + /// + /// Common OAuth scopes include "openid", "profile", "email", etc. + /// + /// + public IEnumerable? Scopes { get; set; } + + /// + /// Gets or sets the authorization redirect delegate for handling the OAuth authorization flow. + /// + /// + /// + /// This delegate is responsible for handling the OAuth authorization URL and obtaining the authorization code. + /// If not specified, a default implementation will be used that prompts the user to enter the code manually. + /// + /// + /// Custom implementations might open a browser, start an HTTP listener, or use other mechanisms to capture + /// the authorization code from the OAuth redirect. + /// + /// + public AuthorizationRedirectDelegate? AuthorizationRedirectDelegate { get; set; } + + /// + /// Gets or sets the authorization server selector function. + /// + /// + /// + /// This function is used to select which authorization server to use when multiple servers are available. + /// If not specified, the first available server will be selected. + /// + /// + /// The function receives a list of available authorization server URIs and should return the selected server, + /// or null if no suitable server is found. + /// + /// + public Func, Uri?>? AuthServerSelector { get; set; } + + /// + /// Gets or sets the client name to use during dynamic client registration. + /// + /// + /// This is a human-readable name for the client that may be displayed to users during authorization. + /// Only used when a is not specified. + /// + public string? ClientName { get; set; } + + /// + /// Gets or sets the client URI to use during dynamic client registration. + /// + /// + /// This should be a URL pointing to the client's home page or information page. + /// Only used when a is not specified. + /// + public Uri? ClientUri { get; set; } + + /// + /// Gets or sets additional parameters to include in the query string of the OAuth authorization request + /// providing extra information or fulfilling specific requirements of the OAuth provider. + /// + /// + /// + /// Parameters specified cannot override or append to any automatically set parameters like the "redirect_uri" + /// which should instead be configured via . + /// + /// + public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs new file mode 100644 index 00000000..96356028 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -0,0 +1,687 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using System.Collections.Specialized; +using System.Diagnostics.CodeAnalysis; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Web; + +namespace ModelContextProtocol.Authentication; + +/// +/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token +/// protection or caching - it acquires a token and server metadata and holds it in memory. +/// This is suitable for demonstration and development purposes. +/// +internal sealed partial class ClientOAuthProvider +{ + /// + /// The Bearer authentication scheme. + /// + private const string BearerScheme = "Bearer"; + + private readonly Uri _serverUrl; + private readonly Uri _redirectUri; + private readonly string[]? _scopes; + private readonly IDictionary _additionalAuthorizationParameters; + private readonly Func, Uri?> _authServerSelector; + private readonly AuthorizationRedirectDelegate _authorizationRedirectDelegate; + + // _clientName and _client URI is used for dynamic client registration (RFC 7591) + private readonly string? _clientName; + private readonly Uri? _clientUri; + + private readonly HttpClient _httpClient; + private readonly ILogger _logger; + + private string? _clientId; + private string? _clientSecret; + + private TokenContainer? _token; + private AuthorizationServerMetadata? _authServerMetadata; + + /// + /// Initializes a new instance of the class using the specified options. + /// + /// The MCP server URL. + /// The OAuth provider configuration options. + /// The HTTP client to use for OAuth requests. If null, a default HttpClient will be used. + /// A logger factory to handle diagnostic messages. + /// Thrown when serverUrl or options are null. + public ClientOAuthProvider( + Uri serverUrl, + ClientOAuthOptions options, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + _serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl)); + _httpClient = httpClient ?? new HttpClient(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + if (options is null) + { + throw new ArgumentNullException(nameof(options)); + } + + _clientId = options.ClientId; + _clientSecret = options.ClientSecret; + _redirectUri = options.RedirectUri ?? throw new ArgumentException("ClientOAuthOptions.RedirectUri must configured."); + _clientName = options.ClientName; + _clientUri = options.ClientUri; + _scopes = options.Scopes?.ToArray(); + _additionalAuthorizationParameters = options.AdditionalAuthorizationParameters; + + // Set up authorization server selection strategy + _authServerSelector = options.AuthServerSelector ?? DefaultAuthServerSelector; + + // Set up authorization URL handler (use default if not provided) + _authorizationRedirectDelegate = options.AuthorizationRedirectDelegate ?? DefaultAuthorizationUrlHandler; + } + + /// + /// Default authorization server selection strategy that selects the first available server. + /// + /// List of available authorization servers. + /// The selected authorization server, or null if none are available. + private static Uri? DefaultAuthServerSelector(IReadOnlyList availableServers) => availableServers.FirstOrDefault(); + + /// + /// Default authorization URL handler that displays the URL to the user for manual input. + /// + /// The authorization URL to handle. + /// The redirect URI where the authorization code will be sent. + /// The cancellation token. + /// The authorization code entered by the user, or null if none was provided. + private static Task DefaultAuthorizationUrlHandler(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) + { + Console.WriteLine($"Please open the following URL in your browser to authorize the application:"); + Console.WriteLine($"{authorizationUrl}"); + Console.WriteLine(); + Console.Write("Enter the authorization code from the redirect URL: "); + var authorizationCode = Console.ReadLine(); + return Task.FromResult(authorizationCode); + } + + /// + /// Gets the collection of authentication schemes supported by this provider. + /// + /// + /// + /// This property returns all authentication schemes that this provider can handle, + /// allowing clients to select the appropriate scheme based on server capabilities. + /// + /// + /// Common values include "Bearer" for JWT tokens, "Basic" for username/password authentication, + /// and "Negotiate" for integrated Windows authentication. + /// + /// + public IEnumerable SupportedSchemes => [BearerScheme]; + + /// + /// Gets an authentication token or credential for authenticating requests to a resource + /// using the specified authentication scheme. + /// + /// The authentication scheme to use. + /// The URI of the resource requiring authentication. + /// A token to cancel the operation. + /// An authentication token string or null if no token could be obtained for the specified scheme. + public async Task GetCredentialAsync(string scheme, Uri resourceUri, CancellationToken cancellationToken = default) + { + ThrowIfNotBearerScheme(scheme); + + // Return the token if it's valid + if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + { + return _token.AccessToken; + } + + // Try to refresh the token if we have a refresh token + if (_token?.RefreshToken != null && _authServerMetadata != null) + { + var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newToken != null) + { + _token = newToken; + return _token.AccessToken; + } + } + + // No valid token - auth handler will trigger the 401 flow + return null; + } + + /// + /// Handles a 401 Unauthorized response from a resource. + /// + /// The authentication scheme that was used when the unauthorized response was received. + /// The HTTP response that contained the 401 status code. + /// A token to cancel the operation. + /// + /// A result object indicating if the provider was able to handle the unauthorized response, + /// and the authentication scheme that should be used for the next attempt, if any. + /// + public async Task HandleUnauthorizedResponseAsync( + string scheme, + HttpResponseMessage response, + CancellationToken cancellationToken = default) + { + // This provider only supports Bearer scheme + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException("This credential provider only supports the Bearer scheme"); + } + + await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false); + } + + /// + /// Performs OAuth authorization by selecting an appropriate authorization server and completing the OAuth flow. + /// + /// The 401 Unauthorized response containing authentication challenge. + /// Cancellation token. + /// Result indicating whether authorization was successful. + private async Task PerformOAuthAuthorizationAsync( + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Get available authorization servers from the 401 response + var protectedResourceMetadata = await ExtractProtectedResourceMetadata(response, _serverUrl, cancellationToken).ConfigureAwait(false); + var availableAuthorizationServers = protectedResourceMetadata.AuthorizationServers; + + if (availableAuthorizationServers.Count == 0) + { + ThrowFailedToHandleUnauthorizedResponse("No authorization servers found in authentication challenge"); + } + + // Select authorization server using configured strategy + var selectedAuthServer = _authServerSelector(availableAuthorizationServers); + + if (selectedAuthServer is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selection returned null. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + if (!availableAuthorizationServers.Contains(selectedAuthServer)) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selector returned a server not in the available list: {selectedAuthServer}. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + LogSelectedAuthorizationServer(selectedAuthServer, availableAuthorizationServers.Count); + + // Get auth server metadata + var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false); + + if (authServerMetadata is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Failed to retrieve metadata for authorization server: '{selectedAuthServer}'"); + } + + // Store auth server metadata for future refresh operations + _authServerMetadata = authServerMetadata; + + // Perform dynamic client registration if needed + if (string.IsNullOrEmpty(_clientId)) + { + await PerformDynamicClientRegistrationAsync(authServerMetadata, cancellationToken).ConfigureAwait(false); + } + + // Perform the OAuth flow + var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + + if (token is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); + } + + _token = token; + LogOAuthAuthorizationCompleted(); + } + + private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) + { + if (!authServerUri.OriginalString.EndsWith("/")) + { + authServerUri = new Uri(authServerUri.OriginalString + "/"); + } + + foreach (var path in new[] { ".well-known/openid-configuration", ".well-known/oauth-authorization-server" }) + { + try + { + var response = await _httpClient.GetAsync(new Uri(authServerUri, path), cancellationToken).ConfigureAwait(false); + if (!response.IsSuccessStatusCode) + { + continue; + } + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var metadata = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata, cancellationToken).ConfigureAwait(false); + + if (metadata != null) + { + metadata.ResponseTypesSupported ??= ["code"]; + metadata.GrantTypesSupported ??= ["authorization_code", "refresh_token"]; + metadata.TokenEndpointAuthMethodsSupported ??= ["client_secret_post"]; + metadata.CodeChallengeMethodsSupported ??= ["S256"]; + + return metadata; + } + } + catch (Exception ex) + { + LogErrorFetchingAuthServerMetadata(ex, path); + } + } + + return null; + } + + private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "refresh_token", + ["refresh_token"] = refreshToken, + ["client_id"] = GetClientIdOrThrow(), + ["client_secret"] = _clientSecret ?? string.Empty, + ["resource"] = resourceUri.ToString(), + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task InitiateAuthorizationCodeFlowAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + var codeVerifier = GenerateCodeVerifier(); + var codeChallenge = GenerateCodeChallenge(codeVerifier); + + var authUrl = BuildAuthorizationUrl(protectedResourceMetadata, authServerMetadata, codeChallenge); + var authCode = await _authorizationRedirectDelegate(authUrl, _redirectUri, cancellationToken).ConfigureAwait(false); + + if (string.IsNullOrEmpty(authCode)) + { + return null; + } + + return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); + } + + private Uri BuildAuthorizationUrl( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + string codeChallenge) + { + if (authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttp && + authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException("AuthorizationEndpoint must use HTTP or HTTPS.", nameof(authServerMetadata)); + } + + var queryParamsDictionary = new Dictionary + { + ["client_id"] = GetClientIdOrThrow(), + ["redirect_uri"] = _redirectUri.ToString(), + ["response_type"] = "code", + ["code_challenge"] = codeChallenge, + ["code_challenge_method"] = "S256", + ["resource"] = protectedResourceMetadata.Resource.ToString(), + }; + + var scopesSupported = protectedResourceMetadata.ScopesSupported; + if (_scopes is not null || scopesSupported.Count > 0) + { + queryParamsDictionary["scope"] = string.Join(" ", _scopes ?? scopesSupported.ToArray()); + } + + // Add extra parameters if provided. Load into a dictionary before constructing to avoid overwiting values. + foreach (var kvp in _additionalAuthorizationParameters) + { + queryParamsDictionary.Add(kvp.Key, kvp.Value); + } + + var queryParams = HttpUtility.ParseQueryString(string.Empty); + foreach (var kvp in queryParamsDictionary) + { + queryParams[kvp.Key] = kvp.Value; + } + + var uriBuilder = new UriBuilder(authServerMetadata.AuthorizationEndpoint) + { + Query = queryParams.ToString() + }; + + return uriBuilder.Uri; + } + + private async Task ExchangeCodeForTokenAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + string authorizationCode, + string codeVerifier, + CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "authorization_code", + ["code"] = authorizationCode, + ["redirect_uri"] = _redirectUri.ToString(), + ["client_id"] = GetClientIdOrThrow(), + ["code_verifier"] = codeVerifier, + ["client_secret"] = _clientSecret ?? string.Empty, + ["resource"] = protectedResourceMetadata.Resource.ToString(), + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + + if (tokenResponse is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); + } + + tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; + return tokenResponse; + } + + /// + /// Fetches the protected resource metadata from the provided URL. + /// + /// The URL to fetch the metadata from. + /// A token to cancel the operation. + /// The fetched ProtectedResourceMetadata, or null if it couldn't be fetched. + private async Task FetchProtectedResourceMetadataAsync(Uri metadataUrl, CancellationToken cancellationToken = default) + { + using var httpResponse = await _httpClient.GetAsync(metadataUrl, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + return await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.ProtectedResourceMetadata, cancellationToken).ConfigureAwait(false); + } + + /// + /// Performs dynamic client registration with the authorization server. + /// + /// The authorization server metadata. + /// Cancellation token. + /// A task representing the asynchronous operation. + private async Task PerformDynamicClientRegistrationAsync( + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + if (authServerMetadata.RegistrationEndpoint is null) + { + ThrowFailedToHandleUnauthorizedResponse("Authorization server does not support dynamic client registration"); + } + + LogPerformingDynamicClientRegistration(authServerMetadata.RegistrationEndpoint); + + var registrationRequest = new DynamicClientRegistrationRequest + { + RedirectUris = [_redirectUri.ToString()], + GrantTypes = ["authorization_code", "refresh_token"], + ResponseTypes = ["code"], + TokenEndpointAuthMethod = "client_secret_post", + ClientName = _clientName, + ClientUri = _clientUri?.ToString(), + Scope = _scopes is not null ? string.Join(" ", _scopes) : null + }; + + var requestJson = JsonSerializer.Serialize(registrationRequest, McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest); + using var requestContent = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.RegistrationEndpoint) + { + Content = requestContent + }; + + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + + if (!httpResponse.IsSuccessStatusCode) + { + var errorContent = await httpResponse.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + ThrowFailedToHandleUnauthorizedResponse($"Dynamic client registration failed with status {httpResponse.StatusCode}: {errorContent}"); + } + + using var responseStream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var registrationResponse = await JsonSerializer.DeserializeAsync( + responseStream, + McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationResponse, + cancellationToken).ConfigureAwait(false); + + if (registrationResponse is null) + { + ThrowFailedToHandleUnauthorizedResponse("Dynamic client registration returned empty response"); + } + + // Update client credentials + _clientId = registrationResponse.ClientId; + if (!string.IsNullOrEmpty(registrationResponse.ClientSecret)) + { + _clientSecret = registrationResponse.ClientSecret; + } + + LogDynamicClientRegistrationSuccessful(_clientId!); + } + + /// + /// Verifies that the resource URI in the metadata exactly matches the original request URL as required by the RFC. + /// Per RFC: The resource value must be identical to the URL that the client used to make the request to the resource server. + /// + /// The metadata to verify. + /// The original URL the client used to make the request to the resource server. + /// True if the resource URI exactly matches the original request URL, otherwise false. + private static bool VerifyResourceMatch(ProtectedResourceMetadata protectedResourceMetadata, Uri resourceLocation) + { + if (protectedResourceMetadata.Resource == null || resourceLocation == null) + { + return false; + } + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server. Compare entire URIs, not just the host. + + // Normalize the URIs to ensure consistent comparison + string normalizedMetadataResource = NormalizeUri(protectedResourceMetadata.Resource); + string normalizedResourceLocation = NormalizeUri(resourceLocation); + + return string.Equals(normalizedMetadataResource, normalizedResourceLocation, StringComparison.OrdinalIgnoreCase); + } + + /// + /// Normalizes a URI for consistent comparison. + /// + /// The URI to normalize. + /// A normalized string representation of the URI. + private static string NormalizeUri(Uri uri) + { + var builder = new UriBuilder(uri) + { + Port = -1 // Always remove port + }; + + if (builder.Path == "/") + { + builder.Path = string.Empty; + } + else if (builder.Path.Length > 1 && builder.Path.EndsWith("/")) + { + builder.Path = builder.Path.TrimEnd('/'); + } + + return builder.Uri.ToString(); + } + + /// + /// Responds to a 401 challenge by parsing the WWW-Authenticate header, fetching the resource metadata, + /// verifying the resource match, and returning the metadata if valid. + /// + /// The HTTP response containing the WWW-Authenticate header. + /// The server URL to verify against the resource metadata. + /// A token to cancel the operation. + /// The resource metadata if the resource matches the server, otherwise throws an exception. + /// Thrown when the response is not a 401, lacks a WWW-Authenticate header, + /// lacks a resource_metadata parameter, the metadata can't be fetched, or the resource URI doesn't match the server URL. + private async Task ExtractProtectedResourceMetadata(HttpResponseMessage response, Uri serverUrl, CancellationToken cancellationToken = default) + { + if (response.StatusCode != System.Net.HttpStatusCode.Unauthorized) + { + throw new InvalidOperationException($"Expected a 401 Unauthorized response, but received {(int)response.StatusCode} {response.StatusCode}"); + } + + // Extract the WWW-Authenticate header + if (response.Headers.WwwAuthenticate.Count == 0) + { + throw new McpException("The 401 response does not contain a WWW-Authenticate header"); + } + + // Look for the Bearer authentication scheme with resource_metadata parameter + string? resourceMetadataUrl = null; + foreach (var header in response.Headers.WwwAuthenticate) + { + if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) + { + resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata"); + if (resourceMetadataUrl != null) + { + break; + } + } + } + + if (resourceMetadataUrl == null) + { + throw new McpException("The WWW-Authenticate header does not contain a resource_metadata parameter"); + } + + Uri metadataUri = new(resourceMetadataUrl); + var metadata = await FetchProtectedResourceMetadataAsync(metadataUri, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"Failed to fetch resource metadata from {resourceMetadataUrl}"); + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server + LogValidatingResourceMetadata(serverUrl); + + if (!VerifyResourceMatch(metadata, serverUrl)) + { + throw new McpException($"Resource URI in metadata ({metadata.Resource}) does not match the expected URI ({serverUrl})"); + } + + return metadata; + } + + /// + /// Parses the WWW-Authenticate header parameters to extract a specific parameter. + /// + /// The parameter string from the WWW-Authenticate header. + /// The name of the parameter to extract. + /// The value of the parameter, or null if not found. + private static string? ParseWwwAuthenticateParameters(string parameters, string parameterName) + { + if (parameters.IndexOf(parameterName, StringComparison.OrdinalIgnoreCase) == -1) + { + return null; + } + + foreach (var part in parameters.Split(',')) + { + string trimmedPart = part.Trim(); + int equalsIndex = trimmedPart.IndexOf('='); + + if (equalsIndex <= 0) + { + continue; + } + + string key = trimmedPart.Substring(0, equalsIndex).Trim(); + + if (string.Equals(key, parameterName, StringComparison.OrdinalIgnoreCase)) + { + string value = trimmedPart.Substring(equalsIndex + 1).Trim(); + + if (value.StartsWith("\"") && value.EndsWith("\"")) + { + value = value.Substring(1, value.Length - 2); + } + + return value; + } + } + + return null; + } + + private static string GenerateCodeVerifier() + { + var bytes = new byte[32]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + return Convert.ToBase64String(bytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private static string GenerateCodeChallenge(string codeVerifier) + { + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + return Convert.ToBase64String(challengeBytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private string GetClientIdOrThrow() => _clientId ?? throw new InvalidOperationException("Client ID is not available. This may indicate an issue with dynamic client registration."); + + private static void ThrowIfNotBearerScheme(string scheme) + { + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"The '{scheme}' is not supported. This credential provider only supports the '{BearerScheme}' scheme"); + } + } + + [DoesNotReturn] + private static void ThrowFailedToHandleUnauthorizedResponse(string message) => + throw new McpException($"Failed to handle unauthorized response with 'Bearer' scheme. {message}"); + + [LoggerMessage(Level = LogLevel.Information, Message = "Selected authorization server: {Server} from {Count} available servers")] + partial void LogSelectedAuthorizationServer(Uri server, int count); + + [LoggerMessage(Level = LogLevel.Information, Message = "OAuth authorization completed successfully")] + partial void LogOAuthAuthorizationCompleted(); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error fetching auth server metadata from {Path}")] + partial void LogErrorFetchingAuthServerMetadata(Exception ex, string path); + + [LoggerMessage(Level = LogLevel.Information, Message = "Performing dynamic client registration with {RegistrationEndpoint}")] + partial void LogPerformingDynamicClientRegistration(Uri registrationEndpoint); + + [LoggerMessage(Level = LogLevel.Information, Message = "Dynamic client registration successful. Client ID: {ClientId}")] + partial void LogDynamicClientRegistrationSuccessful(string clientId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Validating resource metadata against original server URL: {ServerUrl}")] + partial void LogValidatingResourceMetadata(Uri serverUrl); +} diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs new file mode 100644 index 00000000..8496610e --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs @@ -0,0 +1,51 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a client registration request for OAuth 2.0 Dynamic Client Registration (RFC 7591). +/// +internal sealed class DynamicClientRegistrationRequest +{ + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required string[] RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs new file mode 100644 index 00000000..dcd51d68 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a client registration response for OAuth 2.0 Dynamic Client Registration (RFC 7591). +/// +internal sealed class DynamicClientRegistrationResponse +{ + /// + /// Gets or sets the client identifier. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + [JsonPropertyName("client_secret")] + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public string[]? RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; init; } + + /// + /// Gets or sets the client ID issued timestamp. + /// + [JsonPropertyName("client_id_issued_at")] + public long? ClientIdIssuedAt { get; init; } + + /// + /// Gets or sets the client secret expiration time. + /// + [JsonPropertyName("client_secret_expires_at")] + public long? ClientSecretExpiresAt { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs new file mode 100644 index 00000000..88b5bcc0 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs @@ -0,0 +1,145 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the resource metadata for OAuth authorization as defined in RFC 9396. +/// Defined by RFC 9728. +/// +public sealed class ProtectedResourceMetadata +{ + /// + /// The resource URI. + /// + /// + /// REQUIRED. The protected resource's resource identifier. + /// + [JsonPropertyName("resource")] + public required Uri Resource { get; set; } + + /// + /// The list of authorization server URIs. + /// + /// + /// OPTIONAL. JSON array containing a list of OAuth authorization server issuer identifiers + /// for authorization servers that can be used with this protected resource. + /// + [JsonPropertyName("authorization_servers")] + public List AuthorizationServers { get; set; } = []; + + /// + /// The supported bearer token methods. + /// + /// + /// OPTIONAL. JSON array containing a list of the supported methods of sending an OAuth 2.0 bearer token + /// to the protected resource. Defined values are ["header", "body", "query"]. + /// + [JsonPropertyName("bearer_methods_supported")] + public List BearerMethodsSupported { get; set; } = ["header"]; + + /// + /// The supported scopes. + /// + /// + /// RECOMMENDED. JSON array containing a list of scope values that are used in authorization + /// requests to request access to this protected resource. + /// + [JsonPropertyName("scopes_supported")] + public List ScopesSupported { get; set; } = []; + + /// + /// URL of the protected resource's JSON Web Key (JWK) Set document. + /// + /// + /// OPTIONAL. This contains public keys belonging to the protected resource, such as signing key(s) + /// that the resource server uses to sign resource responses. This URL MUST use the https scheme. + /// + [JsonPropertyName("jwks_uri")] + public Uri? JwksUri { get; set; } + + /// + /// List of the JWS signing algorithms supported by the protected resource for signing resource responses. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms (alg values) supported by the protected resource + /// for signing resource responses. No default algorithms are implied if this entry is omitted. The value none MUST NOT be used. + /// + [JsonPropertyName("resource_signing_alg_values_supported")] + public List? ResourceSigningAlgValuesSupported { get; set; } + + /// + /// Human-readable name of the protected resource intended for display to the end user. + /// + /// + /// RECOMMENDED. It is recommended that protected resource metadata include this field. + /// The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_name")] + public string? ResourceName { get; set; } + + /// + /// The URI to the resource documentation. + /// + /// + /// OPTIONAL. URL of a page containing human-readable information that developers might want or need to know + /// when using the protected resource. + /// + [JsonPropertyName("resource_documentation")] + public Uri? ResourceDocumentation { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's requirements. + /// + /// + /// OPTIONAL. Information about how the client can use the data provided by the protected resource. + /// + [JsonPropertyName("resource_policy_uri")] + public Uri? ResourcePolicyUri { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's terms of service. + /// + /// + /// OPTIONAL. The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_tos_uri")] + public Uri? ResourceTosUri { get; set; } + + /// + /// Boolean value indicating protected resource support for mutual-TLS client certificate-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("tls_client_certificate_bound_access_tokens")] + public bool? TlsClientCertificateBoundAccessTokens { get; set; } + + /// + /// List of the authorization details type values supported by the resource server. + /// + /// + /// OPTIONAL. JSON array containing a list of the authorization details type values supported by the resource server + /// when the authorization_details request parameter is used. + /// + [JsonPropertyName("authorization_details_types_supported")] + public List? AuthorizationDetailsTypesSupported { get; set; } + + /// + /// List of the JWS algorithm values supported by the resource server for validating DPoP proof JWTs. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS alg values supported by the resource server + /// for validating Demonstrating Proof of Possession (DPoP) proof JWTs. + /// + [JsonPropertyName("dpop_signing_alg_values_supported")] + public List? DpopSigningAlgValuesSupported { get; set; } + + /// + /// Boolean value specifying whether the protected resource always requires the use of DPoP-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("dpop_bound_access_tokens_required")] + public bool? DpopBoundAccessTokensRequired { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs new file mode 100644 index 00000000..dc55292b --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a token response from the OAuth server. +/// +internal sealed class TokenContainer +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + [JsonPropertyName("expires_in")] + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + [JsonPropertyName("ext_expires_in")] + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + [JsonPropertyName("token_type")] + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + [JsonIgnore] + public DateTimeOffset ObtainedAt { get; set; } + + /// + /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. + /// + [JsonIgnore] + public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); +} diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 39ae7e81..06f2e0bf 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -13,13 +13,13 @@ namespace ModelContextProtocol.Client; internal sealed partial class AutoDetectingClientSessionTransport : ITransport { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; - public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/McpHttpClient.cs b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs new file mode 100644 index 00000000..77ca78fb --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs @@ -0,0 +1,42 @@ +using ModelContextProtocol.Protocol; +using System.Diagnostics; + +#if NET +using System.Net.Http.Json; +#else +using System.Text; +using System.Text.Json; +#endif + +namespace ModelContextProtocol.Client; + +internal class McpHttpClient(HttpClient httpClient) +{ + internal virtual async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + Debug.Assert(request.Content is null, "The request body should only be supplied as a JsonRpcMessage"); + Debug.Assert(message is null || request.Method == HttpMethod.Post, "All messages should be sent in POST requests."); + + using var content = CreatePostBodyContent(message); + request.Content = content; + return await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + } + + private HttpContent? CreatePostBodyContent(JsonRpcMessage? message) + { + if (message is null) + { + return null; + } + +#if NET + return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); +#else + return new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ); +#endif + } +} diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 93559b7d..aba7bbcf 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Net.Http.Headers; using System.Net.ServerSentEvents; -using System.Text; using System.Text.Json; using System.Threading.Channels; @@ -15,7 +14,7 @@ namespace ModelContextProtocol.Client; /// internal sealed partial class SseClientSessionTransport : TransportBase { - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; @@ -31,7 +30,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase public SseClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -74,12 +73,6 @@ public override async Task SendMessageAsync( if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); - string messageId = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) @@ -87,12 +80,9 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) - { - Content = content, - }; + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -154,11 +144,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs index 3fba349b..b31c3479 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; @@ -15,9 +16,10 @@ namespace ModelContextProtocol.Client; public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly McpHttpClient _mcpHttpClient; private readonly ILoggerFactory? _loggerFactory; - private readonly bool _ownsHttpClient; + + private readonly HttpClient? _ownedHttpClient; /// /// Initializes a new instance of the class. @@ -45,10 +47,23 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient Throw.IfNull(httpClient); _options = transportOptions; - _httpClient = httpClient; _loggerFactory = loggerFactory; - _ownsHttpClient = ownsHttpClient; Name = transportOptions.Name ?? transportOptions.Endpoint.ToString(); + + if (transportOptions.OAuth is { } clientOAuthOptions) + { + var oAuthProvider = new ClientOAuthProvider(_options.Endpoint, clientOAuthOptions, httpClient, loggerFactory); + _mcpHttpClient = new AuthenticatingMcpHttpClient(httpClient, oAuthProvider); + } + else + { + _mcpHttpClient = new(httpClient); + } + + if (ownsHttpClient) + { + _ownedHttpClient = httpClient; + } } /// @@ -59,8 +74,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = { return _options.TransportMode switch { - HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name), - HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory), + HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory), + HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory), HttpTransportMode.Sse => await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false), _ => throw new InvalidOperationException($"Unsupported transport mode: {_options.TransportMode}"), }; @@ -68,7 +83,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = private async Task ConnectSseTransportAsync(CancellationToken cancellationToken) { - var sessionTransport = new SseClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory); + var sessionTransport = new SseClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory); try { @@ -85,11 +100,7 @@ private async Task ConnectSseTransportAsync(CancellationToken cancel /// public ValueTask DisposeAsync() { - if (_ownsHttpClient) - { - _httpClient.Dispose(); - } - + _ownedHttpClient?.Dispose(); return default; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs index 9b4af6db..4097844c 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs @@ -1,3 +1,5 @@ +using ModelContextProtocol.Authentication; + namespace ModelContextProtocol.Client; /// @@ -46,7 +48,7 @@ public required Uri Endpoint public HttpTransportMode TransportMode { get; set; } = HttpTransportMode.AutoDetect; /// - /// Gets a transport identifier used for logging purposes. + /// Gets or sets a transport identifier used for logging purposes. /// public string? Name { get; set; } @@ -70,4 +72,9 @@ public required Uri Endpoint /// Use this property to specify custom HTTP headers that should be sent with each request to the server. /// public IDictionary? AdditionalHeaders { get; set; } + + /// + /// Gets sor sets the authorization provider to use for authentication. + /// + public ClientOAuthOptions? OAuth { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 78217868..190bec0b 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -6,12 +6,6 @@ using ModelContextProtocol.Protocol; using System.Threading.Channels; -#if NET -using System.Net.Http.Json; -#else -using System.Text; -#endif - namespace ModelContextProtocol.Client; /// @@ -22,7 +16,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; @@ -36,7 +30,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa public StreamableHttpClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -69,19 +63,8 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; -#if NET - using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); -#else - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); -#endif - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) { - Content = content, Headers = { Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, @@ -90,7 +73,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); // We'll let the caller decide whether to throw or fall back given an unsuccessful response. if (!response.IsSuccessStatusCode) @@ -192,7 +175,7 @@ private async Task ReceiveUnsolicitedMessagesAsync() request.Headers.Accept.Add(s_textEventStreamMediaType); CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -261,7 +244,7 @@ private async Task SendDeleteRequest() try { // Do not validate we get a successful status code, because server support for the DELETE request is optional - (await _httpClient.SendAsync(deleteRequest, CancellationToken.None).ConfigureAwait(false)).Dispose(); + (await _httpClient.SendAsync(deleteRequest, message: null, CancellationToken.None).ConfigureAwait(false)).Dispose(); } catch (Exception ex) { diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 696e0ec0..21e2468d 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; using System.Text.Json; @@ -154,6 +155,12 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] + [JsonSerializable(typeof(ProtectedResourceMetadata))] + [JsonSerializable(typeof(AuthorizationServerMetadata))] + [JsonSerializable(typeof(TokenContainer))] + [JsonSerializable(typeof(DynamicClientRegistrationRequest))] + [JsonSerializable(typeof(DynamicClientRegistrationResponse))] + // Primitive types for use in consuming AIFunctions [JsonSerializable(typeof(string))] [JsonSerializable(typeof(byte))] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs new file mode 100644 index 00000000..2252b1b7 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs @@ -0,0 +1,407 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; +using System.Net; +using System.Reflection; +using Xunit.Sdk; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class AuthTests : KestrelInMemoryTest, IAsyncDisposable +{ + private const string McpServerUrl = "http://localhost:5000"; + private const string OAuthServerUrl = "https://localhost:7029"; + + private readonly CancellationTokenSource _testCts = new(); + private readonly TestOAuthServer.Program _testOAuthServer; + private readonly Task _testOAuthRunTask; + + private Uri? _lastAuthorizationUri; + + public AuthTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + // Let the HandleAuthorizationUrlAsync take a look at the Location header + SocketsHttpHandler.AllowAutoRedirect = false; + // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. + // The easiest workaround is to disable cert validation for testing purposes. + SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; + + _testOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); + _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); + + Builder.Services.AddAuthentication(options => + { + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + }) + .AddJwtBearer(options => + { + options.Backchannel = HttpClient; + options.Authority = OAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = McpServerUrl, + ValidIssuer = OAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + }) + .AddMcp(options => + { + options.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"] + }; + }); + + Builder.Services.AddAuthorization(); + } + + public async ValueTask DisposeAsync() + { + _testCts.Cancel(); + try + { + await _testOAuthRunTask; + } + catch (OperationCanceledException) + { + } + finally + { + _testCts.Dispose(); + } + } + + [Fact] + public async Task CanAuthenticate() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task CannotAuthenticate_WithoutOAuthConfiguration() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + }, HttpClient, LoggerFactory); + + var httpEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal(HttpStatusCode.Unauthorized, httpEx.StatusCode); + } + + [Fact] + public async Task CannotAuthenticate_WithUnregisteredClient() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "unregistered-demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + // The EqualException is thrown by HandleAuthorizationUrlAsync when the /authorize request gets a 400 + var equalEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task CanAuthenticate_WithDynamicClientRegistration() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions() + { + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + ClientName = "Test MCP Client", + ClientUri = new Uri("https://example.com"), + Scopes = ["mcp:tools"] + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task CanAuthenticate_WithTokenRefresh() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "test-refresh-client", + ClientSecret = "test-refresh-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + // The test-refresh-client should get an expired token first, + // then automatically refresh it to get a working token + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(_testOAuthServer.HasIssuedRefreshToken); + } + + [Fact] + public async Task CanAuthenticate_WithExtraParams() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AdditionalAuthorizationParameters = new Dictionary + { + ["custom_param"] = "custom_value", + } + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(_lastAuthorizationUri?.Query); + Assert.Contains("custom_param=custom_value", _lastAuthorizationUri?.Query); + } + + [Fact] + public async Task CannotOverrideExistingParameters_WithExtraParams() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AdditionalAuthorizationParameters = new Dictionary + { + ["redirect_uri"] = "custom_value", + } + }, + }, HttpClient, LoggerFactory); + + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public void CloneResourceMetadataClonesAllProperties() + { + var propertyNames = typeof(ProtectedResourceMetadata).GetProperties().Select(property => property.Name).ToList(); + + // Set metadata properties to non-default values to verify they're copied. + var metadata = new ProtectedResourceMetadata + { + Resource = new Uri("https://example.com/resource"), + AuthorizationServers = [new Uri("https://auth1.example.com"), new Uri("https://auth2.example.com")], + BearerMethodsSupported = ["header", "body", "query"], + ScopesSupported = ["read", "write", "admin"], + JwksUri = new Uri("https://example.com/.well-known/jwks.json"), + ResourceSigningAlgValuesSupported = ["RS256", "ES256"], + ResourceName = "Test Resource", + ResourceDocumentation = new Uri("https://docs.example.com"), + ResourcePolicyUri = new Uri("https://example.com/policy"), + ResourceTosUri = new Uri("https://example.com/terms"), + TlsClientCertificateBoundAccessTokens = true, + AuthorizationDetailsTypesSupported = ["payment_initiation", "account_information"], + DpopSigningAlgValuesSupported = ["RS256", "PS256"], + DpopBoundAccessTokensRequired = true + }; + + // Use reflection to call the internal CloneResourceMetadata method + var handlerType = typeof(McpAuthenticationHandler); + var cloneMethod = handlerType.GetMethod("CloneResourceMetadata", BindingFlags.Static | BindingFlags.NonPublic); + Assert.NotNull(cloneMethod); + + var clonedMetadata = (ProtectedResourceMetadata?)cloneMethod.Invoke(null, [metadata]); + Assert.NotNull(clonedMetadata); + + // Ensure the cloned metadata is not the same instance + Assert.NotSame(metadata, clonedMetadata); + + // Verify Resource property + Assert.Equal(metadata.Resource, clonedMetadata.Resource); + Assert.True(propertyNames.Remove(nameof(metadata.Resource))); + + // Verify AuthorizationServers list is cloned and contains the same values + Assert.NotSame(metadata.AuthorizationServers, clonedMetadata.AuthorizationServers); + Assert.Equal(metadata.AuthorizationServers, clonedMetadata.AuthorizationServers); + Assert.True(propertyNames.Remove(nameof(metadata.AuthorizationServers))); + + // Verify BearerMethodsSupported list is cloned and contains the same values + Assert.NotSame(metadata.BearerMethodsSupported, clonedMetadata.BearerMethodsSupported); + Assert.Equal(metadata.BearerMethodsSupported, clonedMetadata.BearerMethodsSupported); + Assert.True(propertyNames.Remove(nameof(metadata.BearerMethodsSupported))); + + // Verify ScopesSupported list is cloned and contains the same values + Assert.NotSame(metadata.ScopesSupported, clonedMetadata.ScopesSupported); + Assert.Equal(metadata.ScopesSupported, clonedMetadata.ScopesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.ScopesSupported))); + + // Verify JwksUri property + Assert.Equal(metadata.JwksUri, clonedMetadata.JwksUri); + Assert.True(propertyNames.Remove(nameof(metadata.JwksUri))); + + // Verify ResourceSigningAlgValuesSupported list is cloned (nullable list) + Assert.NotSame(metadata.ResourceSigningAlgValuesSupported, clonedMetadata.ResourceSigningAlgValuesSupported); + Assert.Equal(metadata.ResourceSigningAlgValuesSupported, clonedMetadata.ResourceSigningAlgValuesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceSigningAlgValuesSupported))); + + // Verify ResourceName property + Assert.Equal(metadata.ResourceName, clonedMetadata.ResourceName); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceName))); + + // Verify ResourceDocumentation property + Assert.Equal(metadata.ResourceDocumentation, clonedMetadata.ResourceDocumentation); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceDocumentation))); + + // Verify ResourcePolicyUri property + Assert.Equal(metadata.ResourcePolicyUri, clonedMetadata.ResourcePolicyUri); + Assert.True(propertyNames.Remove(nameof(metadata.ResourcePolicyUri))); + + // Verify ResourceTosUri property + Assert.Equal(metadata.ResourceTosUri, clonedMetadata.ResourceTosUri); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceTosUri))); + + // Verify TlsClientCertificateBoundAccessTokens property + Assert.Equal(metadata.TlsClientCertificateBoundAccessTokens, clonedMetadata.TlsClientCertificateBoundAccessTokens); + Assert.True(propertyNames.Remove(nameof(metadata.TlsClientCertificateBoundAccessTokens))); + + // Verify AuthorizationDetailsTypesSupported list is cloned (nullable list) + Assert.NotSame(metadata.AuthorizationDetailsTypesSupported, clonedMetadata.AuthorizationDetailsTypesSupported); + Assert.Equal(metadata.AuthorizationDetailsTypesSupported, clonedMetadata.AuthorizationDetailsTypesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.AuthorizationDetailsTypesSupported))); + + // Verify DpopSigningAlgValuesSupported list is cloned (nullable list) + Assert.NotSame(metadata.DpopSigningAlgValuesSupported, clonedMetadata.DpopSigningAlgValuesSupported); + Assert.Equal(metadata.DpopSigningAlgValuesSupported, clonedMetadata.DpopSigningAlgValuesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.DpopSigningAlgValuesSupported))); + + // Verify DpopBoundAccessTokensRequired property + Assert.Equal(metadata.DpopBoundAccessTokensRequired, clonedMetadata.DpopBoundAccessTokensRequired); + Assert.True(propertyNames.Remove(nameof(metadata.DpopBoundAccessTokensRequired))); + + // Ensure we've checked every property. When new properties get added, we'll have to update this test along with the CloneResourceMetadata implementation. + Assert.Empty(propertyNames); + } + + private async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) + { + _lastAuthorizationUri = authorizationUri; + + var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + var location = redirectResponse.Headers.Location; + + if (location is not null && !string.IsNullOrEmpty(location.Query)) + { + var queryParams = QueryHelpers.ParseQuery(location.Query); + return queryParams["code"]; + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index eb623686..9b3c91b9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -118,7 +118,7 @@ public async Task CallTool_EchoSessionId_ReturnsTheSameSessionId() Assert.Null(result1.IsError); Assert.Null(result2.IsError); Assert.Null(result3.IsError); - + var textContent1 = Assert.Single(result1.Content.OfType()); var textContent2 = Assert.Single(result2.Content.OfType()); var textContent3 = Assert.Single(result3.Content.OfType()); @@ -267,10 +267,10 @@ public async Task Sampling_Sse_TestServer() // Call the server's sampleLLM tool which should trigger our sampling handler var result = await client.CallToolAsync("sampleLLM", new Dictionary - { - ["prompt"] = "Test prompt", - ["maxTokens"] = 100 - }, + { + ["prompt"] = "Test prompt", + ["maxTokens"] = 100 + }, cancellationToken: TestContext.Current.CancellationToken); // assert @@ -288,7 +288,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently() for (int i = 0; i < 4; i++) { var client = (i % 2 == 0) ? client1 : client2; - var result = await client.CallToolAsync( + var result = await client.CallToolAsync( "echo", new Dictionary { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index 0a423850..f3162130 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -20,7 +20,7 @@ public async Task Allows_Customizing_Route(string pattern) await app.StartAsync(TestContext.Current.CancellationToken); - using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var response = await HttpClient.GetAsync($"http://localhost:5000{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); response.EnsureSuccessStatusCode(); using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index f14cc10a..cb1f86db 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -56,7 +56,7 @@ public async Task StreamableHttpMode_Works_WithRootEndpoint() await using var mcpClient = await ConnectAsync("/", new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), TransportMode = HttpTransportMode.AutoDetect }); @@ -82,7 +82,7 @@ public async Task AutoDetectMode_Works_WithRootEndpoint() await using var mcpClient = await ConnectAsync("/", new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), TransportMode = HttpTransportMode.AutoDetect }); @@ -110,7 +110,7 @@ public async Task AutoDetectMode_Works_WithSseEndpoint() await using var mcpClient = await ConnectAsync("/sse", new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), TransportMode = HttpTransportMode.AutoDetect }); @@ -138,7 +138,7 @@ public async Task SseMode_Works_WithSseEndpoint() await using var mcpClient = await ConnectAsync(transportOptions: new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), TransportMode = HttpTransportMode.Sse }); @@ -171,14 +171,16 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectAsync(clientOptions: new() + await using (var mcpClient = await ConnectAsync(clientOptions: new() { ProtocolVersion = "2025-03-26", - }); - await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + })) + { + await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + } - // The header should be included in the GET request, the initialized notification, and the tools/list call. - Assert.Equal(3, protocolVersionHeaderValues.Count); + // The header should be included in the GET request, the initialized notification, the tools/list call, and the delete request. + Assert.NotEmpty(protocolVersionHeaderValues); Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v)); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 6635a8b9..cf54e777 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -33,7 +33,7 @@ protected async Task ConnectAsync( await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions { - Endpoint = new Uri($"http://localhost{path}"), + Endpoint = new Uri($"http://localhost:5000{path}"), TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj index bbcac5f5..34801c73 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj @@ -34,6 +34,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + @@ -56,6 +57,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 756f9e4e..8191f609 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -17,7 +17,7 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr { private readonly SseClientTransportOptions DefaultTransportOptions = new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; @@ -197,7 +197,7 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() var sseOptions = new SseClientTransportOptions { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", AdditionalHeaders = new Dictionary { @@ -224,7 +224,7 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() var sseOptions = new SseClientTransportOptions { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", AdditionalHeaders = new Dictionary() { @@ -251,7 +251,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b response.Headers.ContentType = "text/event-stream"; - await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message"); + await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost:5000/message"); session = transport; try diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index c01333d0..2aa675c8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; +using System.Net; namespace ModelContextProtocol.AspNetCore.Tests; @@ -19,7 +20,7 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable private SseClientTransportOptions DefaultTransportOptions { get; set; } = new() { - Endpoint = new("http://localhost/"), + Endpoint = new("http://localhost:5000/"), }; public SseServerIntegrationTestFixture() @@ -28,14 +29,14 @@ public SseServerIntegrationTestFixture() { ConnectCallback = (context, token) => { - var connection = _inMemoryTransport.CreateConnection(); + var connection = _inMemoryTransport.CreateConnection(new DnsEndPoint("localhost", 5000)); return new(connection.ClientStream); }, }; HttpClient = new HttpClient(socketsHttpHandler) { - BaseAddress = new("http://localhost/"), + BaseAddress = new("http://localhost:5000/"), }; _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index eb89912a..2d4a7868 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -10,7 +10,7 @@ public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, { protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index a9e2e5f5..d16e510c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -7,7 +7,7 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix { protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/stateless"), + Endpoint = new("http://localhost:5000/stateless"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index 1e21eb45..b50a43ed 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -16,7 +16,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem private readonly SseClientTransportOptions DefaultTransportOptions = new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 98de9b13..7ce3516e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -114,7 +114,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); @@ -134,7 +134,7 @@ public async Task CanCallToolConcurrently() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); @@ -160,7 +160,7 @@ public async Task SendsDeleteRequestOnDispose() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 7c4366f1..3524c60a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -13,7 +13,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index a6ada604..4ae743f7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -8,8 +8,6 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils; public class KestrelInMemoryTest : LoggedTest { - private readonly KestrelInMemoryTransport _inMemoryTransport = new(); - public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { @@ -17,17 +15,16 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) // or a helper that does the same every test. But clear out the existing socket transport to avoid potential port conflicts. Builder = WebApplication.CreateSlimBuilder(); Builder.Services.RemoveAll(); - Builder.Services.AddSingleton(_inMemoryTransport); + Builder.Services.AddSingleton(KestrelInMemoryTransport); Builder.Services.AddSingleton(XunitLoggerProvider); - HttpClient = new HttpClient(new SocketsHttpHandler + SocketsHttpHandler.ConnectCallback = (context, token) => { - ConnectCallback = (context, token) => - { - var connection = _inMemoryTransport.CreateConnection(); - return new(connection.ClientStream); - }, - }) + var connection = KestrelInMemoryTransport.CreateConnection(context.DnsEndPoint); + return new(connection.ClientStream); + }; + + HttpClient = new HttpClient(SocketsHttpHandler) { BaseAddress = new Uri("http://localhost:5000/"), Timeout = TimeSpan.FromSeconds(10), @@ -38,6 +35,10 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) public HttpClient HttpClient { get; } + public SocketsHttpHandler SocketsHttpHandler { get; } = new(); + + public KestrelInMemoryTransport KestrelInMemoryTransport { get; } = new(); + public override void Dispose() { HttpClient.Dispose(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs index 399e9a83..71809ad6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs @@ -1,50 +1,59 @@ using Microsoft.AspNetCore.Connections; +using System.Collections.Concurrent; using System.Net; using System.Threading.Channels; namespace ModelContextProtocol.AspNetCore.Tests.Utils; -public sealed class KestrelInMemoryTransport : IConnectionListenerFactory, IConnectionListener +public sealed class KestrelInMemoryTransport : IConnectionListenerFactory { - private readonly Channel _acceptQueue = Channel.CreateUnbounded(); - private EndPoint? _endPoint; + // socket accept queues keyed by listen port. + private readonly ConcurrentDictionary> _acceptQueues = []; - public EndPoint EndPoint => _endPoint ?? throw new InvalidOperationException("EndPoint is not set. Call BindAsync first."); - - public KestrelInMemoryConnection CreateConnection() + public KestrelInMemoryConnection CreateConnection(EndPoint endpoint) { var connection = new KestrelInMemoryConnection(); - _acceptQueue.Writer.TryWrite(connection); + GetAcceptQueue(endpoint).Writer.TryWrite(connection); return connection; } - public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) + public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) => + new(new KestrelInMemoryListener(endpoint, GetAcceptQueue(endpoint))); + + private Channel GetAcceptQueue(EndPoint endpoint) => + _acceptQueues.GetOrAdd(GetEndpointPort(endpoint), _ => Channel.CreateUnbounded()); + + private static int GetEndpointPort(EndPoint endpoint) => + endpoint switch + { + DnsEndPoint dnsEndpoint => dnsEndpoint.Port, + IPEndPoint ipEndpoint => ipEndpoint.Port, + _ => throw new InvalidOperationException($"Unexpected endpoint type: '{endpoint.GetType()}'"), + }; + + private sealed class KestrelInMemoryListener(EndPoint endpoint, Channel acceptQueue) : IConnectionListener { - if (await _acceptQueue.Reader.WaitToReadAsync(cancellationToken)) + public EndPoint EndPoint => endpoint; + + public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) { - while (_acceptQueue.Reader.TryRead(out var item)) + if (await acceptQueue.Reader.WaitToReadAsync(cancellationToken)) { - return item; + while (acceptQueue.Reader.TryRead(out var item)) + { + return item; + } } - } - - return null; - } - public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) - { - _endPoint = endpoint; - return new ValueTask(this); - } + return null; + } - public ValueTask DisposeAsync() - { - return UnbindAsync(default); - } + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + { + acceptQueue.Writer.TryComplete(); + return default; + } - public ValueTask UnbindAsync(CancellationToken cancellationToken = default) - { - _acceptQueue.Writer.TryComplete(); - return default; + public ValueTask DisposeAsync() => UnbindAsync(CancellationToken.None); } } diff --git a/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs new file mode 100644 index 00000000..9d7142ce --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs @@ -0,0 +1,34 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents authorization code information for OAuth flow. +/// +internal sealed class AuthorizationCodeInfo +{ + /// + /// Gets or sets the client ID associated with this authorization code. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the redirect URI associated with this authorization code. + /// + public required string RedirectUri { get; init; } + + /// + /// Gets or sets the code challenge associated with this authorization code (for PKCE). + /// + public required string CodeChallenge { get; init; } + + /// + /// Gets or sets the list of scopes approved for this authorization code. + /// + public List Scope { get; init; } = []; + + /// + /// Gets or sets the optional resource URI this authorization code is for. + /// + public Uri? Resource { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..32472a88 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs @@ -0,0 +1,63 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the authorization server metadata for OAuth discovery. +/// +internal sealed class AuthorizationServerMetadata +{ + /// + /// Gets or sets the issuer URL. + /// + [JsonPropertyName("issuer")] + public required Uri Issuer { get; init; } + + /// + /// Gets or sets the authorization endpoint URL. + /// + [JsonPropertyName("authorization_endpoint")] + public required Uri AuthorizationEndpoint { get; init; } + + /// + /// Gets or sets the token endpoint URL. + /// + [JsonPropertyName("token_endpoint")] + public required Uri TokenEndpoint { get; init; } + + /// + /// Gets the introspection endpoint URL. + /// + [JsonPropertyName("introspection_endpoint")] + public Uri? IntrospectionEndpoint => new Uri($"{Issuer}/introspect"); + + /// + /// Gets or sets the response types supported by this server. + /// + [JsonPropertyName("response_types_supported")] + public required List ResponseTypesSupported { get; init; } + + /// + /// Gets or sets the grant types supported by this server. + /// + [JsonPropertyName("grant_types_supported")] + public required List GrantTypesSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication methods supported by this server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public required List TokenEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the code challenge methods supported by this server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public required List CodeChallengeMethodsSupported { get; init; } + + /// + /// Gets or sets the scopes supported by this server. + /// + [JsonPropertyName("scopes_supported")] + public List? ScopesSupported { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs new file mode 100644 index 00000000..7983476f --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs @@ -0,0 +1,24 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents client information for OAuth flow. +/// +internal sealed class ClientInfo +{ + /// + /// Gets or sets the client ID. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + public required string ClientSecret { get; init; } + + /// + /// Gets or sets the list of redirect URIs allowed for this client. + /// + public List RedirectUris { get; init; } = []; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs new file mode 100644 index 00000000..50592bbe --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs @@ -0,0 +1,93 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a client registration request as defined in RFC 7591. +/// +internal sealed class ClientRegistrationRequest +{ + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required List RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public List? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public List? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the URL for the client's logo. + /// + [JsonPropertyName("logo_uri")] + public string? LogoUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } + + /// + /// Gets or sets the contacts for the client. + /// + [JsonPropertyName("contacts")] + public List? Contacts { get; init; } + + /// + /// Gets or sets the URL for the client's terms of service. + /// + [JsonPropertyName("tos_uri")] + public string? TosUri { get; init; } + + /// + /// Gets or sets the URL for the client's privacy policy. + /// + [JsonPropertyName("policy_uri")] + public string? PolicyUri { get; init; } + + /// + /// Gets or sets the JWK Set URL for the client. + /// + [JsonPropertyName("jwks_uri")] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the software identifier for the client. + /// + [JsonPropertyName("software_id")] + public string? SoftwareId { get; init; } + + /// + /// Gets or sets the software version for the client. + /// + [JsonPropertyName("software_version")] + public string? SoftwareVersion { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs new file mode 100644 index 00000000..3833c490 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs @@ -0,0 +1,147 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a client registration response as defined in RFC 7591. +/// +internal sealed class ClientRegistrationResponse +{ + /// + /// Gets or sets the client identifier. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + [JsonPropertyName("client_secret")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required List RedirectUris { get; init; } + + /// + /// Gets or sets the registration access token. + /// + [JsonPropertyName("registration_access_token")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationAccessToken { get; init; } + + /// + /// Gets or sets the registration client URI. + /// + [JsonPropertyName("registration_client_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationClientUri { get; init; } + + /// + /// Gets or sets the client ID issued timestamp. + /// + [JsonPropertyName("client_id_issued_at")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public long? ClientIdIssuedAt { get; init; } + + /// + /// Gets or sets the client secret expiration time. + /// + [JsonPropertyName("client_secret_expires_at")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public long? ClientSecretExpiresAt { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the URL for the client's logo. + /// + [JsonPropertyName("logo_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? LogoUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Scope { get; init; } + + /// + /// Gets or sets the contacts for the client. + /// + [JsonPropertyName("contacts")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? Contacts { get; init; } + + /// + /// Gets or sets the URL for the client's terms of service. + /// + [JsonPropertyName("tos_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? TosUri { get; init; } + + /// + /// Gets or sets the URL for the client's privacy policy. + /// + [JsonPropertyName("policy_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? PolicyUri { get; init; } + + /// + /// Gets or sets the JWK Set URL for the client. + /// + [JsonPropertyName("jwks_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the software identifier for the client. + /// + [JsonPropertyName("software_id")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? SoftwareId { get; init; } + + /// + /// Gets or sets the software version for the client. + /// + [JsonPropertyName("software_version")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? SoftwareVersion { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs new file mode 100644 index 00000000..562efa52 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs @@ -0,0 +1,45 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a JSON Web Key. +/// +internal sealed class JsonWebKey +{ + /// + /// Gets or sets the key type (e.g., "RSA"). + /// + [JsonPropertyName("kty")] + public required string KeyType { get; init; } + + /// + /// Gets or sets the intended use of the key (e.g., "sig" for signature). + /// + [JsonPropertyName("use")] + public required string Use { get; init; } + + /// + /// Gets or sets the key ID. + /// + [JsonPropertyName("kid")] + public required string KeyId { get; init; } + + /// + /// Gets or sets the algorithm intended for use with the key (e.g., "RS256"). + /// + [JsonPropertyName("alg")] + public required string Algorithm { get; init; } + + /// + /// Gets or sets the RSA exponent (base64url-encoded). + /// + [JsonPropertyName("e")] + public required string Exponent { get; init; } + + /// + /// Gets or sets the RSA modulus (base64url-encoded). + /// + [JsonPropertyName("n")] + public required string Modulus { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs new file mode 100644 index 00000000..223407b7 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a JSON Web Key Set (JWKS) response. +/// +internal sealed class JsonWebKeySet +{ + /// + /// Gets or sets the array of JSON Web Keys. + /// + [JsonPropertyName("keys")] + public required JsonWebKey[] Keys { get; init; } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj b/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj new file mode 100644 index 00000000..51092f56 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj @@ -0,0 +1,9 @@ + + + + net9.0;net8.0 + enable + enable + + + diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs new file mode 100644 index 00000000..c9174fa3 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs @@ -0,0 +1,21 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents an OAuth error response. +/// +internal sealed class OAuthErrorResponse +{ + /// + /// Gets or sets the error code. + /// + [JsonPropertyName("error")] + public required string Error { get; init; } + + /// + /// Gets or sets the error description. + /// + [JsonPropertyName("error_description")] + public required string ErrorDescription { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs new file mode 100644 index 00000000..6caaaea0 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +[JsonSerializable(typeof(OAuthServerMetadata))] +[JsonSerializable(typeof(AuthorizationServerMetadata))] +[JsonSerializable(typeof(TokenResponse))] +[JsonSerializable(typeof(JsonWebKeySet))] +[JsonSerializable(typeof(JsonWebKey))] +[JsonSerializable(typeof(TokenIntrospectionResponse))] +[JsonSerializable(typeof(OAuthErrorResponse))] +[JsonSerializable(typeof(ClientRegistrationRequest))] +[JsonSerializable(typeof(ClientRegistrationResponse))] +[JsonSerializable(typeof(Dictionary))] +internal sealed partial class OAuthJsonContext : JsonSerializerContext; diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs new file mode 100644 index 00000000..646a3992 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs @@ -0,0 +1,174 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the OAuth 2.0 Authorization Server Metadata as defined in RFC 8414. +/// +internal sealed class OAuthServerMetadata +{ + /// + /// Gets or sets the issuer URL. + /// REQUIRED. The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components. + /// + [JsonPropertyName("issuer")] + public required string Issuer { get; init; } + + /// + /// Gets or sets the authorization endpoint URL. + /// URL of the authorization server's authorization endpoint. This is REQUIRED unless no grant types are supported that use the authorization endpoint. + /// + [JsonPropertyName("authorization_endpoint")] + public required string AuthorizationEndpoint { get; init; } + + /// + /// Gets or sets the token endpoint URL. + /// URL of the authorization server's token endpoint. This is REQUIRED unless only the implicit grant type is supported. + /// + [JsonPropertyName("token_endpoint")] + public required string TokenEndpoint { get; init; } + + /// + /// Gets or sets the JWKS URI. + /// OPTIONAL. URL of the authorization server's JWK Set document. + /// + [JsonPropertyName("jwks_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the registration endpoint URL for dynamic client registration. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 Dynamic Client Registration endpoint. + /// + [JsonPropertyName("registration_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationEndpoint { get; init; } + + /// + /// Gets or sets the scopes supported by this server. + /// RECOMMENDED. JSON array containing a list of the OAuth 2.0 scope values that this server supports. + /// + [JsonPropertyName("scopes_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ScopesSupported { get; init; } + + /// + /// Gets or sets the response types supported by this server. + /// RECOMMENDED. JSON array containing a list of the OAuth 2.0 "response_type" values that this server supports. + /// + [JsonPropertyName("response_types_supported")] + public required List ResponseTypesSupported { get; init; } + + /// + /// Gets or sets the response modes supported by this server. + /// OPTIONAL. JSON array containing a list of the OAuth 2.0 "response_mode" values that this server supports. + /// + [JsonPropertyName("response_modes_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ResponseModesSupported { get; init; } + + /// + /// Gets or sets the grant types supported by this server. + /// OPTIONAL. JSON array containing a list of the OAuth 2.0 grant type values that this server supports. + /// + [JsonPropertyName("grant_types_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? GrantTypesSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? TokenEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? TokenEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the introspection endpoint URL. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? IntrospectionEndpoint { get; init; } + + /// + /// Gets or sets the introspection endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IntrospectionEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the introspection endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IntrospectionEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the revocation endpoint URL. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RevocationEndpoint { get; init; } + + /// + /// Gets or sets the revocation endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? RevocationEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the revocation endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? RevocationEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the code challenge methods supported by this server. + /// OPTIONAL. JSON array containing a list of Proof Key for Code Exchange (PKCE) code challenge methods supported by this server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? CodeChallengeMethodsSupported { get; init; } + + // OpenID Connect specific fields that are commonly included in OAuth metadata + /// + /// Gets or sets the subject types supported by this server. + /// REQUIRED for OpenID Connect. JSON array containing a list of the Subject Identifier types that this OP supports. + /// + [JsonPropertyName("subject_types_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? SubjectTypesSupported { get; init; } + + /// + /// Gets or sets the ID token signing algorithms supported by this server. + /// REQUIRED for OpenID Connect. JSON array containing a list of the JWS signing algorithms (alg values) supported by the OP for the ID Token. + /// + [JsonPropertyName("id_token_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IdTokenSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the claims supported by this server. + /// RECOMMENDED for OpenID Connect. JSON array containing a list of the Claim Names of the Claims that the OpenID Provider MAY be able to supply values for. + /// + [JsonPropertyName("claims_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ClaimsSupported { get; init; } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs new file mode 100644 index 00000000..3970394b --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -0,0 +1,634 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.WebUtilities; +using System.Collections.Concurrent; +using System.Globalization; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.TestOAuthServer; + +public sealed class Program +{ + private const int _port = 7029; + private static readonly string _url = $"https://localhost:{_port}"; + + // Port 5000 is used by tests and port 7071 is used by the ProtectedMCPServer sample + private static readonly string[] ValidResources = ["http://localhost:5000/", "http://localhost:7071/"]; + + private readonly ConcurrentDictionary _authCodes = new(); + private readonly ConcurrentDictionary _tokens = new(); + private readonly ConcurrentDictionary _clients = new(); + + private readonly RSA _rsa; + private readonly string _keyId; + + private readonly ILoggerProvider? _loggerProvider; + private readonly IConnectionListenerFactory? _kestrelTransport; + + /// + /// Initializes a new instance of the class with logging and transport parameters. + /// + /// Optional logger provider for logging. + /// Optional Kestrel transport for in-memory connections. + public Program(ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null) + { + _rsa = RSA.Create(2048); + _keyId = Guid.NewGuid().ToString(); + _loggerProvider = loggerProvider; + _kestrelTransport = kestrelTransport; + } + + // Track if we've already issued an already-expired token for the CanAuthenticate_WithTokenRefresh test which uses the test-refresh-client registration. + public bool HasIssuedExpiredToken { get; set; } + public bool HasIssuedRefreshToken { get; set; } + + /// + /// Entry point for the application. + /// + /// Command line arguments. + /// A task representing the asynchronous operation. + public static Task Main(string[] args) => new Program().RunServerAsync(args); + + /// + /// Runs the OAuth server with the specified parameters. + /// + /// Command line arguments. + /// Cancellation token to stop the server. + /// A task representing the asynchronous operation. + public async Task RunServerAsync(string[]? args = null, CancellationToken cancellationToken = default) + { + Console.WriteLine("Starting in-memory test-only OAuth Server..."); + + var builder = WebApplication.CreateEmptyBuilder(new() + { + Args = args, + }); + + if (_kestrelTransport is not null) + { + // Add passed-in transport before calling UseKestrel() to avoid the SocketsHttpHandler getting added. + builder.Services.AddSingleton(_kestrelTransport); + } + + builder.WebHost.UseKestrel(kestrelOptions => + { + kestrelOptions.ListenLocalhost(_port, listenOptions => + { + listenOptions.UseHttps(); + }); + }); + + builder.Services.AddRoutingCore(); + builder.Services.AddLogging(); + + builder.Services.ConfigureHttpJsonOptions(jsonOptions => + { + jsonOptions.SerializerOptions.TypeInfoResolverChain.Add(OAuthJsonContext.Default); + }); + + builder.Logging.AddConsole(); + if (_loggerProvider is not null) + { + builder.Logging.AddProvider(_loggerProvider); + } + + var app = builder.Build(); + + app.UseRouting(); + app.UseEndpoints(_ => { }); + + // Set up the demo client + var clientId = "demo-client"; + var clientSecret = "demo-secret"; + _clients[clientId] = new ClientInfo + { + ClientId = clientId, + ClientSecret = clientSecret, + RedirectUris = ["http://localhost:1179/callback"], + }; + + // When this client ID is used, the first token issued will already be expired to make + // testing the refresh flow easier. + _clients["test-refresh-client"] = new ClientInfo + { + ClientId = "test-refresh-client", + ClientSecret = "test-refresh-secret", + RedirectUris = ["http://localhost:1179/callback"], + }; + + // The MCP spec tells the client to use /.well-known/oauth-authorization-server but AddJwtBearer looks for + // /.well-known/openid-configuration by default. To make things easier, we support both with the same response + // which seems to be common. Ex. https://github.com/keycloak/keycloak/pull/29628 + // + // The requirements for these endpoints are at https://www.rfc-editor.org/rfc/rfc8414 and + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata respectively. + // They do differ, but it's close enough at least for our current testing to use the same response for both. + // See https://gist.github.com/localden/26d8bcf641703c08a5d8741aa9c3336c + string[] metadataEndpoints = ["/.well-known/oauth-authorization-server", "/.well-known/openid-configuration"]; + foreach (var metadataEndpoint in metadataEndpoints) + { + // OAuth 2.0 Authorization Server Metadata (RFC 8414) + app.MapGet(metadataEndpoint, () => + { + var metadata = new OAuthServerMetadata + { + Issuer = _url, + AuthorizationEndpoint = $"{_url}/authorize", + TokenEndpoint = $"{_url}/token", + JwksUri = $"{_url}/.well-known/jwks.json", + ResponseTypesSupported = ["code"], + SubjectTypesSupported = ["public"], + IdTokenSigningAlgValuesSupported = ["RS256"], + ScopesSupported = ["openid", "profile", "email", "mcp:tools"], + TokenEndpointAuthMethodsSupported = ["client_secret_post"], + ClaimsSupported = ["sub", "iss", "name", "email", "aud"], + CodeChallengeMethodsSupported = ["S256"], + GrantTypesSupported = ["authorization_code", "refresh_token"], + IntrospectionEndpoint = $"{_url}/introspect", + RegistrationEndpoint = $"{_url}/register" + }; + + return Results.Ok(metadata); + }); + } + + // JWKS endpoint to expose the public key + app.MapGet("/.well-known/jwks.json", () => + { + var parameters = _rsa.ExportParameters(false); + + // Convert parameters to base64url encoding + var e = WebEncoders.Base64UrlEncode(parameters.Exponent ?? Array.Empty()); + var n = WebEncoders.Base64UrlEncode(parameters.Modulus ?? Array.Empty()); + + var jwks = new JsonWebKeySet + { + Keys = [ + new JsonWebKey + { + KeyType = "RSA", + Use = "sig", + KeyId = _keyId, + Algorithm = "RS256", + Exponent = e, + Modulus = n + } + ] + }; + + return Results.Ok(jwks); + }); + + // Authorize endpoint + app.MapGet("/authorize", ( + [FromQuery] string client_id, + [FromQuery] string? redirect_uri, + [FromQuery] string response_type, + [FromQuery] string code_challenge, + [FromQuery] string code_challenge_method, + [FromQuery] string? scope, + [FromQuery] string? state, + [FromQuery] string? resource) => + { + // Validate client + if (!_clients.TryGetValue(client_id, out var client)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_client", + ErrorDescription = "Client not found" + }); + } + + // Validate redirect_uri + if (string.IsNullOrEmpty(redirect_uri)) + { + if (client.RedirectUris.Count == 1) + { + redirect_uri = client.RedirectUris[0]; + } + else + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "redirect_uri is required when client has multiple registered URIs" + }); + } + } + else if (!client.RedirectUris.Contains(redirect_uri)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Unregistered redirect_uri" + }); + } + + // Validate response_type + if (response_type != "code") + { + return Results.Redirect($"{redirect_uri}?error=unsupported_response_type&error_description=Only+code+response_type+is+supported&state={state}"); + } + + // Validate code challenge method + if (code_challenge_method != "S256") + { + return Results.Redirect($"{redirect_uri}?error=invalid_request&error_description=Only+S256+code_challenge_method+is+supported&state={state}"); + } + + // Validate resource in accordance with RFC 8707 + if (string.IsNullOrEmpty(resource) || !ValidResources.Contains(resource)) + { + return Results.Redirect($"{redirect_uri}?error=invalid_target&error_description=The+specified+resource+is+not+valid&state={state}"); + } + + // Generate a new authorization code + var code = GenerateRandomToken(); + var requestedScopes = scope?.Split(' ').ToList() ?? []; + + // Store code information for later verification + _authCodes[code] = new AuthorizationCodeInfo + { + ClientId = client_id, + RedirectUri = redirect_uri, + CodeChallenge = code_challenge, + Scope = requestedScopes, + Resource = !string.IsNullOrEmpty(resource) ? new Uri(resource) : null + }; + + // Redirect back to client with the code + var redirectUrl = $"{redirect_uri}?code={code}"; + if (!string.IsNullOrEmpty(state)) + { + redirectUrl += $"&state={Uri.EscapeDataString(state)}"; + } + + return Results.Redirect(redirectUrl); + }); + + // Token endpoint + app.MapPost("/token", async (HttpContext context) => + { + var form = await context.Request.ReadFormAsync(); + + // Authenticate client + var client = AuthenticateClient(context, form); + if (client == null) + { + context.Response.StatusCode = 401; + return Results.Problem( + statusCode: 401, + title: "Unauthorized", + detail: "Invalid client credentials", + type: "https://tools.ietf.org/html/rfc6749#section-5.2"); + } + + // Validate resource in accordance with RFC 8707 + var resource = form["resource"].ToString(); + if (string.IsNullOrEmpty(resource) || !ValidResources.Contains(resource)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_target", + ErrorDescription = "The specified resource is not valid." + }); + } + + var grant_type = form["grant_type"].ToString(); + if (grant_type == "authorization_code") + { + var code = form["code"].ToString(); + var code_verifier = form["code_verifier"].ToString(); + var redirect_uri = form["redirect_uri"].ToString(); + + // Validate code + if (string.IsNullOrEmpty(code) || !_authCodes.TryRemove(code, out var codeInfo)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Invalid authorization code" + }); + } + + // Validate client_id + if (codeInfo.ClientId != client.ClientId) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Authorization code was not issued to this client" + }); + } + + // Validate redirect_uri if provided + if (!string.IsNullOrEmpty(redirect_uri) && redirect_uri != codeInfo.RedirectUri) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Redirect URI mismatch" + }); + } + + // Validate code verifier + if (string.IsNullOrEmpty(code_verifier) || !VerifyCodeChallenge(code_verifier, codeInfo.CodeChallenge)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Code verifier does not match the challenge" + }); + } + + // Generate JWT token response + var response = GenerateJwtTokenResponse(client.ClientId, codeInfo.Scope, codeInfo.Resource); + return Results.Ok(response); + } + else if (grant_type == "refresh_token") + { + var refresh_token = form["refresh_token"].ToString(); + + // Validate refresh token + if (string.IsNullOrEmpty(refresh_token) || !_tokens.TryGetValue(refresh_token, out var tokenInfo) || tokenInfo.ClientId != client.ClientId) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Invalid refresh token" + }); + } + + // Generate new token response, keeping the same scopes + var response = GenerateJwtTokenResponse(client.ClientId, tokenInfo.Scopes, tokenInfo.Resource); + + // Remove the old refresh token + if (!string.IsNullOrEmpty(refresh_token)) + { + _tokens.TryRemove(refresh_token, out _); + } + + HasIssuedRefreshToken = true; + return Results.Ok(response); + } + else + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "unsupported_grant_type", + ErrorDescription = "Unsupported grant type" + }); + } + }); + + // Introspection endpoint + app.MapPost("/introspect", async (HttpContext context) => + { + var form = await context.Request.ReadFormAsync(); + var token = form["token"].ToString(); + + if (string.IsNullOrEmpty(token)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Token is required" + }); + } + + // Check opaque access tokens + if (_tokens.TryGetValue(token, out var tokenInfo)) + { + if (tokenInfo.ExpiresAt < DateTimeOffset.UtcNow) + { + return Results.Ok(new TokenIntrospectionResponse { Active = false }); + } + + return Results.Ok(new TokenIntrospectionResponse + { + Active = true, + ClientId = tokenInfo.ClientId, + Scope = string.Join(" ", tokenInfo.Scopes), + ExpirationTime = tokenInfo.ExpiresAt.ToUnixTimeSeconds(), + Audience = tokenInfo.Resource?.ToString() + }); + } + + return Results.Ok(new TokenIntrospectionResponse { Active = false }); + }); + + // Dynamic Client Registration endpoint (RFC 7591) + app.MapPost("/register", async (HttpContext context) => + { + using var stream = context.Request.Body; + var registrationRequest = await JsonSerializer.DeserializeAsync( + stream, + OAuthJsonContext.Default.ClientRegistrationRequest, + context.RequestAborted); + + if (registrationRequest is null) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Invalid registration request" + }); + } + + // Validate redirect URIs are provided + if (registrationRequest.RedirectUris.Count == 0) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_redirect_uri", + ErrorDescription = "At least one redirect URI must be provided" + }); + } + + // Validate redirect URIs + foreach (var redirectUri in registrationRequest.RedirectUris) + { + if (!Uri.TryCreate(redirectUri, UriKind.Absolute, out var uri) || + (uri.Scheme != "http" && uri.Scheme != "https")) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_redirect_uri", + ErrorDescription = $"Invalid redirect URI: {redirectUri}" + }); + } + } + + // Generate client credentials + var clientId = $"dyn-{Guid.NewGuid():N}"; + var clientSecret = GenerateRandomToken(); + var issuedAt = DateTimeOffset.UtcNow; + + // Store the registered client + _clients[clientId] = new ClientInfo + { + ClientId = clientId, + ClientSecret = clientSecret, + RedirectUris = registrationRequest.RedirectUris, + }; + + var registrationResponse = new ClientRegistrationResponse + { + ClientId = clientId, + ClientSecret = clientSecret, + ClientIdIssuedAt = issuedAt.ToUnixTimeSeconds(), + RedirectUris = registrationRequest.RedirectUris, + GrantTypes = ["authorization_code", "refresh_token"], + ResponseTypes = ["code"], + TokenEndpointAuthMethod = "client_secret_post", + }; + + return Results.Ok(registrationResponse); + }); + + app.MapGet("/", () => "Demo In-Memory OAuth 2.0 Server with JWT Support"); + + Console.WriteLine($"OAuth Authorization Server running at {_url}"); + Console.WriteLine($"OAuth Server Metadata at {_url}/.well-known/oauth-authorization-server"); + Console.WriteLine($"JWT keys available at {_url}/.well-known/jwks.json"); + Console.WriteLine($"Demo Client ID: {clientId}"); + Console.WriteLine($"Demo Client Secret: {clientSecret}"); + + await app.RunAsync(cancellationToken); + } + + /// + /// Authenticates a client based on client credentials in the request. + /// + /// The HTTP context. + /// The form collection containing client credentials. + /// The client info if authentication succeeds, null otherwise. + private ClientInfo? AuthenticateClient(HttpContext context, IFormCollection form) + { + var clientId = form["client_id"].ToString(); + var clientSecret = form["client_secret"].ToString(); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + if (_clients.TryGetValue(clientId, out var client) && client.ClientSecret == clientSecret) + { + return client; + } + + return null; + } + + /// + /// Generates a JWT token response. + /// + /// The client ID. + /// The approved scopes. + /// The resource URI. + /// A token response. + private TokenResponse GenerateJwtTokenResponse(string clientId, List scopes, Uri? resource) + { + var expiresIn = TimeSpan.FromHours(1); + var issuedAt = DateTimeOffset.UtcNow; + + // For test-refresh-client, make the first token expired to test refresh functionality. + if (clientId == "test-refresh-client" && !HasIssuedExpiredToken) + { + HasIssuedExpiredToken = true; + expiresIn = TimeSpan.FromHours(-1); + } + + var expiresAt = issuedAt.Add(expiresIn); + var jwtId = Guid.NewGuid().ToString(); + + // Create JWT header and payload + var header = new Dictionary + { + { "alg", "RS256" }, + { "typ", "JWT" }, + { "kid", _keyId } + }; + + var payload = new Dictionary + { + { "iss", _url }, + { "sub", $"user-{clientId}" }, + { "name", $"user-{clientId}" }, + { "aud", resource?.ToString() ?? clientId }, + { "client_id", clientId }, + { "jti", jwtId }, + { "iat", issuedAt.ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture) }, + { "exp", expiresAt.ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture) }, + { "scope", string.Join(" ", scopes) } + }; + + // Create JWT token + var headerJson = JsonSerializer.Serialize(header, OAuthJsonContext.Default.DictionaryStringString); + var payloadJson = JsonSerializer.Serialize(payload, OAuthJsonContext.Default.DictionaryStringString); + + var headerBase64 = WebEncoders.Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson)); + var payloadBase64 = WebEncoders.Base64UrlEncode(Encoding.UTF8.GetBytes(payloadJson)); + + var dataToSign = $"{headerBase64}.{payloadBase64}"; + var signature = _rsa.SignData(Encoding.UTF8.GetBytes(dataToSign), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var signatureBase64 = WebEncoders.Base64UrlEncode(signature); + + var jwtToken = $"{headerBase64}.{payloadBase64}.{signatureBase64}"; + + // Generate opaque refresh token + var refreshToken = GenerateRandomToken(); + + // Store token info (for refresh token and introspection) + var tokenInfo = new TokenInfo + { + ClientId = clientId, + Scopes = scopes, + IssuedAt = issuedAt, + ExpiresAt = expiresAt, + Resource = resource, + JwtId = jwtId + }; + + _tokens[refreshToken] = tokenInfo; + + return new TokenResponse + { + AccessToken = jwtToken, + RefreshToken = refreshToken, + TokenType = "Bearer", + ExpiresIn = (int)expiresIn.TotalSeconds, + Scope = string.Join(" ", scopes) + }; + } + + /// + /// Generates a random token for authorization code or refresh token. + /// + /// A Base64Url encoded random token. + public static string GenerateRandomToken() + { + var bytes = new byte[32]; + Random.Shared.NextBytes(bytes); + return WebEncoders.Base64UrlEncode(bytes); + } + + /// + /// Verifies a PKCE code challenge against a code verifier. + /// + /// The code verifier to verify. + /// The code challenge to verify against. + /// True if the code challenge is valid, false otherwise. + public static bool VerifyCodeChallenge(string codeVerifier, string codeChallenge) + { + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + var computedChallenge = WebEncoders.Base64UrlEncode(challengeBytes); + + return computedChallenge == codeChallenge; + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json b/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json new file mode 100644 index 00000000..71b2b21f --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json @@ -0,0 +1,14 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:7029", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs new file mode 100644 index 00000000..159ef34e --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents token information for OAuth flow. +/// +internal sealed class TokenInfo +{ + /// + /// Gets or sets the client ID associated with this token. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the list of scopes approved for this token. + /// + public List Scopes { get; init; } = []; + + /// + /// Gets or sets the issued time of this token. + /// + public required DateTimeOffset IssuedAt { get; init; } + + /// + /// Gets or sets the expiration time of this token. + /// + public required DateTimeOffset ExpiresAt { get; init; } + + /// + /// Gets or sets the optional resource URI this token is for. + /// + public Uri? Resource { get; init; } + + /// + /// Gets or sets the JWT ID for this token. + /// + public string? JwtId { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs new file mode 100644 index 00000000..a27b624a --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the response from the token introspection endpoint. +/// +internal sealed class TokenIntrospectionResponse +{ + /// + /// Gets or sets a value indicating whether the token is active. + /// + [JsonPropertyName("active")] + public required bool Active { get; init; } + + /// + /// Gets or sets the client ID associated with the token. + /// + [JsonPropertyName("client_id")] + public string? ClientId { get; init; } + + /// + /// Gets or sets the scope of the token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } + + /// + /// Gets or sets the expiration timestamp of the token (Unix timestamp). + /// + [JsonPropertyName("exp")] + public long? ExpirationTime { get; init; } + + /// + /// Gets or sets the audience of the token. + /// + [JsonPropertyName("aud")] + public string? Audience { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs new file mode 100644 index 00000000..20789feb --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the token response for OAuth flow. +/// +internal sealed class TokenResponse +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public required string AccessToken { get; init; } + + /// + /// Gets or sets the token type. + /// + [JsonPropertyName("token_type")] + public required string TokenType { get; init; } + + /// + /// Gets or sets the token expiration time in seconds. + /// + [JsonPropertyName("expires_in")] + public required int ExpiresIn { get; init; } + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; init; } + + /// + /// Gets or sets the scope approved for this token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index cc6b4e0a..3ff50430 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -35,7 +35,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public void Constructor_Throws_For_Null_HttpClient() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, LoggerFactory)); + var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); }