Skip to content
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
Expand Down
2 changes: 1 addition & 1 deletion extensions/Chunkers/Chunkers/PlainTextChunker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ internal enum SeparatorTypes
".", "?", "!", "⁉", "⁈", "⁇", "…",
// Chinese punctuation
"。", "?", "!", ";", ":"
]);
]);

// Prioritized list of characters to split inside a sentence.
private static readonly SeparatorTrie s_potentialSeparators = new([
Expand Down
3 changes: 3 additions & 0 deletions extensions/OpenAI/OpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Context;
using OpenAI;

#pragma warning disable IDE0130 // reduce number of "using" statements
Expand Down Expand Up @@ -260,6 +261,7 @@ public static IServiceCollection AddOpenAITextGeneration(
return services
.AddSingleton<ITextGenerator, OpenAITextGenerator>(serviceProvider => new OpenAITextGenerator(
config: config,
contextProvider: serviceProvider.GetService<IContextProvider>(),
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
Expand All @@ -276,6 +278,7 @@ public static IServiceCollection AddOpenAITextGeneration(
.AddSingleton<ITextGenerator, OpenAITextGenerator>(serviceProvider => new OpenAITextGenerator(
config: config,
openAIClient: openAIClient,
contextProvider: serviceProvider.GetService<IContextProvider>(),
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
}
Expand Down
31 changes: 17 additions & 14 deletions extensions/OpenAI/OpenAI/OpenAITextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI.Internals;
using Microsoft.KernelMemory.Context;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand All @@ -30,8 +31,9 @@ public sealed class OpenAITextGenerator : ITextGenerator
private readonly OpenAIChatCompletionService _client;
private readonly ITextTokenizer _textTokenizer;
private readonly ILogger<OpenAITextGenerator> _log;
private readonly IContextProvider _contextProvider;

private readonly string _textModel;
private readonly string _modelName;

/// <inheritdoc/>
public int MaxTokenTotal { get; }
Expand All @@ -40,18 +42,17 @@ public sealed class OpenAITextGenerator : ITextGenerator
/// Create a new instance.
/// </summary>
/// <param name="config">Client and model configuration</param>
/// <param name="contextProvider">Request context provider with runtime configuration overrides</param>
/// <param name="textTokenizer">Text tokenizer, possibly matching the model used</param>
/// <param name="loggerFactory">App logger factory</param>
/// <param name="httpClient">Optional HTTP client with custom settings</param>
public OpenAITextGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
OpenAIConfig config, IContextProvider? contextProvider = null, ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(
config,
OpenAIClientBuilder.Build(config, httpClient, loggerFactory),
textTokenizer,
OpenAIClientBuilder.Build(config, httpClient, loggerFactory), contextProvider, textTokenizer,
loggerFactory)
{
}
Expand All @@ -61,17 +62,16 @@ public OpenAITextGenerator(
/// </summary>
/// <param name="config">Model configuration</param>
/// <param name="openAIClient">Custom OpenAI client, already configured</param>
/// <param name="contextProvider">Request context provider with runtime configuration overrides</param>
/// <param name="textTokenizer">Text tokenizer, possibly matching the model used</param>
/// <param name="loggerFactory">App logger factory</param>
public OpenAITextGenerator(
OpenAIConfig config,
OpenAIClient openAIClient,
ITextTokenizer? textTokenizer = null,
OpenAIClient openAIClient, IContextProvider? contextProvider = null, ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(
config,
SkClientBuilder.BuildChatClient(config.TextModel, openAIClient, loggerFactory),
textTokenizer,
SkClientBuilder.BuildChatClient(config.TextModel, openAIClient, loggerFactory), contextProvider, textTokenizer,
loggerFactory)
{
}
Expand All @@ -81,17 +81,18 @@ public OpenAITextGenerator(
/// </summary>
/// <param name="config">Model configuration</param>
/// <param name="skClient">Custom Semantic Kernel client, already configured</param>
/// <param name="contextProvider">Request context provider with runtime configuration overrides</param>
/// <param name="textTokenizer">Text tokenizer, possibly matching the model used</param>
/// <param name="loggerFactory">App logger factory</param>
public OpenAITextGenerator(
OpenAIConfig config,
OpenAIChatCompletionService skClient,
ITextTokenizer? textTokenizer = null,
OpenAIChatCompletionService skClient, IContextProvider? contextProvider = null, ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
{
this._client = skClient;
this._contextProvider = contextProvider ?? new RequestContextProvider();
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<OpenAITextGenerator>();
this._textModel = config.TextModel;
this._modelName = config.TextModel;
this.MaxTokenTotal = config.TextModelMaxTokenTotal;

if (textTokenizer == null && !string.IsNullOrEmpty(config.TextModelTokenizer))
Expand Down Expand Up @@ -129,13 +130,15 @@ public async IAsyncEnumerable<GeneratedTextContent> GenerateTextAsync(
TextGenerationOptions options,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var modelName = this._contextProvider.GetContext().GetCustomTextGenerationModelNameOrDefault(this._modelName);
var skOptions = new OpenAIPromptExecutionSettings
{
MaxTokens = options.MaxTokens,
Temperature = options.Temperature,
FrequencyPenalty = options.FrequencyPenalty,
PresencePenalty = options.PresencePenalty,
TopP = options.NucleusSampling
TopP = options.NucleusSampling,
ModelId = modelName
};

if (options.StopSequences is { Count: > 0 })
Expand Down Expand Up @@ -178,7 +181,7 @@ public async IAsyncEnumerable<GeneratedTextContent> GenerateTextAsync(
Timestamp = (DateTimeOffset?)x.Metadata["CreatedAt"] ?? DateTimeOffset.UtcNow,
ServiceType = "OpenAI",
ModelType = Constants.ModelType.TextGeneration,
ModelName = this._textModel,
ModelName = modelName,
ServiceTokensIn = usage!.InputTokenCount,
ServiceTokensOut = usage.OutputTokenCount,
ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount
Expand Down
2 changes: 1 addition & 1 deletion service/Core/Search/AnswerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private IAsyncEnumerable<GeneratedTextContent> GenerateAnswerTokensAsync(string
PresencePenalty = this._config.PresencePenalty,
FrequencyPenalty = this._config.FrequencyPenalty,
StopSequences = this._config.StopSequences,
TokenSelectionBiases = this._config.TokenSelectionBiases,
TokenSelectionBiases = this._config.TokenSelectionBiases
};

if (this._log.IsEnabled(LogLevel.Debug))
Expand Down