Skip to content

Commit 1aba521

Browse files
Add LLM as a Judge component (#866)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context <!-- Why is this change needed? What problem does it solve? --> ## Breaking Changes <!-- Will users need to update their code or configurations? --> --- #### Type of the changes - [x] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent c433b8e commit 1aba521

File tree

6 files changed

+328
-30
lines changed

6 files changed

+328
-30
lines changed

agents/agents-ext/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ kotlin {
3232
jvmTest {
3333
dependencies {
3434
implementation(kotlin("test-junit5"))
35+
implementation(libs.mockk)
3536
}
3637
}
3738
}

agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ai.koog.agents.ext.agent
33
import ai.koog.agents.core.agent.context.AIAgentGraphContextBase
44
import ai.koog.agents.core.agent.entity.ToolSelectionStrategy
55
import ai.koog.agents.core.agent.entity.createStorageKey
6+
import ai.koog.agents.core.annotation.InternalAgentsApi
67
import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker
78
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
89
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphDelegate
@@ -25,22 +26,9 @@ import ai.koog.prompt.llm.LLModel
2526
import ai.koog.prompt.message.Message
2627
import ai.koog.prompt.params.LLMParams
2728
import kotlinx.serialization.KSerializer
28-
import kotlinx.serialization.Serializable
2929
import kotlinx.serialization.json.Json
3030
import kotlinx.serialization.serializer
3131

32-
/**
33-
* Represents the result of a verification process for a subgraph.
34-
*
35-
* @property correct Indicates whether the subgraph verification was successful.
36-
* @property message A message providing details about the verification outcome.
37-
*/
38-
@Serializable
39-
public data class VerifiedSubgraphResult(
40-
val correct: Boolean,
41-
val message: String,
42-
)
43-
4432
/**
4533
* Utility object providing tools and methods for working with subgraphs and tasks in a controlled
4634
* and structured way. These utilities are designed to help finalize subgraph-related tasks and
@@ -240,22 +228,43 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
240228
}
241229

242230
/**
243-
* [subgraphWithTask] with [VerifiedSubgraphResult] result.
231+
* [subgraphWithTask] with [CriticResult] result.
244232
* It verifies if the task was performed correctly or not, and describes the problems if any.
245233
*/
234+
@OptIn(InternalAgentsApi::class)
246235
@Suppress("unused")
247236
@AIAgentBuilderDslMarker
248-
public inline fun <reified Input> AIAgentSubgraphBuilderBase<*, *>.subgraphWithVerification(
237+
public inline fun <reified Input : Any> AIAgentSubgraphBuilderBase<*, *>.subgraphWithVerification(
249238
toolSelectionStrategy: ToolSelectionStrategy,
250239
llmModel: LLModel? = null,
251240
llmParams: LLMParams? = null,
252241
noinline defineTask: suspend AIAgentGraphContextBase.(input: Input) -> String
253-
): AIAgentSubgraphDelegate<Input, VerifiedSubgraphResult> = subgraphWithTask(
254-
toolSelectionStrategy = toolSelectionStrategy,
255-
llmModel = llmModel,
256-
llmParams = llmParams,
257-
defineTask = defineTask
258-
)
242+
): AIAgentSubgraphDelegate<Input, CriticResult<Input>> = subgraph {
243+
val inputKey = createStorageKey<Input>("subgraphWithVerification-input-key")
244+
245+
val saveInput by node<Input, Input> { input ->
246+
storage.set(inputKey, input)
247+
248+
input
249+
}
250+
251+
val verifyTask by subgraphWithTask<Input, CriticResultFromLLM>(
252+
toolSelectionStrategy = toolSelectionStrategy,
253+
llmModel = llmModel,
254+
llmParams = llmParams,
255+
defineTask = defineTask
256+
)
257+
258+
val provideResult by node<CriticResultFromLLM, CriticResult<Input>> { result ->
259+
CriticResult(
260+
successful = result.isCorrect,
261+
feedback = result.feedback,
262+
input = storage.get(inputKey)!!
263+
)
264+
}
265+
266+
nodeStart then saveInput then verifyTask then provideResult then nodeFinish
267+
}
259268

260269
/**
261270
* Constructs a subgraph within an AI agent's strategy graph with additional verification capabilities.
@@ -271,16 +280,16 @@ public inline fun <reified Input> AIAgentSubgraphBuilderBase<*, *>.subgraphWithV
271280
* @param defineTask A suspendable function defining the task that the subgraph will execute,
272281
* which takes an input and produces a string-based task description.
273282
* @return A delegate representing the constructed subgraph with input type `Input` and output type
274-
* as a verified subgraph result `VerifiedSubgraphResult`.
283+
* as a verified subgraph result `CriticResult`.
275284
*/
276285
@Suppress("unused")
277286
@AIAgentBuilderDslMarker
278-
public inline fun <reified Input> AIAgentSubgraphBuilderBase<*, *>.subgraphWithVerification(
287+
public inline fun <reified Input : Any> AIAgentSubgraphBuilderBase<*, *>.subgraphWithVerification(
279288
tools: List<Tool<*, *>>,
280289
llmModel: LLModel? = null,
281290
llmParams: LLMParams? = null,
282291
noinline defineTask: suspend AIAgentGraphContextBase.(input: Input) -> String
283-
): AIAgentSubgraphDelegate<Input, VerifiedSubgraphResult> = subgraphWithVerification(
292+
): AIAgentSubgraphDelegate<Input, CriticResult<Input>> = subgraphWithVerification(
284293
toolSelectionStrategy = ToolSelectionStrategy.Tools(tools.map { it.descriptor }),
285294
llmModel = llmModel,
286295
llmParams = llmParams,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package ai.koog.agents.ext.agent
2+
3+
import ai.koog.agents.core.annotation.InternalAgentsApi
4+
import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate
5+
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
6+
import ai.koog.agents.core.tools.annotations.LLMDescription
7+
import ai.koog.prompt.dsl.prompt
8+
import ai.koog.prompt.executor.clients.openai.OpenAIModels
9+
import ai.koog.prompt.llm.LLModel
10+
import ai.koog.prompt.message.Message
11+
import ai.koog.prompt.structure.StructureFixingParser
12+
import kotlinx.serialization.Serializable
13+
14+
/**
15+
* Represents the result of a plan evaluation performed by an LLM (Large Language Model).
16+
*
17+
* This class is primarily used within internal agent-related implementations where an LLM
18+
* evaluates the correctness of a plan and optionally provides feedback for improvements.
19+
*
20+
* @property isCorrect Indicates whether the evaluated plan is correct.
21+
* @property feedback Optional feedback provided by the LLM about the evaluated plan. This property
22+
* is populated only when the plan is deemed incorrect (`isCorrect == false`) and adjustments
23+
* are suggested.
24+
*/
25+
@InternalAgentsApi
26+
@Serializable
27+
@LLMDescription("Result of the evaluation")
28+
public data class CriticResultFromLLM(
29+
@property:LLMDescription("Was the task solved correctly?")
30+
val isCorrect: Boolean,
31+
@property:LLMDescription(
32+
"Optional feedback about the provided solution. " +
33+
"Only needed if `isCorrect == false` and if solution needs adjustments."
34+
)
35+
val feedback: String
36+
)
37+
38+
/**
39+
* Represents the result of a critique or feedback process.
40+
*
41+
* @property successful Indicates whether the critique operation was successful.
42+
* @property feedback A textual message providing details about the*/
43+
public data class CriticResult<T>(
44+
val successful: Boolean,
45+
val feedback: String,
46+
val input: T
47+
)
48+
49+
/**
50+
* A method to utilize a language model (LLM) as a critic or judge for evaluating tasks with context-aware feedback.
51+
* This method processes a given task and the interaction history to provide structured feedback on the task's correctness.
52+
*
53+
* @param llmModel The optional language model to override the default model during the session. If `null`, the default model will be used.
54+
* @param task The task or instruction to be presented to the language model for critical evaluation.
55+
*/
56+
@OptIn(InternalAgentsApi::class)
57+
public inline fun <reified T> AIAgentSubgraphBuilderBase<*, *>.llmAsAJudge(
58+
llmModel: LLModel? = null,
59+
task: String
60+
): AIAgentNodeDelegate<T, CriticResult<T>> = node<T, CriticResult<T>> { nodeInput ->
61+
llm.writeSession {
62+
val initialPrompt = prompt.copy()
63+
val initialModel = model
64+
65+
prompt = prompt("critic") {
66+
// Combine all history into one message with XML tags
67+
// to prevent LLM from continuing answering in a tool_call -> tool_result pattern
68+
val combinedMessage = buildString {
69+
append("<previous_conversation>\n")
70+
initialPrompt.messages.forEach { message ->
71+
when (message) {
72+
is Message.System -> append("<user>\n${message.content}\n</user>\n")
73+
is Message.User -> append("<user>\n${message.content}\n</user>\n")
74+
is Message.Assistant -> append("<assistant>\n${message.content}\n</assistant>\n")
75+
is Message.Tool.Call -> append(
76+
"<tool_call tool=${message.tool}>\n${message.content}\n</tool_call>\n"
77+
)
78+
79+
is Message.Tool.Result -> append(
80+
"<tool_result tool=${message.tool}>\n${message.content}\n</tool_result>\n"
81+
)
82+
}
83+
}
84+
append("</previous_conversation>\n")
85+
}
86+
87+
// Put Critic Task as a System instruction
88+
system(task)
89+
// And rest of the history -- in a combined XML message
90+
user(combinedMessage)
91+
}
92+
93+
if (llmModel != null) {
94+
model = llmModel
95+
}
96+
97+
val result = requestLLMStructured<CriticResultFromLLM>(
98+
// optional field -- recommented for LLM awareness and reliability of the output
99+
examples = listOf(
100+
CriticResultFromLLM(
101+
isCorrect = true,
102+
feedback = "All good"
103+
),
104+
CriticResultFromLLM(
105+
isCorrect = false,
106+
feedback = "Following parts of the plan have problems: *, *, *. Please consider changing ..."
107+
)
108+
),
109+
// optional field -- recommented for reliability of the format
110+
fixingParser = StructureFixingParser(
111+
fixingModel = OpenAIModels.CostOptimized.GPT4oMini,
112+
retries = 3,
113+
)
114+
).getOrThrow().structure
115+
116+
prompt = initialPrompt
117+
model = initialModel
118+
119+
CriticResult(
120+
successful = result.isCorrect,
121+
feedback = result.feedback,
122+
input = nodeInput
123+
)
124+
}
125+
}

0 commit comments

Comments
 (0)