22
33package ai.koog.agents.test
44
5+ import ai.koog.agents.core.tools.SimpleTool
6+ import ai.koog.agents.core.tools.Tool
7+ import ai.koog.agents.core.tools.ToolDescriptor
8+ import ai.koog.agents.core.tools.ToolException
9+ import ai.koog.agents.core.tools.ToolParameterDescriptor
10+ import ai.koog.agents.core.tools.ToolParameterType
511import ai.koog.agents.core.tools.ToolRegistry
612import ai.koog.agents.ext.agent.simpleSingleRunAgent
713import ai.koog.agents.ext.tool.ExitTool
@@ -12,12 +18,32 @@ import ai.koog.agents.testing.tools.getMockExecutor
1218import ai.koog.agents.testing.tools.mockLLMAnswer
1319import ai.koog.prompt.executor.clients.openai.OpenAIModels
1420import kotlinx.coroutines.runBlocking
21+ import kotlinx.serialization.KSerializer
22+ import kotlinx.serialization.Serializable
23+ import org.junit.jupiter.params.ParameterizedTest
24+ import org.junit.jupiter.params.provider.MethodSource
1525import kotlin.test.AfterTest
1626import kotlin.test.Test
1727import kotlin.test.assertTrue
1828import kotlin.uuid.ExperimentalUuidApi
1929
2030class SimpleAgentMockedTest {
31+ companion object {
32+ @JvmStatic
33+ fun getInputMessage (): Array <String > = arrayOf(
34+ " Call conditional tool with success." ,
35+ " Call conditional tool with error." ,
36+ )
37+
38+ @JvmStatic
39+ fun getToolRegistry (): Array <ToolRegistry > = arrayOf(
40+ ToolRegistry { },
41+ ToolRegistry { tool(SayToUser ) }
42+ )
43+ }
44+
45+ val errorTrigger = " Trigger an error."
46+
2147 val systemPrompt = """
2248 You are a helpful assistant.
2349 You MUST use tools to communicate to the user.
@@ -32,18 +58,41 @@ class SimpleAgentMockedTest {
3258 SayToUser ,
3359 SayToUser .Args (" Calculating..." )
3460 ) onRequestEquals " Write a Kotlin function to calculate factorial."
61+ mockLLMToolCall(
62+ ErrorTool ,
63+ ErrorTool .Args (" test" )
64+ ) onRequestEquals errorTrigger
65+ mockLLMToolCall(
66+ ConditionalTool ,
67+ ConditionalTool .Args (" success" )
68+ ) onRequestEquals " Call conditional tool with success."
69+ mockLLMToolCall(
70+ ConditionalTool ,
71+ ConditionalTool .Args (" error" )
72+ ) onRequestEquals " Call conditional tool with error."
3573 }
3674
3775 val eventHandlerConfig: EventHandlerConfig .() -> Unit = {
3876 onToolCall = { tool, args ->
3977 println (" Tool called: tool ${tool.name} , args $args " )
4078 actualToolCalls.add(tool.name)
79+ iterationCount++
4180 }
4281
4382 onAgentRunError = { strategyName, sessionUuid, throwable ->
4483 errors.add(throwable)
4584 }
4685
86+ onToolCall = { tool, args ->
87+ println (" Tool called: tool ${tool.name} , args $args " )
88+ actualToolCalls.add(tool.name)
89+ }
90+
91+ onToolCallFailure = { tool, args, throwable ->
92+ println (" Tool call failure: tool ${tool.name} , args $args , error=${throwable.message} " )
93+ errors.add(throwable)
94+ }
95+
4796 onAgentFinished = { strategyName, result ->
4897 results.add(result)
4998 }
@@ -52,16 +101,67 @@ class SimpleAgentMockedTest {
52101 val actualToolCalls = mutableListOf<String >()
53102 val errors = mutableListOf<Throwable >()
54103 val results = mutableListOf<String ?>()
104+ var iterationCount = 0
55105
56106 @AfterTest
57107 fun teardown () {
58108 actualToolCalls.clear()
59109 errors.clear()
60110 results.clear()
111+ iterationCount = 0
112+ }
113+
114+ object ErrorTool : SimpleTool<ErrorTool.Args>() {
115+ @Serializable
116+ data class Args (val message : String ) : Tool.Args
117+
118+ override val argsSerializer: KSerializer <Args > = Args .serializer()
119+
120+ override val descriptor: ToolDescriptor = ToolDescriptor (
121+ name = " error_tool" ,
122+ description = " A tool that always throws an exception" ,
123+ requiredParameters = listOf (
124+ ToolParameterDescriptor (
125+ name = " message" ,
126+ description = " Message for the error" ,
127+ type = ToolParameterType .String
128+ )
129+ )
130+ )
131+
132+ override suspend fun doExecute (args : Args ): String {
133+ throw ToolException .ValidationFailure (" This tool always fails" )
134+ }
135+ }
136+
137+ object ConditionalTool : SimpleTool<ConditionalTool.Args>() {
138+ @Serializable
139+ data class Args (val condition : String ) : Tool.Args
140+
141+ override val argsSerializer: KSerializer <Args > = Args .serializer()
142+
143+ override val descriptor: ToolDescriptor = ToolDescriptor (
144+ name = " conditional_tool" ,
145+ description = " A tool that conditionally throws an exception" ,
146+ requiredParameters = listOf (
147+ ToolParameterDescriptor (
148+ name = " condition" ,
149+ description = " Condition that determines if the tool will succeed or fail" ,
150+ type = ToolParameterType .String
151+ )
152+ )
153+ )
154+
155+ override suspend fun doExecute (args : Args ): String {
156+ if (args.condition == " error" ) {
157+ throw ToolException .ValidationFailure (" Conditional failure triggered" )
158+ }
159+ return " Conditional success"
160+ }
61161 }
62162
63163 @Test
64- fun `simpleSingleRunAgent should not call tools by default` () = runBlocking {
164+ fun ` test simpleSingleRunAgent doesn't call tools by default` () = runBlocking {
65165 val agent = simpleSingleRunAgent(
66166 systemPrompt = systemPrompt,
67167 llmModel = OpenAIModels .Reasoning .GPT4oMini ,
@@ -83,7 +183,7 @@ class SimpleAgentMockedTest {
83183 }
84184
85185 @Test
86- fun `simpleSingleRunAgent should call a custom tool` () = runBlocking {
186+ fun `test simpleSingleRunAgent calls a custom tool` () = runBlocking {
87187 val toolRegistry = ToolRegistry {
88188 tool(SayToUser )
89189 }
@@ -108,4 +208,117 @@ class SimpleAgentMockedTest {
108208 " Expected no errors, but got: ${errors.joinToString(" \n " ) { it.message ? : " " }} "
109209 )
110210 }
111- }
211+
212+ @ParameterizedTest
213+ @MethodSource(" getToolRegistry" )
214+ fun `test simpleSingleRunAgent handles non-registered tools` (toolRegistry : ToolRegistry ) = runBlocking {
215+ val agent = simpleSingleRunAgent(
216+ systemPrompt = systemPrompt,
217+ llmModel = OpenAIModels .Reasoning .GPT4oMini ,
218+ temperature = 1.0 ,
219+ toolRegistry = toolRegistry,
220+ maxIterations = 10 ,
221+ executor = testExecutor,
222+ installFeatures = { install(EventHandler , eventHandlerConfig) }
223+ )
224+
225+ try {
226+ agent.run (errorTrigger)
227+ } catch (e: Throwable ) {
228+ assertTrue(e is IllegalArgumentException , " Expected IllegalArgumentException" )
229+ assertTrue(e.message?.contains(" is not defined" ) == true , " Expected 'not defined' error message" )
230+ }
231+
232+ assertTrue(errors.isNotEmpty(), " Expected errors to be recorded" )
233+ }
234+
235+ @Test
236+ fun `test simpleSingleRunAgent handles tool execution errors` () = runBlocking {
237+ val toolRegistry = ToolRegistry {
238+ tool(ErrorTool )
239+ }
240+
241+ val agent = simpleSingleRunAgent(
242+ systemPrompt = systemPrompt,
243+ llmModel = OpenAIModels .Reasoning .GPT4oMini ,
244+ temperature = 1.0 ,
245+ toolRegistry = toolRegistry,
246+ maxIterations = 10 ,
247+ executor = testExecutor,
248+ installFeatures = { install(EventHandler , eventHandlerConfig) }
249+ )
250+
251+ try {
252+ agent.run (errorTrigger)
253+ } catch (e: Throwable ) {
254+ errors.add(e)
255+ }
256+
257+ assertTrue(actualToolCalls.contains(ErrorTool .name), " The ${ErrorTool .name} tool was not called" )
258+ assertTrue(
259+ errors.isEmpty(),
260+ " Expected no errors, but got: ${errors.joinToString(" \n " ) { it.message ? : " " }} "
261+ )
262+ }
263+
264+ @ParameterizedTest
265+ @MethodSource(" getInputMessage" )
266+ fun `test simpleSingleRunAgent handles conditional tool execution` (agentMessage : String ) = runBlocking {
267+ val toolRegistry = ToolRegistry {
268+ tool(ConditionalTool )
269+ }
270+
271+ val successAgent = simpleSingleRunAgent(
272+ systemPrompt = systemPrompt,
273+ llmModel = OpenAIModels .Reasoning .GPT4oMini ,
274+ temperature = 1.0 ,
275+ toolRegistry = toolRegistry,
276+ maxIterations = 10 ,
277+ executor = testExecutor,
278+ installFeatures = { install(EventHandler , eventHandlerConfig) }
279+ )
280+
281+ successAgent.run (agentMessage)
282+
283+ assertTrue(actualToolCalls.contains(ConditionalTool .name), " The ${ConditionalTool .name} tool was not called" )
284+ assertTrue(errors.isEmpty(), " No errors should be recorded for success case" )
285+ }
286+
287+ @Test
288+ fun `test simpleSingleRunAgent fails after reaching maxIterations` () = runBlocking {
289+ val toolRegistry = ToolRegistry {
290+ tool(SayToUser )
291+ }
292+
293+ val loopExecutor = getMockExecutor {
294+ mockLLMToolCall(SayToUser , SayToUser .Args (" Looping..." )) onRequestEquals " Make the agent loop."
295+ }
296+
297+ iterationCount = 0
298+
299+ val agent = simpleSingleRunAgent(
300+ systemPrompt = systemPrompt,
301+ llmModel = OpenAIModels .Reasoning .GPT4oMini ,
302+ temperature = 1.0 ,
303+ toolRegistry = toolRegistry,
304+ maxIterations = 3 ,
305+ executor = loopExecutor,
306+ installFeatures = { install(EventHandler , eventHandlerConfig) }
307+ )
308+
309+ try {
310+ agent.run (" Make the agent loop." )
311+ } catch (e: Throwable ) {
312+ errors.add(e)
313+ }
314+
315+ assertTrue(errors.isNotEmpty(), " Error should be recorded when maxIterations is reached" )
316+ assertTrue(
317+ errors.any {
318+ it.message?.contains(" Maximum number of iterations" ) == true ||
319+ it.message?.contains(" Agent couldn't finish in given number of steps" ) == true
320+ },
321+ " Expected error about maximum iterations"
322+ )
323+ }
324+ }
0 commit comments