Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@ import kotlinx.serialization.json.JsonElement
public class AgentContextData(
internal val messageHistory: List<Message>,
internal val nodePath: String,
internal val lastInput: JsonElement,
@Deprecated("Use lastOutput instead, lastOutput will be removed in future versions")
internal val lastInput: JsonElement? = null,
internal val lastOutput: JsonElement? = null,
internal val rollbackStrategy: RollbackStrategy,
internal val additionalRollbackActions: suspend (AIAgentContext) -> Unit = {}
)
) {
init {
require(lastInput == null || lastOutput == null) { "`lastInput` and `lastOutput` cannot be both set" }
require(lastInput != null || lastOutput != null) { "`lastInput` (until 0.6.0) or `lastOutput` (since 0.6.1) must be set" }
}
}

public enum class RollbackStrategy {
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class AIAgentGraphStrategy<TInput, TOutput>(

@OptIn(InternalAgentsApi::class)
private suspend fun restoreStateIfNeeded(
agentContext: AIAgentContext
agentContext: AIAgentGraphContextBase
) {
val additionalContextData: AgentContextData = agentContext.getAgentContextData() ?: return

Expand All @@ -97,44 +97,44 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
}

@OptIn(InternalAgentsApi::class)
private suspend fun restoreDefault(agentContext: AIAgentContext, data: AgentContextData) {
private suspend fun restoreDefault(agentContext: AIAgentGraphContextBase, data: AgentContextData) {
val nodePath = data.nodePath

// Perform additional cleanup (ex: rollback tools):
data.additionalRollbackActions(agentContext)

// Set current graph node:
setExecutionPoint(nodePath, data.lastInput)
@Suppress("DEPRECATION")
when {
data.lastInput != null -> setExecutionPoint(nodePath, data.lastInput)
data.lastOutput != null -> setExecutionPointAfterNode(nodePath, data.lastOutput, agentContext)

// Unexpected state, either input (before 0.6.1) or output (since 0.6.1) should be saved in checkpiints:
else -> {}
}

// Reset the message history:
agentContext.llm.withPrompt {
this.withMessages { (data.messageHistory) }
}
}

/**
* Finds and sets the node for the strategy based on the provided context.
*/
public fun setExecutionPoint(nodePath: String, input: JsonElement) {
// we drop first because it's agent's id, we don't need it here
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)

if (segments.isEmpty()) {
throw IllegalArgumentException("Invalid node path: $nodePath")
}

val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)
val strategyName = segments.firstOrNull() ?: return
private fun setExecutionPointImpl(pathSegments: List<String>, node: AIAgentNodeBase<*, *>, input: Any?) {
val strategyName = pathSegments.firstOrNull() ?: return

// getting the very first segment (it should be a root strategy node)
var currentNode: AIAgentNodeBase<*, *>? = metadata.nodesMap[strategyName]
var currentPath = strategyName

// restoring the current node for each subgraph including strategy
val segmentsInbetween = segments.drop(1).dropLast(1)
val segmentsInbetween = pathSegments.drop(1).dropLast(1)
for (segment in segmentsInbetween) {
val currNode = currentNode as? ExecutionPointNode
?: throw IllegalStateException("Restore for path $nodePath failed: one of middle segments is not a subgraph")
?: throw IllegalStateException(
"Restore for path " +
"${pathSegments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)} failed: " +
"one of middle segments is not a subgraph"
)

currentPath = "$currentPath${DEFAULT_AGENT_PATH_SEPARATOR}$segment"
val nextNode = metadata.nodesMap[currentPath]
Expand All @@ -144,15 +144,81 @@ public class AIAgentGraphStrategy<TInput, TOutput>(
}
}

// forcing the very last segment to the latest pre-leaf node to complete the chain
val leaf = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")
val inputType = leaf.inputType

val actualInput = serializer.decodeFromJsonElement(serializer.serializersModule.serializer(inputType), input)
leaf.let {
val leaf = node
node.let {
currentNode as? ExecutionPointNode
?: throw IllegalStateException("Node ${currentNode?.name} is not a valid leaf node")
currentNode.enforceExecutionPoint(it, actualInput)
currentNode.enforceExecutionPoint(it, input)
}
}

/**
* Finds and sets the node for the strategy based on the provided context.
*/
@Deprecated("Use setExecutionPointAfterNode instead, setExecutionPoint will be removed in future versions")
public suspend fun setExecutionPoint(nodePath: String, input: JsonElement) {
// we drop first because it's agent's id, we don't need it here
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)

if (segments.isEmpty()) {
throw IllegalArgumentException("Invalid node path: $nodePath")
}

val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)

val completedNode = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")

val actualInput = serializer.decodeFromJsonElement(
serializer.serializersModule.serializer(completedNode.inputType),
input
)

// Note: completed node will be re-executed because the output wasn't saved in checkpoints
// (this was the original behavior before 0.6.1)
setExecutionPointImpl(segments, completedNode, actualInput)
}

/**
* Finds and sets the node for the strategy based on the provided context.
*/
public suspend fun setExecutionPointAfterNode(
nodePath: String,
output: JsonElement,
agentContext: AIAgentGraphContextBase
) {
// we drop first because it's agent's id, we don't need it here
val segments = nodePath.split(DEFAULT_AGENT_PATH_SEPARATOR).drop(1)

if (segments.isEmpty()) {
throw IllegalArgumentException("Invalid node path: $nodePath")
}

val actualPath = segments.joinToString(DEFAULT_AGENT_PATH_SEPARATOR)

val completedNode = metadata.nodesMap[actualPath] ?: throw IllegalStateException("Node $actualPath not found")

val actualOutput = serializer.decodeFromJsonElement(
serializer.serializersModule.serializer(completedNode.outputType),
output
)

if (completedNode is FinishNode<*>) {
// finish node (of some subgraph) doesn't have next edges, and it's input equals output, so it's safe to re-start it:
setExecutionPointImpl(
pathSegments = segments,
node = completedNode,
input = actualOutput
)
} else {
val resolvedEdge = completedNode.resolveEdgeUnsafe(agentContext, actualOutput)
val nextNode = resolvedEdge?.edge?.toNode ?: throw IllegalStateException("Node $nodePath not found")
val nextNodeInput = resolvedEdge.output

setExecutionPointImpl(
pathSegments = segments.dropLast(1) + nextNode.name,
node = nextNode,
input = nextNodeInput
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package ai.koog.agents.core.feature.message
*/
public interface FeatureMessage {

/**
/**n
* Represents the time, in milliseconds, when the feature message or event was created or occurred.
*
* The timestamp is used to track the exact time of the message's creation or event's occurrence,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import kotlin.uuid.Uuid
* @property messageHistory A list of messages exchanged in the session up to the checkpoint. Messages include interactions between the user, system, assistant, and tools.
* @property nodePath The identifier of the node where the checkpoint was created.
* @property lastInput Serialized input received for node with [nodePath]
* @property lastOutput Serialized output received from node with [nodePath]
* @property properties Additional data associated with the checkpoint. This can be used to store additional information about the agent's state.
* @property createdAt The timestamp when the checkpoint was created.
* @property version The version of the checkpoint data structure
Expand All @@ -32,11 +33,36 @@ public data class AgentCheckpointData(
val checkpointId: String,
val createdAt: Instant,
val nodePath: String,
val lastInput: JsonElement,
@Deprecated("Use lastOutput instead, lastOutput will be removed in future versions")
val lastInput: JsonElement? = null,
val lastOutput: JsonElement? = null,
val messageHistory: List<Message>,
val version: Long,
val properties: Map<String, JsonElement>? = null
)
) {
init {
if (nodePath != PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME) {
require(lastInput == null || lastOutput == null) { "`lastInput` and `lastOutput` cannot be both set" }
require(lastInput != null || lastOutput != null) { "`lastInput` (until 0.6.0) or `lastOutput` (since 0.6.1) must be set" }
}
}

private fun eq(json1: JsonElement?, json2: JsonElement?): Boolean =
json1 == json2 || ((json1 == null || json1 == JsonNull) && (json2 == null || json2 == JsonNull))

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is AgentCheckpointData) return false
return checkpointId == other.checkpointId &&
nodePath == other.nodePath &&
createdAt == other.createdAt &&
eq(lastInput, other.lastInput) &&
eq(lastOutput, other.lastOutput) &&
messageHistory == other.messageHistory &&
version == other.version &&
properties == other.properties
}
}

/**
* Creates a tombstone checkpoint for an agent's session.
Expand All @@ -51,7 +77,7 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
checkpointId = Uuid.random().toString(),
createdAt = time,
nodePath = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME,
lastInput = JsonNull,
lastOutput = JsonNull,
messageHistory = emptyList(),
properties = mapOf(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)),
version = version
Expand All @@ -61,7 +87,7 @@ public fun tombstoneCheckpoint(time: Instant, version: Long): AgentCheckpointDat
/**
* Converts an instance of [AgentCheckpointData] to [AgentContextData].
*
* The conversion maps the `messageHistory`, `nodeId`, and `lastInput` properties of
* The conversion maps the `messageHistory`, `nodeId`, and `lastOutput` properties of
* [AgentCheckpointData] directly to a new [AgentContextData] instance.
*
* @return A new [AgentContextData] instance containing the message history, node ID,
Expand All @@ -72,10 +98,12 @@ public fun AgentCheckpointData.toAgentContextData(
agentId: String,
additionalRollbackAction: suspend (AIAgentContext) -> Unit = {}
): AgentContextData {
@Suppress("DEPRECATION")
return AgentContextData(
messageHistory = messageHistory,
nodePath = nodePath,
lastInput = lastInput,
lastOutput = lastOutput,
rollbackStrategy,
additionalRollbackAction
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@

if (config.enableAutomaticPersistence) {
val parent = persistence.getLatestCheckpoint(eventCtx.context.agentId)
persistence.createCheckpoint(
persistence.createCheckpointAfterNode(
agentContext = eventCtx.context,
nodePath = eventCtx.context.executionInfo.path(),
lastInput = eventCtx.input,
lastInputType = eventCtx.inputType,
lastOutput = eventCtx.output,
lastOutputType = eventCtx.outputType,
version = parent?.version?.plus(1) ?: 0L,
)
}
Expand Down Expand Up @@ -159,11 +159,12 @@
*
* @param agentContext The context of the agent containing the state to checkpoint
* @param nodeId The ID of the node where the checkpoint is created
* @param lastInput The input data to include in the checkpoint
* @param lastInput The latest node input data to include in the checkpoint
* @param checkpointId Optional ID for the checkpoint; a random UUID is generated if not provided
* @return The created checkpoint data
*/
@Deprecated("Use `createCheckpointAfterNode` instead")
public suspend fun createCheckpoint(

Check warning on line 167 in agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `createCheckpoint` coverage is below the threshold 50%
agentContext: AIAgentContext,
nodePath: String,
lastInput: Any?,
Expand Down Expand Up @@ -195,6 +196,50 @@
return checkpoint
}

/**
* Creates a checkpoint of the agent's current state.
*
* This method captures the agent's message history, current node, and input data
* and stores it as a checkpoint using the configured storage provider.
*
* @param agentContext The context of the agent containing the state to checkpoint
* @param nodeId The ID of the node where the checkpoint is created
* @param lastOutput The latest node output data to include in the checkpoint
* @param checkpointId Optional ID for the checkpoint; a random UUID is generated if not provided
* @return The created checkpoint data
*/
public suspend fun createCheckpointAfterNode(
agentContext: AIAgentContext,
nodePath: String,
lastOutput: Any?,
lastOutputType: KType,
version: Long,
checkpointId: String? = null,
): AgentCheckpointData? {
val outputJson = SerializationUtils.encodeDataToJsonElementOrNull(lastOutput, lastOutputType)

if (outputJson == null) {
logger.warn {
"Failed to serialize output of type $lastOutputType for checkpoint creation for $nodePath, skipping..."
}
return null
}

val checkpoint = agentContext.llm.readSession {
return@readSession AgentCheckpointData(
checkpointId = checkpointId ?: Uuid.random().toString(),
messageHistory = prompt.messages,
nodePath = agentContext.executionInfo.path(),
lastOutput = outputJson,
createdAt = Clock.System.now(),
version = version,
)
}

saveCheckpoint(agentContext.agentId, checkpoint)
return checkpoint
}

/**
* Creates and saves a tombstone checkpoint for an agent's session.
*
Expand Down Expand Up @@ -256,7 +301,41 @@
messageHistory: List<Message>,
input: JsonElement
) {
agentContext.store(AgentContextData(messageHistory, agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath, input, rollbackStrategy))
agentContext.store(
AgentContextData(
messageHistory,
agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath,
lastInput = input,
rollbackStrategy = rollbackStrategy
)
)
}

/**
* Sets the execution point of an agent to a specified state.
*
* This method updates the agent's context to start execution from a specific point
* in its graph, using the provided message history and finished node output data.
*
* @param agentContext The context of the agent to modify.
* @param nodePath The path to the node inside the agent's graph where execution will begin.
* @param messageHistory The sequence of messages representing the agent's prior interactions.
* @param output The output data to associate with the specified execution point.
*/
public fun setExecutionPointAfterNode(

Check warning on line 325 in agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistence.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `setExecutionPointAfterNode` coverage is below the threshold 50%
agentContext: AIAgentContext,
nodePath: String,
messageHistory: List<Message>,
output: JsonElement
) {
agentContext.store(
AgentContextData(
messageHistory,
agentContext.agentId + DEFAULT_AGENT_PATH_SEPARATOR + nodePath,
lastOutput = output,
rollbackStrategy = rollbackStrategy
)
)
}

/**
Expand Down
Loading
Loading