Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using System.Text.RegularExpressions;
using System.Threading;
Expand Down Expand Up @@ -627,6 +628,49 @@ private ReflectionAIFunction(
var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];

// If the configured serializer options request strict handling of unmapped members,
// verify that every argument key corresponds to a declared parameter name. This mirrors
// JsonSerializerOptions.UnmappedMemberHandling behavior for object deserialization by
// applying the same policy to top-level AIFunction argument binding. Argument name matching
// honors the comparer of the supplied AIFunctionArguments dictionary (ordinal by default).
//
// Validation is skipped when custom ParameterBindingOptions.BindParameter callbacks are in
// use, since those may legitimately source values from argument keys that do not correspond
// to the .NET parameter names.
if (FunctionDescriptor.JsonSerializerOptions.UnmappedMemberHandling is JsonUnmappedMemberHandling.Disallow &&
arguments.Count > 0 &&
!FunctionDescriptor.HasCustomParameterBinding)
{
HashSet<string> expectedNames = FunctionDescriptor.ExpectedArgumentNames;
int matched = 0;
foreach (string name in expectedNames)
{
if (arguments.ContainsKey(name))
{
matched++;
}
}

if (matched != arguments.Count)
{
foreach (KeyValuePair<string, object?> kvp in arguments)
{
if (!expectedNames.Contains(kvp.Key))
{
Throw.ArgumentException(
nameof(arguments),
$"The arguments dictionary contains an unexpected key '{kvp.Key}' that does not correspond to any parameter of '{Name}'.");
}
}

// Fallback for comparer mismatches (e.g. case-insensitive arguments dictionary
// with duplicate-casing keys aliasing to the same parameter).
Throw.ArgumentException(
nameof(arguments),
$"The arguments dictionary contains keys that do not correspond to any parameter of '{Name}'.");
}
}

for (int i = 0; i < args.Length; i++)
{
args[i] = paramMarshallers[i](arguments, cancellationToken);
Expand Down Expand Up @@ -733,6 +777,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions

// Get marshaling delegates for parameters.
ParameterMarshallers = parameters.Length > 0 ? new Func<AIFunctionArguments, CancellationToken, object?>[parameters.Length] : [];
HashSet<string> expectedArgumentNames = new(StringComparer.Ordinal);
bool hasCustomParameterBinding = false;
for (int i = 0; i < parameters.Length; i++)
{
if (boundParameters?.TryGetValue(parameters[i], out AIFunctionFactoryOptions.ParameterBindingOptions options) is not true)
Expand All @@ -741,8 +787,32 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
}

ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);

if (options.BindParameter is not null)
{
// Custom BindParameter callbacks can legally source their value from arbitrary keys in the
// AIFunctionArguments dictionary, so we cannot know in advance which keys are "expected".
// Note this down so that strict unmapped-member validation is skipped in InvokeCoreAsync.
hasCustomParameterBinding = true;
}

// Collect the set of parameter names that are potentially sourced from the arguments dictionary.
// Infrastructure parameters (CancellationToken, AIFunctionArguments, IServiceProvider) are always
// bound from dedicated sources and are never resolved by argument name, so they are excluded from
// the permitted set.
Type pType = parameters[i].ParameterType;
if (pType != typeof(CancellationToken) &&
pType != typeof(AIFunctionArguments) &&
pType != typeof(IServiceProvider) &&
!string.IsNullOrEmpty(parameters[i].Name))
{
_ = expectedArgumentNames.Add(parameters[i].Name!);
}
Comment thread
eiriktsarpalis marked this conversation as resolved.
}

ExpectedArgumentNames = expectedArgumentNames;
HasCustomParameterBinding = hasCustomParameterBinding;

ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType);
Method = key.Method;
Name = key.Name ?? key.Method.GetCustomAttribute<DisplayNameAttribute>(inherit: true)?.DisplayName ?? GetFunctionName(key.Method);
Expand Down Expand Up @@ -770,6 +840,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public JsonElement? ReturnJsonSchema { get; }
public Func<AIFunctionArguments, CancellationToken, object?>[] ParameterMarshallers { get; }
public Func<object?, CancellationToken, ValueTask<object?>> ReturnParameterMarshaller { get; }
public HashSet<string> ExpectedArgumentNames { get; }
public bool HasCustomParameterBinding { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }

private static string GetFunctionName(MethodInfo method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ public AIFunctionFactoryOptions()

/// <summary>Gets or sets the <see cref="JsonSerializerOptions"/> used to marshal .NET values being passed to the underlying delegate.</summary>
/// <remarks>
/// <para>
/// If no value has been specified, the <see cref="AIJsonUtilities.DefaultOptions"/> instance will be used.
/// </para>
/// <para>
/// The <see cref="JsonSerializerOptions.UnmappedMemberHandling"/> setting is honored by the function parameter
/// binder: when set to <see cref="System.Text.Json.Serialization.JsonUnmappedMemberHandling.Disallow"/>, invoking
/// the produced <see cref="AIFunction"/> throws if the supplied <see cref="AIFunctionArguments"/> contains keys
/// that do not correspond to a bindable parameter of the underlying method.
/// </para>
/// </remarks>
public JsonSerializerOptions? SerializerOptions { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,134 @@ public async Task AIFunctionFactory_DynamicMethod()
#endif
}

[Fact]
public async Task Parameters_UnmappedMemberHandlingDisallow_ThrowsOnExtraArgument_Async()
{
JsonSerializerOptions strictOptions = new(AIJsonUtilities.DefaultOptions)
{
UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow,
};

AIFunction func = AIFunctionFactory.Create(
(string taskId, string update, bool markComplete = false) => $"{taskId}:{update}:{markComplete}",
new AIFunctionFactoryOptions { SerializerOptions = strictOptions });

// Extra, unrecognized argument causes a throw.
ArgumentException ex = await Assert.ThrowsAsync<ArgumentException>("arguments", async () =>
await func.InvokeAsync(new()
{
["taskId"] = "abc",
["update"] = "Done",
["phase"] = "completed",
}));
Assert.Contains("phase", ex.Message);

// Still succeeds when no unexpected arguments are present (optional parameter omitted).
object? result = await func.InvokeAsync(new()
{
["taskId"] = "abc",
["update"] = "Done",
});
AssertExtensions.EqualFunctionCallResults("abc:Done:False", result);
}

[Fact]
public async Task Parameters_UnmappedMemberHandlingDefault_IgnoresExtraArgument_Async()
{
// Default behavior (Skip) should preserve pre-existing lenient binding.
AIFunction func = AIFunctionFactory.Create(
(string update, bool markComplete = false) => $"{update}:{markComplete}");

object? result = await func.InvokeAsync(new()
{
["update"] = "Done",
["phase"] = "completed",
});
AssertExtensions.EqualFunctionCallResults("Done:False", result);
}

[Fact]
public async Task Parameters_UnmappedMemberHandlingDisallow_HonorsArgumentsComparer_Async()
{
JsonSerializerOptions strictOptions = new(AIJsonUtilities.DefaultOptions)
{
UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow,
};

AIFunction func = AIFunctionFactory.Create(
(string update, bool markComplete = false) => $"{update}:{markComplete}",
new AIFunctionFactoryOptions { SerializerOptions = strictOptions });

// Case-insensitive arguments dictionary: casing variations of the parameter name must not be
// flagged as unmapped, since the binding lookup itself is case-insensitive.
AIFunctionArguments caseInsensitive = new(StringComparer.OrdinalIgnoreCase)
{
["UPDATE"] = "Done",
["MarkComplete"] = true,
};
AssertExtensions.EqualFunctionCallResults("Done:True", await func.InvokeAsync(caseInsensitive));

// A genuinely unmapped key is still flagged even with a case-insensitive comparer.
AIFunctionArguments withExtra = new(StringComparer.OrdinalIgnoreCase)
{
["update"] = "Done",
["PHASE"] = "completed",
};
ArgumentException ex = await Assert.ThrowsAsync<ArgumentException>("arguments", async () =>
await func.InvokeAsync(withExtra));
Assert.Contains("PHASE", ex.Message);
}

[Fact]
public async Task Parameters_UnmappedMemberHandlingDisallow_ParameterlessMethod_ThrowsOnAnyArgument_Async()
{
JsonSerializerOptions strictOptions = new(AIJsonUtilities.DefaultOptions)
{
UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow,
};

AIFunction func = AIFunctionFactory.Create(
() => "ok",
new AIFunctionFactoryOptions { SerializerOptions = strictOptions });

// No args is fine.
AssertExtensions.EqualFunctionCallResults("ok", await func.InvokeAsync());

// Any extra key is flagged.
ArgumentException ex = await Assert.ThrowsAsync<ArgumentException>("arguments", async () =>
await func.InvokeAsync(new() { ["phase"] = "completed" }));
Assert.Contains("phase", ex.Message);
}

[Fact]
public async Task Parameters_UnmappedMemberHandlingDisallow_CustomBindParameter_SkipsStrictValidation_Async()
{
JsonSerializerOptions strictOptions = new(AIJsonUtilities.DefaultOptions)
{
UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow,
};

// A custom BindParameter callback sources its value from a key that does not correspond
// to the .NET parameter name. Strict validation must be skipped so such binders keep working.
AIFunction func = AIFunctionFactory.Create(
(string update) => $"update:{update}",
new AIFunctionFactoryOptions
{
SerializerOptions = strictOptions,
ConfigureParameterBinding = _ => new()
{
BindParameter = (_, args) => args["aliasedKey"],
},
});

object? result = await func.InvokeAsync(new()
{
["aliasedKey"] = "hello",
["anotherKey"] = "world",
});
AssertExtensions.EqualFunctionCallResults("update:hello", result);
}

[JsonSerializable(typeof(IAsyncEnumerable<int>))]
[JsonSerializable(typeof(int[]))]
[JsonSerializable(typeof(string))]
Expand Down
Loading