Skip to content

Commit 759033a

Browse files
authored
Support for Long, Double, List, and data classes as tool arguments (#210)
1 parent 3d88195 commit 759033a

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

agents/agents-tools/src/jvmMain/kotlin/ai/koog/agents/core/tools/reflect/util.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import kotlin.reflect.KType
1414
import kotlin.reflect.full.findAnnotation
1515
import kotlin.reflect.full.functions
1616
import kotlin.reflect.full.instanceParameter
17+
import kotlin.reflect.full.memberProperties
1718
import kotlin.reflect.jvm.javaMethod
1819
import kotlin.reflect.jvm.kotlinFunction
1920

@@ -197,6 +198,15 @@ public fun KType.asToolType(): ToolParameterType {
197198
Int::class -> ToolParameterType.Integer
198199
Float::class -> ToolParameterType.Float
199200
Boolean::class -> ToolParameterType.Boolean
201+
Long::class -> ToolParameterType.Integer
202+
Double::class -> ToolParameterType.Float
203+
204+
List::class -> {
205+
val listItemType = this.arguments[0].type ?: error("List item type is null")
206+
val listItemToolType = listItemType.asToolType()
207+
ToolParameterType.List(listItemToolType)
208+
}
209+
200210
is KClass<*> -> {
201211
val classJava = classifier.java
202212
when {
@@ -211,6 +221,18 @@ public fun KType.asToolType(): ToolParameterType {
211221
ToolParameterType.List(arrayItemToolType)
212222
}
213223

224+
classifier.isData -> {
225+
val properties = classifier.memberProperties.map { prop ->
226+
val description = prop.findAnnotation<LLMDescription>()?.description ?: prop.name
227+
ToolParameterDescriptor(
228+
name = prop.name,
229+
description = description,
230+
type = prop.returnType.asToolType() // Recursive call
231+
)
232+
}
233+
ToolParameterType.Object(properties)
234+
}
235+
214236
else -> throw kotlin.IllegalArgumentException("Unsupported type $classifier")
215237
}
216238
}

agents/agents-tools/src/jvmTest/kotlin/ai/koog/agents/core/tools/reflect/ReflectionArgsSerializerTest.kt

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
@file:OptIn(InternalAgentToolsApi::class)
2+
13
package ai.koog.agents.core.tools.reflect
24

5+
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
6+
import ai.koog.agents.core.tools.annotations.LLMDescription
7+
import ai.koog.agents.core.tools.annotations.Tool
38
import ai.koog.agents.core.tools.reflect.ToolFromCallable.VarArgsSerializer
9+
import ai.koog.agents.core.tools.reflect.ToolsFromCallableTest.Companion.ToolsEnabler
10+
import kotlinx.coroutines.runBlocking
11+
import kotlinx.serialization.Serializable
412
import kotlinx.serialization.json.Json
513
import org.junit.jupiter.params.ParameterizedTest
614
import org.junit.jupiter.params.provider.Arguments
715
import org.junit.jupiter.params.provider.MethodSource
816
import kotlin.reflect.KCallable
17+
import kotlin.test.Test
918
import kotlin.test.assertContentEquals
19+
import kotlin.test.assertEquals
1020

1121
fun foo(i: Int, s: String, b: Boolean = true) = println("$i $s")
1222
fun Any.fooEx(i: Int, s: String, b: Boolean = true) = println("$i $s")
@@ -33,4 +43,97 @@ class ReflectionArgsSerializerTest {
3343
decodedArguments.asNamedValues().sortedBy { it.first }.toList()
3444
)
3545
}
46+
47+
48+
@Serializable
49+
data class MySpecificToolArgs(
50+
@LLMDescription("arg Long") val argLong: Long,
51+
@LLMDescription("arg Double") val argDouble: Double)
52+
53+
class MySpecificTool() {
54+
@Tool
55+
@LLMDescription("Specific tool")
56+
suspend fun execute(@LLMDescription("arg Long") argLong: Long): String {
57+
return "Specific tool called with $argLong"
58+
}
59+
60+
@Tool
61+
@LLMDescription("Specific tool without tool annotation")
62+
suspend fun executeDouble(@LLMDescription("arg Long") argDouble: Double): String {
63+
return "Specific tool called with $argDouble"
64+
}
65+
66+
@Tool
67+
@LLMDescription("Specific tool without tool annotation")
68+
suspend fun executeWithArgs(@LLMDescription("args Tool") args: MySpecificToolArgs,
69+
@LLMDescription("args Tool") args2: MySpecificToolArgs): String {
70+
return "Specific tool called with ${args.argLong} and ${args.argDouble}"
71+
}
72+
73+
suspend fun executeWithListArg(@LLMDescription("args Tool") args: List<MySpecificToolArgs>): String {
74+
return "Specific tool called with ${args.joinToString(", ") { "${it.argLong} and ${it.argDouble}" }}"
75+
}
76+
}
77+
78+
@Test
79+
fun testToolLongArg() {
80+
val toolClass = MySpecificTool()
81+
val tool = toolClass::execute.asTool(ToolsFromCallableTest.Companion.json)
82+
83+
assertEquals(
84+
"Specific tool called with 42",
85+
runBlocking {
86+
val args = tool.decodeArgsFromString("""{ "argLong": 42 }""")
87+
val (rawResult, _) = tool.executeAndSerialize(args, ToolsEnabler)
88+
rawResult.result
89+
},
90+
)
91+
}
92+
93+
@Test
94+
fun testToolDoubleArg() {
95+
val toolClass = MySpecificTool()
96+
val tool: ToolFromCallable = toolClass::executeDouble.asTool(ToolsFromCallableTest.Companion.json)
97+
98+
assertEquals(
99+
"Specific tool called with 42.0",
100+
runBlocking {
101+
val args = tool.decodeArgsFromString("""{ "argDouble": 42.0 }""")
102+
val (rawResult, _) = tool.executeAndSerialize(args, ToolsEnabler)
103+
rawResult.result
104+
},
105+
)
106+
}
107+
108+
@Test
109+
fun testToolWithArgs() {
110+
val toolClass = MySpecificTool()
111+
val tool = toolClass::executeWithArgs.asTool(ToolsFromCallableTest.Companion.json)
112+
113+
assertEquals(
114+
"Specific tool called with 42 and 3.14",
115+
runBlocking {
116+
val args: ToolFromCallable.VarArgs = tool.decodeArgsFromString("""{ "args": {"argLong": 42, "argDouble": 3.14 }, "args2": {"argLong": 22, "argDouble": 3.14 }}""")
117+
val (rawResult, _) = tool.executeAndSerialize(args, ToolsEnabler)
118+
rawResult.result
119+
}
120+
)
121+
}
122+
123+
@Test
124+
fun testToolWithListArg() {
125+
val toolClass = MySpecificTool()
126+
val tool = toolClass::executeWithListArg.asTool(ToolsFromCallableTest.Companion.json)
127+
128+
assertEquals(
129+
"Specific tool called with 42 and 3.14, 22 and 3.14",
130+
runBlocking {
131+
val args: ToolFromCallable.VarArgs = tool.decodeArgsFromString(
132+
"""{ "args": [{"argLong": 42, "argDouble": 3.14 }, {"argLong": 22, "argDouble": 3.14 }] }"""
133+
)
134+
val (rawResult, _) = tool.executeAndSerialize(args, ToolsEnabler)
135+
rawResult.result
136+
}
137+
)
138+
}
36139
}

agents/agents-tools/src/jvmTest/kotlin/ai/koog/agents/core/tools/reflect/ToolsFromCallableTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,5 @@ class ToolsFromCallableTest {
285285
}.trim()
286286
assertEquals(expectedDescription, rendered)
287287
}
288+
288289
}

0 commit comments

Comments
 (0)