Skip to content

Commit 521e64d

Browse files
authored
Test executeStreaming with tool calls (#1261)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context <!-- Why is this change needed? What problem does it solve? --> Add `integration_testExecuteStreamingWithTools` to test streaming+tools functionality. ## Breaking Changes <!-- Will users need to update their code or configurations? --> --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent 2b84ddd commit 521e64d

File tree

4 files changed

+86
-12
lines changed

4 files changed

+86
-12
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ import ai.koog.integration.tests.utils.structuredOutput.getConfigNoFixingParserM
2323
import ai.koog.integration.tests.utils.structuredOutput.getConfigNoFixingParserNative
2424
import ai.koog.integration.tests.utils.structuredOutput.parseMarkdownStreamToCountries
2525
import ai.koog.integration.tests.utils.structuredOutput.weatherStructuredOutputPrompt
26+
import ai.koog.integration.tests.utils.tools.CalculatorOperation
2627
import ai.koog.integration.tests.utils.tools.CalculatorTool
2728
import ai.koog.integration.tests.utils.tools.LotteryTool
2829
import ai.koog.integration.tests.utils.tools.PickColorFromListTool
2930
import ai.koog.integration.tests.utils.tools.PickColorTool
3031
import ai.koog.integration.tests.utils.tools.PriceCalculatorTool
32+
import ai.koog.integration.tests.utils.tools.SimpleCalculatorTool
3133
import ai.koog.integration.tests.utils.tools.SimplePriceCalculatorTool
3234
import ai.koog.integration.tests.utils.tools.calculatorPrompt
3335
import ai.koog.integration.tests.utils.tools.calculatorPromptNotRequiredOptionalParams
@@ -181,9 +183,6 @@ abstract class ExecutorIntegrationTestBase {
181183

182184
open fun integration_testExecuteStreaming(model: LLModel) = runTest(timeout = 300.seconds) {
183185
Models.assumeAvailable(model.provider)
184-
if (model.id == OpenAIModels.Audio.GPT4oAudio.id || model.id == OpenAIModels.Audio.GPT4oMiniAudio.id) {
185-
assumeTrue(false, "https://github.com/JetBrains/koog/issues/231")
186-
}
187186

188187
val executor = getExecutor(model)
189188

@@ -196,13 +195,15 @@ abstract class ExecutorIntegrationTestBase {
196195
with(StringBuilder()) {
197196
val endMessages = mutableListOf<StreamFrame.End>()
198197
val toolMessages = mutableListOf<StreamFrame.ToolCall>()
199-
executor.executeStreaming(prompt, model).collect {
200-
when (it) {
201-
is StreamFrame.Append -> append(it.text)
202-
is StreamFrame.End -> endMessages.add(it)
203-
is StreamFrame.ToolCall -> toolMessages.add(it)
204-
}
205-
}
198+
199+
executor.executeStreamAndCollect(
200+
prompt = prompt,
201+
model = model,
202+
appendable = this,
203+
endMessages = endMessages,
204+
toolMessages = toolMessages
205+
)
206+
206207
length shouldNotBe (0)
207208
toolMessages.shouldBeEmpty()
208209
when (model.provider) {
@@ -221,6 +222,42 @@ abstract class ExecutorIntegrationTestBase {
221222
}
222223
}
223224

225+
open fun integration_testExecuteStreamingWithTools(model: LLModel) = runTest(timeout = 300.seconds) {
226+
Models.assumeAvailable(model.provider)
227+
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
228+
229+
val executor = getExecutor(model)
230+
231+
val prompt = Prompt.build("test-streaming", LLMParams(toolChoice = ToolChoice.Required)) {
232+
system("You are a helpful assistant.")
233+
user("Count three times five")
234+
}
235+
236+
withRetry(times = 3, testName = "integration_testExecuteStreamingWithTools[${model.id}]") {
237+
with(StringBuilder()) {
238+
val endMessages = mutableListOf<StreamFrame.End>()
239+
val toolMessages = mutableListOf<StreamFrame.ToolCall>()
240+
241+
executor.executeStreamAndCollect(
242+
prompt = prompt,
243+
model = model,
244+
tools = listOf(SimpleCalculatorTool.descriptor),
245+
appendable = this,
246+
endMessages = endMessages,
247+
toolMessages = toolMessages
248+
)
249+
250+
toolMessages.shouldNotBeEmpty()
251+
withClue("Expected calculator tool call but got: [$toolMessages]") {
252+
toolMessages.any {
253+
it.name == SimpleCalculatorTool.name &&
254+
it.content.contains(CalculatorOperation.MULTIPLY.name, ignoreCase = true)
255+
} shouldBe true
256+
}
257+
}
258+
}
259+
}
260+
224261
open fun integration_testToolWithRequiredParams(model: LLModel) = runTest(timeout = 300.seconds) {
225262
Models.assumeAvailable(model.provider)
226263
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
@@ -780,7 +817,7 @@ abstract class ExecutorIntegrationTestBase {
780817

781818
val prompt = calculatorPrompt
782819

783-
/** tool choice auto is default and thus is tested by [integration_testToolWithRequiredParams] */
820+
/* tool choice auto is default and thus is tested by [integration_testToolWithRequiredParams] */
784821

785822
withRetry(times = 3, testName = "integration_testToolChoiceRequired[${model.id}]") {
786823
with(
@@ -1068,7 +1105,10 @@ abstract class ExecutorIntegrationTestBase {
10681105

10691106
val prompt2 = Prompt(
10701107
id = "reasoning-multistep-2",
1071-
messages = prompt1.messages + response1 + Message.User(ContentPart.Text("Multiply the result by 2."), metaInfo = RequestMetaInfo.Empty),
1108+
messages = prompt1.messages + response1 + Message.User(
1109+
ContentPart.Text("Multiply the result by 2."),
1110+
metaInfo = RequestMetaInfo.Empty
1111+
),
10721112
params = params
10731113
)
10741114

@@ -1080,3 +1120,20 @@ abstract class ExecutorIntegrationTestBase {
10801120
}
10811121
}
10821122
}
1123+
1124+
private suspend fun PromptExecutor.executeStreamAndCollect(
1125+
prompt: Prompt,
1126+
model: LLModel,
1127+
tools: List<ToolDescriptor> = emptyList(),
1128+
appendable: Appendable,
1129+
endMessages: MutableList<StreamFrame.End>,
1130+
toolMessages: MutableList<StreamFrame.ToolCall>
1131+
) {
1132+
this.executeStreaming(prompt, model, tools).collect { frame ->
1133+
when (frame) {
1134+
is StreamFrame.Append -> appendable.append(frame.text)
1135+
is StreamFrame.End -> endMessages.add(frame)
1136+
is StreamFrame.ToolCall -> toolMessages.add(frame)
1137+
}
1138+
}
1139+
}

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/MultipleLLMPromptExecutorIntegrationTest.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ai.koog.integration.tests.utils.getLLMClientForProvider
1010
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
1111
import ai.koog.prompt.executor.model.PromptExecutor
1212
import ai.koog.prompt.llm.LLModel
13+
import org.junit.jupiter.api.Disabled
1314
import org.junit.jupiter.params.ParameterizedTest
1415
import org.junit.jupiter.params.provider.Arguments
1516
import org.junit.jupiter.params.provider.MethodSource
@@ -113,6 +114,13 @@ class MultipleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
113114
super.integration_testExecuteStreaming(model)
114115
}
115116

117+
@Disabled("KG-616")
118+
@ParameterizedTest
119+
@MethodSource("allCompletionModels")
120+
override fun integration_testExecuteStreamingWithTools(model: LLModel) {
121+
super.integration_testExecuteStreamingWithTools(model)
122+
}
123+
116124
@ParameterizedTest
117125
@MethodSource("allCompletionModels")
118126
override fun integration_testToolWithRequiredParams(model: LLModel) {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
1111
import ai.koog.prompt.executor.model.PromptExecutor
1212
import ai.koog.prompt.llm.LLMProvider
1313
import ai.koog.prompt.llm.LLModel
14+
import org.junit.jupiter.api.Disabled
1415
import org.junit.jupiter.params.ParameterizedTest
1516
import org.junit.jupiter.params.provider.Arguments
1617
import org.junit.jupiter.params.provider.MethodSource
@@ -107,6 +108,13 @@ class SingleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
107108
super.integration_testExecuteStreaming(model)
108109
}
109110

111+
@Disabled("KG-616")
112+
@ParameterizedTest
113+
@MethodSource("allCompletionModels")
114+
override fun integration_testExecuteStreamingWithTools(model: LLModel) {
115+
super.integration_testExecuteStreamingWithTools(model)
116+
}
117+
110118
@ParameterizedTest
111119
@MethodSource("allCompletionModels")
112120
override fun integration_testToolWithRequiredParams(model: LLModel) {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ object Models {
1919
return Stream.of(
2020
OpenAIModels.Chat.GPT5_1, // reasoning
2121
OpenAIModels.Chat.GPT4_1, // non-reasoning
22+
OpenAIModels.Chat.GPT5_1Codex
2223
)
2324
}
2425

0 commit comments

Comments
 (0)