Skip to content

Commit 58ba445

Browse files
authored
Add integration tests for executing tools with primitive types combinations (#889)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context <!-- Why is this change needed? What problem does it solve? --> As a follow-up to our last week discussion, it's good to test different input/output combinations in tool descriptors after the merge of #791. ## Breaking Changes <!-- Will users need to update their code or configurations? --> None. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [x] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent 9a50bf0 commit 58ba445

File tree

1 file changed

+351
-0
lines changed

1 file changed

+351
-0
lines changed
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
package ai.koog.integration.tests.executor
2+
3+
import ai.koog.agents.core.tools.Tool
4+
import ai.koog.integration.tests.utils.Models
5+
import ai.koog.integration.tests.utils.RetryUtils.withRetry
6+
import ai.koog.integration.tests.utils.TestUtils
7+
import ai.koog.prompt.dsl.prompt
8+
import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient
9+
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels
10+
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
11+
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
12+
import ai.koog.prompt.executor.clients.google.GoogleModels
13+
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
14+
import ai.koog.prompt.executor.clients.openai.OpenAIModels
15+
import ai.koog.prompt.executor.clients.openrouter.OpenRouterModels
16+
import ai.koog.prompt.llm.LLMCapability
17+
import ai.koog.prompt.llm.LLMProvider
18+
import ai.koog.prompt.llm.LLModel
19+
import ai.koog.prompt.message.Message
20+
import ai.koog.prompt.params.LLMParams
21+
import ai.koog.prompt.params.LLMParams.ToolChoice
22+
import kotlinx.coroutines.test.runTest
23+
import kotlinx.serialization.KSerializer
24+
import kotlinx.serialization.builtins.serializer
25+
import org.junit.jupiter.api.Assumptions.assumeTrue
26+
import org.junit.jupiter.params.ParameterizedTest
27+
import org.junit.jupiter.params.provider.Arguments
28+
import org.junit.jupiter.params.provider.MethodSource
29+
import java.util.stream.Stream
30+
import kotlin.test.assertTrue
31+
import kotlin.time.Duration.Companion.seconds
32+
33+
class ToolDescriptorIntegrationTest {
34+
35+
enum class ToolName(val value: String, val displayName: String, val testUserMessage: String) {
36+
INT_TO_STRING(
37+
"int_to_string",
38+
"Tool<Int, String>",
39+
"Convert the number 42 to its string representation using the tool."
40+
),
41+
STRING_TO_INT("string_to_int", "Tool<String, Int>", "Get the length of the string 'hello' using the tool."),
42+
INT_TO_INT("int_to_int", "Tool<Int, Int>", "Double the number 21 using the tool."),
43+
STRING_TO_STRING(
44+
"string_to_string",
45+
"Tool<String, String>",
46+
"Convert 'hello world' to uppercase using the tool."
47+
),
48+
BOOLEAN_TO_STRING(
49+
"boolean_to_string",
50+
"Tool<Boolean, String>",
51+
"Convert the boolean value true to its string representation using the tool."
52+
),
53+
STRING_TO_BOOLEAN(
54+
"string_to_boolean",
55+
"Tool<String, Boolean>",
56+
"Convert the string 'true' to a boolean using the tool."
57+
),
58+
DOUBLE_TO_INT(
59+
"double_to_int",
60+
"Tool<Double, Int>",
61+
"Convert the double value 3.7 to an integer using the tool."
62+
),
63+
INT_TO_DOUBLE("int_to_double", "Tool<Int, Double>", "Convert the integer value 42 to a double using the tool."),
64+
LONG_TO_DOUBLE(
65+
"long_to_double",
66+
"Tool<Long, Double>",
67+
"Convert the long value 100 to a double with decimal places using the tool."
68+
),
69+
DOUBLE_TO_LONG(
70+
"double_to_long",
71+
"Tool<Double, Long>",
72+
"Convert the double value 15.8 to a long using the tool."
73+
),
74+
FLOAT_TO_BOOLEAN(
75+
"float_to_boolean",
76+
"Tool<Float, Boolean>",
77+
"Convert the float value 2.5 to a boolean using the tool."
78+
),
79+
BOOLEAN_TO_FLOAT(
80+
"boolean_to_float",
81+
"Tool<Boolean, Float>",
82+
"Convert the boolean value true to a float using the tool."
83+
),
84+
LONG_TO_INT("long_to_int", "Tool<Long, Int>", "Convert the long value 12345 to an integer using the tool."),
85+
INT_TO_LONG("int_to_long", "Tool<Int, Long>", "Convert the integer value 789 to a long using the tool."),
86+
FLOAT_TO_STRING(
87+
"float_to_string",
88+
"Tool<Float, String>",
89+
"Convert the float value 3.14 to its string representation using the tool."
90+
),
91+
STRING_TO_FLOAT(
92+
"string_to_float",
93+
"Tool<String, Float>",
94+
"Convert the string 'hello' to a float based on its length using the tool."
95+
),
96+
DOUBLE_TO_STRING(
97+
"double_to_string",
98+
"Tool<Double, String>",
99+
"Convert the double value 2.718 to its string representation using the tool."
100+
),
101+
STRING_TO_DOUBLE(
102+
"string_to_double",
103+
"Tool<String, Double>",
104+
"Convert the string 'world' to a double based on its length using the tool."
105+
);
106+
107+
override fun toString(): String = displayName
108+
}
109+
110+
companion object {
111+
@JvmStatic
112+
fun allModels(): Stream<LLModel> {
113+
return Stream.of(
114+
OpenAIModels.CostOptimized.GPT4_1Mini,
115+
AnthropicModels.Sonnet_3_7,
116+
GoogleModels.Gemini2_5Flash,
117+
BedrockModels.AnthropicClaude35Haiku,
118+
OpenRouterModels.Mistral7B,
119+
)
120+
}
121+
122+
@JvmStatic
123+
fun primitiveToolAndModelCombinations(): Stream<Arguments> {
124+
val primitiveTools = listOf(
125+
IntToStringTool(),
126+
StringToIntTool(),
127+
IntToIntTool(),
128+
StringToStringTool(),
129+
BooleanToStringTool(),
130+
StringToBooleanTool(),
131+
DoubleToIntTool(),
132+
IntToDoubleTool(),
133+
LongToDoubleTool(),
134+
DoubleToLongTool(),
135+
FloatToBooleanTool(),
136+
BooleanToFloatTool(),
137+
LongToIntTool(),
138+
IntToLongTool(),
139+
FloatToStringTool(),
140+
StringToFloatTool(),
141+
DoubleToStringTool(),
142+
StringToDoubleTool()
143+
)
144+
145+
return allModels().flatMap { model ->
146+
primitiveTools.map { tool ->
147+
Arguments.arguments(tool, model)
148+
}.stream()
149+
}
150+
}
151+
}
152+
153+
abstract class TestTool<T, R> : Tool<T, R>() {
154+
abstract val toolName: ToolName
155+
override val name: String get() = toolName.value
156+
override fun toString(): String = toolName.displayName
157+
}
158+
159+
class IntToStringTool : TestTool<Int, String>() {
160+
override val toolName = ToolName.INT_TO_STRING
161+
override val argsSerializer: KSerializer<Int> = Int.serializer()
162+
override val resultSerializer: KSerializer<String> = String.serializer()
163+
override val description: String = "Converts an integer to its string representation"
164+
165+
override suspend fun execute(args: Int): String = "Number: $args"
166+
}
167+
168+
class StringToIntTool : TestTool<String, Int>() {
169+
override val toolName = ToolName.STRING_TO_INT
170+
override val argsSerializer: KSerializer<String> = String.serializer()
171+
override val resultSerializer: KSerializer<Int> = Int.serializer()
172+
override val description: String = "Converts a string to an integer"
173+
174+
override suspend fun execute(args: String): Int = args.length
175+
}
176+
177+
class IntToIntTool : TestTool<Int, Int>() {
178+
override val toolName = ToolName.INT_TO_INT
179+
override val argsSerializer: KSerializer<Int> = Int.serializer()
180+
override val resultSerializer: KSerializer<Int> = Int.serializer()
181+
override val description: String = "Doubles an integer value"
182+
183+
override suspend fun execute(args: Int): Int = args * 2
184+
}
185+
186+
class StringToStringTool : TestTool<String, String>() {
187+
override val toolName = ToolName.STRING_TO_STRING
188+
override val argsSerializer: KSerializer<String> = String.serializer()
189+
override val resultSerializer: KSerializer<String> = String.serializer()
190+
override val description: String = "Converts string to uppercase"
191+
192+
override suspend fun execute(args: String): String = args.uppercase()
193+
}
194+
195+
class BooleanToStringTool : TestTool<Boolean, String>() {
196+
override val toolName = ToolName.BOOLEAN_TO_STRING
197+
override val argsSerializer: KSerializer<Boolean> = Boolean.serializer()
198+
override val resultSerializer: KSerializer<String> = String.serializer()
199+
override val description: String = "Converts boolean to descriptive string"
200+
201+
override suspend fun execute(args: Boolean): String = if (args) "TRUE_VALUE" else "FALSE_VALUE"
202+
}
203+
204+
class DoubleToIntTool : TestTool<Double, Int>() {
205+
override val toolName = ToolName.DOUBLE_TO_INT
206+
override val argsSerializer: KSerializer<Double> = Double.serializer()
207+
override val resultSerializer: KSerializer<Int> = Int.serializer()
208+
override val description: String = "Converts double to integer by rounding"
209+
210+
override suspend fun execute(args: Double): Int = args.toInt()
211+
}
212+
213+
class LongToDoubleTool : TestTool<Long, Double>() {
214+
override val toolName = ToolName.LONG_TO_DOUBLE
215+
override val argsSerializer: KSerializer<Long> = Long.serializer()
216+
override val resultSerializer: KSerializer<Double> = Double.serializer()
217+
override val description: String = "Converts long to double with decimal places"
218+
219+
override suspend fun execute(args: Long): Double = args + 0.5
220+
}
221+
222+
class FloatToBooleanTool : TestTool<Float, Boolean>() {
223+
override val toolName = ToolName.FLOAT_TO_BOOLEAN
224+
override val argsSerializer: KSerializer<Float> = Float.serializer()
225+
override val resultSerializer: KSerializer<Boolean> = Boolean.serializer()
226+
override val description: String = "Converts float to boolean (positive = true)"
227+
228+
override suspend fun execute(args: Float): Boolean = args > 0f
229+
}
230+
231+
class StringToBooleanTool : TestTool<String, Boolean>() {
232+
override val toolName = ToolName.STRING_TO_BOOLEAN
233+
override val argsSerializer: KSerializer<String> = String.serializer()
234+
override val resultSerializer: KSerializer<Boolean> = Boolean.serializer()
235+
override val description: String = "Converts string to boolean ('true' = true, others = false)"
236+
237+
override suspend fun execute(args: String): Boolean = args.equals("true", ignoreCase = true)
238+
}
239+
240+
class IntToDoubleTool : TestTool<Int, Double>() {
241+
override val toolName = ToolName.INT_TO_DOUBLE
242+
override val argsSerializer: KSerializer<Int> = Int.serializer()
243+
override val resultSerializer: KSerializer<Double> = Double.serializer()
244+
override val description: String = "Converts integer to double"
245+
246+
override suspend fun execute(args: Int): Double = args.toDouble()
247+
}
248+
249+
class DoubleToLongTool : TestTool<Double, Long>() {
250+
override val toolName = ToolName.DOUBLE_TO_LONG
251+
override val argsSerializer: KSerializer<Double> = Double.serializer()
252+
override val resultSerializer: KSerializer<Long> = Long.serializer()
253+
override val description: String = "Converts double to long by rounding"
254+
255+
override suspend fun execute(args: Double): Long = args.toLong()
256+
}
257+
258+
class BooleanToFloatTool : TestTool<Boolean, Float>() {
259+
override val toolName = ToolName.BOOLEAN_TO_FLOAT
260+
override val argsSerializer: KSerializer<Boolean> = Boolean.serializer()
261+
override val resultSerializer: KSerializer<Float> = Float.serializer()
262+
override val description: String = "Converts boolean to float (true = 1.0f, false = 0.0f)"
263+
264+
override suspend fun execute(args: Boolean): Float = if (args) 1.0f else 0.0f
265+
}
266+
267+
class LongToIntTool : TestTool<Long, Int>() {
268+
override val toolName = ToolName.LONG_TO_INT
269+
override val argsSerializer: KSerializer<Long> = Long.serializer()
270+
override val resultSerializer: KSerializer<Int> = Int.serializer()
271+
override val description: String = "Converts long to integer"
272+
273+
override suspend fun execute(args: Long): Int = args.toInt()
274+
}
275+
276+
class IntToLongTool : TestTool<Int, Long>() {
277+
override val toolName = ToolName.INT_TO_LONG
278+
override val argsSerializer: KSerializer<Int> = Int.serializer()
279+
override val resultSerializer: KSerializer<Long> = Long.serializer()
280+
override val description: String = "Converts integer to long"
281+
282+
override suspend fun execute(args: Int): Long = args.toLong()
283+
}
284+
285+
class FloatToStringTool : TestTool<Float, String>() {
286+
override val toolName = ToolName.FLOAT_TO_STRING
287+
override val argsSerializer: KSerializer<Float> = Float.serializer()
288+
override val resultSerializer: KSerializer<String> = String.serializer()
289+
override val description: String = "Converts float to string"
290+
291+
override suspend fun execute(args: Float): String = "Float: $args"
292+
}
293+
294+
class StringToFloatTool : TestTool<String, Float>() {
295+
override val toolName = ToolName.STRING_TO_FLOAT
296+
override val argsSerializer: KSerializer<String> = String.serializer()
297+
override val resultSerializer: KSerializer<Float> = Float.serializer()
298+
override val description: String = "Converts string length to float"
299+
300+
override suspend fun execute(args: String): Float = args.length.toFloat()
301+
}
302+
303+
class DoubleToStringTool : TestTool<Double, String>() {
304+
override val toolName = ToolName.DOUBLE_TO_STRING
305+
override val argsSerializer: KSerializer<Double> = Double.serializer()
306+
override val resultSerializer: KSerializer<String> = String.serializer()
307+
override val description: String = "Converts double to string"
308+
309+
override suspend fun execute(args: Double): String = "Double: $args"
310+
}
311+
312+
class StringToDoubleTool : TestTool<String, Double>() {
313+
override val toolName = ToolName.STRING_TO_DOUBLE
314+
override val argsSerializer: KSerializer<String> = String.serializer()
315+
override val resultSerializer: KSerializer<Double> = Double.serializer()
316+
override val description: String = "Converts string length to double"
317+
318+
override suspend fun execute(args: String): Double = args.length.toDouble()
319+
}
320+
321+
@ParameterizedTest(name = "{0} with {1}")
322+
@MethodSource("primitiveToolAndModelCombinations")
323+
fun integration_testPrimitiveTools(tool: Tool<*, *>, model: LLModel) = runTest(timeout = 300.seconds) {
324+
Models.assumeAvailable(model.provider)
325+
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
326+
327+
val client = when (model.provider) {
328+
is LLMProvider.Anthropic -> AnthropicLLMClient(TestUtils.readTestAnthropicKeyFromEnv())
329+
is LLMProvider.Google -> GoogleLLMClient(TestUtils.readTestGoogleAIKeyFromEnv())
330+
else -> OpenAILLMClient(TestUtils.readTestOpenAIKeyFromEnv())
331+
}
332+
333+
val testTool = tool as TestTool<*, *>
334+
val prompt = prompt(testTool.toolName.value, params = LLMParams(toolChoice = ToolChoice.Required)) {
335+
system("You are a helpful assistant with access to tools. ALWAYS use the available tool.")
336+
user(testTool.toolName.testUserMessage)
337+
}
338+
339+
withRetry {
340+
val response = client.execute(prompt, model, listOf(tool.descriptor))
341+
assertTrue(response.isNotEmpty(), "Response should not be empty for tool ${tool.name} with model $model")
342+
val hasToolCall = response.any { message ->
343+
message is Message.Tool.Call && message.tool == tool.name
344+
}
345+
assertTrue(
346+
hasToolCall,
347+
"Response should contain a Tool.Call message for tool '${tool.name}' with model $model."
348+
)
349+
}
350+
}
351+
}

0 commit comments

Comments
 (0)