Skip to content

Port Microsoft.Extensions.AI.AzureAIInference to Azure.AI.Inference #50097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion eng/Packages.Data.props
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
<PackageReference Update="System.ValueTuple" Version="4.5.0" />
<PackageReference Update="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
<PackageReference Update="Microsoft.CSharp" Version="4.7.0" />
<PackageReference Update="Microsoft.Extensions.Logging.Abstractions" Version="8.0.3"/>
<PackageReference Update="Microsoft.Extensions.Logging.Abstractions" Version="8.0.3"/>

<!-- Azure SDK packages -->
<PackageReference Update="Azure.AI.Inference" Version="1.0.0-beta.4" />
Expand Down Expand Up @@ -187,6 +187,11 @@
<PackageReference Update="Microsoft.AspNetCore.Http.Features" Version="[2.1.1,6.0)" />
</ItemGroup>

<ItemGroup Condition="$(MSBuildProjectName.StartsWith('Azure.AI.Inference'))">
<!-- Microsoft.Extensions.AI.Abstractions approved for Azure.AI.Inference; it has down-level support -->
<PackageReference Update="Microsoft.Extensions.AI.Abstractions" Version="9.5.0"/>
</ItemGroup>

<ItemGroup Condition="$(MSBuildProjectName.StartsWith('Azure.AI.OpenAI'))">
<PackageReference Update="OpenAI" Version="2.2.0-beta.4" />
</ItemGroup>
Expand Down Expand Up @@ -374,7 +379,9 @@
<PackageReference Update="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.8.0" />
<PackageReference Update="Microsoft.CSharp" Version="4.7.0" />
<PackageReference Update="Microsoft.Data.SqlClient" Version="5.2.2" />
<PackageReference Update="Microsoft.Extensions.AI" Version="9.5.0" /> <!-- 9.x approved for test project use, as there is no 8.x version available. -->
<PackageReference Update="Microsoft.Extensions.Azure" Version="1.11.0" />
<PackageReference Update="Microsoft.Extensions.Caching.Memory" Version="8.0.1" />
<PackageReference Update="Microsoft.Extensions.Configuration.Abstractions" Version="8.0.0" />
<PackageReference Update="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
<PackageReference Update="Microsoft.Extensions.Configuration.Json" Version="8.0.1" />
Expand Down
9 changes: 9 additions & 0 deletions sdk/ai/Azure.AI.Inference/api/Azure.AI.Inference.net8.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,15 @@ internal StreamingToolCallUpdate() { }
public int ToolCallIndex { get { throw null; } }
}
}
namespace Microsoft.Extensions.AI
{
public static partial class AzureAIInferenceExtensions
{
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Inference.ChatCompletionsClient chatCompletionsClient, string modelId = null) { throw null; }
public static Microsoft.Extensions.AI.IEmbeddingGenerator<string, Microsoft.Extensions.AI.Embedding<float>> AsIEmbeddingGenerator(this Azure.AI.Inference.EmbeddingsClient embeddingsClient, string defaultModelId = null, int? defaultModelDimensions = default(int?)) { throw null; }
public static Microsoft.Extensions.AI.IEmbeddingGenerator<Microsoft.Extensions.AI.DataContent, Microsoft.Extensions.AI.Embedding<float>> AsIEmbeddingGenerator(this Azure.AI.Inference.ImageEmbeddingsClient imageEmbeddingsClient, string defaultModelId = null, int? defaultModelDimensions = default(int?)) { throw null; }
}
}
namespace Microsoft.Extensions.Azure
{
public static partial class AIInferenceClientBuilderExtensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,15 @@ internal StreamingToolCallUpdate() { }
public int ToolCallIndex { get { throw null; } }
}
}
namespace Microsoft.Extensions.AI
{
public static partial class AzureAIInferenceExtensions
{
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Inference.ChatCompletionsClient chatCompletionsClient, string modelId = null) { throw null; }
public static Microsoft.Extensions.AI.IEmbeddingGenerator<string, Microsoft.Extensions.AI.Embedding<float>> AsIEmbeddingGenerator(this Azure.AI.Inference.EmbeddingsClient embeddingsClient, string defaultModelId = null, int? defaultModelDimensions = default(int?)) { throw null; }
public static Microsoft.Extensions.AI.IEmbeddingGenerator<Microsoft.Extensions.AI.DataContent, Microsoft.Extensions.AI.Embedding<float>> AsIEmbeddingGenerator(this Azure.AI.Inference.ImageEmbeddingsClient imageEmbeddingsClient, string defaultModelId = null, int? defaultModelDimensions = default(int?)) { throw null; }
}
}
namespace Microsoft.Extensions.Azure
{
public static partial class AIInferenceClientBuilderExtensions
Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/Azure.AI.Inference/src/Azure.AI.Inference.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="Azure.Core" />
<PackageReference Include="System.ClientModel" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions"/>
</ItemGroup>

</Project>

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable disable

using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Buffers.Text;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;

namespace Microsoft.Extensions.AI
{
/// <summary>Represents an <see cref="IEmbeddingGenerator{String, Embedding}"/> for an Azure.AI.Inference <see cref="EmbeddingsClient"/>.</summary>
internal sealed class AzureAIInferenceEmbeddingGenerator :
IEmbeddingGenerator<string, Embedding<float>>
{
/// <summary>Metadata about the embedding generator.</summary>
private readonly EmbeddingGeneratorMetadata _metadata;

/// <summary>The underlying <see cref="EmbeddingsClient" />.</summary>
private readonly EmbeddingsClient _embeddingsClient;

/// <summary>The number of dimensions produced by the generator.</summary>
private readonly int? _dimensions;

/// <summary>Initializes a new instance of the <see cref="AzureAIInferenceEmbeddingGenerator"/> class.</summary>
/// <param name="embeddingsClient">The underlying client.</param>
/// <param name="defaultModelId">
/// The ID of the model to use. This can also be overridden per request via <see cref="EmbeddingGenerationOptions.ModelId"/>.
/// Either this parameter or <see cref="EmbeddingGenerationOptions.ModelId"/> must provide a valid model ID.
/// </param>
/// <param name="defaultModelDimensions">The number of dimensions to generate in each embedding.</param>
/// <exception cref="ArgumentNullException"><paramref name="embeddingsClient"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="defaultModelId"/> is empty or composed entirely of whitespace.</exception>
/// <exception cref="ArgumentOutOfRangeException"><paramref name="defaultModelDimensions"/> is not positive.</exception>
public AzureAIInferenceEmbeddingGenerator(
EmbeddingsClient embeddingsClient, string defaultModelId = null, int? defaultModelDimensions = null)
{
Argument.AssertNotNull(embeddingsClient, nameof(embeddingsClient));

if (defaultModelId is not null)
{
Argument.AssertNotNullOrWhiteSpace(defaultModelId, nameof(defaultModelId));
}

if (defaultModelDimensions is { } modelDimensions)
{
Argument.AssertInRange(modelDimensions, 1, int.MaxValue, nameof(defaultModelDimensions));
}

_embeddingsClient = embeddingsClient;
_dimensions = defaultModelDimensions;
_metadata = new EmbeddingGeneratorMetadata("az.ai.inference", embeddingsClient.Endpoint, defaultModelId, defaultModelDimensions);
}

/// <inheritdoc />
object IEmbeddingGenerator.GetService(Type serviceType, object serviceKey)
{
Argument.AssertNotNull(serviceType, nameof(serviceType));

return
serviceKey is not null ? null :
serviceType == typeof(EmbeddingsClient) ? _embeddingsClient :
serviceType == typeof(EmbeddingGeneratorMetadata) ? _metadata :
serviceType.IsInstanceOfType(this) ? this :
null;
}

/// <inheritdoc />
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
IEnumerable<string> values, EmbeddingGenerationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(values, nameof(values));

var azureAIOptions = ToAzureAIOptions(values, options);

var embeddings = (await _embeddingsClient.EmbedAsync(azureAIOptions, cancellationToken).ConfigureAwait(false)).Value;

GeneratedEmbeddings<Embedding<float>> result = new(embeddings.Data.Select(e =>
new Embedding<float>(ParseBase64Floats(e.Embedding))
{
CreatedAt = DateTimeOffset.UtcNow,
ModelId = embeddings.Model ?? azureAIOptions.Model,
}));

if (embeddings.Usage is not null)
{
result.Usage = new()
{
InputTokenCount = embeddings.Usage.PromptTokens,
TotalTokenCount = embeddings.Usage.TotalTokens
};
}

return result;
}

/// <inheritdoc />
void IDisposable.Dispose()
{
// Nothing to dispose. Implementation required for the IEmbeddingGenerator interface.
}

internal static float[] ParseBase64Floats(BinaryData binaryData)
{
ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span;

// Remove quotes around base64 string.
if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"')
{
ThrowInvalidData();
}

base64 = base64.Slice(1, base64.Length - 2);

// Decode base64 string to bytes.
byte[] bytes = ArrayPool<byte>.Shared.Rent(Base64.GetMaxDecodedFromUtf8Length(base64.Length));
OperationStatus status = Base64.DecodeFromUtf8(base64, bytes.AsSpan(), out int bytesConsumed, out int bytesWritten);
if (status != OperationStatus.Done || bytesWritten % sizeof(float) != 0)
{
ThrowInvalidData();
}

// Interpret bytes as floats
float[] vector = new float[bytesWritten / sizeof(float)];
bytes.AsSpan(0, bytesWritten).CopyTo(MemoryMarshal.AsBytes(vector.AsSpan()));
if (!BitConverter.IsLittleEndian)
{
Span<int> ints = MemoryMarshal.Cast<float, int>(vector.AsSpan());
#if NET
BinaryPrimitives.ReverseEndianness(ints, ints);
#else
for (int i = 0; i < ints.Length; i++)
{
ints[i] = BinaryPrimitives.ReverseEndianness(ints[i]);
}
#endif
}

ArrayPool<byte>.Shared.Return(bytes);
return vector;

static void ThrowInvalidData() =>
throw new FormatException("The input is not a valid Base64 string of encoded floats.");
}

/// <summary>Converts an extensions options instance to an Azure.AI.Inference options instance.</summary>
private EmbeddingsOptions ToAzureAIOptions(IEnumerable<string> inputs, EmbeddingGenerationOptions options)
{
if (options?.RawRepresentationFactory?.Invoke(this) is not EmbeddingsOptions result)
{
result = new EmbeddingsOptions(inputs);
}
else
{
foreach (var input in inputs)
{
result.Input.Add(input);
}
}

result.Dimensions ??= options?.Dimensions ?? _dimensions;
result.Model ??= options?.ModelId ?? _metadata.DefaultModelId;
result.EncodingFormat = EmbeddingEncodingFormat.Base64;

if (options?.AdditionalProperties is { } props)
{
foreach (var prop in props)
{
if (prop.Value is not null)
{
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}
}
}

return result;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable disable

using Azure.AI.Inference;

namespace Microsoft.Extensions.AI
{
/// <summary>Provides extension methods for working with Azure AI Inference.</summary>
public static class AzureAIInferenceExtensions
{
/// <summary>Gets an <see cref="IChatClient"/> for use with this <see cref="ChatCompletionsClient"/>.</summary>
/// <param name="chatCompletionsClient">The client.</param>
/// <param name="modelId">The ID of the model to use. If <see langword="null"/>, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
/// <returns>An <see cref="IChatClient"/> that can be used to converse via the <see cref="ChatCompletionsClient"/>.</returns>
public static IChatClient AsIChatClient(
this ChatCompletionsClient chatCompletionsClient, string modelId = null) =>
new AzureAIInferenceChatClient(chatCompletionsClient, modelId);

/// <summary>Gets an <see cref="IEmbeddingGenerator{String, Single}"/> for use with this <see cref="EmbeddingsClient"/>.</summary>
/// <param name="embeddingsClient">The client.</param>
/// <param name="defaultModelId">The ID of the model to use. If <see langword="null"/>, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
/// <param name="defaultModelDimensions">The number of dimensions generated in each embedding.</param>
/// <returns>An <see cref="IEmbeddingGenerator{String, Embedding}"/> that can be used to generate embeddings via the <see cref="EmbeddingsClient"/>.</returns>
public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerator(
this EmbeddingsClient embeddingsClient, string defaultModelId = null, int? defaultModelDimensions = null) =>
new AzureAIInferenceEmbeddingGenerator(embeddingsClient, defaultModelId, defaultModelDimensions);

/// <summary>Gets an <see cref="IEmbeddingGenerator{DataContent, Single}"/> for use with this <see cref="EmbeddingsClient"/>.</summary>
/// <param name="imageEmbeddingsClient">The client.</param>
/// <param name="defaultModelId">The ID of the model to use. If <see langword="null"/>, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
/// <param name="defaultModelDimensions">The number of dimensions generated in each embedding.</param>
/// <returns>An <see cref="IEmbeddingGenerator{DataContent, Embedding}"/> that can be used to generate embeddings via the <see cref="ImageEmbeddingsClient"/>.</returns>
public static IEmbeddingGenerator<DataContent, Embedding<float>> AsIEmbeddingGenerator(
this ImageEmbeddingsClient imageEmbeddingsClient, string defaultModelId = null, int? defaultModelDimensions = null) =>
new AzureAIInferenceImageEmbeddingGenerator(imageEmbeddingsClient, defaultModelId, defaultModelDimensions);
}
}
Loading