Skip to content

Commit 57be8f8

Browse files
authored
fix: properly pass BedrockLLMClient timeout setting to BedrockRuntimeClient.HttpClient (#1190)
1 parent 80aedf6 commit 57be8f8

File tree

2 files changed

+77
-18
lines changed
  • prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src

2 files changed

+77
-18
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ import kotlinx.coroutines.flow.transform
5959
import kotlinx.coroutines.withContext
6060
import kotlinx.datetime.Clock
6161
import kotlinx.serialization.json.Json
62+
import kotlin.time.Duration.Companion.milliseconds
6263

6364
/**
6465
* Configuration settings for connecting to the AWS Bedrock API.
@@ -104,7 +105,7 @@ public class BedrockGuardrailsSettings(
104105
* @return A configured [LLMClient] instance for Bedrock
105106
*/
106107
public class BedrockLLMClient(
107-
private val bedrockClient: BedrockRuntimeClient,
108+
internal val bedrockClient: BedrockRuntimeClient,
108109
private val moderationGuardrailsSettings: BedrockGuardrailsSettings? = null,
109110
private val fallbackModelFamily: BedrockModelFamilies? = null,
110111
private val clock: Clock = Clock.System,
@@ -146,6 +147,16 @@ public class BedrockLLMClient(
146147
this.retryStrategy = StandardRetryStrategy {
147148
maxAttempts = settings.maxRetries
148149
}
150+
151+
val timeoutConfig = settings.timeoutConfig
152+
153+
this.callTimeout = timeoutConfig.requestTimeoutMillis.milliseconds
154+
155+
this.httpClient {
156+
connectTimeout = timeoutConfig.connectTimeoutMillis.milliseconds
157+
socketReadTimeout = timeoutConfig.socketTimeoutMillis.milliseconds
158+
socketWriteTimeout = timeoutConfig.socketTimeoutMillis.milliseconds
159+
}
149160
},
150161
moderationGuardrailsSettings = settings.moderationGuardrailsSettings,
151162
fallbackModelFamily = settings.fallbackModelFamily,

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClientTest.kt

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,23 @@ import aws.sdk.kotlin.services.bedrockruntime.model.ListAsyncInvokesRequest
3737
import aws.sdk.kotlin.services.bedrockruntime.model.ListAsyncInvokesResponse
3838
import aws.sdk.kotlin.services.bedrockruntime.model.StartAsyncInvokeRequest
3939
import aws.sdk.kotlin.services.bedrockruntime.model.StartAsyncInvokeResponse
40+
import io.kotest.matchers.nulls.shouldNotBeNull
41+
import io.kotest.matchers.shouldBe
4042
import kotlinx.coroutines.flow.toList
4143
import kotlinx.coroutines.test.runTest
4244
import kotlinx.datetime.Clock
4345
import org.junit.jupiter.api.parallel.Execution
4446
import org.junit.jupiter.api.parallel.ExecutionMode
47+
import kotlin.random.Random.Default.nextInt
48+
import kotlin.random.Random.Default.nextLong
4549
import kotlin.test.Test
4650
import kotlin.test.assertEquals
4751
import kotlin.test.assertFails
4852
import kotlin.test.assertFailsWith
4953
import kotlin.test.assertFalse
5054
import kotlin.test.assertNotNull
5155
import kotlin.test.assertTrue
56+
import kotlin.time.Duration.Companion.milliseconds
5257

5358
class BedrockLLMClientTest {
5459
@Test
@@ -159,15 +164,22 @@ class BedrockLLMClientTest {
159164

160165
@Test
161166
fun `client configuration options work correctly`() {
167+
// given
168+
val requestTimeoutMillis = nextLong(1000, 2000)
169+
val connectTimeoutMillis = nextLong(100, 200)
170+
val socketTimeoutMillis = nextLong(200, 300)
171+
val maxRetries = nextInt(5, 10)
172+
173+
// when
162174
val customSettings = BedrockClientSettings(
163175
region = BedrockRegions.EU_WEST_1.regionCode,
164176
endpointUrl = "https://custom.endpoint.com",
165-
maxRetries = 5,
177+
maxRetries = maxRetries,
166178
enableLogging = true,
167179
timeoutConfig = ConnectionTimeoutConfig(
168-
requestTimeoutMillis = 120_000,
169-
connectTimeoutMillis = 10_000,
170-
socketTimeoutMillis = 120_000
180+
requestTimeoutMillis = requestTimeoutMillis,
181+
connectTimeoutMillis = connectTimeoutMillis,
182+
socketTimeoutMillis = socketTimeoutMillis
171183
)
172184
)
173185

@@ -180,11 +192,21 @@ class BedrockLLMClientTest {
180192
clock = Clock.System
181193
)
182194

183-
assertNotNull(client)
184-
assertEquals(BedrockRegions.EU_WEST_1.regionCode, customSettings.region)
185-
assertEquals("https://custom.endpoint.com", customSettings.endpointUrl)
186-
assertEquals(5, customSettings.maxRetries)
187-
assertEquals(true, customSettings.enableLogging)
195+
// then
196+
client shouldNotBeNull {
197+
bedrockClient.config shouldNotBeNull {
198+
callTimeout shouldBe requestTimeoutMillis.milliseconds
199+
endpointUrl.toString() shouldBe "https://custom.endpoint.com"
200+
region shouldBe BedrockRegions.EU_WEST_1.regionCode
201+
retryStrategy.config.maxAttempts shouldBe maxRetries
202+
203+
httpClient.config shouldNotBeNull {
204+
socketReadTimeout.inWholeMilliseconds shouldBe socketTimeoutMillis
205+
socketWriteTimeout.inWholeMilliseconds shouldBe socketTimeoutMillis
206+
connectTimeout.inWholeMilliseconds shouldBe connectTimeoutMillis
207+
}
208+
}
209+
}
188210
}
189211

190212
@Test
@@ -271,7 +293,10 @@ class BedrockLLMClientTest {
271293
override suspend fun converse(input: ConverseRequest): ConverseResponse =
272294
throw UnsupportedOperationException("converse not implemented in mock client")
273295

274-
override suspend fun <T> converseStream(input: ConverseStreamRequest, block: suspend (ConverseStreamResponse) -> T): T =
296+
override suspend fun <T> converseStream(
297+
input: ConverseStreamRequest,
298+
block: suspend (ConverseStreamResponse) -> T
299+
): T =
275300
throw UnsupportedOperationException("converseStream not implemented in mock client")
276301

277302
override suspend fun getAsyncInvoke(input: GetAsyncInvokeRequest): GetAsyncInvokeResponse =
@@ -280,10 +305,16 @@ class BedrockLLMClientTest {
280305
override suspend fun invokeModel(input: InvokeModelRequest): InvokeModelResponse =
281306
throw UnsupportedOperationException("invokeModel not implemented in mock client")
282307

283-
override suspend fun <T> invokeModelWithBidirectionalStream(input: InvokeModelWithBidirectionalStreamRequest, block: suspend (InvokeModelWithBidirectionalStreamResponse) -> T): T =
308+
override suspend fun <T> invokeModelWithBidirectionalStream(
309+
input: InvokeModelWithBidirectionalStreamRequest,
310+
block: suspend (InvokeModelWithBidirectionalStreamResponse) -> T
311+
): T =
284312
throw UnsupportedOperationException("invokeModelWithBidirectionalStream not implemented in mock client")
285313

286-
override suspend fun <T> invokeModelWithResponseStream(input: InvokeModelWithResponseStreamRequest, block: suspend (InvokeModelWithResponseStreamResponse) -> T): T =
314+
override suspend fun <T> invokeModelWithResponseStream(
315+
input: InvokeModelWithResponseStreamRequest,
316+
block: suspend (InvokeModelWithResponseStreamResponse) -> T
317+
): T =
287318
throw UnsupportedOperationException("invokeModelWithResponseStream not implemented in mock client")
288319

289320
override suspend fun listAsyncInvokes(input: ListAsyncInvokesRequest): ListAsyncInvokesResponse =
@@ -402,7 +433,11 @@ class BedrockLLMClientTest {
402433

403434
client.moderate(prompt, model)
404435

405-
assertEquals(2, applyGuardrailCallCount, "Should call applyGuardrail exactly twice for prompts with both Request and Response")
436+
assertEquals(
437+
2,
438+
applyGuardrailCallCount,
439+
"Should call applyGuardrail exactly twice for prompts with both Request and Response"
440+
)
406441
} finally {
407442
client.close()
408443
}
@@ -434,7 +469,11 @@ class BedrockLLMClientTest {
434469

435470
client.moderate(prompt, model)
436471

437-
assertEquals(1, applyGuardrailCallCount, "Should call applyGuardrail exactly once for Response-only prompts")
472+
assertEquals(
473+
1,
474+
applyGuardrailCallCount,
475+
"Should call applyGuardrail exactly once for Response-only prompts"
476+
)
438477
} finally {
439478
client.close()
440479
}
@@ -458,7 +497,10 @@ class BedrockLLMClientTest {
458497
override suspend fun converse(input: ConverseRequest): ConverseResponse =
459498
throw UnsupportedOperationException("converse not implemented in mock client")
460499

461-
override suspend fun <T> converseStream(input: ConverseStreamRequest, block: suspend (ConverseStreamResponse) -> T): T =
500+
override suspend fun <T> converseStream(
501+
input: ConverseStreamRequest,
502+
block: suspend (ConverseStreamResponse) -> T
503+
): T =
462504
throw UnsupportedOperationException("converseStream not implemented in mock client")
463505

464506
override suspend fun getAsyncInvoke(input: GetAsyncInvokeRequest): GetAsyncInvokeResponse =
@@ -467,10 +509,16 @@ class BedrockLLMClientTest {
467509
override suspend fun invokeModel(input: InvokeModelRequest): InvokeModelResponse =
468510
throw UnsupportedOperationException("invokeModel not implemented in mock client")
469511

470-
override suspend fun <T> invokeModelWithBidirectionalStream(input: InvokeModelWithBidirectionalStreamRequest, block: suspend (InvokeModelWithBidirectionalStreamResponse) -> T): T =
512+
override suspend fun <T> invokeModelWithBidirectionalStream(
513+
input: InvokeModelWithBidirectionalStreamRequest,
514+
block: suspend (InvokeModelWithBidirectionalStreamResponse) -> T
515+
): T =
471516
throw UnsupportedOperationException("invokeModelWithBidirectionalStream not implemented in mock client")
472517

473-
override suspend fun <T> invokeModelWithResponseStream(input: InvokeModelWithResponseStreamRequest, block: suspend (InvokeModelWithResponseStreamResponse) -> T): T =
518+
override suspend fun <T> invokeModelWithResponseStream(
519+
input: InvokeModelWithResponseStreamRequest,
520+
block: suspend (InvokeModelWithResponseStreamResponse) -> T
521+
): T =
474522
throw UnsupportedOperationException("invokeModelWithResponseStream not implemented in mock client")
475523

476524
override suspend fun listAsyncInvokes(input: ListAsyncInvokesRequest): ListAsyncInvokesResponse =

0 commit comments

Comments
 (0)