Skip to content

Commit 59c0a67

Browse files
authored
Update event handlers. Add LLModelDefinitions interface for classes with predefined LLModels (#212)
1 parent bc7fd52 commit 59c0a67

File tree

31 files changed

+255
-103
lines changed

31 files changed

+255
-103
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgent.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public open class AIAgent(
140140
toolRegistry,
141141
agentConfig.prompt,
142142
agentConfig.model,
143-
promptExecutor = PromptExecutorProxy(promptExecutor, pipeline),
143+
promptExecutor = PromptExecutorProxy(promptExecutor, pipeline, sessionUuid!!),
144144
environment = preparedEnvironment,
145145
agentConfig,
146146
clock
@@ -359,7 +359,7 @@ public open class AIAgent(
359359
throw error.asException()
360360
} catch (e: AgentEngineException) {
361361
logger.error(e) { "Execution exception reported by server!" }
362-
pipeline.onAgentRunError(strategyName = strategy.name, throwable = e)
362+
pipeline.onAgentRunError(strategyName = strategy.name, sessionUuid = sessionUuid, throwable = e)
363363
}
364364
}
365365

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentStrategy.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ public class AIAgentStrategy(
2424

2525
override suspend fun execute(context: AIAgentContextBase, input: String): String {
2626
return runCatchingCancellable {
27-
context.pipeline.onStrategyStarted(this)
27+
context.pipeline.onStrategyStarted(this, context)
2828
val result = super.execute(context, input)
29-
context.pipeline.onStrategyFinished(name, result)
29+
context.pipeline.onStrategyFinished(this, context, result)
3030
result
3131
}.onSuccess {
3232
context.environment.sendTermination(it)

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@file:OptIn(ExperimentalUuidApi::class)
2+
13
package ai.koog.agents.core.feature
24

35
import ai.koog.agents.core.agent.AIAgent
@@ -13,11 +15,14 @@ import ai.koog.agents.core.tools.ToolDescriptor
1315
import ai.koog.agents.core.tools.ToolResult
1416
import ai.koog.agents.features.common.config.FeatureConfig
1517
import ai.koog.prompt.dsl.Prompt
18+
import ai.koog.prompt.llm.LLModel
1619
import ai.koog.prompt.message.Message
1720
import io.github.oshai.kotlinlogging.KotlinLogging
1821
import kotlinx.coroutines.Dispatchers
1922
import kotlinx.coroutines.launch
2023
import kotlinx.coroutines.withContext
24+
import kotlin.uuid.ExperimentalUuidApi
25+
import kotlin.uuid.Uuid
2126

2227
/**
2328
* Pipeline for AI agent features that provides interception points for various agent lifecycle events.
@@ -169,8 +174,9 @@ public class AIAgentPipeline {
169174
* @param strategyName The name of the strategy during which the error occurred
170175
* @param throwable The exception that was thrown during agent execution
171176
*/
172-
public suspend fun onAgentRunError(strategyName: String, throwable: Throwable) {
173-
agentHandlers.values.forEach { handler -> handler.agentRunErrorHandler.handle(strategyName, throwable) }
177+
@OptIn(ExperimentalUuidApi::class)
178+
public suspend fun onAgentRunError(strategyName: String, sessionUuid: Uuid?, throwable: Throwable) {
179+
agentHandlers.values.forEach { handler -> handler.agentRunErrorHandler.handle(strategyName, sessionUuid, throwable) }
174180
}
175181

176182
/**
@@ -203,22 +209,29 @@ public class AIAgentPipeline {
203209
* Notifies all registered strategy handlers that a strategy has started execution.
204210
*
205211
* @param strategy The strategy that has started execution
212+
* @param context The context of the strategy execution
206213
*/
207-
public suspend fun onStrategyStarted(strategy: AIAgentStrategy) {
214+
@OptIn(ExperimentalUuidApi::class)
215+
public suspend fun onStrategyStarted(strategy: AIAgentStrategy, context: AIAgentContextBase) {
208216
strategyHandlers.values.forEach { handler ->
209-
val updateContext = StrategyUpdateContext(strategy, handler.feature)
217+
val updateContext = StrategyUpdateContext(strategy, context.sessionUuid, handler.feature)
210218
handler.handleStrategyStartedUnsafe(updateContext)
211219
}
212220
}
213221

214222
/**
215223
* Notifies all registered strategy handlers that a strategy has finished execution.
216224
*
217-
* @param strategyName The name of the strategy that has finished
225+
* @param strategy The strategy that has started execution
226+
* @param context The context of the strategy execution
218227
* @param result The result produced by the strategy execution
219228
*/
220-
public suspend fun onStrategyFinished(strategyName: String, result: String) {
221-
strategyHandlers.values.forEach { handler -> handler.strategyFinishedHandler.handle(strategyName, result) }
229+
@OptIn(ExperimentalUuidApi::class)
230+
public suspend fun onStrategyFinished(strategy: AIAgentStrategy, context: AIAgentContextBase, result: String) {
231+
strategyHandlers.values.forEach { handler ->
232+
val updateContext = StrategyUpdateContext(strategy, context.sessionUuid, handler.feature)
233+
handler.handleStrategyFinishedUnsafe(updateContext, result)
234+
}
222235
}
223236

224237
//endregion Trigger Strategy Handlers
@@ -281,17 +294,17 @@ public class AIAgentPipeline {
281294
*
282295
* @param prompt The prompt that will be sent to the language model
283296
*/
284-
public suspend fun onBeforeLLMCall(prompt: Prompt, tools: List<ToolDescriptor>) {
285-
executeLLMHandlers.values.forEach { handler -> handler.beforeLLMCallHandler.handle(prompt, tools) }
297+
public suspend fun onBeforeLLMCall(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, sessionUuid: Uuid) {
298+
executeLLMHandlers.values.forEach { handler -> handler.beforeLLMCallHandler.handle(prompt, tools, model, sessionUuid) }
286299
}
287300

288301
/**
289302
* Notifies all registered LLM handlers after a language model call has completed.
290303
*
291304
* @param responses A single or multiple response messages received from the language model
292305
*/
293-
public suspend fun onAfterLLMCall(responses: List<Message.Response>) {
294-
executeLLMHandlers.values.forEach { handler -> handler.afterLLMCallHandler.handle(responses) }
306+
public suspend fun onAfterLLMCall(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, responses: List<Message.Response>, sessionUuid: Uuid) {
307+
executeLLMHandlers.values.forEach { handler -> handler.afterLLMCallHandler.handle(prompt, tools, model, responses, sessionUuid) }
295308
}
296309

297310
//endregion Trigger LLM Call Handlers
@@ -470,15 +483,16 @@ public class AIAgentPipeline {
470483
* }
471484
* ```
472485
*/
486+
@OptIn(ExperimentalUuidApi::class)
473487
public fun <TFeature : Any> interceptAgentRunError(
474488
feature: AIAgentFeature<*, TFeature>,
475489
featureImpl: TFeature,
476-
handle: suspend TFeature.(strategyName: String, throwable: Throwable) -> Unit
490+
handle: suspend TFeature.(strategyName: String, sessionUuid: Uuid?, throwable: Throwable) -> Unit
477491
) {
478492
val existingHandler = agentHandlers.getOrPut(feature.key) { AgentHandler(featureImpl) }
479493

480-
existingHandler.agentRunErrorHandler = AgentRunErrorHandler { strategyName, throwable ->
481-
with(featureImpl) { handle(strategyName, throwable) }
494+
existingHandler.agentRunErrorHandler = AgentRunErrorHandler { strategyName, sessionUuid, throwable ->
495+
with(featureImpl) { handle(strategyName, sessionUuid, throwable) }
482496
}
483497
}
484498

@@ -532,12 +546,21 @@ public class AIAgentPipeline {
532546
public fun <TFeature : Any> interceptStrategyFinished(
533547
feature: AIAgentFeature<*, TFeature>,
534548
featureImpl: TFeature,
535-
handle: suspend TFeature.(strategyName: String, result: String) -> Unit
549+
handle: suspend StrategyUpdateContext<TFeature>.(String) -> Unit
536550
) {
537551
val existingHandler = strategyHandlers.getOrPut(feature.key) { StrategyHandler(featureImpl) }
538552

539-
existingHandler.strategyFinishedHandler = StrategyFinishedHandler { strategyName, result ->
540-
with(featureImpl) { handle(strategyName, result) }
553+
@Suppress("UNCHECKED_CAST")
554+
if (existingHandler as? StrategyHandler<TFeature> == null) {
555+
logger.debug {
556+
"Expected to get an agent handler for feature of type <${featureImpl::class}>, but get a handler of type <${feature.key}> instead. " +
557+
"Skipping adding strategy finished interceptor for feature."
558+
}
559+
return
560+
}
561+
562+
existingHandler.strategyFinishedHandler = StrategyFinishedHandler { updateContext, result ->
563+
handle(updateContext, result)
541564
}
542565
}
543566

@@ -609,12 +632,12 @@ public class AIAgentPipeline {
609632
public fun <TFeature : Any> interceptBeforeLLMCall(
610633
feature: AIAgentFeature<*, TFeature>,
611634
featureImpl: TFeature,
612-
handle: suspend TFeature.(prompt: Prompt, tools: List<ToolDescriptor>) -> Unit
635+
handle: suspend TFeature.(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, sessionUuid: Uuid) -> Unit
613636
) {
614637
val existingHandler = executeLLMHandlers.getOrPut(feature.key) { ExecuteLLMHandler() }
615638

616-
existingHandler.beforeLLMCallHandler = BeforeLLMCallHandler { prompt, tools ->
617-
with(featureImpl) { handle(prompt, tools) }
639+
existingHandler.beforeLLMCallHandler = BeforeLLMCallHandler { prompt, tools, model, sessionUuid ->
640+
with(featureImpl) { handle(prompt, tools, model, sessionUuid) }
618641
}
619642
}
620643

@@ -633,12 +656,12 @@ public class AIAgentPipeline {
633656
public fun <TFeature : Any> interceptAfterLLMCall(
634657
feature: AIAgentFeature<*, TFeature>,
635658
featureImpl: TFeature,
636-
handle: suspend TFeature.(responses: List<Message.Response>) -> Unit
659+
handle: suspend TFeature.(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, responses: List<Message.Response>, sessionUuid: Uuid) -> Unit
637660
) {
638661
val existingHandler = executeLLMHandlers.getOrPut(feature.key) { ExecuteLLMHandler() }
639662

640-
existingHandler.afterLLMCallHandler = AfterLLMCallHandler { responses ->
641-
with(featureImpl) { handle(responses) }
663+
existingHandler.afterLLMCallHandler = AfterLLMCallHandler { prompt, tools, model, responses, sessionUuid ->
664+
with(featureImpl) { handle(prompt, tools, model, responses, sessionUuid) }
642665
}
643666
}
644667

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/PromptExecutorProxy.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import ai.koog.prompt.llm.LLModel
77
import ai.koog.prompt.message.Message
88
import io.github.oshai.kotlinlogging.KotlinLogging
99
import kotlinx.coroutines.flow.Flow
10+
import kotlin.uuid.ExperimentalUuidApi
11+
import kotlin.uuid.Uuid
1012

1113
/**
1214
* A wrapper around [ai.koog.prompt.executor.model.PromptExecutor] that allows for adding internal functionality to the executor
@@ -15,9 +17,11 @@ import kotlinx.coroutines.flow.Flow
1517
* @property executor The [ai.koog.prompt.executor.model.PromptExecutor] to wrap.
1618
* @property pipeline The [AIAgentPipeline] associated with the executor.
1719
*/
20+
@OptIn(ExperimentalUuidApi::class)
1821
public class PromptExecutorProxy(
1922
private val executor: PromptExecutor,
20-
private val pipeline: AIAgentPipeline
23+
private val pipeline: AIAgentPipeline,
24+
private val sessionUuid: Uuid,
2125
) : PromptExecutor {
2226

2327
private companion object {
@@ -26,12 +30,12 @@ public class PromptExecutorProxy(
2630

2731
override suspend fun execute(prompt: Prompt, model: LLModel, tools: List<ToolDescriptor>): List<Message.Response> {
2832
logger.debug { "Executing LLM call (prompt: $prompt, tools: [${tools.joinToString { it.name }}])" }
29-
pipeline.onBeforeLLMCall(prompt, tools)
33+
pipeline.onBeforeLLMCall(prompt, tools, model, sessionUuid)
3034

3135
val responses = executor.execute(prompt, model, tools)
3236

3337
logger.debug { "Finished LLM call with responses: [${responses.joinToString { "${it.role}: ${it.content}" } }]" }
34-
pipeline.onAfterLLMCall(responses)
38+
pipeline.onAfterLLMCall(prompt, tools, model, responses, sessionUuid)
3539

3640
return responses
3741
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentHandler.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
@file:OptIn(ExperimentalUuidApi::class)
2+
13
package ai.koog.agents.core.feature.handler
24

35
import ai.koog.agents.core.agent.AIAgent
46
import ai.koog.agents.core.agent.entity.AIAgentStrategy
57
import ai.koog.agents.core.annotation.InternalAgentsApi
68
import ai.koog.agents.core.environment.AIAgentEnvironment
9+
import kotlin.uuid.ExperimentalUuidApi
10+
import kotlin.uuid.Uuid
711

812
/**
913
* Feature implementation for agent and strategy interception.
@@ -28,7 +32,7 @@ public class AgentHandler<FeatureT : Any>(public val feature: FeatureT) {
2832
AgentFinishedHandler { _, _ -> }
2933

3034
public var agentRunErrorHandler: AgentRunErrorHandler =
31-
AgentRunErrorHandler { _, _ -> }
35+
AgentRunErrorHandler { _, _, _ -> }
3236

3337
/**
3438
* Transforms the provided AgentEnvironment using the configured environment transformer.
@@ -87,8 +91,9 @@ public fun interface AgentFinishedHandler {
8791
public suspend fun handle(strategyName: String, result: String?)
8892
}
8993

94+
@OptIn(ExperimentalUuidApi::class)
9095
public fun interface AgentRunErrorHandler {
91-
public suspend fun handle(strategyName: String, throwable: Throwable)
96+
public suspend fun handle(strategyName: String, sessionUuid: Uuid?, throwable: Throwable)
9297
}
9398

9499
public class AgentCreateContext<FeatureT>(
@@ -111,8 +116,10 @@ public class AgentStartContext<TFeature>(
111116
}
112117
}
113118

119+
@OptIn(ExperimentalUuidApi::class)
114120
public class StrategyUpdateContext<FeatureT>(
115121
public val strategy: AIAgentStrategy,
122+
public val sessionUuid: Uuid,
116123
public val feature: FeatureT
117124
) {
118125
public suspend fun readStrategy(block: suspend (AIAgentStrategy) -> Unit) {
Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
1+
@file:OptIn(ExperimentalUuidApi::class)
2+
13
package ai.koog.agents.core.feature.handler
24

35
import ai.koog.agents.core.tools.ToolDescriptor
46
import ai.koog.prompt.dsl.Prompt
7+
import ai.koog.prompt.llm.LLModel
58
import ai.koog.prompt.message.Message
9+
import kotlin.uuid.ExperimentalUuidApi
10+
import kotlin.uuid.Uuid
611

712
public class ExecuteLLMHandler {
813

914
public var beforeLLMCallHandler: BeforeLLMCallHandler =
10-
BeforeLLMCallHandler { prompt, tools -> }
15+
BeforeLLMCallHandler { prompt, tools, model, sessionUuid -> }
1116

1217
public var afterLLMCallHandler: AfterLLMCallHandler =
13-
AfterLLMCallHandler { response -> }
18+
AfterLLMCallHandler { prompt, tools, model, response, sessionUuid -> }
1419
}
1520

1621
public fun interface BeforeLLMCallHandler {
17-
public suspend fun handle(prompt: Prompt, tools: List<ToolDescriptor>)
22+
public suspend fun handle(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, sessionUuid: Uuid)
1823
}
1924

2025
public fun interface AfterLLMCallHandler {
21-
public suspend fun handle(responses: List<Message.Response>)
26+
public suspend fun handle(prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, responses: List<Message.Response>, sessionUuid: Uuid)
2227
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/StrategyHandler.kt

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ public class StrategyHandler<FeatureT : Any>(public val feature: FeatureT) {
99
public var strategyStartedHandler: StrategyStartedHandler<FeatureT> =
1010
StrategyStartedHandler { context -> }
1111

12-
public var strategyFinishedHandler: StrategyFinishedHandler =
13-
StrategyFinishedHandler { strategyName, result -> }
12+
public var strategyFinishedHandler: StrategyFinishedHandler<FeatureT> =
13+
StrategyFinishedHandler { context, result -> }
1414

1515
/**
1616
* Handles strategy starts events by delegating to the handler.
@@ -30,12 +30,31 @@ public class StrategyHandler<FeatureT : Any>(public val feature: FeatureT) {
3030
public suspend fun handleStrategyStartedUnsafe(context: StrategyUpdateContext<*>) {
3131
handleStrategyStarted(context as StrategyUpdateContext<FeatureT>)
3232
}
33+
34+
/**
35+
* Handles strategy finish events by delegating to the handler.
36+
*
37+
* @param context The context for updating the agent with the feature
38+
*/
39+
public suspend fun handleStrategyFinished(context: StrategyUpdateContext<FeatureT>, result: String) {
40+
strategyFinishedHandler.handle(context, result)
41+
}
42+
43+
/**
44+
* Internal API for handling strategy finish events with type casting.
45+
*
46+
* @param context The context for updating the agent
47+
*/
48+
@Suppress("UNCHECKED_CAST")
49+
public suspend fun handleStrategyFinishedUnsafe(context: StrategyUpdateContext<*>, result: String) {
50+
handleStrategyFinished(context as StrategyUpdateContext<FeatureT>, result)
51+
}
3352
}
3453

3554
public fun interface StrategyStartedHandler<FeatureT : Any> {
3655
public suspend fun handle(context: StrategyUpdateContext<FeatureT>)
3756
}
3857

39-
public fun interface StrategyFinishedHandler {
40-
public suspend fun handle(strategyName: String, result: String)
58+
public fun interface StrategyFinishedHandler<FeatureT : Any> {
59+
public suspend fun handle(context: StrategyUpdateContext<FeatureT>, result: String)
4160
}

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
@file:OptIn(ExperimentalUuidApi::class)
2+
13
package ai.koog.agents.core.feature
24

35
import ai.koog.agents.core.agent.context.AIAgentContextBase
4-
import ai.koog.agents.core.tools.ToolDescriptor
6+
import ai.koog.agents.core.agent.entity.AIAgentNodeBase
57
import ai.koog.agents.core.agent.entity.AIAgentStorageKey
68
import ai.koog.agents.core.agent.entity.createStorageKey
9+
import ai.koog.agents.core.tools.ToolDescriptor
710
import ai.koog.agents.features.common.config.FeatureConfig
8-
import ai.koog.agents.core.agent.entity.AIAgentNodeBase
911
import ai.koog.prompt.dsl.Prompt
12+
import ai.koog.prompt.llm.LLModel
1013
import ai.koog.prompt.message.Message
14+
import kotlin.uuid.ExperimentalUuidApi
15+
import kotlin.uuid.Uuid
1116

1217
class TestFeature(val events: MutableList<String>) {
1318

@@ -40,11 +45,11 @@ class TestFeature(val events: MutableList<String>) {
4045
TestFeature(mutableListOf())
4146
}
4247

43-
pipeline.interceptBeforeLLMCall(this, feature) { prompt: Prompt, tools: List<ToolDescriptor> ->
48+
pipeline.interceptBeforeLLMCall(this, feature) { prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, sessionUuid: Uuid ->
4449
feature.events += "LLM: start LLM call (prompt: ${prompt.messages.firstOrNull { it.role == Message.Role.User }?.content}, tools: [${tools.joinToString { it.name }}])"
4550
}
4651

47-
pipeline.interceptAfterLLMCall(this, feature) { responses: List<Message.Response> ->
52+
pipeline.interceptAfterLLMCall(this, feature) { prompt: Prompt, tools: List<ToolDescriptor>, model: LLModel, responses: List<Message.Response>, sessionUuid: Uuid ->
4853
feature.events += "LLM: finish LLM call (responses: [${responses.joinToString(", ") { "${it.role.name}: ${it.content}" }}])"
4954
}
5055

0 commit comments

Comments
 (0)