Skip to content

Commit f06b44e

Browse files
Ololoshechkinkarloti
authored andcommitted
Make sure last successful node is not executed twice when restoring from Persistence checkpoints (JetBrains#1308)
Make sure last successful node is not executed twice when restoring from Persistence checkpoints ## Motivation and Context Previously, if latest node was a successful tool call or any other operation with a side-effect, it was re-executed again although we could've reused the latest successful result. The reason behind that was becaused internally input was required to restore the full subgraph structure. In this PR, `lastInput` field of the checkpoints was changed to `lastOutput`, and `resolveEdge()` method was used in order to restore the NEXT node with the correct input instead of the current node. ## Breaking Changes In checkpoints structure, `lastInput` field was changed to `lastOutput` --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [x] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent 561a2e6 commit f06b44e

File tree

19 files changed

+811
-77
lines changed

19 files changed

+811
-77
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AgentContextData.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,17 @@ import kotlinx.serialization.json.JsonElement
1010
public class AgentContextData(
1111
internal val messageHistory: List<Message>,
1212
internal val nodePath: String,
13-
internal val lastInput: JsonElement,
13+
@Deprecated("Use lastOutput instead, lastOutput will be removed in future versions")
14+
internal val lastInput: JsonElement? = null,
15+
internal val lastOutput: JsonElement? = null,
1416
internal val rollbackStrategy: RollbackStrategy,
1517
internal val additionalRollbackActions: suspend (AIAgentContext) -> Unit = {}
16-
)
18+
) {
19+
init {
20+
require(lastInput == null || lastOutput == null) { "`lastInput` and `lastOutput` cannot be both set" }
21+
require(lastInput != null || lastOutput != null) { "`lastInput` (until 0.6.0) or `lastOutput` (since 0.6.1) must be set" }
22+
}
23+
}
1724

1825
public enum class RollbackStrategy {
1926
/**

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentGraphStrategy.kt

Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
7878

7979
@OptIn(InternalAgentsApi::class)
8080
private suspend fun restoreStateIfNeeded(
81-
agentContext: AIAgentContext
81+
agentContext: AIAgentGraphContextBase
8282
) {
8383
val additionalContextData: AgentContextData = agentContext.getAgentContextData() ?: return
8484

@@ -97,44 +97,44 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
9797
}
9898

9999
@OptIn(InternalAgentsApi::class)
100-
private suspend fun restoreDefault(agentContext: AIAgentContext, data: AgentContextData) {
100+
private suspend fun restoreDefault(agentContext: AIAgentGraphContextBase, data: AgentContextData) {
101101
val nodePath = data.nodePath
102102

103103
// Perform additional cleanup (ex: rollback tools):
104104
data.additionalRollbackActions(agentContext)
105105

106106
// Set current graph node:
107-
setExecutionPoint(nodePath, data.lastInput)
107+
@Suppress("DEPRECATION")
108+
when {
109+
data.lastInput != null -> setExecutionPoint(nodePath, data.lastInput)
110+
data.lastOutput != null -> setExecutionPointAfterNode(nodePath, data.lastOutput, agentContext)
111+
112+
// Unexpected state, either input (before 0.6.1) or output (since 0.6.1) should be saved in checkpiints:
113+
else -> {}
114+
}
108115

109116
// Reset the message history:
110117
agentContext.llm.withPrompt {
111118
this.withMessages { (data.messageHistory) }
112119
}
113120
}
114121

115-
/**
116-
* Finds and sets the node for the strategy based on the provided context.
117-
*/
118-
public fun setExecutionPoint(nodePath: String, input: JsonElement) {
119-
// we drop first because it's agent's id, we don't need it here
120-
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)
121-
122-
if (segments.isEmpty()) {
123-
throw IllegalArgumentException("Invalid node path: $nodePath")
124-
}
125-
126-
val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
127-
val strategyName = segments.firstOrNull() ?: return
122+
private fun setExecutionPointImpl(pathSegments: List<String>, node: AIAgentNodeBase<*, *>, input: Any?) {
123+
val strategyName = pathSegments.firstOrNull() ?: return
128124

129125
// getting the very first segment (it should be a root strategy node)
130126
var currentNode: AIAgentNodeBase<*, *>? = metadata.nodesMap[strategyName]
131127
var currentPath = strategyName
132128

133129
// restoring the current node for each subgraph including strategy
134-
val segmentsInbetween = segments.drop(1).dropLast(1)
130+
val segmentsInbetween = pathSegments.drop(1).dropLast(1)
135131
for (segment in segmentsInbetween) {
136132
val currNode = currentNode as? ExecutionPointNode
137-
?: throw IllegalStateException("Restore for path $nodePath failed: one of middle segments is not a subgraph")
133+
?: throw IllegalStateException(
134+
"Restore for path " +
135+
"${pathSegments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)} failed: " +
136+
"one of middle segments is not a subgraph"
137+
)
138138

139139
currentPath = "$currentPath${DEFAULT_AGENT_PATH_SEPARATOR}$segment"
140140
val nextNode = metadata.nodesMap[currentPath]
@@ -144,15 +144,81 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
144144
}
145145
}
146146

147-
// forcing the very last segment to the latest pre-leaf node to complete the chain
148-
val leaf = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")
149-
val inputType = leaf.inputType
150-
151-
val actualInput = serializer.decodeFromJsonElement(serializer.serializersModule.serializer(inputType), input)
152-
leaf.let {
147+
val leaf = node
148+
node.let {
153149
currentNode as? ExecutionPointNode
154150
?: throw IllegalStateException("Node ${currentNode?.name} is not a valid leaf node")
155-
currentNode.enforceExecutionPoint(it, actualInput)
151+
currentNode.enforceExecutionPoint(it, input)
152+
}
153+
}
154+
155+
/**
156+
* Finds and sets the node for the strategy based on the provided context.
157+
*/
158+
@Deprecated("Use setExecutionPointAfterNode instead, setExecutionPoint will be removed in future versions")
159+
public suspend fun setExecutionPoint(nodePath: String, input: JsonElement) {
160+
// we drop first because it's agent's id, we don't need it here
161+
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)
162+
163+
if (segments.isEmpty()) {
164+
throw IllegalArgumentException("Invalid node path: $nodePath")
165+
}
166+
167+
val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
168+
169+
val completedNode = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")
170+
171+
val actualInput = serializer.decodeFromJsonElement(
172+
serializer.serializersModule.serializer(completedNode.inputType),
173+
input
174+
)
175+
176+
// Note: completed node will be re-executed because the output wasn't saved in checkpoints
177+
// (this was the original behavior before 0.6.1)
178+
setExecutionPointImpl(segments, completedNode, actualInput)
179+
}
180+
181+
/**
182+
* Finds and sets the node for the strategy based on the provided context.
183+
*/
184+
public suspend fun setExecutionPointAfterNode(
185+
nodePath: String,
186+
output: JsonElement,
187+
agentContext: AIAgentGraphContextBase
188+
) {
189+
// we drop first because it's agent's id, we don't need it here
190+
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)
191+
192+
if (segments.isEmpty()) {
193+
throw IllegalArgumentException("Invalid node path: $nodePath")
194+
}
195+
196+
val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
197+
198+
val completedNode = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")
199+
200+
val actualOutput = serializer.decodeFromJsonElement(
201+
serializer.serializersModule.serializer(completedNode.outputType),
202+
output
203+
)
204+
205+
if (completedNode is FinishNode<*>) {
206+
// finish node (of some subgraph) doesn't have next edges, and it's input equals output, so it's safe to re-start it:
207+
setExecutionPointImpl(
208+
pathSegments = segments,
209+
node = completedNode,
210+
input = actualOutput
211+
)
212+
} else {
213+
val resolvedEdge = completedNode.resolveEdgeUnsafe(agentContext, actualOutput)
214+
val nextNode = resolvedEdge?.edge?.toNode ?: throw IllegalStateException("Node $nodePath not found")
215+
val nextNodeInput = resolvedEdge.output
216+
217+
setExecutionPointImpl(
218+
pathSegments = segments.dropLast(1) + nextNode.name,
219+
node = nextNode,
220+
input = nextNodeInput
221+
)
156222
}
157223
}
158224
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/message/FeatureMessage.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ package ai.koog.agents.core.feature.message
88
*/
99
public interface FeatureMessage {
1010

11-
/**
11+
/**n
1212
* Represents the time, in milliseconds, when the feature message or event was created or occurred.
1313
*
1414
* The timestamp is used to track the exact time of the message's creation or event's occurrence,

agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlin.uuid.Uuid
2323
* @property messageHistory A list of messages exchanged in the session up to the checkpoint. Messages include interactions between the user, system, assistant, and tools.
2424
* @property nodePath The identifier of the node where the checkpoint was created.
2525
* @property lastInput Serialized input received for node with [nodePath]
26+
* @property lastOutput Serialized output received from node with [nodePath]
2627
* @property properties Additional data associated with the checkpoint. This can be used to store additional information about the agent's state.
2728
* @property createdAt The timestamp when the checkpoint was created.
2829
* @property version The version of the checkpoint data structure
@@ -32,11 +33,36 @@ public data class AgentCheckpointData(
3233
val checkpointId: String,
3334
val createdAt: Instant,
3435
val nodePath: String,
35-
val lastInput: JsonElement,
36+
@Deprecated("Use lastOutput instead, lastOutput will be removed in future versions")
37+
val lastInput: JsonElement? = null,
38+
val lastOutput: JsonElement? = null,
3639
val messageHistory: List<Message>,
3740
val version: Long,
3841
val properties: Map<String, JsonElement>? = null
39-
)
42+
) {
43+
init {
44+
if (nodePath != PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME) {
45+
require(lastInput == null || lastOutput == null) { "`lastInput` and `lastOutput` cannot be both set" }
46+
require(lastInput != null || lastOutput != null) { "`lastInput` (until 0.6.0) or `lastOutput` (since 0.6.1) must be set" }
47+
}
48+
}
49+
50+
private fun eq(json1: JsonElement?, json2: JsonElement?): Boolean =
51+
json1 == json2 || ((json1 == null || json1 == JsonNull) && (json2 == null || json2 == JsonNull))
52+
53+
override fun equals(other: Any?): Boolean {
54+
if (this === other) return true
55+
if (other !is AgentCheckpointData) return false
56+
return checkpointId == other.checkpointId &&
57+
nodePath == other.nodePath &&
58+
createdAt == other.createdAt &&
59+
eq(lastInput, other.lastInput) &&
60+
eq(lastOutput, other.lastOutput) &&
61+
messageHistory == other.messageHistory &&
62+
version == other.version &&
63+
properties == other.properties
64+
}
65+
}
4066

4167
/**
4268
* Creates a tombstone checkpoint for an agent's session.
@@ -51,7 +77,7 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
5177
checkpointId = Uuid.random().toString(),
5278
createdAt = time,
5379
nodePath = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME,
54-
lastInput = JsonNull,
80+
lastOutput = JsonNull,
5581
messageHistory = emptyList(),
5682
properties = mapOf(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)),
5783
version = version
@@ -61,7 +87,7 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
6187
/**
6288
* Converts an instance of [AgentCheckpointData] to [AgentContextData].
6389
*
64-
* The conversion maps the `messageHistory`, `nodeId`, and `lastInput` properties of
90+
* The conversion maps the `messageHistory`, `nodeId`, and `lastOutput` properties of
6591
* [AgentCheckpointData] directly to a new [AgentContextData] instance.
6692
*
6793
* @return A new [AgentContextData] instance containing the message history, node ID,
@@ -72,10 +98,12 @@ public fun AgentCheckpointData.toAgentContextData(
7298
agentId: String,
7399
additionalRollbackAction: suspend (AIAgentContext) -> Unit = {}
74100
): AgentContextData {
101+
@Suppress("DEPRECATION")
75102
return AgentContextData(
76103
messageHistory = messageHistory,
77104
nodePath = nodePath,
78105
lastInput = lastInput,
106+
lastOutput = lastOutput,
79107
rollbackStrategy,
80108
additionalRollbackAction
81109
)

agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ public class Persistence(
122122

123123
if (config.enableAutomaticPersistence) {
124124
val parent = persistence.getLatestCheckpoint(eventCtx.context.agentId)
125-
persistence.createCheckpoint(
125+
persistence.createCheckpointAfterNode(
126126
agentContext = eventCtx.context,
127127
nodePath = eventCtx.context.executionInfo.path(),
128-
lastInput = eventCtx.input,
129-
lastInputType = eventCtx.inputType,
128+
lastOutput = eventCtx.output,
129+
lastOutputType = eventCtx.outputType,
130130
version = parent?.version?.plus(1) ?: 0L,
131131
)
132132
}
@@ -159,10 +159,11 @@ public class Persistence(
159159
*
160160
* @param agentContext The context of the agent containing the state to checkpoint
161161
* @param nodeId The ID of the node where the checkpoint is created
162-
* @param lastInput The input data to include in the checkpoint
162+
* @param lastInput The latest node input data to include in the checkpoint
163163
* @param checkpointId Optional ID for the checkpoint; a random UUID is generated if not provided
164164
* @return The created checkpoint data
165165
*/
166+
@Deprecated("Use `createCheckpointAfterNode` instead")
166167
public suspend fun createCheckpoint(
167168
agentContext: AIAgentContext,
168169
nodePath: String,
@@ -195,6 +196,50 @@ public class Persistence(
195196
return checkpoint
196197
}
197198

199+
/**
200+
* Creates a checkpoint of the agent's current state.
201+
*
202+
* This method captures the agent's message history, current node, and input data
203+
* and stores it as a checkpoint using the configured storage provider.
204+
*
205+
* @param agentContext The context of the agent containing the state to checkpoint
206+
* @param nodeId The ID of the node where the checkpoint is created
207+
* @param lastOutput The latest node output data to include in the checkpoint
208+
* @param checkpointId Optional ID for the checkpoint; a random UUID is generated if not provided
209+
* @return The created checkpoint data
210+
*/
211+
public suspend fun createCheckpointAfterNode(
212+
agentContext: AIAgentContext,
213+
nodePath: String,
214+
lastOutput: Any?,
215+
lastOutputType: KType,
216+
version: Long,
217+
checkpointId: String? = null,
218+
): AgentCheckpointData? {
219+
val outputJson = SerializationUtils.encodeDataToJsonElementOrNull(lastOutput, lastOutputType)
220+
221+
if (outputJson == null) {
222+
logger.warn {
223+
"Failed to serialize output of type $lastOutputType for checkpoint creation for $nodePath, skipping..."
224+
}
225+
return null
226+
}
227+
228+
val checkpoint = agentContext.llm.readSession {
229+
return@readSession AgentCheckpointData(
230+
checkpointId = checkpointId ?: Uuid.random().toString(),
231+
messageHistory = prompt.messages,
232+
nodePath = agentContext.executionInfo.path(),
233+
lastOutput = outputJson,
234+
createdAt = Clock.System.now(),
235+
version = version,
236+
)
237+
}
238+
239+
saveCheckpoint(agentContext.agentId, checkpoint)
240+
return checkpoint
241+
}
242+
198243
/**
199244
* Creates and saves a tombstone checkpoint for an agent's session.
200245
*
@@ -256,7 +301,41 @@ public class Persistence(
256301
messageHistory: List<Message>,
257302
input: JsonElement
258303
) {
259-
agentContext.store(AgentContextData(messageHistory, agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath, input, rollbackStrategy))
304+
agentContext.store(
305+
AgentContextData(
306+
messageHistory,
307+
agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath,
308+
lastInput = input,
309+
rollbackStrategy = rollbackStrategy
310+
)
311+
)
312+
}
313+
314+
/**
315+
* Sets the execution point of an agent to a specified state.
316+
*
317+
* This method updates the agent's context to start execution from a specific point
318+
* in its graph, using the provided message history and finished node output data.
319+
*
320+
* @param agentContext The context of the agent to modify.
321+
* @param nodePath The path to the node inside the agent's graph where execution will begin.
322+
* @param messageHistory The sequence of messages representing the agent's prior interactions.
323+
* @param output The output data to associate with the specified execution point.
324+
*/
325+
public fun setExecutionPointAfterNode(
326+
agentContext: AIAgentContext,
327+
nodePath: String,
328+
messageHistory: List<Message>,
329+
output: JsonElement
330+
) {
331+
agentContext.store(
332+
AgentContextData(
333+
messageHistory,
334+
agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath,
335+
lastOutput = output,
336+
rollbackStrategy = rollbackStrategy
337+
)
338+
)
260339
}
261340

262341
/**

0 commit comments

Comments
 (0)