Skip to content

Commit e5cdc36

Browse files
Heuristic required tool choice (#1323)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context <!-- Why is this change needed? What problem does it solve? --> Models without the tool choice capability (e.g. most of the ollama models) cannot rely on the `toolChoice` in the model api. This makes some workflows inconsistent, in particular, `subgraphWithTask` does not work with such models since it relies on the `toolChoice` parameter. This pr introduces a workaround for this issue using `LLMBasedToolCallFixProcessor`, which now will use prompts to force LLM to generate a tool when the `toolChoice` is set to `Required`. ## Breaking Changes <!-- Will users need to update their code or configurations? --> --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [x] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [x] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated
1 parent 860c198 commit e5cdc36

File tree

9 files changed

+218
-45
lines changed

9 files changed

+218
-45
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentSubgraph.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public open class AIAgentSubgraph<TInput, TOutput>(
170170
tools = newTools,
171171
model = llmModel ?: context.llm.model,
172172
prompt = context.llm.prompt.copy(params = llmParams ?: context.llm.prompt.params),
173-
responseProcessor = responseProcessor
173+
responseProcessor = responseProcessor ?: context.llm.responseProcessor,
174174
),
175175
),
176176
)

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import ai.koog.prompt.message.LLMChoice
1616
import ai.koog.prompt.message.Message
1717
import ai.koog.prompt.params.LLMParams
1818
import ai.koog.prompt.processor.ResponseProcessor
19+
import ai.koog.prompt.processor.executeProcessed
1920
import ai.koog.prompt.streaming.StreamFrame
2021
import ai.koog.prompt.structure.StructureFixingParser
2122
import ai.koog.prompt.structure.StructuredRequestConfig
@@ -64,7 +65,7 @@ public open class AIAgentLLMSession(
6465
@InternalAgentsApi
6566
public open override suspend fun executeMultiple(prompt: Prompt, tools: List<ToolDescriptor>): List<Message.Response> {
6667
val preparedPrompt = preparePrompt(prompt, tools)
67-
return executor.execute(preparedPrompt, model, tools)
68+
return executor.executeProcessed(preparedPrompt, model, tools, responseProcessor)
6869
}
6970

7071
@InternalAgentsApi
@@ -96,11 +97,15 @@ public open class AIAgentLLMSession(
9697

9798
public open override suspend fun requestLLMOnlyCallingTools(): Message.Response {
9899
validateSession()
99-
// We use the multiple-response method to ensure we capture all context (e.g. thinking)
100-
// even though we only return the specific tool call.
101-
val responses = requestLLMMultipleOnlyCallingTools()
100+
val promptWithOnlyCallingTools = prompt.withUpdatedParams {
101+
toolChoice = LLMParams.ToolChoice.Required
102+
}
103+
val responses = executeMultiple(promptWithOnlyCallingTools, tools)
104+
105+
// some models might fail to produce a tool call
106+
// it's better to not fail here and allow the user to handle that
102107
return responses.firstOrNull { it is Message.Tool.Call }
103-
?: error("requestLLMOnlyCallingTools expected at least one Tool.Call but received: ${responses.map { it::class.simpleName }}")
108+
?: responses.first { it is Message.Assistant }
104109
}
105110

106111
public open override suspend fun requestLLMMultipleOnlyCallingTools(): List<Message.Response> {

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSessionImpl.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ internal class AIAgentLLMWriteSessionImpl internal constructor(
9696
return super<AIAgentLLMSession>.requestLLMWithoutTools().also { response -> appendPrompt { message(response) } }
9797
}
9898

99+
override suspend fun requestLLMOnlyCallingTools(): Message.Response {
100+
return super<AIAgentLLMSession>.requestLLMOnlyCallingTools()
101+
.also { response -> appendPrompt { message(response) } }
102+
}
103+
99104
override suspend fun requestLLMMultipleOnlyCallingTools(): List<Message.Response> {
100105
return super<AIAgentLLMSession>.requestLLMMultipleOnlyCallingTools()
101106
.also { responses ->

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public inline fun <reified T> AIAgentSubgraphBuilderBase<*, *>.nodeUpdatePrompt(
8282
* @param name Optional name for the node.
8383
*/
8484
@AIAgentBuilderDslMarker
85-
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
85+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestOnlyCallingTools(
8686
name: String? = null
8787
): AIAgentNodeDelegate<String, Message.Response> =
8888
node(name) { message ->
@@ -95,14 +95,48 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
9595
}
9696
}
9797

98+
/**
99+
* A node that appends a user message to the LLM prompt and gets a response where the LLM can only call tools.
100+
*
101+
* @param name Optional name for the node.
102+
*/
103+
@Deprecated(
104+
"Please use nodeLLMRequestOnlyCallingTools instead.",
105+
ReplaceWith("nodeLLMRequestOnlyCallingTools(name)")
106+
)
107+
@AIAgentBuilderDslMarker
108+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools(
109+
name: String? = null
110+
): AIAgentNodeDelegate<String, Message.Response> =
111+
nodeLLMRequestOnlyCallingTools(name)
112+
113+
/**
114+
* A node that appends a user message to the LLM prompt and gets multiple LLM responses where the LLM can only call tools.
115+
*
116+
* @param name Optional name for the node.
117+
*/
118+
@AIAgentBuilderDslMarker
119+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestMultipleOnlyCallingTools(
120+
name: String? = null
121+
): AIAgentNodeDelegate<String, List<Message.Response>> =
122+
node(name) { message ->
123+
llm.writeSession {
124+
appendPrompt {
125+
user(message)
126+
}
127+
128+
requestLLMMultipleOnlyCallingTools()
129+
}
130+
}
131+
98132
/**
99133
* A node that that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
100134
*
101135
* @param name Optional node name.
102136
* @param tool Tool descriptor the LLM is required to use.
103137
*/
104138
@AIAgentBuilderDslMarker
105-
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
139+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestForceOneTool(
106140
name: String? = null,
107141
tool: ToolDescriptor
108142
): AIAgentNodeDelegate<String, Message.Response> =
@@ -116,18 +150,52 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
116150
}
117151
}
118152

153+
/**
154+
* A node that that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
155+
*
156+
* @param name Optional node name.
157+
* @param tool Tool descriptor the LLM is required to use.
158+
*/
159+
@Deprecated(
160+
"Please use nodeLLMRequestForceOneTool instead.",
161+
ReplaceWith("nodeLLMRequestForceOneTool(name, tool)")
162+
)
163+
@AIAgentBuilderDslMarker
164+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
165+
name: String? = null,
166+
tool: ToolDescriptor
167+
): AIAgentNodeDelegate<String, Message.Response> =
168+
nodeLLMRequestForceOneTool(name, tool)
169+
119170
/**
120171
* A node that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
121172
*
122173
* @param name Optional node name.
123174
* @param tool Tool the LLM is required to use.
124175
*/
125176
@AIAgentBuilderDslMarker
177+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestForceOneTool(
178+
name: String? = null,
179+
tool: Tool<*, *>
180+
): AIAgentNodeDelegate<String, Message.Response> =
181+
nodeLLMRequestForceOneTool(name, tool.descriptor)
182+
183+
/**
184+
* A node that appends a user message to the LLM prompt and forces the LLM to use a specific tool.
185+
*
186+
* @param name Optional node name.
187+
* @param tool Tool the LLM is required to use.
188+
*/
189+
@Deprecated(
190+
"Please use nodeLLMRequestForceOneTool instead.",
191+
ReplaceWith("nodeLLMRequestForceOneTool(name, tool)")
192+
)
193+
@AIAgentBuilderDslMarker
126194
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool(
127195
name: String? = null,
128196
tool: Tool<*, *>
129197
): AIAgentNodeDelegate<String, Message.Response> =
130-
nodeLLMSendMessageForceOneTool(name, tool.descriptor)
198+
nodeLLMRequestForceOneTool(name, tool)
131199

132200
/**
133201
* A node that appends a user message to the LLM prompt and gets a response with optional tool usage.
@@ -407,6 +475,27 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResult(
407475
}
408476
}
409477

478+
/**
479+
* A node that adds a tool result to the prompt and gets an LLM response where the LLM can only call tools.
480+
*
481+
* @param name Optional node name.
482+
*/
483+
@AIAgentBuilderDslMarker
484+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResultOnlyCallingTools(
485+
name: String? = null
486+
): AIAgentNodeDelegate<List<ReceivedToolResult>, Message.Response> =
487+
node(name) { results ->
488+
llm.writeSession {
489+
appendPrompt {
490+
tool {
491+
results.forEach { result(it) }
492+
}
493+
}
494+
495+
requestLLMOnlyCallingTools()
496+
}
497+
}
498+
410499
/**
411500
* A node that executes multiple tool calls. These calls can optionally be executed in parallel.
412501
*
@@ -481,6 +570,27 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResults(
481570
}
482571
}
483572

573+
/**
574+
* A node that adds multiple tool results to the prompt and gets multiple LLM responses where the LLM can only call tools.
575+
*
576+
* @param name Optional node name.
577+
*/
578+
@AIAgentBuilderDslMarker
579+
public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResultsOnlyCallingTools(
580+
name: String? = null
581+
): AIAgentNodeDelegate<List<ReceivedToolResult>, List<Message.Response>> =
582+
node(name) { results ->
583+
llm.writeSession {
584+
appendPrompt {
585+
tool {
586+
results.forEach { result(it) }
587+
}
588+
}
589+
590+
requestLLMMultipleOnlyCallingTools()
591+
}
592+
}
593+
484594
/**
485595
* A node that calls a specific tool directly using the provided arguments.
486596
*

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentSubgraphExt.kt

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker
1111
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
1212
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphDelegate
1313
import ai.koog.agents.core.dsl.builder.forwardTo
14+
import ai.koog.agents.core.dsl.extension.nodeLLMRequest
1415
import ai.koog.agents.core.dsl.extension.nodeLLMRequestMultiple
1516
import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults
17+
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
1618
import ai.koog.agents.core.dsl.extension.setToolChoiceRequired
1719
import ai.koog.agents.core.environment.ReceivedToolResult
1820
import ai.koog.agents.core.environment.ToolResultKind
@@ -491,7 +493,12 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
491493
// Helper node to overcome problems of the current api and repeat less code when writing routing conditions
492494
val nodeDecide by node<List<Message.Response>, List<Message.Response>> { it }
493495

494-
val nodeCallLLM by nodeLLMRequestMultiple()
496+
val nodeCallLLMDelegate = if (runMode == ToolCalls.SINGLE_RUN_SEQUENTIAL) {
497+
nodeLLMRequest().transform { listOf(it) }
498+
} else {
499+
nodeLLMRequestMultiple()
500+
}
501+
val nodeCallLLM by nodeCallLLMDelegate
495502

496503
val callToolsHacked by node<List<Message.Tool.Call>, List<ReceivedToolResult>> { toolCalls ->
497504
val (finishToolCalls, regularToolCalls) = toolCalls.partition { it.tool == finishTool.name }
@@ -521,8 +528,6 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
521528
}
522529
}
523530

524-
val sendToolsResults by nodeLLMSendMultipleToolResults()
525-
526531
@OptIn(DetachedPromptExecutorAPI::class)
527532
val handleAssistantMessage by node<Message.Assistant, List<Message.Response>> { response ->
528533
if (llm.model.capabilities.contains(LLMCapability.ToolChoice)) {
@@ -590,9 +595,14 @@ public inline fun <reified Input, reified Output, reified OutputTransformed> AIA
590595
transformed { toolsResults -> toolsResults.first() }
591596
)
592597

593-
edge(callToolsHacked forwardTo sendToolsResults)
594-
595-
edge(sendToolsResults forwardTo nodeDecide)
598+
if (runMode == ToolCalls.SINGLE_RUN_SEQUENTIAL) {
599+
val sendToolResult by nodeLLMSendToolResult()
600+
edge(callToolsHacked forwardTo sendToolResult transformed { it.first() })
601+
edge(sendToolResult forwardTo nodeDecide transformed { listOf(it) })
602+
} else {
603+
val sendToolsResults by nodeLLMSendMultipleToolResults()
604+
callToolsHacked then sendToolsResults then nodeDecide
605+
}
596606

597607
edge(finalizeTask forwardTo nodeFinish)
598608
}

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSessionTest.kt

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ import ai.koog.prompt.params.LLMParams
2626
import ai.koog.prompt.processor.ResponseProcessor
2727
import kotlinx.coroutines.test.runTest
2828
import kotlinx.serialization.Serializable
29+
import kotlin.test.Ignore
2930
import kotlin.test.Test
31+
import kotlin.test.assertContains
3032
import kotlin.test.assertEquals
33+
import kotlin.test.assertIs
3134
import kotlin.test.assertNotNull
3235
import kotlin.test.assertTrue
3336

@@ -392,6 +395,11 @@ class AIAgentLLMWriteSessionTest {
392395
}
393396

394397
@Test
398+
// This behavior is not supported for non-list responses from "requestLLM..." methods
399+
// The test was passing due to a bug in the requestLLMOnlyCallingTools implementation
400+
// See KG-663
401+
// TODO(): remove the test after deprecating non-list responses from LLM
402+
@Ignore
395403
fun testRequestLLMOnlyCallingToolsWithThinking() = runTest {
396404
val thinkingContent = "<thinking>Checking file...</thinking>"
397405
val testTool = TestTool()
@@ -417,30 +425,6 @@ class AIAgentLLMWriteSessionTest {
417425
assertEquals("test-tool", (lastTwoMessages[1] as Message.Tool.Call).tool)
418426
}
419427

420-
@Test
421-
fun testRequestLLMOnlyCallingToolsNoToolCallThrowsException() = runTest {
422-
val mockExecutor = getMockExecutor(clock = testClock) {
423-
// Simulate model refusing to use tools and just responding with text
424-
mockLLMAnswer("I cannot use tools for this request.").asDefaultResponse
425-
}
426-
427-
val session = createSession(mockExecutor, listOf(TestTool()))
428-
429-
val exception = kotlin.runCatching {
430-
session.requestLLMOnlyCallingTools()
431-
}.exceptionOrNull()
432-
433-
assertNotNull(exception, "Expected an exception when no tool call is found")
434-
assertTrue(
435-
exception is IllegalStateException,
436-
"Expected IllegalStateException but got ${exception::class.simpleName}"
437-
)
438-
assertTrue(
439-
exception.message?.contains("expected at least one Tool.Call") == true,
440-
"Exception message should indicate missing tool call"
441-
)
442-
}
443-
444428
@Test
445429
fun testRequestLLMOnlyCallingToolsWithMultipleToolCalls() = runTest {
446430
val testTool = TestTool()
@@ -464,8 +448,9 @@ class AIAgentLLMWriteSessionTest {
464448
assertTrue(response is Message.Tool.Call, "Expected response to be a Tool Call")
465449
assertEquals("test-tool", response.tool)
466450

467-
// Both tool calls should be in history
468-
val lastTwoMessages = session.prompt.messages.takeLast(2)
469-
assertTrue(lastTwoMessages.all { it is Message.Tool.Call })
451+
// Only the first tool call should be added to the history
452+
val lastMessage = session.prompt.messages.last()
453+
assertIs<Message.Tool.Call>(lastMessage)
454+
assertContains(lastMessage.content, "first")
470455
}
471456
}

0 commit comments

Comments
 (0)