Skip to content

Commit

Permalink
Add an IChatClient implementation to OnnxRuntimeGenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Dec 3, 2024
1 parent 44a8f22 commit 99d78b4
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 1 deletion.
232 changes: 232 additions & 0 deletions src/csharp/ChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>Provides an <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
public sealed partial class ChatClient : IChatClient
{
/// <summary>The options used to configure the instance.</summary>
private readonly ChatClientConfiguration _config;

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cpu-arm64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cpu-arm64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cpu-arm64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cpu-arm64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cuda-x64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cuda-x64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cuda-x64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux-cuda-x64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux_cpu_x64

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux_cpu_x64

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux_cpu_x64

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / linux_cpu_x64

Field 'ChatClient._config' is never assigned to, and will always have its default value null

Check warning on line 15 in src/csharp/ChatClient.cs

View workflow job for this annotation

GitHub Actions / windows-cuda-x64-build

Field 'ChatClient._config' is never assigned to, and will always have its default value null
/// <summary>The wrapped <see cref="Model"/>.</summary>
private readonly Model _model;
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
private readonly Tokenizer _tokenizer;
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
private readonly bool _ownsModel;

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="modelPath">The file path to the model to load.</param>
/// <param name="configuration">Options used to configure the client instance.</param>
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
public ChatClient(string modelPath, ChatClientConfiguration configuration)
{
if (modelPath is null)
{
throw new ArgumentNullException(nameof(modelPath));
}

_ownsModel = true;
_model = new Model(modelPath);
_tokenizer = new Tokenizer(_model);

Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath);
}

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="model">The model to employ.</param>
/// <param name="ownsModel">
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
/// The default is <see langword="true"/>.
/// </param>
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
public ChatClient(Model model, bool ownsModel = true)
{
if (model is null)
{
throw new ArgumentNullException(nameof(model));
}

_ownsModel = ownsModel;
_model = model;
_tokenizer = new Tokenizer(_model);

Metadata = new("onnxruntime-genai");
}

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; }

/// <inheritdoc/>
public void Dispose()
{
_tokenizer.Dispose();

if (_ownsModel)
{
_model.Dispose();
}
}

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

StringBuilder text = new();
await Task.Run(() =>
{
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
cancellationToken.ThrowIfCancellationRequested();

generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);

if (IsStop(next, options))
{
break;
}

text.Append(next);
}
}, cancellationToken);

return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
{
CompletionId = Guid.NewGuid().ToString(),
CreatedAt = DateTimeOffset.UtcNow,
ModelId = Metadata.ModelId,
};
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
string next = await Task.Run(() =>
{
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
}, cancellationToken);

if (IsStop(next, options))
{
break;
}

yield return new StreamingChatCompletionUpdate
{
CompletionId = completionId,
CreatedAt = DateTimeOffset.UtcNow,
Role = ChatRole.Assistant,
Text = next,
};
}
}

/// <inheritdoc/>
public object GetService(Type serviceType, object key = null) =>
key is not null ? null :
serviceType == typeof(Model) ? _model :
serviceType == typeof(Tokenizer) ? _tokenizer :
serviceType?.IsInstanceOfType(this) is true ? this :
null;

/// <summary>Gets whether the specified token is a stop sequence.</summary>
private bool IsStop(string token, ChatOptions options) =>
options?.StopSequences?.Contains(token) is true ||
Array.IndexOf(_config.StopSequences, token) >= 0;

/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
{
if (options is null)
{
return;
}

if (options.MaxOutputTokens.HasValue)
{
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
}

if (options.Temperature.HasValue)
{
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
}

if (options.TopP.HasValue || options.TopK.HasValue)
{
if (options.TopP.HasValue)
{
generatorParams.SetSearchOption("top_p", options.TopP.Value);
}

if (options.TopK.HasValue)
{
generatorParams.SetSearchOption("top_k", options.TopK.Value);
}
}

if (options.Seed.HasValue)
{
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
}

if (options.AdditionalProperties is { } props)
{
foreach (var entry in props)
{
switch (entry.Value)
{
case int i: generatorParams.SetSearchOption(entry.Key, i); break;
case long l: generatorParams.SetSearchOption(entry.Key, l); break;
case float f: generatorParams.SetSearchOption(entry.Key, f); break;
case double d: generatorParams.SetSearchOption(entry.Key, d); break;
case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
}
}
}
}
}
73 changes: 73 additions & 0 deletions src/csharp/ChatClientConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
/// <remarks>
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
/// the configuration should be tailored to the exact nature of the model being used. For example,
/// when using a Phi3 model, a configuration like the following may be used:
/// <code>
/// static ChatClientConfiguration CreateForPhi3() =&gt;
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
/// {
/// StringBuilder prompt = new();
///
/// foreach (var message in messages)
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
///
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
/// });
/// </code>
/// </remarks>
public sealed class ChatClientConfiguration
{
private string[] _stopSequences;
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;

/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
/// <param name="stopSequences">The stop sequences used by the model.</param>
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
public ChatClientConfiguration(
string[] stopSequences,
Func<IEnumerable<ChatMessage>, string> promptFormatter)
{
if (stopSequences is null)
{
throw new ArgumentNullException(nameof(stopSequences));
}

if (promptFormatter is null)
{
throw new ArgumentNullException(nameof(promptFormatter));
}

StopSequences = stopSequences;
PromptFormatter = promptFormatter;
}

/// <summary>
/// Gets or sets stop sequences to use during generation.
/// </summary>
/// <remarks>
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
/// </remarks>
public string[] StopSequences
{
get => _stopSequences;
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
}

/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
{
get => _promptFormatter;
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
}
}
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@
<PackageReference Include="System.Memory" Version="4.5.5" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
</ItemGroup>

</Project>
30 changes: 29 additions & 1 deletion test/csharp/TestOnnxRuntimeGenAIAPI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.CompilerServices;
using Xunit;
using Xunit.Abstractions;
using System.Collections.Generic;
using Microsoft.Extensions.AI;
using System.Text;

namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
Expand Down Expand Up @@ -349,6 +351,32 @@ public void TestTopKTopPSearch()
}
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
public async void TestChatClient()
{
using var client = new ChatClient(
_phi2Path,
new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
(IEnumerable<ChatMessage> messages) =>
{
StringBuilder prompt = new();

foreach (var message in messages)
foreach (var content in message.Contents.OfType<TextContent>())
prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");

return prompt.Append("<|assistant|>\n").ToString();
}));

var completion = await client.CompleteAsync("What is 2 + 3?", new()
{
MaxOutputTokens = 20,
Temperature = 0f,
});

Assert.Contains("5", completion.ToString());
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
public void TestTokenizerBatchEncodeDecode()
{
Expand Down

0 comments on commit 99d78b4

Please sign in to comment.