Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.koog.agents.core.dsl.extension.setToolChoiceRequired
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.ToolResultKind
import ai.koog.agents.core.environment.executeTools
import ai.koog.agents.core.environment.toSafeResult
import ai.koog.agents.core.tools.Tool
import ai.koog.agents.core.tools.ToolDescriptor
Expand Down Expand Up @@ -99,6 +98,17 @@
* to prevent redundancy in responses and ensure conciseness in communication.
*/
public const val ASSISTANT_RESPONSE_REPEAT_MAX: Int = 3

/**
* A message shown to the model when it does not return a tool call during the subgraphWithTask execution.
*
* The message clarifies to the model that a tool call is required here,
* And if the task is finished, the finish tool has to be called.
*/
public fun messageOnAssistantResponse(finishToolName: String): String = markdown {
h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.")
h2("IF YOU HAVE FINISHED, CALL `$finishToolName` TOOL!")
}
}

/**
Expand Down Expand Up @@ -405,7 +415,7 @@
)
)
@InternalAgentsApi
public inline fun <reified Input, reified Output, reified OutputTransformed> AIAgentSubgraphBuilderBase<Input, OutputTransformed>.setupSubgraphWithTask(

Check warning on line 418 in agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `setupSubgraphWithTask` coverage is below the threshold 50%
finishTool: Tool<Output, OutputTransformed>,
assistantResponseRepeatMax: Int? = null,
noinline defineTask: suspend AIAgentGraphContextBase.(Input) -> String
Expand Down Expand Up @@ -433,7 +443,7 @@
* context of an AI agent graph and based on the given input data.
*/
@InternalAgentsApi
public inline fun <reified Input, reified Output, reified OutputTransformed> AIAgentSubgraphBuilderBase<Input, OutputTransformed>.setupSubgraphWithTask(

Check warning on line 446 in agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `setupSubgraphWithTask` coverage is below the threshold 50%
finishTool: Tool<Output, OutputTransformed>,
runMode: ToolCalls,
assistantResponseRepeatMax: Int? = null,
Expand Down Expand Up @@ -478,36 +488,18 @@
val nodeCallLLM by nodeLLMRequestMultiple()

val callToolsHacked by node<List<Message.Tool.Call>, List<ReceivedToolResult>> { toolCalls ->
val (finishToolCalls, regularToolCalls) = toolCalls.partition { it.tool == finishTool.name }

// Execute finish tool
val finishToolResult = finishToolCalls.firstOrNull()?.let { toolCall ->
executeFinishTool<Output, OutputTransformed>(toolCall, finishTool)
}

// Execute regular tools
val regularToolsResults = when (runMode) {
ToolCalls.PARALLEL -> {
environment.executeTools(regularToolCalls)
}
ToolCalls.SEQUENTIAL,
ToolCalls.SINGLE_RUN_SEQUENTIAL -> {
regularToolCalls.map { toolCall ->
environment.executeTool(toolCall)
}
}
}

buildList {
finishToolResult?.let { add(it) }
addAll(regularToolsResults)
}
// use a method for the subtask to avoid code duplication
executeMultipleToolsHacked<Output, OutputTransformed>(
toolCalls,
finishTool,
runMode == ToolCalls.PARALLEL
)
}

val sendToolsResults by nodeLLMSendMultipleToolResults()

@OptIn(DetachedPromptExecutorAPI::class)
val handleAssistantMessage by node<Message.Assistant, List<Message.Response>> { response ->
val handleAssistantMessage by node<Message.Assistant, String> { response ->
if (llm.model.capabilities.contains(LLMCapability.ToolChoice)) {
error(
"Subgraph with task must always call tools, but no ${Message.Tool.Call::class.simpleName} was generated, " +
Expand All @@ -526,19 +518,7 @@
)
}

llm.writeSession {
// append a new message to the history with feedback:
appendPrompt {
user {
markdown {
h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.")
h2("IF YOU HAVE FINISHED, CALL `${finishTool.name}` TOOL!")
}
}
}

requestLLMMultiple()
}
SubgraphWithTaskUtils.messageOnAssistantResponse(finishTool.name)
}

nodeStart then setupTask then nodeCallLLM then nodeDecide
Expand All @@ -555,7 +535,7 @@
transformed { responses -> responses.first() as Message.Assistant }
)

edge(handleAssistantMessage forwardTo nodeDecide)
edge(handleAssistantMessage forwardTo nodeCallLLM)

// throw to terminate the agent early with exception
edge(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.koog.agents.ext.agent

import ai.koog.agents.core.agent.ToolCalls
import ai.koog.agents.core.agent.context.AIAgentContext
import ai.koog.agents.core.agent.context.AIAgentFunctionalContext
import ai.koog.agents.core.agent.context.DetachedPromptExecutorAPI
import ai.koog.agents.core.annotation.InternalAgentsApi
Expand All @@ -13,6 +14,7 @@
import ai.koog.agents.core.dsl.extension.setToolChoiceRequired
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.executeTools
import ai.koog.agents.core.environment.result
import ai.koog.agents.core.environment.toSafeResult
import ai.koog.agents.core.tools.Tool
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
Expand Down Expand Up @@ -123,7 +125,7 @@
* @return The transformed final result of executing the finishing tool to complete the subtask.
*/
@OptIn(InternalAgentToolsApi::class, DetachedPromptExecutorAPI::class, InternalAgentsApi::class)
public suspend inline fun <reified Input, reified Output, reified OutputTransformed> AIAgentFunctionalContext.subtask(

Check warning on line 128 in agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubtaskExt.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `subtask` coverage is below the threshold 50%
input: Input,
tools: List<Tool<*, *>>? = null,
finishTool: Tool<Output, OutputTransformed>,
Expand Down Expand Up @@ -208,12 +210,13 @@
val toolCalls = extractToolCalls(responses)
val toolResults =
executeMultipleToolsHacked(toolCalls, finishTool, parallelTools = runMode == ToolCalls.PARALLEL)
responses = sendMultipleToolResults(toolResults)

toolResults.firstOrNull { it.tool == finishTool.descriptor.name }
?.let { finishResult ->
return finishResult.toSafeResult(finishTool).asSuccessful().result
}

responses = sendMultipleToolResults(toolResults)
}

else -> {
Expand Down Expand Up @@ -249,11 +252,12 @@
when {
response is Message.Tool.Call -> {
val toolResult = executeToolHacked(response, finishTool)
response = sendToolResult(toolResult)

if (toolResult.tool == finishTool.descriptor.name) {
return toolResult.toSafeResult(finishTool).asSuccessful().result
}

response = sendToolResult(toolResult)
}

else -> {
Expand All @@ -266,10 +270,7 @@
}

response = requestLLM(
message = markdown {
h1("DO NOT CHAT WITH ME DIRECTLY! CALL TOOLS, INSTEAD.")
h2("IF YOU HAVE FINISHED, CALL `${finishTool.name}` TOOL!")
}
SubgraphWithTaskUtils.messageOnAssistantResponse(finishTool.name)
)
}
}
Expand All @@ -278,25 +279,37 @@

@OptIn(InternalAgentToolsApi::class, InternalAgentsApi::class)
@PublishedApi
internal suspend inline fun <reified Output, reified OutputTransformed> AIAgentFunctionalContext.executeMultipleToolsHacked(
internal suspend inline fun <reified Output, reified OutputTransformed> AIAgentContext.executeMultipleToolsHacked(
toolCalls: List<Message.Tool.Call>,
finishTool: Tool<Output, OutputTransformed>,
parallelTools: Boolean = false
): List<ReceivedToolResult> {
val finishTools = toolCalls.filter { it.tool == finishTool.descriptor.name }
val normalTools = toolCalls.filterNot { it.tool == finishTool.descriptor.name }

val finishToolResults = finishTools.map { toolCall ->
executeFinishTool(toolCall, finishTool)
}
val (finishTools, normalTools) = toolCalls.partition { it.tool == finishTool.name }

val normalToolResults = if (parallelTools) {
environment.executeTools(normalTools)
} else {
normalTools.map { environment.executeTool(it) }
}

return finishToolResults + normalToolResults
// if a finish tool was called, the subtask execution will be finished,
// and the normal tool results have to be appended to the prompt here,
// otherwise they will be lost
if (finishTools.isNotEmpty()) {
llm.writeSession {
appendPrompt {
tool {
normalToolResults.forEach { result(it) }
}
}
}
}

val finishToolResults = finishTools.map { toolCall ->
executeFinishTool(toolCall, finishTool)
}

return normalToolResults + finishToolResults
}

@OptIn(InternalAgentToolsApi::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ai.koog.agents.core.agent.AIAgent
import ai.koog.agents.core.agent.ToolCalls
import ai.koog.agents.core.agent.config.AIAgentConfig
import ai.koog.agents.core.agent.execution.path
import ai.koog.agents.core.agent.functionalStrategy
import ai.koog.agents.core.agent.singleRunStrategy
import ai.koog.agents.core.dsl.builder.ParallelNodeExecutionResult
import ai.koog.agents.core.dsl.builder.forwardTo
Expand All @@ -15,8 +16,11 @@ import ai.koog.agents.core.dsl.extension.nodeLLMRequest
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
import ai.koog.agents.core.dsl.extension.onAssistantMessage
import ai.koog.agents.core.dsl.extension.onToolCall
import ai.koog.agents.core.dsl.extension.requestLLM
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.tools.ToolRegistry
import ai.koog.agents.ext.agent.reActStrategy
import ai.koog.agents.ext.agent.subtask
import ai.koog.agents.features.eventHandler.feature.EventHandler
import ai.koog.agents.features.eventHandler.feature.EventHandlerConfig
import ai.koog.agents.snapshot.feature.Persistence
Expand All @@ -26,6 +30,7 @@ import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
import ai.koog.integration.tests.utils.Models
import ai.koog.integration.tests.utils.RetryUtils.withRetry
import ai.koog.integration.tests.utils.tools.CalculateSumTool
import ai.koog.integration.tests.utils.tools.CalculatorTool
import ai.koog.integration.tests.utils.tools.CalculatorToolNoArgs
import ai.koog.integration.tests.utils.tools.DelayTool
import ai.koog.integration.tests.utils.tools.GetTransactionsTool
Expand Down Expand Up @@ -69,6 +74,7 @@ import java.util.Base64
import java.util.stream.Stream
import kotlin.io.path.readBytes
import kotlin.reflect.typeOf
import kotlin.test.assertContains
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

Expand Down Expand Up @@ -111,6 +117,13 @@ class AIAgentIntegrationTest : AIAgentTestBase() {
Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)")
)
}

@JvmStatic
fun runModes(): Stream<ToolCalls> = Stream.of(
ToolCalls.SEQUENTIAL,
ToolCalls.PARALLEL,
ToolCalls.SINGLE_RUN_SEQUENTIAL,
)
}

val twoToolsRegistry = ToolRegistry {
Expand Down Expand Up @@ -171,7 +184,7 @@ class AIAgentIntegrationTest : AIAgentTestBase() {
name = "compress_history",
strategy = strategy
)
val compressToolResult by nodeLLMCompressHistory<ai.koog.agents.core.environment.ReceivedToolResult>(
val compressToolResult by nodeLLMCompressHistory<ReceivedToolResult>(
name = "compress_history",
strategy = strategy
)
Expand Down Expand Up @@ -1074,7 +1087,7 @@ class AIAgentIntegrationTest : AIAgentTestBase() {
agent.run("Hi")

with(state) {
errors.shouldBeEmpty() // There should be no errors during parallel execution}
errors.shouldBeEmpty() // There should be no errors during parallel execution
results.shouldNotBeEmpty().first() as String should {
contain("Math result: 56")
contain("Text result: Hello World")
Expand Down Expand Up @@ -1302,4 +1315,38 @@ class AIAgentIntegrationTest : AIAgentTestBase() {
}
}
}

@ParameterizedTest
@MethodSource("runModes")
fun integration_testSubtaskCorrectlySavesToolMessages(runMode: ToolCalls) = runTest(timeout = 3600.seconds) {
withRetry {
val model = OpenAIModels.Chat.GPT4o
val executor = getExecutor(model)
val toolRegistry = ToolRegistry {
tool(CalculatorTool)
}

val strategy = functionalStrategy<String, String>("subtask-test") { input ->
subtask<String, Int>(input, runMode = runMode) { it }
requestLLM("What's the result?").content
}

val agent = AIAgent(
strategy = strategy,
promptExecutor = executor,
agentConfig = AIAgentConfig(
prompt = prompt("subtask-test") {
system("You are a helpful assistant specialized in simple calculations.")
},
model = model,
maxAgentIterations = 10
),
toolRegistry = toolRegistry
)

val result = agent.run("2 * 7")

assertContains(result, "14")
}
}
}
Loading