Skip to content

Commit ecc8cc4

Browse files
Parse tool usage in Bedrock Anthropic streaming (#1310)
Updates how Bedrock Anthropic responses are parsed to include tool usage. It reuses the logic from the Anthropic LLM client. As a result, this PR also has a couple of additional effects: - Enforcing snake case for properties in Bedrock Anthropic responses - Changes how other non-Anthropic Bedrock providers are handled This should resolve [KG-627](https://youtrack.jetbrains.com/issue/KG-627/Error-from-Bedrock-executor-on-a-streaming-with-a-tool-call). ## Motivation and Context Tool usage is missing when using Anthropic models with streaming in Bedrock.
1 parent 6f670e8 commit ecc8cc4

File tree

5 files changed

+291
-187
lines changed

5 files changed

+291
-187
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,10 +1132,6 @@ abstract class ExecutorIntegrationTestBase {
11321132
model.provider !== LLMProvider.OpenRouter,
11331133
"KG-626 Error from OpenRouter on a streaming with a tool call"
11341134
)
1135-
assumeTrue(
1136-
model.provider !== LLMProvider.Bedrock,
1137-
"KG-627 Error from Bedrock executor on a streaming with a tool call"
1138-
)
11391135

11401136
val executor = getExecutor(model)
11411137

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
5454
import kotlinx.coroutines.FlowPreview
5555
import kotlinx.coroutines.flow.Flow
5656
import kotlinx.coroutines.flow.channelFlow
57+
import kotlinx.coroutines.flow.filterNot
5758
import kotlinx.coroutines.flow.map
5859
import kotlinx.coroutines.flow.transform
5960
import kotlinx.coroutines.withContext
@@ -319,40 +320,57 @@ public class BedrockLLMClient @JvmOverloads constructor(
319320
logger.error(exception) { exception.message }
320321
close(exception)
321322
}
322-
}.map { chunkJsonString ->
323-
try {
324-
if (chunkJsonString.isBlank()) return@map emptyList()
325-
when (modelFamily) {
326-
is BedrockModelFamilies.AI21Jamba -> BedrockAI21JambaSerialization.parseJambaStreamChunk(
327-
chunkJsonString
328-
)
323+
}.filterNot {
324+
it.isBlank()
325+
}.run {
326+
when (modelFamily) {
327+
is BedrockModelFamilies.AI21Jamba -> genericProcessStream(
328+
this,
329+
BedrockAI21JambaSerialization::parseJambaStreamChunk
330+
)
329331

330-
is BedrockModelFamilies.AmazonNova -> BedrockAmazonNovaSerialization.parseNovaStreamChunk(
331-
chunkJsonString
332-
)
332+
is BedrockModelFamilies.AmazonNova -> genericProcessStream(
333+
this,
334+
BedrockAmazonNovaSerialization::parseNovaStreamChunk
335+
)
333336

334-
is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.parseAnthropicStreamChunk(
335-
chunkJsonString
336-
)
337+
is BedrockModelFamilies.Meta -> genericProcessStream(
338+
this,
339+
BedrockMetaLlamaSerialization::parseLlamaStreamChunk
340+
)
337341

338-
is BedrockModelFamilies.Meta -> BedrockMetaLlamaSerialization.parseLlamaStreamChunk(chunkJsonString)
342+
is BedrockModelFamilies.AnthropicClaude -> BedrockAnthropicClaudeSerialization.transformAnthropicStreamChunks(
343+
chunkJsonStringFlow = this,
344+
clock = clock,
345+
)
339346

340-
is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere ->
341-
throw LLMClientException(
342-
clientName,
343-
"Embedding models do not support streaming chat completions. Use embed() instead."
344-
)
345-
}
347+
is BedrockModelFamilies.TitanEmbedding, is BedrockModelFamilies.Cohere ->
348+
throw LLMClientException(
349+
clientName,
350+
"Embedding models do not support streaming chat completions. Use embed() instead."
351+
)
352+
}
353+
}
354+
}
355+
356+
/**
357+
* Processes a flow of JSON strings into StreamFrames using the provided processor function.
358+
* Handles exceptions by logging and re-throwing them.
359+
*/
360+
private fun genericProcessStream(
361+
chunkJsonStringFlow: Flow<String>,
362+
processor: (String) -> List<StreamFrame>
363+
): Flow<StreamFrame> =
364+
chunkJsonStringFlow.map { chunkJsonString ->
365+
try {
366+
processor(chunkJsonString)
346367
} catch (e: Exception) {
347368
logger.warn(e) { "Failed to parse Bedrock stream chunk: $chunkJsonString" }
348369
throw e
349370
}
350371
}.transform { frames ->
351-
frames.forEach {
352-
emit(it)
353-
}
372+
frames.forEach { emit(it) }
354373
}
355-
}
356374

357375
override suspend fun embed(text: String, model: LLModel): List<Double> {
358376
model.requireCapability(LLMCapability.Embed)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/BedrockDataClasses.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import kotlinx.serialization.SerialName
55
import kotlinx.serialization.Serializable
66
import kotlinx.serialization.json.JsonClassDiscriminator
77
import kotlinx.serialization.json.JsonElement
8-
import kotlinx.serialization.json.JsonNames
98
import kotlinx.serialization.json.JsonObject
109

1110
/**
@@ -205,7 +204,7 @@ public data class BedrockAnthropicResponse(
205204
val role: String,
206205
val content: List<AnthropicContent>,
207206
val model: String,
208-
@JsonNames("stop_reason") val stopReason: String? = null,
207+
val stopReason: String? = null,
209208
val usage: BedrockAnthropicUsage? = null
210209
)
211210

@@ -220,6 +219,6 @@ public data class BedrockAnthropicResponse(
220219
*/
221220
@Serializable
222221
public data class BedrockAnthropicUsage(
223-
@SerialName("input_tokens") @JsonNames("inputTokens", "input_tokens") val inputTokens: Int,
224-
@SerialName("output_tokens") @JsonNames("outputTokens", "output_tokens") val outputTokens: Int
222+
val inputTokens: Int,
223+
val outputTokens: Int
225224
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.anthropic
33
import ai.koog.agents.core.tools.ToolDescriptor
44
import ai.koog.prompt.dsl.Prompt
55
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicContent
6+
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamDeltaContentType
7+
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamEventType
68
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicStreamResponse
9+
import ai.koog.prompt.executor.clients.anthropic.models.AnthropicUsage
710
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModel
811
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelContent
912
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModelMessage
@@ -15,9 +18,12 @@ import ai.koog.prompt.message.Message
1518
import ai.koog.prompt.message.ResponseMetaInfo
1619
import ai.koog.prompt.params.LLMParams
1720
import ai.koog.prompt.streaming.StreamFrame
21+
import ai.koog.prompt.streaming.buildStreamFrameFlow
1822
import io.github.oshai.kotlinlogging.KotlinLogging
23+
import kotlinx.coroutines.flow.Flow
1924
import kotlinx.datetime.Clock
2025
import kotlinx.serialization.json.Json
26+
import kotlinx.serialization.json.JsonNamingStrategy
2127
import kotlinx.serialization.json.buildJsonArray
2228
import kotlinx.serialization.json.buildJsonObject
2329
import kotlinx.serialization.json.encodeToJsonElement
@@ -31,6 +37,7 @@ internal object BedrockAnthropicClaudeSerialization {
3137
ignoreUnknownKeys = true
3238
isLenient = true
3339
explicitNulls = false
40+
namingStrategy = JsonNamingStrategy.SnakeCase
3441
}
3542

3643
private fun buildMessagesHistory(prompt: Prompt): MutableList<BedrockAnthropicInvokeModelMessage> {
@@ -217,72 +224,103 @@ internal object BedrockAnthropicClaudeSerialization {
217224
}
218225
}
219226

220-
internal fun parseAnthropicStreamChunk(chunkJsonString: String, clock: Clock = Clock.System): List<StreamFrame> {
221-
val streamResponse = json.decodeFromString<AnthropicStreamResponse>(chunkJsonString)
222-
223-
return when (streamResponse.type) {
224-
"content_block_delta" -> {
225-
streamResponse.delta?.let {
226-
buildList {
227-
it.text?.let(StreamFrame::Append)?.let(::add)
228-
it.toolUse?.let { toolUse ->
229-
StreamFrame.ToolCall(
230-
id = toolUse.id,
231-
name = toolUse.name,
232-
content = toolUse.input.toString()
227+
internal fun transformAnthropicStreamChunks(
228+
chunkJsonStringFlow: Flow<String>,
229+
clock: Clock = Clock.System
230+
): Flow<StreamFrame> = buildStreamFrameFlow {
231+
var inputTokens: Int? = null
232+
var outputTokens: Int? = null
233+
234+
fun updateUsage(usage: AnthropicUsage) {
235+
inputTokens = usage.inputTokens ?: inputTokens
236+
outputTokens = usage.outputTokens ?: outputTokens
237+
}
238+
239+
fun getMetaInfo(): ResponseMetaInfo = ResponseMetaInfo.create(
240+
clock = clock,
241+
totalTokensCount = inputTokens?.plus(outputTokens ?: 0) ?: outputTokens,
242+
inputTokensCount = inputTokens,
243+
outputTokensCount = outputTokens,
244+
)
245+
246+
chunkJsonStringFlow.collect { chunkJsonString ->
247+
val response = json.decodeFromString<AnthropicStreamResponse>(chunkJsonString)
248+
249+
when (response.type) {
250+
AnthropicStreamEventType.MESSAGE_START.value -> {
251+
response.message?.usage?.let(::updateUsage)
252+
}
253+
254+
AnthropicStreamEventType.CONTENT_BLOCK_START.value -> {
255+
when (val contentBlock = response.contentBlock) {
256+
is AnthropicContent.Text -> {
257+
emitAppend(contentBlock.text)
258+
}
259+
260+
is AnthropicContent.ToolUse -> {
261+
upsertToolCall(
262+
index = response.index ?: error("Tool index is missing"),
263+
id = contentBlock.id,
264+
name = contentBlock.name,
233265
)
234-
}?.let(::add)
266+
}
267+
268+
else -> {
269+
contentBlock?.let { logger.warn { "Unknown Anthropic stream content block type: ${it::class}" } }
270+
?: logger.warn { "Anthropic stream content block is missing" }
271+
}
235272
}
236-
} ?: emptyList()
237-
}
273+
}
238274

239-
"message_delta" -> {
240-
streamResponse.message?.content?.map { content ->
241-
when (content) {
242-
is AnthropicContent.Text ->
243-
StreamFrame.Append(content.text)
275+
AnthropicStreamEventType.CONTENT_BLOCK_DELTA.value -> {
276+
response.delta?.let { delta ->
277+
// Handles deltas for tool calls and text
244278

245-
is AnthropicContent.Thinking ->
246-
StreamFrame.Append(content.thinking)
279+
when (delta.type) {
280+
AnthropicStreamDeltaContentType.INPUT_JSON_DELTA.value -> {
281+
upsertToolCall(
282+
index = response.index ?: error("Tool index is missing"),
283+
args = delta.partialJson ?: error("Tool args are missing")
284+
)
285+
}
247286

248-
is AnthropicContent.ToolUse ->
249-
StreamFrame.ToolCall(
250-
id = content.id,
251-
name = content.name,
252-
content = content.input.toString()
253-
)
287+
AnthropicStreamDeltaContentType.TEXT_DELTA.value -> {
288+
emitAppend(
289+
delta.text ?: error("Text delta is missing")
290+
)
291+
}
254292

255-
else -> throw IllegalArgumentException(
256-
"Unsupported AnthropicContent type in message_delta. Content: $content"
257-
)
293+
else -> {
294+
logger.warn { "Unknown Anthropic stream delta type: ${delta.type}" }
295+
}
296+
}
258297
}
259-
} ?: emptyList()
260-
}
298+
}
261299

262-
"message_start" -> {
263-
val inputTokens = streamResponse.message?.usage?.inputTokens
264-
logger.debug { "Bedrock stream starts. Input tokens: $inputTokens" }
265-
emptyList()
266-
}
300+
AnthropicStreamEventType.CONTENT_BLOCK_STOP.value -> {
301+
tryEmitPendingToolCall()
302+
}
267303

268-
"message_stop" -> {
269-
val inputTokens = streamResponse.message?.usage?.inputTokens
270-
val outputTokens = streamResponse.message?.usage?.outputTokens
271-
logger.debug { "Bedrock stream stops. Output tokens: $outputTokens" }
272-
listOf(
273-
StreamFrame.End(
274-
finishReason = streamResponse.message?.stopReason,
275-
metaInfo = ResponseMetaInfo.create(
276-
clock = clock,
277-
totalTokensCount = inputTokens?.let { it + (outputTokens ?: 0) } ?: outputTokens,
278-
inputTokensCount = inputTokens,
279-
outputTokensCount = outputTokens
280-
)
304+
AnthropicStreamEventType.MESSAGE_DELTA.value -> {
305+
response.usage?.let(::updateUsage)
306+
emitEnd(
307+
finishReason = response.delta?.stopReason,
308+
metaInfo = getMetaInfo()
281309
)
282-
)
283-
}
310+
}
311+
312+
AnthropicStreamEventType.MESSAGE_STOP.value -> {
313+
logger.debug { "Received stop message event from Anthropic" }
314+
}
284315

285-
else -> emptyList()
316+
AnthropicStreamEventType.ERROR.value -> {
317+
error("Anthropic error: ${response.error}")
318+
}
319+
320+
AnthropicStreamEventType.PING.value -> {
321+
logger.debug { "Received ping from Anthropic" }
322+
}
323+
}
286324
}
287325
}
288326
}

0 commit comments

Comments
 (0)