Skip to content

Commit 86ee696

Browse files
authored
Cover SafeTool with tests (#359)
Signed-off-by: Sergey Karpov <[email protected]>
1 parent 54d33d3 commit 86ee696

File tree

1 file changed

+333
-0
lines changed
  • agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/environment

1 file changed

+333
-0
lines changed
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
package ai.koog.agents.core.environment
2+
3+
import ai.koog.agents.core.CalculatorChatExecutor.testClock
4+
import ai.koog.agents.core.tools.reflect.ToolFromCallable
5+
import ai.koog.prompt.message.Message
6+
import kotlinx.coroutines.test.runTest
7+
import kotlinx.serialization.Serializable
8+
import kotlinx.serialization.json.Json
9+
import org.junit.jupiter.api.assertThrows
10+
import kotlin.reflect.typeOf
11+
import kotlin.test.Test
12+
import kotlin.test.assertEquals
13+
import kotlin.test.assertTrue
14+
15+
class SafeToolTest {
16+
private fun testInvalidArguments(vararg args: Any?) = runTest {
17+
val mockEnvironment = MockEnvironment(shouldSucceed = true)
18+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
19+
assertEquals(safeTool.toolFunction, ::testFunction)
20+
21+
assertThrows<IllegalStateException> {
22+
safeTool.execute(*args)
23+
}
24+
}
25+
26+
companion object {
27+
private const val TEST_RESULT = "Test result"
28+
private const val TEST_ERROR = "Error: Test error"
29+
}
30+
31+
private fun testFunction(param1: String, param2: Int): String {
32+
return "Result: $param1 - $param2"
33+
}
34+
35+
private fun testFunctionWithDefaultParam(param1: String, param2: Int = 42): String {
36+
return "Result with default: $param1 - $param2"
37+
}
38+
39+
enum class TestEnum {
40+
FIRST, SECOND, THIRD
41+
}
42+
43+
@Serializable
44+
data class SimpleDataClass(val name: String, val value: Int)
45+
46+
@Serializable
47+
data class ComplexDataClass(
48+
val id: String,
49+
val numbers: List<Int>,
50+
val nested: SimpleDataClass,
51+
val enumValue: TestEnum
52+
)
53+
54+
private fun testFunctionWithComplexArgs(
55+
param1: String,
56+
param2: List<Int>,
57+
param3: ComplexDataClass
58+
): String {
59+
return "Complex result: $param1 - ${param2.size} items - ${param3.id}"
60+
}
61+
62+
private fun testFunctionWithNullableArg(param1: String, param2: Int?): String {
63+
return "Nullable result: $param1 - ${param2 ?: "null"}"
64+
}
65+
66+
private class MockEnvironment(
67+
private val shouldSucceed: Boolean = true,
68+
private val resultContent: String = "Success content",
69+
) : AIAgentEnvironment {
70+
override suspend fun executeTools(toolCalls: List<Message.Tool.Call>): List<ReceivedToolResult> {
71+
return toolCalls.map { toolCall ->
72+
if (shouldSucceed) {
73+
ReceivedToolResult(
74+
id = toolCall.id,
75+
tool = toolCall.tool,
76+
content = resultContent,
77+
result = ToolFromCallable.Result(
78+
result = TEST_RESULT,
79+
type = typeOf<String>(),
80+
json = Json,
81+
)
82+
)
83+
} else {
84+
ReceivedToolResult(
85+
id = toolCall.id,
86+
tool = toolCall.tool,
87+
content = TEST_ERROR,
88+
result = null,
89+
)
90+
}
91+
}
92+
}
93+
94+
override suspend fun reportProblem(exception: Throwable) {
95+
throw exception
96+
}
97+
}
98+
99+
100+
@Test
101+
fun testExecuteSuccess() = runTest {
102+
val mockEnvironment = MockEnvironment(shouldSucceed = true)
103+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
104+
assertEquals(safeTool.toolFunction, ::testFunction)
105+
106+
val result = safeTool.execute("test", 123)
107+
108+
assertTrue(result.isSuccessful())
109+
assertEquals(TEST_RESULT, result.asSuccessful().result)
110+
assertEquals("Success content", result.content)
111+
}
112+
113+
@Test
114+
fun testExecuteFailure() = runTest {
115+
val mockEnvironment = MockEnvironment(shouldSucceed = false)
116+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
117+
assertEquals(safeTool.toolFunction, ::testFunction)
118+
119+
val result = safeTool.execute("test", 123)
120+
121+
assertTrue(result.isFailure())
122+
assertEquals(TEST_ERROR, result.content)
123+
assertEquals(TEST_ERROR, result.asFailure().message)
124+
}
125+
126+
@Test
127+
fun testExecuteRaw() = runTest {
128+
val mockEnvironment = MockEnvironment(shouldSucceed = true, resultContent = "Raw result content")
129+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
130+
assertEquals(safeTool.toolFunction, ::testFunction)
131+
132+
val result = safeTool.executeRaw("test", 123)
133+
134+
assertEquals("Raw result content", result)
135+
}
136+
137+
@Test
138+
fun testResultSuccessHelpers() = runTest {
139+
val success = SafeToolFromCallable.Result.Success(TEST_RESULT, "Success content")
140+
141+
assertTrue(success.isSuccessful())
142+
assertEquals(TEST_RESULT, success.asSuccessful().result)
143+
assertEquals("Success content", success.content)
144+
}
145+
146+
@Test
147+
fun testResultFailureHelpers() = runTest {
148+
val failure = SafeToolFromCallable.Result.Failure<String>("Error message")
149+
150+
assertTrue(failure.isFailure())
151+
assertEquals("Error message", failure.asFailure().message)
152+
assertEquals("Error message", failure.content)
153+
}
154+
155+
@Test
156+
fun testInvalidArgumentCount() = testInvalidArguments("test")
157+
158+
@Test
159+
fun testZeroArgumentCount() = testInvalidArguments()
160+
161+
@Test
162+
fun testTooManyArguments() = testInvalidArguments("test", 123, "extra argument")
163+
164+
@Test
165+
fun testWithNullArgumentInMockEnvironment() = runTest {
166+
val mockEnvironment = MockEnvironment(shouldSucceed = true)
167+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
168+
assertEquals(safeTool.toolFunction, ::testFunction)
169+
170+
val result = safeTool.execute("test", null)
171+
172+
assertTrue(result.isSuccessful())
173+
assertEquals(TEST_RESULT, result.asSuccessful().result)
174+
}
175+
176+
@Test
177+
fun testSafeToolParameters() = runTest {
178+
val mockEnvironment = MockEnvironment(shouldSucceed = true)
179+
val safeTool = SafeToolFromCallable(::testFunction, mockEnvironment, testClock)
180+
assertEquals(safeTool.toolFunction, ::testFunction)
181+
182+
val safeToolParams = safeTool.toolFunction.parameters.joinToString(", ") { it.name.toString() }
183+
184+
assertEquals("param1, param2", safeToolParams)
185+
}
186+
187+
@Test
188+
fun testWithNullArgumentInDirectCallEnvironment() = runTest {
189+
val directCallEnvironment = object : AIAgentEnvironment {
190+
override suspend fun executeTools(toolCalls: List<Message.Tool.Call>): List<ReceivedToolResult> {
191+
return toolCalls.map { toolCall ->
192+
try {
193+
val result = testFunction("test", null as Int)
194+
195+
ReceivedToolResult(
196+
id = toolCall.id,
197+
tool = toolCall.tool,
198+
content = "Success: $result",
199+
result = ToolFromCallable.Result(
200+
result = result,
201+
type = typeOf<String>(),
202+
json = Json,
203+
)
204+
)
205+
} catch (e: Exception) {
206+
ReceivedToolResult(
207+
id = toolCall.id,
208+
tool = toolCall.tool,
209+
content = "Error: ${e.message}",
210+
result = null
211+
)
212+
}
213+
}
214+
}
215+
216+
override suspend fun reportProblem(exception: Throwable) {
217+
throw exception
218+
}
219+
}
220+
221+
val safeTool = SafeToolFromCallable(::testFunction, directCallEnvironment, testClock)
222+
assertEquals(safeTool.toolFunction, ::testFunction)
223+
224+
val result = safeTool.execute("test", null)
225+
226+
assertTrue(result.isFailure())
227+
assertTrue(result.content.contains("null cannot be cast to non-null type kotlin.Int"))
228+
}
229+
230+
@Test
231+
fun testWithDefaultParameter() = runTest {
232+
val mockEnvironment = MockEnvironment(shouldSucceed = true, resultContent = "Default param result")
233+
val safeTool = SafeToolFromCallable(::testFunctionWithDefaultParam, mockEnvironment, testClock)
234+
assertEquals(safeTool.toolFunction, ::testFunctionWithDefaultParam)
235+
236+
val result = safeTool.execute("test", 123)
237+
238+
assertTrue(result.isSuccessful())
239+
assertEquals(TEST_RESULT, result.asSuccessful().result)
240+
}
241+
242+
@Test
243+
fun testWithNullableArgument() = runTest {
244+
val mockEnvironment = MockEnvironment(shouldSucceed = true, resultContent = "Nullable arg result")
245+
val safeTool = SafeToolFromCallable(::testFunctionWithNullableArg, mockEnvironment, testClock)
246+
assertEquals(safeTool.toolFunction, ::testFunctionWithNullableArg)
247+
248+
val resultWithValue = safeTool.execute("test", 123)
249+
assertTrue(resultWithValue.isSuccessful())
250+
assertEquals(TEST_RESULT, resultWithValue.asSuccessful().result)
251+
252+
val resultWithNull = safeTool.execute("test", null)
253+
assertTrue(resultWithNull.isSuccessful())
254+
assertEquals(TEST_RESULT, resultWithNull.asSuccessful().result)
255+
}
256+
257+
@Test
258+
fun testWithComplexArguments() = runTest {
259+
val mockEnvironment = MockEnvironment(shouldSucceed = true, resultContent = "Complex args result")
260+
val safeTool = SafeToolFromCallable(::testFunctionWithComplexArgs, mockEnvironment, testClock)
261+
assertEquals(safeTool.toolFunction, ::testFunctionWithComplexArgs)
262+
263+
val complexData = ComplexDataClass(
264+
id = "test-id",
265+
numbers = listOf(1, 2, 3),
266+
nested = SimpleDataClass(name = "nested-name", value = 42),
267+
enumValue = TestEnum.SECOND
268+
)
269+
270+
val result = safeTool.execute("test", listOf(4, 5, 6), complexData)
271+
272+
assertTrue(result.isSuccessful())
273+
assertEquals(TEST_RESULT, result.asSuccessful().result)
274+
}
275+
276+
@Test
277+
fun testWithComplexArgumentsInDirectCallEnvironment() = runTest {
278+
val directCallEnvironment = object : AIAgentEnvironment {
279+
override suspend fun executeTools(toolCalls: List<Message.Tool.Call>): List<ReceivedToolResult> {
280+
return toolCalls.map { toolCall ->
281+
try {
282+
val complexData = ComplexDataClass(
283+
id = "direct-call-id",
284+
numbers = listOf(7, 8, 9),
285+
nested = SimpleDataClass(name = "direct-nested", value = 100),
286+
enumValue = TestEnum.THIRD
287+
)
288+
289+
val result = testFunctionWithComplexArgs("direct-test", listOf(10, 11, 12), complexData)
290+
291+
ReceivedToolResult(
292+
id = toolCall.id,
293+
tool = toolCall.tool,
294+
content = "Success: $result",
295+
result = ToolFromCallable.Result(
296+
result = result,
297+
type = typeOf<String>(),
298+
json = Json,
299+
)
300+
)
301+
} catch (e: Exception) {
302+
ReceivedToolResult(
303+
id = toolCall.id,
304+
tool = toolCall.tool,
305+
content = "Error: ${e.message}",
306+
result = null
307+
)
308+
}
309+
}
310+
}
311+
312+
override suspend fun reportProblem(exception: Throwable) {
313+
throw exception
314+
}
315+
}
316+
317+
val safeTool = SafeToolFromCallable(::testFunctionWithComplexArgs, directCallEnvironment, testClock)
318+
assertEquals(safeTool.toolFunction, ::testFunctionWithComplexArgs)
319+
320+
val complexData = ComplexDataClass(
321+
id = "test-complex-id",
322+
numbers = listOf(1, 2, 3),
323+
nested = SimpleDataClass(name = "test-nested", value = 42),
324+
enumValue = TestEnum.FIRST
325+
)
326+
327+
val result = safeTool.execute("test-param", listOf(4, 5, 6), complexData)
328+
329+
assertTrue(result.isSuccessful())
330+
assertTrue(result.content.contains("direct-test"))
331+
assertTrue(result.content.contains("direct-call-id"))
332+
}
333+
}

0 commit comments

Comments
 (0)