Skip to content

Commit dddcd08

Browse files
committed
Add test and make llmAsAJudge generic for any input
1 parent c2a99d4 commit dddcd08

File tree

3 files changed

+184
-5
lines changed

3 files changed

+184
-5
lines changed

agents/agents-ext/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ kotlin {
3232
jvmTest {
3333
dependencies {
3434
implementation(kotlin("test-junit5"))
35+
implementation(libs.mockk)
3536
}
3637
}
3738
}

agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/LLMAsAJudge.kt

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.koog.agents.ext.agent
22

3+
import ai.koog.agents.core.annotation.InternalAgentsApi
34
import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate
45
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
56
import ai.koog.agents.core.tools.annotations.LLMDescription
@@ -10,9 +11,21 @@ import ai.koog.prompt.message.Message
1011
import ai.koog.prompt.structure.StructureFixingParser
1112
import kotlinx.serialization.Serializable
1213

14+
/**
15+
* Represents the result of a plan evaluation performed by an LLM (Large Language Model).
16+
*
17+
* This class is primarily used within internal agent-related implementations where an LLM
18+
* evaluates the correctness of a plan and optionally provides feedback for improvements.
19+
*
20+
* @property isCorrect Indicates whether the evaluated plan is correct.
21+
* @property feedback Optional feedback provided by the LLM about the evaluated plan. This property
22+
* is populated only when the plan is deemed incorrect (`isCorrect == false`) and adjustments
23+
* are suggested.
24+
*/
25+
@InternalAgentsApi
1326
@Serializable
1427
@LLMDescription("Result of the evaluation")
15-
internal data class CriticResultFromLLM(
28+
public data class CriticResultFromLLM(
1629
@property:LLMDescription("Was the plan correct?")
1730
val isCorrect: Boolean,
1831
@property:LLMDescription(
@@ -27,10 +40,10 @@ internal data class CriticResultFromLLM(
2740
*
2841
* @property successful Indicates whether the critique operation was successful.
2942
* @property feedback A textual message providing details about the*/
30-
public data class CriticResult(
43+
public data class CriticResult<T>(
3144
val successful: Boolean,
3245
val feedback: String,
33-
val input: String
46+
val input: T
3447
)
3548

3649
/**
@@ -40,10 +53,11 @@ public data class CriticResult(
4053
* @param llmModel The optional language model to override the default model during the session. If `null`, the default model will be used.
4154
* @param task The task or instruction to be presented to the language model for critical evaluation.
4255
*/
43-
public fun AIAgentSubgraphBuilderBase<*, *>.llmAsAJudge(
56+
@OptIn(InternalAgentsApi::class)
57+
public inline fun <reified T> AIAgentSubgraphBuilderBase<*, *>.llmAsAJudge(
4458
llmModel: LLModel? = null,
4559
task: String
46-
): AIAgentNodeDelegate<String, CriticResult> = node<String, CriticResult> { nodeInput ->
60+
): AIAgentNodeDelegate<T, CriticResult<T>> = node<T, CriticResult<T>> { nodeInput ->
4761
llm.writeSession {
4862
val initialPrompt = prompt.copy()
4963
val initialModel = model
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package ai.koog.agents.ext.agent
2+
3+
import ai.koog.agents.core.agent.config.AIAgentConfig
4+
import ai.koog.agents.core.agent.context.AIAgentGraphContext
5+
import ai.koog.agents.core.agent.context.AIAgentLLMContext
6+
import ai.koog.agents.core.agent.context.DetachedPromptExecutorAPI
7+
import ai.koog.agents.core.agent.entity.FinishNode
8+
import ai.koog.agents.core.agent.entity.StartNode
9+
import ai.koog.agents.core.annotation.InternalAgentsApi
10+
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
11+
import ai.koog.agents.core.environment.AIAgentEnvironment
12+
import ai.koog.agents.core.feature.AIAgentGraphPipeline
13+
import ai.koog.agents.core.tools.ToolRegistry
14+
import ai.koog.prompt.dsl.prompt
15+
import ai.koog.prompt.executor.clients.openai.OpenAIModels
16+
import ai.koog.prompt.executor.model.PromptExecutor
17+
import ai.koog.prompt.llm.OllamaModels
18+
import ai.koog.prompt.message.Message
19+
import ai.koog.prompt.message.ResponseMetaInfo
20+
import io.mockk.coEvery
21+
import io.mockk.coVerify
22+
import io.mockk.mockk
23+
import kotlinx.coroutines.test.runTest
24+
import kotlinx.datetime.Clock
25+
import kotlinx.datetime.Instant
26+
import kotlinx.serialization.json.Json
27+
import kotlin.reflect.typeOf
28+
import kotlin.test.Test
29+
import kotlin.test.assertEquals
30+
import kotlin.test.assertNotEquals
31+
32+
class LLMAsJudgeNodeTest {
33+
private val defaultName = "re_act"
34+
35+
private val testClock: Clock = object : Clock {
36+
override fun now(): Instant = Instant.parse("2023-01-01T00:00:00Z")
37+
}
38+
39+
val CRITIC_TASK = "Find all numbers produced by LLM and check that they are not divided by 3"
40+
41+
@OptIn(InternalAgentsApi::class, DetachedPromptExecutorAPI::class)
42+
@Test
43+
fun testChatStrategyDefaultName() = runTest {
44+
45+
val initialPrompt = prompt("id") {
46+
system("System message")
47+
user("User question")
48+
assistant("Assistant question")
49+
user("User answer")
50+
tool {
51+
call(id = "tool-id-1", tool = "tool1", content = "{x=1}")
52+
result(id = "tool-id-1", tool = "tool1", content = "{result=2}")
53+
}
54+
tool {
55+
call(id = "tool-id-2", tool = "tool2", content = "{x=100}")
56+
result(id = "tool-id-2", tool = "tool2", content = "{result=-200}")
57+
}
58+
}
59+
60+
val mockPromptExecutor = mockk<PromptExecutor>()
61+
62+
val mockEnv = mockk<AIAgentEnvironment>()
63+
64+
val initialModel = OllamaModels.Meta.LLAMA_3_2
65+
66+
val mockLLM = AIAgentLLMContext(
67+
tools = emptyList(),
68+
toolRegistry = ToolRegistry {},
69+
prompt = initialPrompt,
70+
model = initialModel,
71+
promptExecutor = mockPromptExecutor,
72+
environment = mockEnv,
73+
config = AIAgentConfig(prompt = prompt("id") {}, model = OpenAIModels.Chat.GPT4o, maxAgentIterations = 10),
74+
clock = testClock
75+
)
76+
77+
val context = AIAgentGraphContext(
78+
environment = mockEnv,
79+
agentInputType = typeOf<String>(),
80+
agentInput = "Hello",
81+
config = mockk(),
82+
llm = mockLLM,
83+
stateManager = mockk(),
84+
storage = mockk(),
85+
runId = "run-1",
86+
strategyName = "test-strategy",
87+
pipeline = AIAgentGraphPipeline(),
88+
agentId = "agent-01"
89+
)
90+
91+
val subgraphContext = object : AIAgentSubgraphBuilderBase<String, String>() {
92+
override val nodeStart: StartNode<String> = mockk()
93+
override val nodeFinish: FinishNode<String> = mockk()
94+
}
95+
96+
val anotherModel = OllamaModels.Meta.LLAMA_4_SCOUT
97+
98+
val llmJudgeNode by subgraphContext.llmAsAJudge<Int>(
99+
llmModel = anotherModel,
100+
task = CRITIC_TASK
101+
)
102+
103+
coEvery { mockPromptExecutor.execute(any(), any(), any()) } returns listOf(
104+
Message.Assistant(
105+
content = Json.encodeToString(
106+
CriticResultFromLLM.serializer(),
107+
CriticResultFromLLM(isCorrect = true, feedback = "All good")
108+
),
109+
metaInfo = ResponseMetaInfo.create(testClock),
110+
)
111+
)
112+
113+
llmJudgeNode.execute(context, input = -200)
114+
115+
val expectedXMLHistory = """
116+
<previous_conversation>
117+
<user>
118+
System message
119+
</user>
120+
<user>
121+
User question
122+
</user>
123+
<assistant>
124+
Assistant question
125+
</assistant>
126+
<user>
127+
User answer
128+
</user>
129+
<tool_call tool=tool1>
130+
{x=1}
131+
</tool_call>
132+
<tool_result tool=tool1>
133+
{result=2}
134+
</tool_result>
135+
<tool_call tool=tool2>
136+
{x=100}
137+
</tool_call>
138+
<tool_result tool=tool2>
139+
{result=-200}
140+
</tool_result>
141+
</previous_conversation>
142+
""".trimIndent()
143+
144+
coVerify {
145+
mockPromptExecutor.execute(
146+
prompt = match {
147+
(it.messages.size == 2)
148+
&& (it.messages.first().role == Message.Role.System && it.messages.first().content == CRITIC_TASK)
149+
&& (it.messages.last().role == Message.Role.User && it.messages.last().content.trimIndent() == expectedXMLHistory)
150+
&& (it.id == "critic")
151+
},
152+
model = match {
153+
it == anotherModel
154+
},
155+
tools = any()
156+
)
157+
}
158+
159+
assertEquals(initialPrompt, context.llm.prompt)
160+
161+
assertEquals(initialModel, context.llm.model)
162+
assertNotEquals(anotherModel, context.llm.model)
163+
}
164+
}

0 commit comments

Comments
 (0)