Skip to content

Commit 4f43c51

Browse files
Rizzenvova-jb
authored andcommitted
Removed requrement for unique node names for Peristency feature (#1288)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context Aim of the PR is to remove requirement (and check) for unique graph node names and migrate to node path usage. ## Breaking Changes <!-- Will users need to update their code or configurations? --> Yes --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent 3e7df1f commit 4f43c51

File tree

32 files changed

+296
-192
lines changed

32 files changed

+296
-192
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AgentContextData.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import kotlinx.serialization.json.JsonElement
99
@InternalAgentsApi
1010
public class AgentContextData(
1111
internal val messageHistory: List<Message>,
12-
internal val nodeId: String,
12+
internal val nodePath: String,
1313
internal val lastInput: JsonElement,
1414
internal val rollbackStrategy: RollbackStrategy,
1515
internal val additionalRollbackActions: suspend (AIAgentContext) -> Unit = {}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentGraphStrategy.kt

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import ai.koog.agents.core.agent.context.RollbackStrategy
77
import ai.koog.agents.core.agent.context.getAgentContextData
88
import ai.koog.agents.core.agent.context.removeAgentContextData
99
import ai.koog.agents.core.agent.context.with
10+
import ai.koog.agents.core.agent.execution.DEFAULT_AGENT_PATH_SEPARATOR
1011
import ai.koog.agents.core.annotation.InternalAgentsApi
1112
import ai.koog.agents.core.utils.runCatchingCancellable
1213
import io.github.oshai.kotlinlogging.KotlinLogging
@@ -97,13 +98,13 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
9798

9899
@OptIn(InternalAgentsApi::class)
99100
private suspend fun restoreDefault(agentContext: AIAgentContext, data: AgentContextData) {
100-
val nodeId = data.nodeId
101+
val nodePath = data.nodePath
101102

102103
// Perform additional cleanup (ex: rollback tools):
103104
data.additionalRollbackActions(agentContext)
104105

105106
// Set current graph node:
106-
setExecutionPoint(nodeId, data.lastInput)
107+
setExecutionPoint(nodePath, data.lastInput)
107108

108109
// Reset the message history:
109110
agentContext.llm.withPrompt {
@@ -114,17 +115,15 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
114115
/**
115116
* Finds and sets the node for the strategy based on the provided context.
116117
*/
117-
public fun setExecutionPoint(nodeId: String, input: JsonElement) {
118-
val fullPath = metadata.nodesMap.keys.firstOrNull {
119-
val segments = it.split(":")
120-
segments.last() == nodeId
121-
} ?: throw IllegalArgumentException("Node $nodeId not found")
118+
public fun setExecutionPoint(nodePath: String, input: JsonElement) {
119+
// we drop first because it's agent's id, we don't need it here
120+
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)
122121

123-
val segments = fullPath.split(":")
124122
if (segments.isEmpty()) {
125-
throw IllegalArgumentException("Invalid node path: $fullPath")
123+
throw IllegalArgumentException("Invalid node path: $nodePath")
126124
}
127125

126+
val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
128127
val strategyName = segments.firstOrNull() ?: return
129128

130129
// getting the very first segment (it should be a root strategy node)
@@ -134,25 +133,25 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
134133
// restoring the current node for each subgraph including strategy
135134
val segmentsInbetween = segments.drop(1).dropLast(1)
136135
for (segment in segmentsInbetween) {
137-
currentNode as? ExecutionPointNode
138-
?: throw IllegalStateException("Node ${currentNode?.name} does not have subnodes")
136+
val currNode = currentNode as? ExecutionPointNode
137+
?: throw IllegalStateException("Restore for path $nodePath failed: one of middle segments is not a subgraph")
139138

140-
currentPath = "$currentPath:$segment"
139+
currentPath = "$currentPath${DEFAULT_AGENT_PATH_SEPARATOR}$segment"
141140
val nextNode = metadata.nodesMap[currentPath]
142141
if (nextNode is ExecutionPointNode) {
143-
currentNode.enforceExecutionPoint(nextNode, input)
142+
currNode.enforceExecutionPoint(nextNode, input)
144143
currentNode = nextNode
145144
}
146145
}
147146

148147
// forcing the very last segment to the latest pre-leaf node to complete the chain
149-
val leaf = metadata.nodesMap[fullPath] ?: throw IllegalStateException("Node ${segments.last()} not found")
148+
val leaf = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")
150149
val inputType = leaf.inputType
151150

152151
val actualInput = serializer.decodeFromJsonElement(serializer.serializersModule.serializer(inputType), input)
153152
leaf.let {
154153
currentNode as? ExecutionPointNode
155-
?: throw IllegalStateException("Node ${currentNode?.name} does not have subnodes")
154+
?: throw IllegalStateException("Node ${currentNode?.name} is not a valid leaf node")
156155
currentNode.enforceExecutionPoint(it, actualInput)
157156
}
158157
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/execution/AgentExecutionInfo.kt

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,6 @@ public data class AgentExecutionInfo(
1616
public val parent: AgentExecutionInfo?,
1717
public val partName: String
1818
) {
19-
20-
/**
21-
* A companion object for the `AgentExecutionInfo` class that provides utility constants.
22-
*/
23-
public companion object {
24-
25-
/**
26-
* The default separator used to join parts of a path.
27-
*/
28-
public val defaultPathSeparator: CharSequence = "/"
29-
}
30-
3119
/**
3220
* Constructs a path string representing the sequence of `partName` values from the current
3321
* `AgentExecutionInfo` instance to the top-most parent, joined by the specified separator.
@@ -36,7 +24,7 @@ public data class AgentExecutionInfo(
3624
* @return A string representing the path constructed from `partName` values.
3725
*/
3826
public fun path(separator: String? = null): String {
39-
val separator = separator ?: defaultPathSeparator
27+
val separator = separator ?: DEFAULT_AGENT_PATH_SEPARATOR
4028

4129
return buildList {
4230
var current: AgentExecutionInfo? = this@AgentExecutionInfo
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package ai.koog.agents.core.agent.execution
2+
3+
/**
4+
* The default separator used to join parts of an agent's execution path.
5+
*/
6+
public const val DEFAULT_AGENT_PATH_SEPARATOR: String = "/"
7+
8+
/**
9+
* Joins the given parts into a single path string using the specified separator.
10+
*
11+
* @param parts The parts to join into a path.
12+
* @param separator The separator to use between parts. Defaults to [DEFAULT_AGENT_PATH_SEPARATOR].
13+
* @return A string representing the joined path.
14+
*/
15+
public fun path(vararg parts: String, separator: String = DEFAULT_AGENT_PATH_SEPARATOR): String {
16+
return parts.joinToString(separator)
17+
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentSubgraphBuilder.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import ai.koog.agents.core.agent.entity.FinishNode
1212
import ai.koog.agents.core.agent.entity.StartNode
1313
import ai.koog.agents.core.agent.entity.SubgraphMetadata
1414
import ai.koog.agents.core.agent.entity.ToolSelectionStrategy
15+
import ai.koog.agents.core.agent.execution.DEFAULT_AGENT_PATH_SEPARATOR
1516
import ai.koog.agents.core.annotation.InternalAgentsApi
1617
import ai.koog.agents.core.tools.Tool
1718
import ai.koog.prompt.llm.LLModel
@@ -229,7 +230,7 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
229230
}
230231

231232
private fun getNodePath(node: AIAgentNodeBase<*, *>, parentPath: String): String {
232-
return "$parentPath:${node.id}"
233+
return "$parentPath${DEFAULT_AGENT_PATH_SEPARATOR}${node.id}"
233234
}
234235

235236
internal fun buildSubgraphMetadata(
@@ -248,12 +249,11 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
248249
}
249250

250251
// Validate that all nodes have unique names within the subgraph
251-
val names = subgraphNodes.keys.map { it.split(":").last() }
252-
val uniqueNames = names.toSet().size == names.size
252+
val names = subgraphNodes.keys
253253

254254
return SubgraphMetadata(
255255
nodesMap = subgraphNodes,
256-
uniqueNames = uniqueNames
256+
uniqueNames = names.toSet().size == names.size
257257
)
258258
}
259259

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/execution/AgentExecutionInfoTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class AgentExecutionInfoTest {
4848

4949
@Test
5050
fun testDefaultPathSeparator() {
51-
assertEquals("/", AgentExecutionInfo.defaultPathSeparator)
51+
assertEquals("/", DEFAULT_AGENT_PATH_SEPARATOR)
5252
}
5353

5454
@Test

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/AIAgentPipelineTest.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import ai.koog.agents.core.agent.config.AIAgentConfig
88
import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy
99
import ai.koog.agents.core.agent.entity.AIAgentSubgraph.Companion.FINISH_NODE_PREFIX
1010
import ai.koog.agents.core.agent.entity.AIAgentSubgraph.Companion.START_NODE_PREFIX
11+
import ai.koog.agents.core.agent.execution.DEFAULT_AGENT_PATH_SEPARATOR
1112
import ai.koog.agents.core.dsl.builder.forwardTo
1213
import ai.koog.agents.core.dsl.builder.strategy
1314
import ai.koog.agents.core.dsl.extension.nodeDoNothing
@@ -32,7 +33,6 @@ import ai.koog.agents.core.feature.handler.AgentLifecycleEventType.ToolCallFaile
3233
import ai.koog.agents.core.feature.handler.AgentLifecycleEventType.ToolCallStarting
3334
import ai.koog.agents.core.feature.handler.AgentLifecycleEventType.ToolValidationFailed
3435
import ai.koog.agents.core.tools.ToolRegistry
35-
import ai.koog.agents.testing.agent.agentExecutionPath
3636
import ai.koog.agents.testing.tools.DummyTool
3737
import ai.koog.agents.testing.tools.getMockExecutor
3838
import ai.koog.prompt.dsl.prompt
@@ -817,5 +817,7 @@ class AIAgentPipelineTest {
817817
)
818818
}
819819

820+
private fun agentExecutionPath(vararg parts: String) = parts.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
821+
820822
//endregion Private Methods
821823
}

agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ import kotlin.uuid.Uuid
2121
*
2222
* @property checkpointId The unique identifier of the checkpoint. This allows tracking and restoring the agent's session to a specific state.
2323
* @property messageHistory A list of messages exchanged in the session up to the checkpoint. Messages include interactions between the user, system, assistant, and tools.
24-
* @property nodeId The identifier of the node where the checkpoint was created.
25-
* @property lastInput Serialized input received for node with [nodeId]
24+
* @property nodePath The identifier of the node where the checkpoint was created.
25+
* @property lastInput Serialized input received for node with [nodePath]
2626
* @property properties Additional data associated with the checkpoint. This can be used to store additional information about the agent's state.
27+
* @property createdAt The timestamp when the checkpoint was created.
28+
* @property version The version of the checkpoint data structure
2729
*/
2830
@Serializable
2931
public data class AgentCheckpointData(
3032
val checkpointId: String,
3133
val createdAt: Instant,
32-
val nodeId: String,
34+
val nodePath: String,
3335
val lastInput: JsonElement,
3436
val messageHistory: List<Message>,
3537
val version: Long,
@@ -48,7 +50,7 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
4850
return AgentCheckpointData(
4951
checkpointId = Uuid.random().toString(),
5052
createdAt = time,
51-
nodeId = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME,
53+
nodePath = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME,
5254
lastInput = JsonNull,
5355
messageHistory = emptyList(),
5456
properties = mapOf(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)),
@@ -67,11 +69,12 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
6769
*/
6870
public fun AgentCheckpointData.toAgentContextData(
6971
rollbackStrategy: RollbackStrategy,
72+
agentId: String,
7073
additionalRollbackAction: suspend (AIAgentContext) -> Unit = {}
7174
): AgentContextData {
7275
return AgentContextData(
7376
messageHistory = messageHistory,
74-
nodeId = nodeId,
77+
nodePath = nodePath,
7578
lastInput = lastInput,
7679
rollbackStrategy,
7780
additionalRollbackAction

agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt renamed to agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ai.koog.agents.core.agent.context.store
1010
import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy
1111
import ai.koog.agents.core.agent.entity.AIAgentStorageKey
1212
import ai.koog.agents.core.agent.entity.AIAgentSubgraph
13+
import ai.koog.agents.core.agent.execution.DEFAULT_AGENT_PATH_SEPARATOR
1314
import ai.koog.agents.core.agent.featureOrThrow
1415
import ai.koog.agents.core.annotation.InternalAgentsApi
1516
import ai.koog.agents.core.feature.AIAgentGraphFeature
@@ -80,20 +81,6 @@ public class Persistence(
8081
*/
8182
public var rollbackToolRegistry: RollbackToolRegistry = RollbackToolRegistry {}
8283

83-
/**
84-
* Represents the identifier of the current node being executed within the agent pipeline.
85-
*
86-
* This property is used to track the state of the agent's execution and is updated whenever
87-
* the agent begins processing a new node.
88-
* It plays a crucial role in maintaining the agent's
89-
* state across checkpoints and ensuring accurate state restoration during rollbacks.
90-
*
91-
* The value is nullable, indicating that there might be no current node under execution
92-
* (e.g., when the pipeline is idle or has not started).
93-
*/
94-
public var currentNodeId: String? = null
95-
private set
96-
9784
/**
9885
* Companion object implementing agent feature, handling [Persistence] creation and installation.
9986
*/
@@ -122,7 +109,7 @@ public class Persistence(
122109
val checkpoint = persistence.rollbackToLatestCheckpoint(ctx.context)
123110

124111
if (checkpoint != null) {
125-
logger.info { "Restoring checkpoint: ${checkpoint.checkpointId} to node ${checkpoint.nodeId}" }
112+
logger.info { "Restoring checkpoint: ${checkpoint.checkpointId} to node ${checkpoint.nodePath}" }
126113
} else {
127114
logger.info { "No non-tombstone checkpoint found, starting from the beginning" }
128115
}
@@ -137,18 +124,14 @@ public class Persistence(
137124
val parent = persistence.getLatestCheckpoint(eventCtx.context.agentId)
138125
persistence.createCheckpoint(
139126
agentContext = eventCtx.context,
140-
nodeId = eventCtx.node.id,
127+
nodePath = eventCtx.context.executionInfo.path(),
141128
lastInput = eventCtx.input,
142129
lastInputType = eventCtx.inputType,
143130
version = parent?.version?.plus(1) ?: 0L,
144131
)
145132
}
146133
}
147134

148-
pipeline.interceptNodeExecutionStarting(this) { eventCtx ->
149-
persistence.currentNodeId = eventCtx.node.id
150-
}
151-
152135
pipeline.interceptStrategyCompleted(this) { ctx ->
153136
if (config.enableAutomaticPersistence && config.rollbackStrategy == RollbackStrategy.Default) {
154137
val parent = persistence.getLatestCheckpoint(ctx.context.agentId)
@@ -182,7 +165,7 @@ public class Persistence(
182165
*/
183166
public suspend fun createCheckpoint(
184167
agentContext: AIAgentContext,
185-
nodeId: String,
168+
nodePath: String,
186169
lastInput: Any?,
187170
lastInputType: KType,
188171
version: Long,
@@ -192,7 +175,7 @@ public class Persistence(
192175

193176
if (inputJson == null) {
194177
logger.warn {
195-
"Failed to serialize input of type $lastInputType for checkpoint creation for $nodeId, skipping..."
178+
"Failed to serialize input of type $lastInputType for checkpoint creation for $nodePath, skipping..."
196179
}
197180
return null
198181
}
@@ -201,7 +184,7 @@ public class Persistence(
201184
return@readSession AgentCheckpointData(
202185
checkpointId = checkpointId ?: Uuid.random().toString(),
203186
messageHistory = prompt.messages,
204-
nodeId = nodeId,
187+
nodePath = agentContext.executionInfo.path(),
205188
lastInput = inputJson,
206189
createdAt = Clock.System.now(),
207190
version = version,
@@ -263,17 +246,17 @@ public class Persistence(
263246
* with the specified message history and input data.
264247
*
265248
* @param agentContext The context of the agent to modify
266-
* @param nodeId The ID of the node to set as the current execution point
249+
* @param nodePath The path to the node inside the agent's graph
267250
* @param messageHistory The message history to set for the agent
268251
* @param input The input data to set for the agent
269252
*/
270253
public fun setExecutionPoint(
271254
agentContext: AIAgentContext,
272-
nodeId: String,
255+
nodePath: String,
273256
messageHistory: List<Message>,
274257
input: JsonElement
275258
) {
276-
agentContext.store(AgentContextData(messageHistory, nodeId, input, rollbackStrategy))
259+
agentContext.store(AgentContextData(messageHistory, agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath, input, rollbackStrategy))
277260
}
278261

279262
/**
@@ -297,7 +280,7 @@ public class Persistence(
297280
val checkpoint: AgentCheckpointData? = getCheckpointById(agentContext.agentId, checkpointId)
298281
if (checkpoint != null) {
299282
agentContext.store(
300-
checkpoint.toAgentContextData(rollbackStrategy) { context ->
283+
checkpoint.toAgentContextData(rollbackStrategy, agentContext.agentId) { context ->
301284
messageHistoryDiff(
302285
currentMessages = context.llm.prompt.messages,
303286
checkpointMessages = checkpoint.messageHistory
@@ -356,7 +339,7 @@ public class Persistence(
356339
return null
357340
}
358341

359-
agentContext.store(checkpoint.toAgentContextData(rollbackStrategy))
342+
agentContext.store(checkpoint.toAgentContextData(rollbackStrategy, agentContext.agentId))
360343
return checkpoint
361344
}
362345
}

agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt renamed to agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistenceFeatureConfig.kt

File renamed without changes.

0 commit comments

Comments
 (0)