Skip to content

Commit 5f76fbd

Browse files
authored
fix(prompt): add missing token usage info to OpenAI-like clients in streaming mode (#1404)
* `OpenAILLMClient` was missing `includeUsage = true` in its completions endpoint requests (which is the default endpoint) * For all OpenAI-like clients, usage metadata in streaming events is reported **after** the "stop reason" event, but the assumption was that it happens in the same event. Fixed streaming chunks processing to correctly handle usage metadata chunk * Updated integration test to check that token usage info is present when testing streaming Fix #1072
1 parent f031504 commit 5f76fbd

File tree

7 files changed

+174
-78
lines changed
  • integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor
  • prompt/prompt-executor/prompt-executor-clients
    • prompt-executor-dashscope-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/dashscope
    • prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek
    • prompt-executor-mistralai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/mistralai
    • prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base
    • prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai
    • prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter

7 files changed

+174
-78
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ import io.kotest.matchers.collections.shouldNotBeEmpty
7676
import io.kotest.matchers.collections.shouldNotContainAnyOf
7777
import io.kotest.matchers.ints.shouldBeGreaterThan
7878
import io.kotest.matchers.nulls.shouldNotBeNull
79+
import io.kotest.matchers.should
7980
import io.kotest.matchers.shouldBe
8081
import io.kotest.matchers.shouldNotBe
8182
import io.kotest.matchers.string.shouldContain
@@ -240,7 +241,18 @@ abstract class ExecutorIntegrationTestBase {
240241
toolMessages.shouldBeEmpty()
241242
when (model.provider) {
242243
is LLMProvider.Ollama -> endMessages.size shouldBe 0
243-
else -> endMessages.size shouldBe 1
244+
245+
else -> {
246+
endMessages.size shouldBe 1
247+
endMessages.first() should { end ->
248+
end.metaInfo should { meta ->
249+
withClue("ResponseMetaInfo should contain at least some non-nullable token count info") {
250+
listOf(meta.inputTokensCount, meta.outputTokensCount, meta.totalTokensCount)
251+
.shouldForAny { it != null }
252+
}
253+
}
254+
}
255+
}
244256
}
245257

246258
toString() shouldNotBeNull {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-dashscope-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/dashscope/DashscopeLLMClient.kt

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
1616
import ai.koog.prompt.llm.LLMProvider
1717
import ai.koog.prompt.llm.LLModel
1818
import ai.koog.prompt.message.LLMChoice
19+
import ai.koog.prompt.message.ResponseMetaInfo
1920
import ai.koog.prompt.params.LLMParams
20-
import ai.koog.prompt.streaming.StreamFrameFlowBuilder
21+
import ai.koog.prompt.streaming.StreamFrame
22+
import ai.koog.prompt.streaming.buildStreamFrameFlow
2123
import io.github.oshai.kotlinlogging.KotlinLogging
2224
import io.ktor.client.HttpClient
25+
import kotlinx.coroutines.flow.Flow
2326
import kotlinx.datetime.Clock
2427
import kotlin.jvm.JvmOverloads
2528

@@ -123,18 +126,31 @@ public class DashscopeLLMClient @JvmOverloads constructor(
123126
override fun decodeResponse(data: String): DashscopeChatCompletionResponse =
124127
json.decodeFromString(data)
125128

126-
override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: DashscopeChatCompletionStreamResponse) {
127-
chunk.choices.firstOrNull()?.let { choice ->
128-
choice.delta.content?.let { emitAppend(it) }
129-
choice.delta.toolCalls?.forEach { toolCall ->
130-
val index = toolCall.index
131-
val id = toolCall.id
132-
val name = toolCall.function?.name
133-
val arguments = toolCall.function?.arguments
134-
upsertToolCall(index, id, name, arguments)
129+
override fun processStreamingResponse(
130+
response: Flow<DashscopeChatCompletionStreamResponse>
131+
): Flow<StreamFrame> = buildStreamFrameFlow {
132+
var finishReason: String? = null
133+
var metaInfo: ResponseMetaInfo? = null
134+
135+
response.collect { chunk ->
136+
chunk.choices.firstOrNull()?.let { choice ->
137+
choice.delta.content?.let { emitAppend(it) }
138+
139+
choice.delta.toolCalls?.forEach { toolCall ->
140+
val index = toolCall.index
141+
val id = toolCall.id
142+
val name = toolCall.function?.name
143+
val arguments = toolCall.function?.arguments
144+
upsertToolCall(index, id, name, arguments)
145+
}
146+
147+
choice.finishReason?.let { finishReason = it }
135148
}
136-
choice.finishReason?.let { emitEnd(it, createMetaInfo(chunk.usage)) }
149+
150+
chunk.usage?.let { metaInfo = createMetaInfo(chunk.usage) }
137151
}
152+
153+
emitEnd(finishReason, metaInfo)
138154
}
139155

140156
public override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
2020
import ai.koog.prompt.llm.LLMProvider
2121
import ai.koog.prompt.llm.LLModel
2222
import ai.koog.prompt.message.LLMChoice
23+
import ai.koog.prompt.message.ResponseMetaInfo
2324
import ai.koog.prompt.params.LLMParams
24-
import ai.koog.prompt.streaming.StreamFrameFlowBuilder
25+
import ai.koog.prompt.streaming.StreamFrame
26+
import ai.koog.prompt.streaming.buildStreamFrameFlow
2527
import io.github.oshai.kotlinlogging.KotlinLogging
2628
import io.ktor.client.HttpClient
29+
import kotlinx.coroutines.flow.Flow
2730
import kotlinx.datetime.Clock
2831
import kotlin.jvm.JvmOverloads
2932

@@ -140,18 +143,31 @@ public class DeepSeekLLMClient @JvmOverloads constructor(
140143
override fun decodeResponse(data: String): DeepSeekChatCompletionResponse =
141144
json.decodeFromString(data)
142145

143-
override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: DeepSeekChatCompletionStreamResponse) {
144-
chunk.choices.firstOrNull()?.let { choice ->
145-
choice.delta.content?.let { emitAppend(it) }
146-
choice.delta.toolCalls?.forEach { toolCall ->
147-
val index = toolCall.index
148-
val id = toolCall.id
149-
val name = toolCall.function?.name
150-
val arguments = toolCall.function?.arguments
151-
upsertToolCall(index, id, name, arguments)
146+
override fun processStreamingResponse(
147+
response: Flow<DeepSeekChatCompletionStreamResponse>
148+
): Flow<StreamFrame> = buildStreamFrameFlow {
149+
var finishReason: String? = null
150+
var metaInfo: ResponseMetaInfo? = null
151+
152+
response.collect { chunk ->
153+
chunk.choices.firstOrNull()?.let { choice ->
154+
choice.delta.content?.let { emitAppend(it) }
155+
156+
choice.delta.toolCalls?.forEach { toolCall ->
157+
val index = toolCall.index
158+
val id = toolCall.id
159+
val name = toolCall.function?.name
160+
val arguments = toolCall.function?.arguments
161+
upsertToolCall(index, id, name, arguments)
162+
}
163+
164+
choice.finishReason?.let { finishReason = it }
152165
}
153-
choice.finishReason?.let { emitEnd(it, createMetaInfo(chunk.usage)) }
166+
167+
chunk.usage?.let { metaInfo = createMetaInfo(chunk.usage) }
154168
}
169+
170+
emitEnd(finishReason, metaInfo)
155171
}
156172

157173
override fun createResponseFormat(schema: LLMParams.Schema?, model: LLModel): OpenAIResponseFormat? {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-mistralai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/mistralai/MistralAILLMClient.kt

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ import ai.koog.prompt.llm.LLMCapability
3232
import ai.koog.prompt.llm.LLMProvider
3333
import ai.koog.prompt.llm.LLModel
3434
import ai.koog.prompt.message.LLMChoice
35+
import ai.koog.prompt.message.ResponseMetaInfo
3536
import ai.koog.prompt.params.LLMParams
36-
import ai.koog.prompt.streaming.StreamFrameFlowBuilder
37+
import ai.koog.prompt.streaming.StreamFrame
38+
import ai.koog.prompt.streaming.buildStreamFrameFlow
3739
import io.github.oshai.kotlinlogging.KotlinLogging
3840
import io.ktor.client.HttpClient
3941
import kotlinx.coroutines.CancellationException
42+
import kotlinx.coroutines.flow.Flow
4043
import kotlinx.datetime.Clock
4144

4245
/**
@@ -153,23 +156,39 @@ public open class MistralAILLMClient(
153156
override fun decodeResponse(data: String): MistralAIChatCompletionResponse =
154157
json.decodeFromString(data)
155158

156-
override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: MistralAIChatCompletionStreamResponse) {
157-
chunk.choices.firstOrNull()?.let { choice ->
158-
choice.delta.content?.let { emitAppend(it) }
159-
choice.delta.toolCalls?.forEach { toolCall ->
160-
val index = toolCall.index
161-
val id = toolCall.id
162-
val name = toolCall.function?.name
163-
val arguments = toolCall.function?.arguments
164-
upsertToolCall(index, id, name, arguments)
159+
override fun processStreamingResponse(
160+
response: Flow<MistralAIChatCompletionStreamResponse>
161+
): Flow<StreamFrame> = buildStreamFrameFlow {
162+
var finishReason: String? = null
163+
var metaInfo: ResponseMetaInfo? = null
164+
165+
response.collect { chunk ->
166+
chunk.choices.firstOrNull()?.let { choice ->
167+
choice.delta.content?.let { emitAppend(it) }
168+
169+
choice.delta.toolCalls?.forEach { toolCall ->
170+
val index = toolCall.index
171+
val id = toolCall.id
172+
val name = toolCall.function?.name
173+
val arguments = toolCall.function?.arguments
174+
upsertToolCall(index, id, name, arguments)
175+
}
176+
177+
choice.finishReason?.let { finishReason = it }
178+
}
179+
180+
chunk.usage?.let { usage ->
181+
metaInfo = createMetaInfo(
182+
OpenAIUsage(
183+
promptTokens = usage.promptTokens,
184+
completionTokens = usage.completionTokens,
185+
totalTokens = usage.totalTokens,
186+
)
187+
)
165188
}
166-
val usageInfo = OpenAIUsage(
167-
promptTokens = chunk.usage?.promptTokens,
168-
completionTokens = chunk.usage?.completionTokens,
169-
totalTokens = chunk.usage?.totalTokens,
170-
)
171-
choice.finishReason?.let { emitEnd(it, createMetaInfo(usageInfo)) }
172189
}
190+
191+
emitEnd(finishReason, metaInfo)
173192
}
174193

175194
/**

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import ai.koog.prompt.message.Message
3232
import ai.koog.prompt.message.ResponseMetaInfo
3333
import ai.koog.prompt.params.LLMParams
3434
import ai.koog.prompt.streaming.StreamFrame
35-
import ai.koog.prompt.streaming.StreamFrameFlowBuilder
36-
import ai.koog.prompt.streaming.buildStreamFrameFlow
3735
import ai.koog.prompt.structure.RegisteredBasicJsonSchemaGenerators
3836
import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators
3937
import ai.koog.prompt.structure.annotations.InternalStructuredOutputApi
@@ -49,6 +47,7 @@ import io.ktor.http.contentType
4947
import io.ktor.serialization.kotlinx.json.json
5048
import kotlinx.coroutines.CancellationException
5149
import kotlinx.coroutines.flow.Flow
50+
import kotlinx.coroutines.flow.channelFlow
5251
import kotlinx.datetime.Clock
5352
import kotlinx.serialization.json.Json
5453
import kotlinx.serialization.json.JsonNamingStrategy
@@ -165,10 +164,10 @@ public abstract class AbstractOpenAILLMClient<TResponse : OpenAIBaseLLMResponse,
165164
protected abstract fun decodeResponse(data: String): TResponse
166165

167166
/**
168-
* Processes a provider-specific streaming response chunk.
167+
* Processes a provider-specific streaming response.
169168
* Must be implemented by concrete client classes.
170169
*/
171-
protected abstract suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: TStreamResponse)
170+
protected abstract fun processStreamingResponse(response: Flow<TStreamResponse>): Flow<StreamFrame>
172171

173172
override suspend fun execute(prompt: Prompt, model: LLModel, tools: List<ToolDescriptor>): List<Message.Response> {
174173
val response = getResponse(prompt, model, tools)
@@ -193,27 +192,25 @@ public abstract class AbstractOpenAILLMClient<TResponse : OpenAIBaseLLMResponse,
193192
stream = true
194193
)
195194

196-
return buildStreamFrameFlow {
197-
try {
195+
return try {
196+
channelFlow {
198197
httpClient.sse(
199198
path = chatCompletionsPath,
200199
request = request,
201200
requestBodyType = String::class,
202201
dataFilter = { it != "[DONE]" },
203202
decodeStreamingResponse = ::decodeStreamingResponse,
204203
processStreamingChunk = { it }
205-
).collect {
206-
processStreamingChunk(it)
207-
}
208-
} catch (e: CancellationException) {
209-
throw e
210-
} catch (e: Exception) {
211-
throw LLMClientException(
212-
clientName = clientName,
213-
message = e.message,
214-
cause = e
215-
)
216-
}
204+
).collect { send(it) }
205+
}.let { processStreamingResponse(it) }
206+
} catch (e: CancellationException) {
207+
throw e
208+
} catch (e: Exception) {
209+
throw LLMClientException(
210+
clientName = clientName,
211+
message = e.message,
212+
cause = e
213+
)
217214
}
218215
}
219216

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIContentPart
1919
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage
2020
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIModalities
2121
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStaticContent
22+
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStreamOptions
2223
import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool
2324
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
2425
import ai.koog.prompt.executor.clients.openai.models.InputContent
@@ -50,7 +51,7 @@ import ai.koog.prompt.message.Message
5051
import ai.koog.prompt.message.ResponseMetaInfo
5152
import ai.koog.prompt.params.LLMParams
5253
import ai.koog.prompt.streaming.StreamFrame
53-
import ai.koog.prompt.streaming.StreamFrameFlowBuilder
54+
import ai.koog.prompt.streaming.buildStreamFrameFlow
5455
import ai.koog.utils.io.SuitableForIO
5556
import io.github.oshai.kotlinlogging.KotlinLogging
5657
import io.ktor.client.HttpClient
@@ -145,6 +146,11 @@ public open class OpenAILLMClient @JvmOverloads constructor(
145146
}
146147

147148
val responseFormat = createResponseFormat(chatParams.schema, model)
149+
val streamOptions = if (stream) {
150+
OpenAIStreamOptions(includeUsage = true)
151+
} else {
152+
null
153+
}
148154

149155
val request = OpenAIChatCompletionRequest(
150156
messages = messages,
@@ -167,6 +173,7 @@ public open class OpenAILLMClient @JvmOverloads constructor(
167173
stop = chatParams.stop,
168174
store = chatParams.store,
169175
stream = stream,
176+
streamOptions = streamOptions,
170177
temperature = chatParams.temperature,
171178
toolChoice = toolChoice,
172179
tools = tools,
@@ -256,18 +263,31 @@ public open class OpenAILLMClient @JvmOverloads constructor(
256263
override fun decodeResponse(data: String): OpenAIChatCompletionResponse =
257264
json.decodeFromString(data)
258265

259-
override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: OpenAIChatCompletionStreamResponse) {
260-
chunk.choices.firstOrNull()?.let { choice ->
261-
choice.delta.content?.let { emitAppend(it) }
262-
choice.delta.toolCalls?.forEach { openAIToolCall ->
263-
val index = openAIToolCall.index
264-
val id = openAIToolCall.id
265-
val functionName = openAIToolCall.function?.name
266-
val functionArgs = openAIToolCall.function?.arguments
267-
upsertToolCall(index, id, functionName, functionArgs)
266+
override fun processStreamingResponse(
267+
response: Flow<OpenAIChatCompletionStreamResponse>
268+
): Flow<StreamFrame> = buildStreamFrameFlow {
269+
var finishReason: String? = null
270+
var metaInfo: ResponseMetaInfo? = null
271+
272+
response.collect { chunk ->
273+
chunk.choices.firstOrNull()?.let { choice ->
274+
choice.delta.content?.let { emitAppend(it) }
275+
276+
choice.delta.toolCalls?.forEach { openAIToolCall ->
277+
val index = openAIToolCall.index
278+
val id = openAIToolCall.id
279+
val functionName = openAIToolCall.function?.name
280+
val functionArgs = openAIToolCall.function?.arguments
281+
upsertToolCall(index, id, functionName, functionArgs)
282+
}
283+
284+
choice.finishReason?.let { finishReason = it }
268285
}
269-
choice.finishReason?.let { emitEnd(it, createMetaInfo(chunk.usage)) }
286+
287+
chunk.usage?.let { metaInfo = createMetaInfo(it) }
270288
}
289+
290+
emitEnd(finishReason, metaInfo)
271291
}
272292

273293
override suspend fun execute(prompt: Prompt, model: LLModel, tools: List<ToolDescriptor>): List<Message.Response> {

0 commit comments

Comments
 (0)