1+ package ai.koog.agents.core.environment
2+
3+ import ai.koog.agents.core.CalculatorChatExecutor.testClock
4+ import ai.koog.agents.core.tools.reflect.ToolFromCallable
5+ import ai.koog.prompt.message.Message
6+ import kotlinx.coroutines.test.runTest
7+ import kotlinx.serialization.Serializable
8+ import kotlinx.serialization.json.Json
9+ import org.junit.jupiter.api.assertThrows
10+ import kotlin.reflect.typeOf
11+ import kotlin.test.Test
12+ import kotlin.test.assertEquals
13+ import kotlin.test.assertTrue
14+
15+ class SafeToolTest {
16+ private fun testInvalidArguments (vararg args : Any? ) = runTest {
17+ val mockEnvironment = MockEnvironment (shouldSucceed = true )
18+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
19+ assertEquals(safeTool.toolFunction, ::testFunction)
20+
21+ assertThrows<IllegalStateException > {
22+ safeTool.execute(* args)
23+ }
24+ }
25+
26+ companion object {
27+ private const val TEST_RESULT = " Test result"
28+ private const val TEST_ERROR = " Error: Test error"
29+ }
30+
31+ private fun testFunction (param1 : String , param2 : Int ): String {
32+ return " Result: $param1 - $param2 "
33+ }
34+
35+ private fun testFunctionWithDefaultParam (param1 : String , param2 : Int = 42): String {
36+ return " Result with default: $param1 - $param2 "
37+ }
38+
39+ enum class TestEnum {
40+ FIRST , SECOND , THIRD
41+ }
42+
43+ @Serializable
44+ data class SimpleDataClass (val name : String , val value : Int )
45+
46+ @Serializable
47+ data class ComplexDataClass (
48+ val id : String ,
49+ val numbers : List <Int >,
50+ val nested : SimpleDataClass ,
51+ val enumValue : TestEnum
52+ )
53+
54+ private fun testFunctionWithComplexArgs (
55+ param1 : String ,
56+ param2 : List <Int >,
57+ param3 : ComplexDataClass
58+ ): String {
59+ return " Complex result: $param1 - ${param2.size} items - ${param3.id} "
60+ }
61+
62+ private fun testFunctionWithNullableArg (param1 : String , param2 : Int? ): String {
63+ return " Nullable result: $param1 - ${param2 ? : " null" } "
64+ }
65+
66+ private class MockEnvironment (
67+ private val shouldSucceed : Boolean = true ,
68+ private val resultContent : String = " Success content" ,
69+ ) : AIAgentEnvironment {
70+ override suspend fun executeTools (toolCalls : List <Message .Tool .Call >): List <ReceivedToolResult > {
71+ return toolCalls.map { toolCall ->
72+ if (shouldSucceed) {
73+ ReceivedToolResult (
74+ id = toolCall.id,
75+ tool = toolCall.tool,
76+ content = resultContent,
77+ result = ToolFromCallable .Result (
78+ result = TEST_RESULT ,
79+ type = typeOf<String >(),
80+ json = Json ,
81+ )
82+ )
83+ } else {
84+ ReceivedToolResult (
85+ id = toolCall.id,
86+ tool = toolCall.tool,
87+ content = TEST_ERROR ,
88+ result = null ,
89+ )
90+ }
91+ }
92+ }
93+
94+ override suspend fun reportProblem (exception : Throwable ) {
95+ throw exception
96+ }
97+ }
98+
99+
100+ @Test
101+ fun testExecuteSuccess () = runTest {
102+ val mockEnvironment = MockEnvironment (shouldSucceed = true )
103+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
104+ assertEquals(safeTool.toolFunction, ::testFunction)
105+
106+ val result = safeTool.execute(" test" , 123 )
107+
108+ assertTrue(result.isSuccessful())
109+ assertEquals(TEST_RESULT , result.asSuccessful().result)
110+ assertEquals(" Success content" , result.content)
111+ }
112+
113+ @Test
114+ fun testExecuteFailure () = runTest {
115+ val mockEnvironment = MockEnvironment (shouldSucceed = false )
116+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
117+ assertEquals(safeTool.toolFunction, ::testFunction)
118+
119+ val result = safeTool.execute(" test" , 123 )
120+
121+ assertTrue(result.isFailure())
122+ assertEquals(TEST_ERROR , result.content)
123+ assertEquals(TEST_ERROR , result.asFailure().message)
124+ }
125+
126+ @Test
127+ fun testExecuteRaw () = runTest {
128+ val mockEnvironment = MockEnvironment (shouldSucceed = true , resultContent = " Raw result content" )
129+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
130+ assertEquals(safeTool.toolFunction, ::testFunction)
131+
132+ val result = safeTool.executeRaw(" test" , 123 )
133+
134+ assertEquals(" Raw result content" , result)
135+ }
136+
137+ @Test
138+ fun testResultSuccessHelpers () = runTest {
139+ val success = SafeToolFromCallable .Result .Success (TEST_RESULT , " Success content" )
140+
141+ assertTrue(success.isSuccessful())
142+ assertEquals(TEST_RESULT , success.asSuccessful().result)
143+ assertEquals(" Success content" , success.content)
144+ }
145+
146+ @Test
147+ fun testResultFailureHelpers () = runTest {
148+ val failure = SafeToolFromCallable .Result .Failure <String >(" Error message" )
149+
150+ assertTrue(failure.isFailure())
151+ assertEquals(" Error message" , failure.asFailure().message)
152+ assertEquals(" Error message" , failure.content)
153+ }
154+
155+ @Test
156+ fun testInvalidArgumentCount () = testInvalidArguments(" test" )
157+
158+ @Test
159+ fun testZeroArgumentCount () = testInvalidArguments()
160+
161+ @Test
162+ fun testTooManyArguments () = testInvalidArguments(" test" , 123 , " extra argument" )
163+
164+ @Test
165+ fun testWithNullArgumentInMockEnvironment () = runTest {
166+ val mockEnvironment = MockEnvironment (shouldSucceed = true )
167+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
168+ assertEquals(safeTool.toolFunction, ::testFunction)
169+
170+ val result = safeTool.execute(" test" , null )
171+
172+ assertTrue(result.isSuccessful())
173+ assertEquals(TEST_RESULT , result.asSuccessful().result)
174+ }
175+
176+ @Test
177+ fun testSafeToolParameters () = runTest {
178+ val mockEnvironment = MockEnvironment (shouldSucceed = true )
179+ val safeTool = SafeToolFromCallable (::testFunction, mockEnvironment, testClock)
180+ assertEquals(safeTool.toolFunction, ::testFunction)
181+
182+ val safeToolParams = safeTool.toolFunction.parameters.joinToString(" , " ) { it.name.toString() }
183+
184+ assertEquals(" param1, param2" , safeToolParams)
185+ }
186+
187+ @Test
188+ fun testWithNullArgumentInDirectCallEnvironment () = runTest {
189+ val directCallEnvironment = object : AIAgentEnvironment {
190+ override suspend fun executeTools (toolCalls : List <Message .Tool .Call >): List <ReceivedToolResult > {
191+ return toolCalls.map { toolCall ->
192+ try {
193+ val result = testFunction(" test" , null as Int )
194+
195+ ReceivedToolResult (
196+ id = toolCall.id,
197+ tool = toolCall.tool,
198+ content = " Success: $result " ,
199+ result = ToolFromCallable .Result (
200+ result = result,
201+ type = typeOf<String >(),
202+ json = Json ,
203+ )
204+ )
205+ } catch (e: Exception ) {
206+ ReceivedToolResult (
207+ id = toolCall.id,
208+ tool = toolCall.tool,
209+ content = " Error: ${e.message} " ,
210+ result = null
211+ )
212+ }
213+ }
214+ }
215+
216+ override suspend fun reportProblem (exception : Throwable ) {
217+ throw exception
218+ }
219+ }
220+
221+ val safeTool = SafeToolFromCallable (::testFunction, directCallEnvironment, testClock)
222+ assertEquals(safeTool.toolFunction, ::testFunction)
223+
224+ val result = safeTool.execute(" test" , null )
225+
226+ assertTrue(result.isFailure())
227+ assertTrue(result.content.contains(" null cannot be cast to non-null type kotlin.Int" ))
228+ }
229+
230+ @Test
231+ fun testWithDefaultParameter () = runTest {
232+ val mockEnvironment = MockEnvironment (shouldSucceed = true , resultContent = " Default param result" )
233+ val safeTool = SafeToolFromCallable (::testFunctionWithDefaultParam, mockEnvironment, testClock)
234+ assertEquals(safeTool.toolFunction, ::testFunctionWithDefaultParam)
235+
236+ val result = safeTool.execute(" test" , 123 )
237+
238+ assertTrue(result.isSuccessful())
239+ assertEquals(TEST_RESULT , result.asSuccessful().result)
240+ }
241+
242+ @Test
243+ fun testWithNullableArgument () = runTest {
244+ val mockEnvironment = MockEnvironment (shouldSucceed = true , resultContent = " Nullable arg result" )
245+ val safeTool = SafeToolFromCallable (::testFunctionWithNullableArg, mockEnvironment, testClock)
246+ assertEquals(safeTool.toolFunction, ::testFunctionWithNullableArg)
247+
248+ val resultWithValue = safeTool.execute(" test" , 123 )
249+ assertTrue(resultWithValue.isSuccessful())
250+ assertEquals(TEST_RESULT , resultWithValue.asSuccessful().result)
251+
252+ val resultWithNull = safeTool.execute(" test" , null )
253+ assertTrue(resultWithNull.isSuccessful())
254+ assertEquals(TEST_RESULT , resultWithNull.asSuccessful().result)
255+ }
256+
257+ @Test
258+ fun testWithComplexArguments () = runTest {
259+ val mockEnvironment = MockEnvironment (shouldSucceed = true , resultContent = " Complex args result" )
260+ val safeTool = SafeToolFromCallable (::testFunctionWithComplexArgs, mockEnvironment, testClock)
261+ assertEquals(safeTool.toolFunction, ::testFunctionWithComplexArgs)
262+
263+ val complexData = ComplexDataClass (
264+ id = " test-id" ,
265+ numbers = listOf (1 , 2 , 3 ),
266+ nested = SimpleDataClass (name = " nested-name" , value = 42 ),
267+ enumValue = TestEnum .SECOND
268+ )
269+
270+ val result = safeTool.execute(" test" , listOf (4 , 5 , 6 ), complexData)
271+
272+ assertTrue(result.isSuccessful())
273+ assertEquals(TEST_RESULT , result.asSuccessful().result)
274+ }
275+
276+ @Test
277+ fun testWithComplexArgumentsInDirectCallEnvironment () = runTest {
278+ val directCallEnvironment = object : AIAgentEnvironment {
279+ override suspend fun executeTools (toolCalls : List <Message .Tool .Call >): List <ReceivedToolResult > {
280+ return toolCalls.map { toolCall ->
281+ try {
282+ val complexData = ComplexDataClass (
283+ id = " direct-call-id" ,
284+ numbers = listOf (7 , 8 , 9 ),
285+ nested = SimpleDataClass (name = " direct-nested" , value = 100 ),
286+ enumValue = TestEnum .THIRD
287+ )
288+
289+ val result = testFunctionWithComplexArgs(" direct-test" , listOf (10 , 11 , 12 ), complexData)
290+
291+ ReceivedToolResult (
292+ id = toolCall.id,
293+ tool = toolCall.tool,
294+ content = " Success: $result " ,
295+ result = ToolFromCallable .Result (
296+ result = result,
297+ type = typeOf<String >(),
298+ json = Json ,
299+ )
300+ )
301+ } catch (e: Exception ) {
302+ ReceivedToolResult (
303+ id = toolCall.id,
304+ tool = toolCall.tool,
305+ content = " Error: ${e.message} " ,
306+ result = null
307+ )
308+ }
309+ }
310+ }
311+
312+ override suspend fun reportProblem (exception : Throwable ) {
313+ throw exception
314+ }
315+ }
316+
317+ val safeTool = SafeToolFromCallable (::testFunctionWithComplexArgs, directCallEnvironment, testClock)
318+ assertEquals(safeTool.toolFunction, ::testFunctionWithComplexArgs)
319+
320+ val complexData = ComplexDataClass (
321+ id = " test-complex-id" ,
322+ numbers = listOf (1 , 2 , 3 ),
323+ nested = SimpleDataClass (name = " test-nested" , value = 42 ),
324+ enumValue = TestEnum .FIRST
325+ )
326+
327+ val result = safeTool.execute(" test-param" , listOf (4 , 5 , 6 ), complexData)
328+
329+ assertTrue(result.isSuccessful())
330+ assertTrue(result.content.contains(" direct-test" ))
331+ assertTrue(result.content.contains(" direct-call-id" ))
332+ }
333+ }
0 commit comments