Skip to content

Commit 8dc1bb4

Browse files
authored
[prompt] Converse API support in Bedrock LLM client (#1384)
Adds [Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) support to Bedrock LLM client. Multimodality and streaming are supported. Also adds new `apiMethod` parameter to Bedrock LLM client settings/constructor to control the API used: manual `InvokeModel` or higher-level `Converse`. Default is `InvokeModel` for now. There's also a new integration test suite `BedrockConverseApiIntegrationTest` to verify that it works as expected. Fixes #1050
1 parent a32d3a4 commit 8dc1bb4

File tree

20 files changed

+1570
-79
lines changed

20 files changed

+1570
-79
lines changed

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ agp = "8.12.3"
44
annotations = "26.0.2-1"
55
assertj = "3.27.6"
66
awaitility = "4.3.0"
7-
aws-sdk-kotlin = "1.5.16"
7+
aws-sdk-kotlin = "1.5.123"
88
dokka = "2.1.0"
99
exposed = "0.61.0"
1010
h2 = "2.4.240"
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
package ai.koog.integration.tests.executor
2+
3+
import ai.koog.integration.tests.utils.MediaTestScenarios.AudioTestScenario
4+
import ai.koog.integration.tests.utils.MediaTestScenarios.ImageTestScenario
5+
import ai.koog.integration.tests.utils.MediaTestScenarios.MarkdownTestScenario
6+
import ai.koog.integration.tests.utils.MediaTestScenarios.TextTestScenario
7+
import ai.koog.integration.tests.utils.Models
8+
import ai.koog.integration.tests.utils.TestCredentials.readAwsAccessKeyIdFromEnv
9+
import ai.koog.integration.tests.utils.TestCredentials.readAwsBedrockGuardrailIdFromEnv
10+
import ai.koog.integration.tests.utils.TestCredentials.readAwsBedrockGuardrailVersionFromEnv
11+
import ai.koog.integration.tests.utils.TestCredentials.readAwsSecretAccessKeyFromEnv
12+
import ai.koog.integration.tests.utils.TestCredentials.readAwsSessionTokenFromEnv
13+
import ai.koog.prompt.executor.clients.LLMClient
14+
import ai.koog.prompt.executor.clients.bedrock.BedrockAPIMethod
15+
import ai.koog.prompt.executor.clients.bedrock.BedrockClientSettings
16+
import ai.koog.prompt.executor.clients.bedrock.BedrockGuardrailsSettings
17+
import ai.koog.prompt.executor.clients.bedrock.BedrockLLMClient
18+
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
19+
import ai.koog.prompt.executor.clients.bedrock.converse.BedrockConverseParams
20+
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
21+
import ai.koog.prompt.executor.model.PromptExecutor
22+
import ai.koog.prompt.llm.LLMProvider
23+
import ai.koog.prompt.llm.LLModel
24+
import ai.koog.prompt.params.LLMParams
25+
import aws.sdk.kotlin.runtime.auth.credentials.StaticCredentialsProvider
26+
import kotlinx.serialization.json.buildJsonObject
27+
import kotlinx.serialization.json.put
28+
import org.junit.jupiter.api.Disabled
29+
import org.junit.jupiter.params.ParameterizedTest
30+
import org.junit.jupiter.params.provider.Arguments
31+
import org.junit.jupiter.params.provider.MethodSource
32+
import java.util.stream.Stream
33+
import kotlin.enums.EnumEntries
34+
35+
/**
36+
* Test newer Bedrock Converse API using the same suite of executor tests.
37+
*/
38+
class BedrockConverseApiIntegrationTest : ExecutorIntegrationTestBase() {
39+
companion object {
40+
private fun EnumEntries<*>.combineBedrockModels(): Stream<Arguments> {
41+
return toList()
42+
.flatMap { scenario ->
43+
Models
44+
.bedrockModels()
45+
.toArray()
46+
.map { model -> Arguments.of(scenario, model) }
47+
}
48+
.stream()
49+
}
50+
51+
@JvmStatic
52+
fun markdownScenarioModelCombinations(): Stream<Arguments> {
53+
return MarkdownTestScenario.entries.combineBedrockModels()
54+
}
55+
56+
@JvmStatic
57+
fun imageScenarioModelCombinations(): Stream<Arguments> {
58+
return ImageTestScenario.entries.combineBedrockModels()
59+
}
60+
61+
@JvmStatic
62+
fun textScenarioModelCombinations(): Stream<Arguments> {
63+
return TextTestScenario.entries.combineBedrockModels()
64+
}
65+
66+
@JvmStatic
67+
fun audioScenarioModelCombinations(): Stream<Arguments> {
68+
return AudioTestScenario.entries.combineBedrockModels()
69+
}
70+
71+
@JvmStatic
72+
fun reasoningCapableModels(): Stream<LLModel> {
73+
return listOf(BedrockModels.AnthropicClaude4_5Sonnet).stream()
74+
}
75+
76+
@JvmStatic
77+
fun allCompletionModels(): Stream<LLModel> {
78+
return Models.bedrockModels()
79+
}
80+
}
81+
82+
private val client = run {
83+
BedrockLLMClient(
84+
identityProvider = StaticCredentialsProvider {
85+
this.accessKeyId = readAwsAccessKeyIdFromEnv()
86+
this.secretAccessKey = readAwsSecretAccessKeyFromEnv()
87+
readAwsSessionTokenFromEnv()?.let { this.sessionToken = it }
88+
},
89+
settings = BedrockClientSettings(
90+
moderationGuardrailsSettings = BedrockGuardrailsSettings(
91+
guardrailIdentifier = readAwsBedrockGuardrailIdFromEnv(),
92+
guardrailVersion = readAwsBedrockGuardrailVersionFromEnv()
93+
),
94+
apiMethod = BedrockAPIMethod.Converse,
95+
)
96+
)
97+
}
98+
99+
private val executor: MultiLLMPromptExecutor = MultiLLMPromptExecutor(client)
100+
101+
override fun getLLMClient(model: LLModel): LLMClient {
102+
require(model.provider == LLMProvider.Bedrock) { "Model ${model.id} is not a Bedrock model" }
103+
104+
return client
105+
}
106+
107+
override fun getExecutor(model: LLModel): PromptExecutor = executor
108+
109+
override fun createReasoningParams(model: LLModel): LLMParams {
110+
require(model in reasoningCapableModels().toArray()) {
111+
"Model ${model.id} is not a reasoning capable model"
112+
}
113+
114+
return BedrockConverseParams(
115+
additionalProperties = mapOf(
116+
// Anthropic-specific reasoning config
117+
"reasoning_config" to buildJsonObject {
118+
put("type", "enabled")
119+
put("budget_tokens", 1024)
120+
}
121+
)
122+
)
123+
}
124+
125+
@ParameterizedTest
126+
@MethodSource("markdownScenarioModelCombinations")
127+
override fun integration_testMarkdownProcessingBasic(
128+
scenario: MarkdownTestScenario,
129+
model: LLModel
130+
) {
131+
super.integration_testMarkdownProcessingBasic(scenario, model)
132+
}
133+
134+
@ParameterizedTest
135+
@MethodSource("imageScenarioModelCombinations")
136+
override fun integration_testImageProcessing(scenario: ImageTestScenario, model: LLModel) {
137+
super.integration_testImageProcessing(scenario, model)
138+
}
139+
140+
@ParameterizedTest
141+
@MethodSource("textScenarioModelCombinations")
142+
override fun integration_testTextProcessingBasic(scenario: TextTestScenario, model: LLModel) {
143+
super.integration_testTextProcessingBasic(scenario, model)
144+
}
145+
146+
@Disabled("Converse API does not support audio processing")
147+
@ParameterizedTest
148+
@MethodSource("audioScenarioModelCombinations")
149+
override fun integration_testAudioProcessingBasic(scenario: AudioTestScenario, model: LLModel) {
150+
super.integration_testAudioProcessingBasic(scenario, model)
151+
}
152+
153+
// Core integration test methods
154+
@ParameterizedTest
155+
@MethodSource("allCompletionModels")
156+
override fun integration_testExecute(model: LLModel) {
157+
super.integration_testExecute(model)
158+
}
159+
160+
@ParameterizedTest
161+
@MethodSource("allCompletionModels")
162+
override fun integration_testExecuteStreaming(model: LLModel) {
163+
super.integration_testExecuteStreaming(model)
164+
}
165+
166+
@ParameterizedTest
167+
@MethodSource("allCompletionModels")
168+
override fun integration_testExecuteStreamingWithTools(model: LLModel) {
169+
super.integration_testExecuteStreamingWithTools(model)
170+
}
171+
172+
@ParameterizedTest
173+
@MethodSource("allCompletionModels")
174+
override fun integration_testToolWithRequiredParams(model: LLModel) {
175+
super.integration_testToolWithRequiredParams(model)
176+
}
177+
178+
@ParameterizedTest
179+
@MethodSource("allCompletionModels")
180+
override fun integration_testToolWithNotRequiredOptionalParams(model: LLModel) {
181+
super.integration_testToolWithNotRequiredOptionalParams(model)
182+
}
183+
184+
@ParameterizedTest
185+
@MethodSource("allCompletionModels")
186+
override fun integration_testToolWithOptionalParams(model: LLModel) {
187+
super.integration_testToolWithOptionalParams(model)
188+
}
189+
190+
@ParameterizedTest
191+
@MethodSource("allCompletionModels")
192+
override fun integration_testToolWithNoParams(model: LLModel) {
193+
super.integration_testToolWithNoParams(model)
194+
}
195+
196+
@ParameterizedTest
197+
@MethodSource("allCompletionModels")
198+
override fun integration_testToolWithListEnumParams(model: LLModel) {
199+
super.integration_testToolWithListEnumParams(model)
200+
}
201+
202+
@ParameterizedTest
203+
@MethodSource("allCompletionModels")
204+
override fun integration_testToolWithNestedListParams(model: LLModel) {
205+
super.integration_testToolWithNestedListParams(model)
206+
}
207+
208+
@ParameterizedTest
209+
@MethodSource("allCompletionModels")
210+
override fun integration_testToolsWithNullParams(model: LLModel) {
211+
super.integration_testToolsWithNullParams(model)
212+
}
213+
214+
@ParameterizedTest
215+
@MethodSource("allCompletionModels")
216+
override fun integration_testToolsWithAnyOfParams(model: LLModel) {
217+
super.integration_testToolsWithAnyOfParams(model)
218+
}
219+
220+
@ParameterizedTest
221+
@MethodSource("allCompletionModels")
222+
override fun integration_testMarkdownStructuredDataStreaming(model: LLModel) {
223+
super.integration_testMarkdownStructuredDataStreaming(model)
224+
}
225+
226+
@ParameterizedTest
227+
@MethodSource("allCompletionModels")
228+
override fun integration_testToolChoiceRequired(model: LLModel) {
229+
super.integration_testToolChoiceRequired(model)
230+
}
231+
232+
@Disabled("Converse API does not support tool choice none")
233+
@ParameterizedTest
234+
@MethodSource("allCompletionModels")
235+
override fun integration_testToolChoiceNone(model: LLModel) {
236+
super.integration_testToolChoiceNone(model)
237+
}
238+
239+
@ParameterizedTest
240+
@MethodSource("allCompletionModels")
241+
override fun integration_testToolChoiceNamed(model: LLModel) {
242+
super.integration_testToolChoiceNamed(model)
243+
}
244+
245+
@ParameterizedTest
246+
@MethodSource("allCompletionModels")
247+
override fun integration_testBase64EncodedAttachment(model: LLModel) {
248+
super.integration_testBase64EncodedAttachment(model)
249+
}
250+
251+
@Disabled("Converse API supports only S3 url attachments")
252+
@ParameterizedTest
253+
@MethodSource("allCompletionModels")
254+
override fun integration_testUrlBasedAttachment(model: LLModel) {
255+
super.integration_testUrlBasedAttachment(model)
256+
}
257+
258+
@Disabled("Converse API does ot support native structured output")
259+
@ParameterizedTest
260+
@MethodSource("allCompletionModels")
261+
override fun integration_testStructuredOutputNative(model: LLModel) {
262+
super.integration_testStructuredOutputNative(model)
263+
}
264+
265+
@Disabled("Converse API does ot support native structured output")
266+
@ParameterizedTest
267+
@MethodSource("allCompletionModels")
268+
override fun integration_testStructuredOutputNativeWithFixingParser(model: LLModel) {
269+
super.integration_testStructuredOutputNativeWithFixingParser(model)
270+
}
271+
272+
@ParameterizedTest
273+
@MethodSource("allCompletionModels")
274+
override fun integration_testStructuredOutputManual(model: LLModel) {
275+
super.integration_testStructuredOutputManual(model)
276+
}
277+
278+
@ParameterizedTest
279+
@MethodSource("allCompletionModels")
280+
override fun integration_testStructuredOutputManualWithFixingParser(model: LLModel) {
281+
super.integration_testStructuredOutputManualWithFixingParser(model)
282+
}
283+
284+
@ParameterizedTest
285+
@MethodSource("allCompletionModels")
286+
override fun integration_testMultipleSystemMessages(model: LLModel) {
287+
super.integration_testMultipleSystemMessages(model)
288+
}
289+
290+
@ParameterizedTest
291+
@MethodSource("reasoningCapableModels")
292+
override fun integration_testReasoningCapability(model: LLModel) {
293+
super.integration_testReasoningCapability(model)
294+
}
295+
296+
@ParameterizedTest
297+
@MethodSource("reasoningCapableModels")
298+
override fun integration_testReasoningMultiStep(model: LLModel) {
299+
super.integration_testReasoningMultiStep(model)
300+
}
301+
}

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ import ai.koog.prompt.llm.LLMCapability
5858
import ai.koog.prompt.llm.LLMProvider
5959
import ai.koog.prompt.llm.LLModel
6060
import ai.koog.prompt.markdown.markdown
61-
import ai.koog.prompt.message.AttachmentContent
6261
import ai.koog.prompt.message.ContentPart
6362
import ai.koog.prompt.message.Message
6463
import ai.koog.prompt.message.RequestMetaInfo
@@ -90,7 +89,7 @@ import org.junit.jupiter.api.Assumptions.assumeTrue
9089
import org.junit.jupiter.api.BeforeAll
9190
import java.nio.file.Path
9291
import java.nio.file.Paths
93-
import java.util.*
92+
import java.util.Base64
9493
import kotlin.io.path.pathString
9594
import kotlin.io.path.readBytes
9695
import kotlin.io.path.readText
@@ -124,7 +123,7 @@ abstract class ExecutorIntegrationTestBase {
124123

125124
open fun getLLMClient(model: LLModel): LLMClient = getLLMClientForProvider(model.provider)
126125

127-
private fun createReasoningParams(model: LLModel): LLMParams {
126+
open fun createReasoningParams(model: LLModel): LLMParams {
128127
return when (model.provider) {
129128
is LLMProvider.Anthropic -> AnthropicParams(
130129
thinking = AnthropicThinking.Enabled(budgetTokens = 1024)
@@ -500,21 +499,7 @@ abstract class ExecutorIntegrationTestBase {
500499
+"I'm sending you an image. Please analyze it and identify the image format if possible."
501500
}
502501

503-
when (scenario) {
504-
ImageTestScenario.LARGE_IMAGE, ImageTestScenario.LARGE_IMAGE_ANTHROPIC -> {
505-
image(
506-
ContentPart.Image(
507-
content = AttachmentContent.Binary.Bytes(imageFile.readBytes()),
508-
format = "jpg",
509-
mimeType = "image/jpeg"
510-
)
511-
)
512-
}
513-
514-
else -> {
515-
image(KtPath(imageFile.pathString))
516-
}
517-
}
502+
image(KtPath(imageFile.pathString))
518503
}
519504
}
520505

@@ -527,19 +512,23 @@ abstract class ExecutorIntegrationTestBase {
527512
ImageTestScenario.LARGE_IMAGE_ANTHROPIC, ImageTestScenario.LARGE_IMAGE -> {
528513
val message = e.message.shouldNotBeNull()
529514

530-
message.shouldContain("Status code: 400")
531-
message.shouldContain("image exceeds")
515+
listOf(
516+
"Status code: 400",
517+
"image exceeds",
518+
"Could not process image"
519+
).any { it in message }
520+
.shouldBe(true, "Must contain error message from the list")
532521
}
533522

534523
ImageTestScenario.CORRUPTED_IMAGE, ImageTestScenario.EMPTY_IMAGE -> {
535524
val message = e.message.shouldNotBeNull()
536525

537-
message.shouldContain("Status code: 400")
538-
if (model.provider == LLMProvider.Anthropic) {
539-
message.shouldContain("Could not process image")
540-
} else if (model.provider == LLMProvider.OpenAI) {
541-
message.shouldContain("You uploaded an unsupported image. Please make sure your image is valid.")
542-
}
526+
listOf(
527+
"Status code: 400",
528+
"Could not process image",
529+
"You uploaded an unsupported image. Please make sure your image is valid.",
530+
).any { it in message }
531+
.shouldBe(true, "Must contain error message from the list")
543532
}
544533

545534
else -> {

0 commit comments

Comments
 (0)