Skip to content

Commit 721f865

Browse files
committed
[prompt] Initial support for Converse API for Bedrock LLM client
1 parent 4804618 commit 721f865

File tree

6 files changed

+995
-20
lines changed

6 files changed

+995
-20
lines changed

koog-ktor/src/jvmMain/kotlin/ai/koog/ktor/BedrockConfig.kt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ public fun KoogAgentsConfig.bedrock(
4949
}
5050
addLLMClient(
5151
LLMProvider.Bedrock,
52-
BedrockLLMClient(client, moderationGuardrailsSettings, null, clock)
52+
BedrockLLMClient(
53+
bedrockClient = client,
54+
moderationGuardrailsSettings = moderationGuardrailsSettings,
55+
clock = clock
56+
)
5357
)
5458
}
5559

@@ -70,6 +74,10 @@ public fun KoogAgentsConfig.bedrock(
7074
}
7175
addLLMClient(
7276
LLMProvider.Bedrock,
73-
BedrockLLMClient(client, moderationGuardrailsSettings, null, clock)
77+
BedrockLLMClient(
78+
bedrockClient = client,
79+
moderationGuardrailsSettings = moderationGuardrailsSettings,
80+
clock = clock
81+
)
7482
)
7583
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 125 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig
99
import ai.koog.prompt.executor.clients.LLMClient
1010
import ai.koog.prompt.executor.clients.LLMClientException
1111
import ai.koog.prompt.executor.clients.LLMEmbeddingProvider
12+
import ai.koog.prompt.executor.clients.bedrock.converse.BedrockConverseConverters
1213
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.BedrockAnthropicInvokeModel
1314
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.BedrockAI21JambaSerialization
1415
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest
@@ -60,29 +61,56 @@ import kotlinx.coroutines.flow.transform
6061
import kotlinx.coroutines.withContext
6162
import kotlinx.datetime.Clock
6263
import kotlinx.serialization.json.Json
63-
import kotlin.time.Duration.Companion.milliseconds
64+
import org.jetbrains.annotations.VisibleForTesting
6465

66+
import kotlin.time.Duration.Companion.milliseconds
6567
/**
6668
* Configuration settings for connecting to the AWS Bedrock API.
6769
*
6870
* @property region The AWS region where Bedrock service is hosted.
6971
* @property timeoutConfig Configuration for connection timeouts.
7072
* @property endpointUrl Optional custom endpoint URL for testing or private deployments.
73+
* @property apiMethod The API method to use for interacting with Bedrock models that support messages, defaults to [BedrockAPIMethod.InvokeModel].
7174
* @property maxRetries Maximum number of retries for failed requests.
7275
* @property enableLogging Whether to enable detailed AWS SDK logging.
7376
* @property moderationGuardrailsSettings Optional settings of the AWS bedrock Guardrails (see [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-independent-api.html) ) that would be used for the [LLMClient.moderate] request
7477
* @property fallbackModelFamily Optional fallback model family to use for unsupported models. If not provided, unsupported models will throw an exception.
7578
*/
7679
public class BedrockClientSettings(
77-
internal val region: String = BedrockRegions.US_WEST_2.regionCode,
78-
internal val timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(),
79-
internal val endpointUrl: String? = null,
80-
internal val maxRetries: Int = 3,
81-
internal val enableLogging: Boolean = false,
82-
internal val moderationGuardrailsSettings: BedrockGuardrailsSettings? = null,
83-
internal val fallbackModelFamily: BedrockModelFamilies? = null
80+
public val region: String = BedrockRegions.US_WEST_2.regionCode,
81+
public val timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(),
82+
public val endpointUrl: String? = null,
83+
public val apiMethod: BedrockAPIMethod = BedrockAPIMethod.InvokeModel,
84+
public val maxRetries: Int = 3,
85+
public val enableLogging: Boolean = false,
86+
public val moderationGuardrailsSettings: BedrockGuardrailsSettings? = null,
87+
public val fallbackModelFamily: BedrockModelFamilies? = null
8488
)
8589

90+
/**
91+
* Defines Bedrock API methods to interact with the models that support messages.
92+
*/
93+
public sealed interface BedrockAPIMethod {
94+
/**
95+
* Defines `/model/{modelId}/invoke` API method.
96+
* When using this method, request body is formatted manually and is specific to the invoked model.
97+
*
98+
* Consider using [Converse] if possible, since this is a newer method aiming to provide a consistent interface for all models.
99+
*
100+
* [AWS API docs](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html)
101+
*/
102+
public object InvokeModel : BedrockAPIMethod
103+
104+
/**
105+
* Defines `/model/{modelId}/converse` API method.
106+
* Provides a consistent interface that works with all models that support messages.
107+
* Supports custom inference parameters for models that require it.
108+
*
109+
* [AWS API docs](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html)
110+
*/
111+
public object Converse : BedrockAPIMethod
112+
}
113+
86114
/**
87115
* Represents the settings configuration for Bedrock guardrails.
88116
*
@@ -92,21 +120,24 @@ public class BedrockClientSettings(
92120
* @property guardrailVersion The version of the guardrail configuration.
93121
*/
94122
public class BedrockGuardrailsSettings(
95-
internal val guardrailIdentifier: String,
96-
internal val guardrailVersion: String,
123+
public val guardrailIdentifier: String,
124+
public val guardrailVersion: String,
97125
)
98126

99127
/**
100128
* Creates a new Bedrock LLM client configured with the specified AWS credentials and settings.
101129
*
102130
* @param bedrockClient The runtime client for interacting with Bedrock, highly configurable
131+
* @param apiMethod The API method to use for interacting with Bedrock models that support messages, defaults to [BedrockAPIMethod.InvokeModel].
103132
* @param moderationGuardrailsSettings Optional settings of the AWS bedrock Guardrails (see [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-independent-api.html) ) that would be used for the [LLMClient.moderate] request
104-
* @param fallbackModelFamily Optional fallback model family to use for unsupported models
133+
* @param fallbackModelFamily Optional fallback model family to use for unsupported models. If not provided, unsupported models will throw an exception.
105134
* @param clock A clock used for time-based operations
106135
* @return A configured [LLMClient] instance for Bedrock
107136
*/
108137
public class BedrockLLMClient @JvmOverloads constructor(
138+
@VisibleForTesting
109139
internal val bedrockClient: BedrockRuntimeClient,
140+
private val apiMethod: BedrockAPIMethod = BedrockAPIMethod.InvokeModel,
110141
private val moderationGuardrailsSettings: BedrockGuardrailsSettings? = null,
111142
private val fallbackModelFamily: BedrockModelFamilies? = null,
112143
private val clock: Clock = Clock.System,
@@ -161,6 +192,7 @@ public class BedrockLLMClient @JvmOverloads constructor(
161192
},
162193
moderationGuardrailsSettings = settings.moderationGuardrailsSettings,
163194
fallbackModelFamily = settings.fallbackModelFamily,
195+
apiMethod = settings.apiMethod,
164196
clock = clock
165197
)
166198

@@ -171,8 +203,10 @@ public class BedrockLLMClient @JvmOverloads constructor(
171203
encodeDefaults = true
172204
}
173205

206+
@VisibleForTesting
174207
internal fun getBedrockModelFamily(model: LLModel): BedrockModelFamilies {
175208
require(model.provider == LLMProvider.Bedrock) { "Model ${model.id} is not a Bedrock model" }
209+
176210
return when {
177211
model.id.contains("anthropic.claude") -> BedrockModelFamilies.AnthropicClaude
178212

@@ -197,11 +231,6 @@ public class BedrockLLMClient @JvmOverloads constructor(
197231
}
198232
}
199233

200-
/**
201-
* Provides the current language learning model provider utilized by this client.
202-
*
203-
* @return the [LLMProvider] instance, specifically `LLMProvider.Bedrock` for this client.
204-
*/
205234
override fun llmProvider(): LLMProvider = LLMProvider.Bedrock
206235

207236
override suspend fun execute(
@@ -210,12 +239,28 @@ public class BedrockLLMClient @JvmOverloads constructor(
210239
tools: List<ToolDescriptor>
211240
): List<Message.Response> {
212241
logger.debug { "Executing prompt for model: ${model.id}" }
213-
val modelFamily = getBedrockModelFamily(model)
242+
214243
model.requireCapability(LLMCapability.Completion, "Model ${model.id} does not support chat completions")
215244
// Check tool support
216245
if (tools.isNotEmpty() && !model.capabilities.contains(LLMCapability.Tools)) {
217246
throw LLMClientException(clientName, "Model ${model.id} does not support tools")
218247
}
248+
249+
return when (apiMethod) {
250+
is BedrockAPIMethod.InvokeModel -> doExecuteInvokeModel(prompt, model, tools)
251+
is BedrockAPIMethod.Converse -> doExecuteConverse(prompt, model, tools)
252+
}
253+
}
254+
255+
/**
256+
* Executes prompt using [BedrockAPIMethod.InvokeModel].
257+
*/
258+
private suspend fun doExecuteInvokeModel(
259+
prompt: Prompt,
260+
model: LLModel,
261+
tools: List<ToolDescriptor>
262+
): List<Message.Response> {
263+
val modelFamily = getBedrockModelFamily(model)
219264
val requestBody = createRequestBody(prompt, model, tools)
220265
val invokeRequest = InvokeModelRequest {
221266
this.modelId = model.id
@@ -273,15 +318,65 @@ public class BedrockLLMClient @JvmOverloads constructor(
273318
}
274319
}
275320

321+
/**
322+
* Executes prompt using [BedrockAPIMethod.Converse].
323+
*/
324+
private suspend fun doExecuteConverse(
325+
prompt: Prompt,
326+
model: LLModel,
327+
tools: List<ToolDescriptor>
328+
): List<Message.Response> {
329+
val converseRequest = BedrockConverseConverters.createConverseRequest(prompt, model, tools)
330+
331+
return withContext(Dispatchers.SuitableForIO) {
332+
try {
333+
logger.debug { "Bedrock Converse Request: ModelID: ${model.id}, Request: $converseRequest" }
334+
val response = bedrockClient.converse(converseRequest)
335+
logger.debug { "Bedrock Converse Response: $response" }
336+
337+
BedrockConverseConverters.convertConverseResponse(response, clock)
338+
} catch (e: CancellationException) {
339+
throw e
340+
} catch (e: Exception) {
341+
throw LLMClientException(
342+
clientName = clientName,
343+
message = e.message,
344+
cause = e
345+
)
346+
}
347+
}
348+
}
349+
276350
@OptIn(ExperimentalCoroutinesApi::class, FlowPreview::class)
277351
override fun executeStreaming(
278352
prompt: Prompt,
279353
model: LLModel,
280354
tools: List<ToolDescriptor>
281355
): Flow<StreamFrame> {
282356
logger.debug { "Executing streaming prompt for model: ${model.id}" }
283-
val modelFamily = getBedrockModelFamily(model)
357+
284358
model.requireCapability(LLMCapability.Completion, "Model ${model.id} does not support chat completions")
359+
// Check tool support
360+
if (tools.isNotEmpty() && !model.capabilities.contains(LLMCapability.Tools)) {
361+
throw LLMClientException(clientName, "Model ${model.id} does not support tools")
362+
}
363+
364+
return when (apiMethod) {
365+
is BedrockAPIMethod.InvokeModel -> doExecuteStreamingInvokeModel(prompt, model, tools)
366+
is BedrockAPIMethod.Converse -> doExecuteStreamingConverse(prompt, model, tools)
367+
}
368+
}
369+
370+
/**
371+
* Executes prompt using [BedrockAPIMethod.InvokeModel] in streaming mode.
372+
*/
373+
@OptIn(ExperimentalCoroutinesApi::class, FlowPreview::class)
374+
private fun doExecuteStreamingInvokeModel(
375+
prompt: Prompt,
376+
model: LLModel,
377+
tools: List<ToolDescriptor>
378+
): Flow<StreamFrame> {
379+
val modelFamily = getBedrockModelFamily(model)
285380
val requestBody = createRequestBody(prompt, model, tools)
286381
val streamRequest = InvokeModelWithResponseStreamRequest {
287382
this.modelId = model.id
@@ -372,8 +467,20 @@ public class BedrockLLMClient @JvmOverloads constructor(
372467
frames.forEach { emit(it) }
373468
}
374469

470+
/**
471+
* Executes prompt using [BedrockAPIMethod.Converse] in streaming mode.
472+
*/
473+
private fun doExecuteStreamingConverse(
474+
prompt: Prompt,
475+
model: LLModel,
476+
tools: List<ToolDescriptor>
477+
): Flow<StreamFrame> {
478+
throw NotImplementedError("Converse API method is not yet supported")
479+
}
480+
375481
override suspend fun embed(text: String, model: LLModel): List<Double> {
376482
model.requireCapability(LLMCapability.Embed)
483+
377484
logger.debug { "Embedding text with model: ${model.id}" }
378485
val modelFamily = getBedrockModelFamily(model)
379486
val requestBody = createEmbeddingRequestBody(text, model)

0 commit comments

Comments
 (0)