diff --git a/Directory.Packages.props b/Directory.Packages.props index 00234fa3..fa70d20a 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -2,7 +2,7 @@ true 10.0.0-preview.3.25171.5 - 9.4.3-preview.1.25230.7 + 9.4.4-preview.1.25259.16 @@ -31,7 +31,7 @@ - + diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 6de2d30f..c42d7653 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -34,7 +34,6 @@ - diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs index 4c5da534..a31a4a28 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; @@ -68,7 +69,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, - Services = options?.Services, + CreateInstance = AIFunctionMcpServerTool.GetCreateInstanceFunc(), ConfigureParameterBinding = pi => { if (pi.ParameterType == typeof(RequestContext)) @@ -110,6 +111,32 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( }; } + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + return default; static RequestContext? GetRequestContext(AIFunctionArguments args) diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerResource.cs index 6ba2a0fd..be0cc84e 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerResource.cs @@ -1,10 +1,10 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; using System.ComponentModel; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Reflection; @@ -76,7 +76,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = McpJsonUtilities.DefaultOptions, - Services = options?.Services, + CreateInstance = AIFunctionMcpServerTool.GetCreateInstanceFunc(), ConfigureParameterBinding = pi => { if (pi.ParameterType == typeof(RequestContext)) @@ -118,6 +118,32 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( }; } + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + // These parameters are the ones and only ones to include in the schema. The schema // won't be consumed by anyone other than this instance, which will use it to determine // which properties should show up in the URI template. diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index d892039c..872f8868 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; @@ -60,6 +61,14 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool options); } + // TODO: Fix the need for this suppression. + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2111:ReflectionToDynamicallyAccessedMembers", + Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] + internal static Func GetCreateInstanceFunc() => + static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? + ActivatorUtilities.CreateInstance(services, type) : + Activator.CreateInstance(type)!; + private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MethodInfo method, McpServerToolCreateOptions? options) => new() @@ -68,7 +77,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, - Services = options?.Services, + CreateInstance = GetCreateInstanceFunc(), ConfigureParameterBinding = pi => { if (pi.ParameterType == typeof(RequestContext)) @@ -110,6 +119,32 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( }; } + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + return default; static RequestContext? GetRequestContext(AIFunctionArguments args) diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 93a76e1a..48a5e09b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -60,7 +60,7 @@ public async Task SupportsServiceFromDI() Assert.Contains("something", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); - await Assert.ThrowsAsync(async () => await prompt.GetAsync( + await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 0c765f47..10b3d17d 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -375,7 +375,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); - await Assert.ThrowsAsync(async () => await resource.ReadAsync( + await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index efd5028b..36f4b318 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -156,6 +156,10 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() [Fact] public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposableTargets() { + ServiceCollection sc = new(); + sc.AddSingleton(); + IServiceProvider services = sc.BuildServiceProvider(); + McpServerToolCreateOptions options = new() { SerializerOptions = JsonContext2.Default.Options }; McpServerTool tool1 = McpServerTool.Create( typeof(AsyncDisposableAndDisposableToolType).GetMethod(nameof(AsyncDisposableAndDisposableToolType.InstanceMethod))!, @@ -163,7 +167,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", result.Content[0].Text); } @@ -428,6 +432,11 @@ public object InstanceMethod() private class AsyncDisposableAndDisposableToolType : IAsyncDisposable, IDisposable { + public AsyncDisposableAndDisposableToolType(MyService service) + { + Assert.NotNull(service); + } + [JsonPropertyOrder(0)] public int AsyncDisposals { get; private set; }