Skip to content

Commit 8a2d21f

Browse files
authored
JBAI-13751 Extend SimpleAgentMockedTest coverage (#226)
1 parent e0e1e41 commit 8a2d21f

File tree

2 files changed

+220
-7
lines changed

2 files changed

+220
-7
lines changed

agents/agents-test/build.gradle.kts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ kotlin {
1212
sourceSets {
1313
commonMain {
1414
dependencies {
15-
api(kotlin("test"))
16-
1715
api(project(":agents:agents-core"))
1816
api(project(":agents:agents-ext"))
1917
api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
2018
api(project(":prompt:prompt-executor:prompt-executor-llms-all"))
2119
api(project(":prompt:prompt-tokenizer"))
2220

21+
api(kotlin("test"))
22+
2323
api(libs.jetbrains.annotations)
2424
api(libs.kotlinx.coroutines.core)
2525
api(libs.kotlinx.serialization.json)
@@ -36,9 +36,9 @@ kotlin {
3636

3737
jvmTest {
3838
dependencies {
39-
implementation(kotlin("test-junit5"))
4039
implementation(project(":agents:agents-features:agents-features-event-handler"))
41-
40+
implementation(kotlin("test-junit5"))
41+
implementation(libs.junit.jupiter.params)
4242
implementation(libs.ktor.client.cio)
4343
}
4444
}

agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt

Lines changed: 216 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
package ai.koog.agents.test
44

5+
import ai.koog.agents.core.tools.SimpleTool
6+
import ai.koog.agents.core.tools.Tool
7+
import ai.koog.agents.core.tools.ToolDescriptor
8+
import ai.koog.agents.core.tools.ToolException
9+
import ai.koog.agents.core.tools.ToolParameterDescriptor
10+
import ai.koog.agents.core.tools.ToolParameterType
511
import ai.koog.agents.core.tools.ToolRegistry
612
import ai.koog.agents.ext.agent.simpleSingleRunAgent
713
import ai.koog.agents.ext.tool.ExitTool
@@ -12,12 +18,32 @@ import ai.koog.agents.testing.tools.getMockExecutor
1218
import ai.koog.agents.testing.tools.mockLLMAnswer
1319
import ai.koog.prompt.executor.clients.openai.OpenAIModels
1420
import kotlinx.coroutines.runBlocking
21+
import kotlinx.serialization.KSerializer
22+
import kotlinx.serialization.Serializable
23+
import org.junit.jupiter.params.ParameterizedTest
24+
import org.junit.jupiter.params.provider.MethodSource
1525
import kotlin.test.AfterTest
1626
import kotlin.test.Test
1727
import kotlin.test.assertTrue
1828
import kotlin.uuid.ExperimentalUuidApi
1929

2030
class SimpleAgentMockedTest {
31+
companion object {
32+
@JvmStatic
33+
fun getInputMessage(): Array<String> = arrayOf(
34+
"Call conditional tool with success.",
35+
"Call conditional tool with error.",
36+
)
37+
38+
@JvmStatic
39+
fun getToolRegistry(): Array<ToolRegistry> = arrayOf(
40+
ToolRegistry { },
41+
ToolRegistry { tool(SayToUser) }
42+
)
43+
}
44+
45+
val errorTrigger = "Trigger an error."
46+
2147
val systemPrompt = """
2248
You are a helpful assistant.
2349
You MUST use tools to communicate to the user.
@@ -32,18 +58,41 @@ class SimpleAgentMockedTest {
3258
SayToUser,
3359
SayToUser.Args("Calculating...")
3460
) onRequestEquals "Write a Kotlin function to calculate factorial."
61+
mockLLMToolCall(
62+
ErrorTool,
63+
ErrorTool.Args("test")
64+
) onRequestEquals errorTrigger
65+
mockLLMToolCall(
66+
ConditionalTool,
67+
ConditionalTool.Args("success")
68+
) onRequestEquals "Call conditional tool with success."
69+
mockLLMToolCall(
70+
ConditionalTool,
71+
ConditionalTool.Args("error")
72+
) onRequestEquals "Call conditional tool with error."
3573
}
3674

3775
val eventHandlerConfig: EventHandlerConfig.() -> Unit = {
3876
onToolCall = { tool, args ->
3977
println("Tool called: tool ${tool.name}, args $args")
4078
actualToolCalls.add(tool.name)
79+
iterationCount++
4180
}
4281

4382
onAgentRunError = { strategyName, sessionUuid, throwable ->
4483
errors.add(throwable)
4584
}
4685

86+
onToolCall = { tool, args ->
87+
println("Tool called: tool ${tool.name}, args $args")
88+
actualToolCalls.add(tool.name)
89+
}
90+
91+
onToolCallFailure = { tool, args, throwable ->
92+
println("Tool call failure: tool ${tool.name}, args $args, error=${throwable.message}")
93+
errors.add(throwable)
94+
}
95+
4796
onAgentFinished = { strategyName, result ->
4897
results.add(result)
4998
}
@@ -52,16 +101,67 @@ class SimpleAgentMockedTest {
52101
val actualToolCalls = mutableListOf<String>()
53102
val errors = mutableListOf<Throwable>()
54103
val results = mutableListOf<String?>()
104+
var iterationCount = 0
55105

56106
@AfterTest
57107
fun teardown() {
58108
actualToolCalls.clear()
59109
errors.clear()
60110
results.clear()
111+
iterationCount = 0
112+
}
113+
114+
object ErrorTool : SimpleTool<ErrorTool.Args>() {
115+
@Serializable
116+
data class Args(val message: String) : Tool.Args
117+
118+
override val argsSerializer: KSerializer<Args> = Args.serializer()
119+
120+
override val descriptor: ToolDescriptor = ToolDescriptor(
121+
name = "error_tool",
122+
description = "A tool that always throws an exception",
123+
requiredParameters = listOf(
124+
ToolParameterDescriptor(
125+
name = "message",
126+
description = "Message for the error",
127+
type = ToolParameterType.String
128+
)
129+
)
130+
)
131+
132+
override suspend fun doExecute(args: Args): String {
133+
throw ToolException.ValidationFailure("This tool always fails")
134+
}
135+
}
136+
137+
object ConditionalTool : SimpleTool<ConditionalTool.Args>() {
138+
@Serializable
139+
data class Args(val condition: String) : Tool.Args
140+
141+
override val argsSerializer: KSerializer<Args> = Args.serializer()
142+
143+
override val descriptor: ToolDescriptor = ToolDescriptor(
144+
name = "conditional_tool",
145+
description = "A tool that conditionally throws an exception",
146+
requiredParameters = listOf(
147+
ToolParameterDescriptor(
148+
name = "condition",
149+
description = "Condition that determines if the tool will succeed or fail",
150+
type = ToolParameterType.String
151+
)
152+
)
153+
)
154+
155+
override suspend fun doExecute(args: Args): String {
156+
if (args.condition == "error") {
157+
throw ToolException.ValidationFailure("Conditional failure triggered")
158+
}
159+
return "Conditional success"
160+
}
61161
}
62162

63163
@Test
64-
fun `simpleSingleRunAgent should not call tools by default`() = runBlocking {
164+
fun ` test simpleSingleRunAgent doesn't call tools by default`() = runBlocking {
65165
val agent = simpleSingleRunAgent(
66166
systemPrompt = systemPrompt,
67167
llmModel = OpenAIModels.Reasoning.GPT4oMini,
@@ -83,7 +183,7 @@ class SimpleAgentMockedTest {
83183
}
84184

85185
@Test
86-
fun `simpleSingleRunAgent should call a custom tool`() = runBlocking {
186+
fun `test simpleSingleRunAgent calls a custom tool`() = runBlocking {
87187
val toolRegistry = ToolRegistry {
88188
tool(SayToUser)
89189
}
@@ -108,4 +208,117 @@ class SimpleAgentMockedTest {
108208
"Expected no errors, but got: ${errors.joinToString("\n") { it.message ?: "" }}"
109209
)
110210
}
111-
}
211+
212+
@ParameterizedTest
213+
@MethodSource("getToolRegistry")
214+
fun `test simpleSingleRunAgent handles non-registered tools`(toolRegistry: ToolRegistry) = runBlocking {
215+
val agent = simpleSingleRunAgent(
216+
systemPrompt = systemPrompt,
217+
llmModel = OpenAIModels.Reasoning.GPT4oMini,
218+
temperature = 1.0,
219+
toolRegistry = toolRegistry,
220+
maxIterations = 10,
221+
executor = testExecutor,
222+
installFeatures = { install(EventHandler, eventHandlerConfig) }
223+
)
224+
225+
try {
226+
agent.run(errorTrigger)
227+
} catch (e: Throwable) {
228+
assertTrue(e is IllegalArgumentException, "Expected IllegalArgumentException")
229+
assertTrue(e.message?.contains("is not defined") == true, "Expected 'not defined' error message")
230+
}
231+
232+
assertTrue(errors.isNotEmpty(), "Expected errors to be recorded")
233+
}
234+
235+
@Test
236+
fun `test simpleSingleRunAgent handles tool execution errors`() = runBlocking {
237+
val toolRegistry = ToolRegistry {
238+
tool(ErrorTool)
239+
}
240+
241+
val agent = simpleSingleRunAgent(
242+
systemPrompt = systemPrompt,
243+
llmModel = OpenAIModels.Reasoning.GPT4oMini,
244+
temperature = 1.0,
245+
toolRegistry = toolRegistry,
246+
maxIterations = 10,
247+
executor = testExecutor,
248+
installFeatures = { install(EventHandler, eventHandlerConfig) }
249+
)
250+
251+
try {
252+
agent.run(errorTrigger)
253+
} catch (e: Throwable) {
254+
errors.add(e)
255+
}
256+
257+
assertTrue(actualToolCalls.contains(ErrorTool.name), "The ${ErrorTool.name} tool was not called")
258+
assertTrue(
259+
errors.isEmpty(),
260+
"Expected no errors, but got: ${errors.joinToString("\n") { it.message ?: "" }}"
261+
)
262+
}
263+
264+
@ParameterizedTest
265+
@MethodSource("getInputMessage")
266+
fun `test simpleSingleRunAgent handles conditional tool execution`(agentMessage: String) = runBlocking {
267+
val toolRegistry = ToolRegistry {
268+
tool(ConditionalTool)
269+
}
270+
271+
val successAgent = simpleSingleRunAgent(
272+
systemPrompt = systemPrompt,
273+
llmModel = OpenAIModels.Reasoning.GPT4oMini,
274+
temperature = 1.0,
275+
toolRegistry = toolRegistry,
276+
maxIterations = 10,
277+
executor = testExecutor,
278+
installFeatures = { install(EventHandler, eventHandlerConfig) }
279+
)
280+
281+
successAgent.run(agentMessage)
282+
283+
assertTrue(actualToolCalls.contains(ConditionalTool.name), "The ${ConditionalTool.name} tool was not called")
284+
assertTrue(errors.isEmpty(), "No errors should be recorded for success case")
285+
}
286+
287+
@Test
288+
fun `test simpleSingleRunAgent fails after reaching maxIterations`() = runBlocking {
289+
val toolRegistry = ToolRegistry {
290+
tool(SayToUser)
291+
}
292+
293+
val loopExecutor = getMockExecutor {
294+
mockLLMToolCall(SayToUser, SayToUser.Args("Looping...")) onRequestEquals "Make the agent loop."
295+
}
296+
297+
iterationCount = 0
298+
299+
val agent = simpleSingleRunAgent(
300+
systemPrompt = systemPrompt,
301+
llmModel = OpenAIModels.Reasoning.GPT4oMini,
302+
temperature = 1.0,
303+
toolRegistry = toolRegistry,
304+
maxIterations = 3,
305+
executor = loopExecutor,
306+
installFeatures = { install(EventHandler, eventHandlerConfig) }
307+
)
308+
309+
try {
310+
agent.run("Make the agent loop.")
311+
} catch (e: Throwable) {
312+
errors.add(e)
313+
}
314+
315+
assertTrue(errors.isNotEmpty(), "Error should be recorded when maxIterations is reached")
316+
assertTrue(
317+
errors.any {
318+
it.message?.contains("Maximum number of iterations") == true ||
319+
it.message?.contains("Agent couldn't finish in given number of steps") == true
320+
},
321+
"Expected error about maximum iterations"
322+
)
323+
}
324+
}

0 commit comments

Comments
 (0)