Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -164,13 +165,21 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
{
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);

// Aggregation state for multi-iteration function calling loops.
// Text content from intermediate iterations (before tool calls) would otherwise be lost.
// Token usage must be summed across all API calls to report accurate totals.
StringBuilder? aggregatedContent = null;
int totalPromptTokens = 0;
int totalCandidatesTokens = 0;
bool hadMultipleIterations = false;

for (state.Iteration = 1; ; state.Iteration++)
{
List<GeminiChatMessageContent> chatResponses;
GeminiResponse geminiResponse;
using (var activity = ModelDiagnostics.StartCompletionActivity(
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
GeminiResponse geminiResponse;
try
{
geminiResponse = await this.SendRequestAndReturnValidGeminiResponseAsync(
Expand All @@ -190,19 +199,42 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
geminiResponse.UsageMetadata?.CandidatesTokenCount);
}

// Aggregate usage across all iterations.
totalPromptTokens += geminiResponse.UsageMetadata?.PromptTokenCount ?? 0;
totalCandidatesTokens += geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0;

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!state.AutoInvoke || chatResponses.Count != 1)
{
// Apply aggregated content and usage to the final message.
this.ApplyAggregatedState(chatResponses, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations);
return chatResponses;
}

state.LastMessage = chatResponses[0];
if (state.LastMessage.ToolCalls is null || state.LastMessage.ToolCalls.Count == 0)
{
// Apply aggregated content and usage to the final message.
this.ApplyAggregatedState(chatResponses, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations);
return chatResponses;
}

// We're about to process tool calls and continue the loop - mark that we have multiple iterations.
hadMultipleIterations = true;

// Accumulate text content from this iteration before processing tool calls.
// The LLM may generate text (e.g., "Let me check that for you...") before tool calls.
if (!string.IsNullOrEmpty(state.LastMessage.Content))
{
aggregatedContent ??= new StringBuilder();
if (aggregatedContent.Length > 0)
{
aggregatedContent.Append("\n\n");
}
aggregatedContent.Append(state.LastMessage.Content);
}

// ToolCallBehavior is not null because we are in auto-invoke mode but we check it again to be sure it wasn't changed in the meantime
Verify.NotNull(state.ExecutionSettings.ToolCallBehavior);

Expand All @@ -213,7 +245,10 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
// and return the last chat message content that was added to chat history
if (state.FilterTerminationRequested)
{
return [state.ChatHistory.Last()];
var lastMessage = state.ChatHistory.Last();
// Apply aggregated content and usage to the filter-terminated message.
this.ApplyAggregatedState(lastMessage, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations);
return [lastMessage];
}
}
}
Expand Down Expand Up @@ -889,6 +924,85 @@ private static GeminiMetadata GetResponseMetadata(
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};

/// <summary>
/// Applies aggregated text content and usage from previous iterations to the final message(s).
/// This ensures that text generated before tool calls is not lost and that token usage
/// reflects the total across all API calls in the function calling loop.
/// </summary>
/// <param name="messages">The list of messages to update.</param>
/// <param name="aggregatedContent">Accumulated text content from previous iterations, or null if none.</param>
/// <param name="totalPromptTokens">Total prompt tokens across all iterations.</param>
/// <param name="totalCandidatesTokens">Total candidates tokens across all iterations.</param>
/// <param name="hadMultipleIterations">Whether the function calling loop had multiple iterations.</param>
private void ApplyAggregatedState(
List<GeminiChatMessageContent> messages,
StringBuilder? aggregatedContent,
int totalPromptTokens,
int totalCandidatesTokens,
bool hadMultipleIterations)
{
if (messages.Count == 0)
{
return;
}

this.ApplyAggregatedStateToMessage(messages[0], aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations);
}

/// <summary>
/// Applies aggregated text content and usage from previous iterations to a single message.
/// </summary>
/// <param name="message">The message to update.</param>
/// <param name="aggregatedContent">Accumulated text content from previous iterations, or null if none.</param>
/// <param name="totalPromptTokens">Total prompt tokens across all iterations.</param>
/// <param name="totalCandidatesTokens">Total candidates tokens across all iterations.</param>
/// <param name="hadMultipleIterations">Whether the function calling loop had multiple iterations.</param>
private void ApplyAggregatedState(
ChatMessageContent message,
StringBuilder? aggregatedContent,
int totalPromptTokens,
int totalCandidatesTokens,
bool hadMultipleIterations)
{
this.ApplyAggregatedStateToMessage(message, aggregatedContent, totalPromptTokens, totalCandidatesTokens, hadMultipleIterations);
}

/// <summary>
/// Core implementation for applying aggregated state to a message.
/// </summary>
private void ApplyAggregatedStateToMessage(
ChatMessageContent message,
StringBuilder? aggregatedContent,
int totalPromptTokens,
int totalCandidatesTokens,
bool hadMultipleIterations)
{
// Prepend aggregated content from previous iterations.
if (aggregatedContent is { Length: > 0 })
{
if (!string.IsNullOrEmpty(message.Content))
{
aggregatedContent.Append("\n\n");
aggregatedContent.Append(message.Content);
}
message.Content = aggregatedContent.ToString();
}

// Update metadata with aggregated usage if we had multiple iterations.
// This ensures token counts are accurate even when intermediate iterations had no text content.
if (hadMultipleIterations && message.Metadata is GeminiMetadata existingMetadata)
{
// Create a new metadata dictionary with aggregated values.
var updatedDict = new Dictionary<string, object?>(existingMetadata)
{
[nameof(GeminiMetadata.PromptTokenCount)] = totalPromptTokens,
[nameof(GeminiMetadata.CandidatesTokenCount)] = totalCandidatesTokens,
[nameof(GeminiMetadata.TotalTokenCount)] = totalPromptTokens + totalCandidatesTokens
};
message.Metadata = GeminiMetadata.FromDictionary(updatedDict);
}
}

private sealed class ChatCompletionState
{
internal ChatHistory ChatHistory { get; set; } = null!;
Expand Down
105 changes: 104 additions & 1 deletion dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
var endpoint = this.GetEndpoint(mistralExecutionSettings, path: "chat/completions");
var autoInvoke = kernel is not null && mistralExecutionSettings.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;

// Aggregation state for multi-iteration function calling loops.
// Text content from intermediate iterations (before tool calls) would otherwise be lost.
// Token usage must be summed across all API calls to report accurate totals.
StringBuilder? aggregatedContent = null;
int totalPromptTokens = 0;
int totalCompletionTokens = 0;
bool hadMultipleIterations = false;

for (int requestIndex = 1; ; requestIndex++)
{
var chatRequest = this.CreateChatCompletionRequest(modelId, stream: false, chatHistory, mistralExecutionSettings, kernel);
Expand Down Expand Up @@ -105,10 +113,16 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
activity?.SetCompletionResponse(responseContent, responseData.Usage?.PromptTokens, responseData.Usage?.CompletionTokens);
}

// Aggregate usage across all iterations.
totalPromptTokens += responseData.Usage?.PromptTokens ?? 0;
totalCompletionTokens += responseData.Usage?.CompletionTokens ?? 0;

// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
if (!autoInvoke || responseData.Choices.Count != 1)
{
// Apply aggregated content and usage to the final message.
ApplyAggregatedState(responseContent, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations);
return responseContent;
}

Expand All @@ -120,9 +134,27 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
MistralChatChoice chatChoice = responseData.Choices[0]; // TODO Handle multiple choices
if (!chatChoice.IsToolCall)
{
// Apply aggregated content and usage to the final message.
ApplyAggregatedState(responseContent, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations);
return responseContent;
}

// We're about to process tool calls and continue the loop - mark that we have multiple iterations.
hadMultipleIterations = true;

// Accumulate text content from this iteration before processing tool calls.
// The LLM may generate text (e.g., "Let me check that for you...") before tool calls.
var currentContent = responseContent.Count > 0 ? responseContent[0].Content : null;
if (!string.IsNullOrEmpty(currentContent))
{
aggregatedContent ??= new StringBuilder();
if (aggregatedContent.Length > 0)
{
aggregatedContent.Append("\n\n");
}
aggregatedContent.Append(currentContent);
}

if (this._logger.IsEnabled(LogLevel.Debug))
{
this._logger.LogDebug("Tool requests: {Requests}", chatChoice.ToolCallCount);
Expand Down Expand Up @@ -226,7 +258,10 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
this._logger.LogDebug("Filter requested termination of automatic function invocation.");
}

return [chatHistory.Last()];
var lastMessage = chatHistory.Last();
// Apply aggregated content and usage to the filter-terminated message.
ApplyAggregatedState(lastMessage, aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations);
return [lastMessage];
}
}

Expand Down Expand Up @@ -1088,5 +1123,73 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(context
await functionCallCallback(context).ConfigureAwait(false);
}
}

/// <summary>
/// Applies aggregated text content and usage from previous iterations to the final message(s).
/// This ensures that text generated before tool calls is not lost and that token usage
/// reflects the total across all API calls in the function calling loop.
/// </summary>
/// <param name="messages">The list of messages to update.</param>
/// <param name="aggregatedContent">Accumulated text content from previous iterations, or null if none.</param>
/// <param name="totalPromptTokens">Total prompt tokens across all iterations.</param>
/// <param name="totalCompletionTokens">Total completion tokens across all iterations.</param>
/// <param name="hadMultipleIterations">Whether the function calling loop had multiple iterations.</param>
private static void ApplyAggregatedState(
List<ChatMessageContent> messages,
StringBuilder? aggregatedContent,
int totalPromptTokens,
int totalCompletionTokens,
bool hadMultipleIterations)
{
if (messages.Count == 0)
{
return;
}

ApplyAggregatedState(messages[0], aggregatedContent, totalPromptTokens, totalCompletionTokens, hadMultipleIterations);
}

/// <summary>
/// Applies aggregated text content and usage from previous iterations to a single message.
/// </summary>
/// <param name="message">The message to update.</param>
/// <param name="aggregatedContent">Accumulated text content from previous iterations, or null if none.</param>
/// <param name="totalPromptTokens">Total prompt tokens across all iterations.</param>
/// <param name="totalCompletionTokens">Total completion tokens across all iterations.</param>
/// <param name="hadMultipleIterations">Whether the function calling loop had multiple iterations.</param>
private static void ApplyAggregatedState(
ChatMessageContent message,
StringBuilder? aggregatedContent,
int totalPromptTokens,
int totalCompletionTokens,
bool hadMultipleIterations)
{
// Prepend aggregated content from previous iterations.
if (aggregatedContent is { Length: > 0 })
{
if (!string.IsNullOrEmpty(message.Content))
{
aggregatedContent.Append("\n\n");
aggregatedContent.Append(message.Content);
}
message.Content = aggregatedContent.ToString();
}

// Update metadata with aggregated usage if we had multiple iterations.
// This ensures token counts are accurate even when intermediate iterations had no text content.
if (hadMultipleIterations && message.Metadata is not null)
{
var updatedMetadata = new Dictionary<string, object?>(message.Metadata)
{
["AggregatedUsage"] = new Dictionary<string, int>
{
["PromptTokens"] = totalPromptTokens,
["CompletionTokens"] = totalCompletionTokens,
["TotalTokens"] = totalPromptTokens + totalCompletionTokens
}
};
message.Metadata = new System.Collections.ObjectModel.ReadOnlyDictionary<string, object?>(updatedMetadata);
}
}
#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@
<None Update="TestData\filters_chatclient_streaming_multiple_function_calls_test_response.txt">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\aggregation_function_call_with_text_response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\aggregation_final_response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
Loading