Skip to content

Commit 9072fea

Browse files
committed
KG-178. Drop runId from agent execution path.
- runId is not a part of the path of an agent graph execution. It is a detail of the agent execution flow and can be added to agent events separately.
1 parent d6560ef commit 9072fea

File tree

2 files changed

+58
-83
lines changed

2 files changed

+58
-83
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/StatefulSingleUseAIAgent.kt

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import ai.koog.agents.core.agent.AIAgent.Companion.State
44
import ai.koog.agents.core.agent.AIAgent.Companion.State.NotStarted
55
import ai.koog.agents.core.agent.context.AIAgentContext
66
import ai.koog.agents.core.agent.context.element.AgentRunInfoContextElement
7-
import ai.koog.agents.core.agent.context.with
87
import ai.koog.agents.core.agent.entity.AIAgentStrategy
98
import ai.koog.agents.core.annotation.InternalAgentsApi
109
import ai.koog.agents.core.feature.AIAgentFeature
@@ -99,38 +98,36 @@ public abstract class StatefulSingleUseAIAgent<Input, Output, TContext : AIAgent
9998
)
10099
) {
101100
context.withPreparedPipeline {
102-
context.with(partName = runId) { executionInfo ->
103-
agentStateMutex.withLock {
104-
@OptIn(InternalAgentsApi::class)
105-
state = State.Running(context.parentContext ?: context)
106-
}
101+
agentStateMutex.withLock {
102+
@OptIn(InternalAgentsApi::class)
103+
state = State.Running(context.parentContext ?: context)
104+
}
107105

108-
logger.debug { formatLog(id, runId, "Starting agent execution") }
109-
pipeline.onAgentStarting<Input, Output>(executionInfo, runId, this@StatefulSingleUseAIAgent, context)
110-
111-
val result = try {
112-
@Suppress("UNCHECKED_CAST")
113-
strategy.execute(context = context, input = agentInput)
114-
} catch (e: Throwable) {
115-
logger.error(e) { "Execution exception reported by server!" }
116-
pipeline.onAgentExecutionFailed(executionInfo, id, runId, e)
117-
agentStateMutex.withLock { state = State.Failed(e) }
118-
throw e
119-
}
106+
logger.debug { formatLog(id, runId, "Starting agent execution") }
107+
pipeline.onAgentStarting<Input, Output>(context.executionInfo, runId, this@StatefulSingleUseAIAgent, context)
108+
109+
val result = try {
110+
@Suppress("UNCHECKED_CAST")
111+
strategy.execute(context = context, input = agentInput)
112+
} catch (e: Throwable) {
113+
logger.error(e) { "Execution exception reported by server!" }
114+
pipeline.onAgentExecutionFailed(context.executionInfo, id, runId, e)
115+
agentStateMutex.withLock { state = State.Failed(e) }
116+
throw e
117+
}
120118

121-
logger.debug { formatLog(id, runId, "Finished agent execution") }
122-
pipeline.onAgentCompleted(executionInfo, id, runId, result)
119+
logger.debug { formatLog(id, runId, "Finished agent execution") }
120+
pipeline.onAgentCompleted(context.executionInfo, id, runId, result)
123121

124-
agentStateMutex.withLock {
125-
state = if (result != null) {
126-
State.Finished(result)
127-
} else {
128-
State.Failed(Exception("result is null"))
129-
}
122+
agentStateMutex.withLock {
123+
state = if (result != null) {
124+
State.Finished(result)
125+
} else {
126+
State.Failed(Exception("result is null"))
130127
}
131-
132-
result ?: error("result is null")
133128
}
129+
130+
result ?: error("result is null")
134131
}
135132
}
136133
}

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/AIAgentPipelineTest.kt

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ class AIAgentPipelineTest {
9797
collectedEvent.startsWith(NodeExecutionCompleted::class.simpleName.toString())
9898
}
9999

100-
val runId = interceptedRunIds.first()
101-
102100
val expectedEvents = listOf(
103-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
104-
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
105-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, dummyNodeName)}, name: $dummyNodeName, input: kotlin.Unit)",
106-
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, dummyNodeName)}, name: $dummyNodeName, input: kotlin.Unit, output: kotlin.Unit)",
107-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, FINISH_NODE_PREFIX)}, name: $FINISH_NODE_PREFIX, input: $agentResult)",
108-
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, FINISH_NODE_PREFIX)}, name: $FINISH_NODE_PREFIX, input: $agentResult, output: $agentResult)",
101+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
102+
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
103+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, dummyNodeName)}, name: $dummyNodeName, input: kotlin.Unit)",
104+
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, dummyNodeName)}, name: $dummyNodeName, input: kotlin.Unit, output: kotlin.Unit)",
105+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, FINISH_NODE_PREFIX)}, name: $FINISH_NODE_PREFIX, input: $agentResult)",
106+
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, FINISH_NODE_PREFIX)}, name: $FINISH_NODE_PREFIX, input: $agentResult, output: $agentResult)",
109107
)
110108

111109
assertEquals(
@@ -155,13 +153,11 @@ class AIAgentPipelineTest {
155153
collectedEvent.startsWith(NodeExecutionCompleted::class.simpleName.toString())
156154
}
157155

158-
val runId = interceptedRunIds.first()
159-
160156
val expectedEvents = listOf(
161-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
162-
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
163-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeName)}, name: $nodeName, input: $agentInput)",
164-
"${NodeExecutionFailed::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeName)}, name: $nodeName, error: $testErrorMessage)",
157+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
158+
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
159+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeName)}, name: $nodeName, input: $agentInput)",
160+
"${NodeExecutionFailed::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeName)}, name: $nodeName, error: $testErrorMessage)",
165161
)
166162

167163
assertEquals(
@@ -211,12 +207,10 @@ class AIAgentPipelineTest {
211207
collectedEvent.startsWith(NodeExecutionCompleted::class.simpleName.toString())
212208
}
213209

214-
val runId = interceptedRunIds.first()
215-
216210
val expectedEvents = listOf(
217-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
218-
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
219-
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeWithErrorName)}, name: $nodeWithErrorName, input: $agentInput)",
211+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput)",
212+
"${NodeExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, START_NODE_PREFIX)}, name: $START_NODE_PREFIX, input: $agentInput, output: $agentInput)",
213+
"${NodeExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeWithErrorName)}, name: $nodeWithErrorName, input: $agentInput)",
220214
)
221215

222216
assertEquals(
@@ -265,11 +259,9 @@ class AIAgentPipelineTest {
265259
collectedEvent.startsWith(SubgraphExecutionFailed::class.simpleName.toString())
266260
}
267261

268-
val runId = interceptedRunIds.first()
269-
270262
val expectedEvents = listOf(
271-
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput)",
272-
"${SubgraphExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput, output: $subgraphOutput)",
263+
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput)",
264+
"${SubgraphExecutionCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput, output: $subgraphOutput)",
273265
)
274266

275267
assertEquals(
@@ -327,11 +319,9 @@ class AIAgentPipelineTest {
327319
collectedEvent.startsWith(SubgraphExecutionFailed::class.simpleName.toString())
328320
}
329321

330-
val runId = interceptedRunIds.first()
331-
332322
val expectedEvents = listOf(
333-
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput)",
334-
"${SubgraphExecutionFailed::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, subgraphName)}, name: $subgraphName, error: $testErrorMessage)",
323+
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, subgraphName)}, name: $subgraphName, input: $agentInput)",
324+
"${SubgraphExecutionFailed::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, subgraphName)}, name: $subgraphName, error: $testErrorMessage)",
335325
)
336326

337327
assertEquals(
@@ -386,10 +376,8 @@ class AIAgentPipelineTest {
386376
collectedEvent.startsWith(SubgraphExecutionFailed::class.simpleName.toString())
387377
}
388378

389-
val runId = interceptedRunIds.first()
390-
391379
val expectedEvents = listOf(
392-
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, subgraphWithErrorName)}, name: $subgraphWithErrorName, input: $agentInput)",
380+
"${SubgraphExecutionStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, subgraphWithErrorName)}, name: $subgraphWithErrorName, input: $agentInput)",
393381
)
394382

395383
assertEquals(
@@ -443,13 +431,11 @@ class AIAgentPipelineTest {
443431
collectedEvent.startsWith(LLMCallCompleted::class.simpleName.toString())
444432
}
445433

446-
val runId = interceptedRunIds.first()
447-
448434
val expectedEvents = listOf(
449-
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCallWithoutToolsName)}, prompt: $testLLMResponse, tools: [])",
450-
"${LLMCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCallWithoutToolsName)}, responses: [${Role.Assistant.name}: $DEFAULT_ASSISTANT_RESPONSE])",
451-
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCall)}, prompt: $llmCallWithToolsResponse, tools: [${DummyTool().name}])",
452-
"${LLMCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCall)}, responses: [${Role.Assistant.name}: $DEFAULT_ASSISTANT_RESPONSE])",
435+
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCallWithoutToolsName)}, prompt: $testLLMResponse, tools: [])",
436+
"${LLMCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCallWithoutToolsName)}, responses: [${Role.Assistant.name}: $DEFAULT_ASSISTANT_RESPONSE])",
437+
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCall)}, prompt: $llmCallWithToolsResponse, tools: [${DummyTool().name}])",
438+
"${LLMCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCall)}, responses: [${Role.Assistant.name}: $DEFAULT_ASSISTANT_RESPONSE])",
453439
)
454440

455441
assertEquals(
@@ -506,13 +492,11 @@ class AIAgentPipelineTest {
506492
collectedEvent.startsWith(ToolCallCompleted::class.simpleName.toString())
507493
}
508494

509-
val runId = interceptedRunIds.first()
510-
511495
val expectedEvents = listOf(
512-
"${ToolCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, args: ${
496+
"${ToolCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, args: ${
513497
CalculatorTools.PlusTool.encodeArgs(CalculatorTools.CalculatorTool.Args(2.2F, 2.2F))
514498
})",
515-
"${ToolCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, result: ${
499+
"${ToolCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, result: ${
516500
CalculatorTools.PlusTool.encodeResult(CalculatorTools.CalculatorTool.Result(4.4F))
517501
})"
518502
)
@@ -561,8 +545,8 @@ class AIAgentPipelineTest {
561545
val runId = interceptedRunIds.first()
562546

563547
val expectedEvents = listOf(
564-
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId)}, id: $agentId, run id: $runId)",
565-
"${AgentCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId)}, id: $agentId, run id: $runId, result: $agentOutput)",
548+
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agentId)}, id: $agentId, run id: $runId)",
549+
"${AgentCompleted::class.simpleName} (path: ${agentExecutionPath(agentId)}, id: $agentId, run id: $runId, result: $agentOutput)",
566550
"${AgentClosing::class.simpleName} (path: ${agentExecutionPath(agentId)}, id: $agentId)"
567551
)
568552

@@ -603,10 +587,8 @@ class AIAgentPipelineTest {
603587
collectedEvent.startsWith(StrategyStarting::class.simpleName.toString())
604588
}
605589

606-
val runId = interceptedRunIds.first()
607-
608590
val expectedEvents = listOf(
609-
"${StrategyStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName)}, strategy: $strategyName)",
591+
"${StrategyStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName)}, strategy: $strategyName)",
610592
)
611593

612594
assertEquals(
@@ -667,8 +649,8 @@ class AIAgentPipelineTest {
667649
val runId2 = runIds[1]
668650

669651
val expectedEvents = listOf(
670-
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agent1Id, runId1)}, id: $agent1Id, run id: $runId1)",
671-
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agent2Id, runId2)}, id: $agent2Id, run id: $runId2)",
652+
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agent1Id)}, id: $agent1Id, run id: $runId1)",
653+
"${AgentStarting::class.simpleName} (path: ${agentExecutionPath(agent2Id)}, id: $agent2Id, run id: $runId2)",
672654
)
673655

674656
assertEquals(
@@ -718,11 +700,9 @@ class AIAgentPipelineTest {
718700
collectedEvent.startsWith(LLMCallStarting::class.simpleName.toString())
719701
}
720702

721-
val runId = interceptedRunIds.first()
722-
723703
val expectedEvents = listOf(
724-
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCallWithoutToolsName)}, prompt: $testLLMResponse, tools: [])",
725-
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeLLMCallName)}, prompt: $llmCallWithToolsResponse, tools: [${DummyTool().name}])",
704+
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCallWithoutToolsName)}, prompt: $testLLMResponse, tools: [])",
705+
"${LLMCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeLLMCallName)}, prompt: $llmCallWithToolsResponse, tools: [${DummyTool().name}])",
726706
)
727707

728708
assertEquals(
@@ -780,13 +760,11 @@ class AIAgentPipelineTest {
780760
collectedEvent.startsWith(ToolCallFailed::class.simpleName.toString())
781761
}
782762

783-
val runId = interceptedRunIds.first()
784-
785763
val expectedEvents = listOf(
786-
"${ToolCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, args: ${
764+
"${ToolCallStarting::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, args: ${
787765
CalculatorTools.PlusTool.encodeArgs(CalculatorTools.CalculatorTool.Args(2.2F, 2.2F))
788766
})",
789-
"${ToolCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, runId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, result: ${
767+
"${ToolCallCompleted::class.simpleName} (path: ${agentExecutionPath(agentId, strategyName, nodeToolCallName)}, tool: ${CalculatorTools.PlusTool.name}, result: ${
790768
CalculatorTools.PlusTool.encodeResult(CalculatorTools.CalculatorTool.Result(4.4F))
791769
})"
792770
)

0 commit comments

Comments
 (0)