Skip to content

Commit 9824150

Browse files
committed
Make tool choice a parameter of the agent
Signed-off-by: Djordje Lukic <djordje.lukic@docker.com>
1 parent b355a29 commit 9824150

File tree

17 files changed

+258
-9
lines changed

17 files changed

+258
-9
lines changed

agent-schema.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,15 @@
365365
"type": "boolean",
366366
"description": "Whether to add a 'description' parameter to tool calls, allowing the LLM to provide context about why it is calling a tool"
367367
},
368+
"tool_choice": {
369+
"type": "string",
370+
"description": "Controls how the model selects tools. 'auto' (default) lets the model decide, 'required' forces the model to always call a tool, 'none' prevents tool use.",
371+
"enum": [
372+
"auto",
373+
"required",
374+
"none"
375+
]
376+
},
368377
"hooks": {
369378
"$ref": "#/definitions/HooksConfig",
370379
"description": "Lifecycle hooks for executing shell commands at various points in the agent's execution"

examples/tool_choice_required.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
agents:
2+
root:
3+
model: anthropic/claude-opus-4-6
4+
instruction: >
5+
You are a coding agent. Use your tools to read, write, and modify files,
6+
and to run shell commands. Always use tools to accomplish tasks rather than
7+
just describing what to do.
8+
tool_choice: required
9+
toolsets:
10+
- type: filesystem
11+
- type: shell

pkg/agent/agent.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Agent struct {
3232
addDate bool
3333
addEnvironmentInfo bool
3434
addDescriptionParameter bool
35+
toolChoice string
3536
maxIterations int
3637
maxConsecutiveToolCalls int
3738
maxOldToolCallTokens int
@@ -206,6 +207,11 @@ func (a *Agent) Hooks() *latest.HooksConfig {
206207
return a.hooks
207208
}
208209

210+
// ToolChoice returns the tool choice mode for this agent (e.g., "auto", "required", "none").
211+
func (a *Agent) ToolChoice() string {
212+
return a.toolChoice
213+
}
214+
209215
// Tools returns the tools available to this agent
210216
func (a *Agent) Tools(ctx context.Context) ([]tools.Tool, error) {
211217
a.ensureToolSetsAreStarted(ctx)

pkg/agent/opts.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ func WithAddDescriptionParameter(addDescriptionParameter bool) Opt {
115115
}
116116
}
117117

118+
func WithToolChoice(toolChoice string) Opt {
119+
return func(a *Agent) {
120+
a.toolChoice = toolChoice
121+
}
122+
}
123+
118124
func WithAddPromptFiles(addPromptFiles []string) Opt {
119125
return func(a *Agent) {
120126
a.addPromptFiles = addPromptFiles

pkg/config/latest/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ type AgentConfig struct {
366366
AddEnvironmentInfo bool `json:"add_environment_info,omitempty"`
367367
CodeModeTools bool `json:"code_mode_tools,omitempty"`
368368
AddDescriptionParameter bool `json:"add_description_parameter,omitempty"`
369+
ToolChoice string `json:"tool_choice,omitempty"`
369370
MaxIterations int `json:"max_iterations,omitempty"`
370371
MaxConsecutiveToolCalls int `json:"max_consecutive_tool_calls,omitempty"`
371372
MaxOldToolCallTokens int `json:"max_old_tool_call_tokens,omitempty"`

pkg/config/latest/validate.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ func (t *Config) validate() error {
2323
return err
2424
}
2525

26+
// Validate tool_choice
27+
if err := agent.validateToolChoice(); err != nil {
28+
return err
29+
}
30+
2631
for j := range agent.Toolsets {
2732
if err := agent.Toolsets[j].validate(); err != nil {
2833
return err
@@ -38,6 +43,19 @@ func (t *Config) validate() error {
3843
return nil
3944
}
4045

46+
// validateToolChoice validates the tool_choice configuration for an agent
47+
func (a *AgentConfig) validateToolChoice() error {
48+
if a.ToolChoice == "" {
49+
return nil
50+
}
51+
switch a.ToolChoice {
52+
case "auto", "required", "none":
53+
return nil
54+
default:
55+
return errors.New("tool_choice must be one of: auto, required, none")
56+
}
57+
}
58+
4159
// validateFallback validates the fallback configuration for an agent
4260
func (a *AgentConfig) validateFallback() error {
4361
if a.Fallback == nil {

pkg/config/latest/validate_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,79 @@ agents:
115115
})
116116
}
117117
}
118+
119+
func TestAgentConfig_Validate_ToolChoice(t *testing.T) {
120+
t.Parallel()
121+
122+
tests := []struct {
123+
name string
124+
config string
125+
wantErr string
126+
}{
127+
{
128+
name: "valid tool_choice auto",
129+
config: `
130+
agents:
131+
root:
132+
model: "openai/gpt-4"
133+
tool_choice: auto
134+
`,
135+
wantErr: "",
136+
},
137+
{
138+
name: "valid tool_choice required",
139+
config: `
140+
agents:
141+
root:
142+
model: "openai/gpt-4"
143+
tool_choice: required
144+
`,
145+
wantErr: "",
146+
},
147+
{
148+
name: "valid tool_choice none",
149+
config: `
150+
agents:
151+
root:
152+
model: "openai/gpt-4"
153+
tool_choice: none
154+
`,
155+
wantErr: "",
156+
},
157+
{
158+
name: "no tool_choice set",
159+
config: `
160+
agents:
161+
root:
162+
model: "openai/gpt-4"
163+
`,
164+
wantErr: "",
165+
},
166+
{
167+
name: "invalid tool_choice value",
168+
config: `
169+
agents:
170+
root:
171+
model: "openai/gpt-4"
172+
tool_choice: force
173+
`,
174+
wantErr: "tool_choice must be one of: auto, required, none",
175+
},
176+
}
177+
178+
for _, tt := range tests {
179+
t.Run(tt.name, func(t *testing.T) {
180+
t.Parallel()
181+
182+
var cfg Config
183+
err := yaml.Unmarshal([]byte(tt.config), &cfg)
184+
185+
if tt.wantErr != "" {
186+
require.Error(t, err)
187+
require.Contains(t, err.Error(), tt.wantErr)
188+
} else {
189+
require.NoError(t, err)
190+
}
191+
})
192+
}
193+
}

pkg/model/provider/anthropic/beta_client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ func (c *Client) createBetaStream(
109109

110110
if len(requestTools) > 0 {
111111
slog.Debug("Anthropic Beta API: Adding tools to request", "tool_count", len(requestTools))
112+
113+
// Apply tool_choice from agent config
114+
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
115+
switch toolChoice {
116+
case "required":
117+
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAny: &anthropic.BetaToolChoiceAnyParam{}}
118+
case "none":
119+
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfNone: &anthropic.BetaToolChoiceNoneParam{}}
120+
default: // "auto" or any other value
121+
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAuto: &anthropic.BetaToolChoiceAutoParam{}}
122+
}
123+
slog.Debug("Anthropic Beta API request using tool_choice", "tool_choice", toolChoice)
124+
}
112125
}
113126

114127
slog.Debug("Anthropic Beta API chat completion stream request",

pkg/model/provider/anthropic/client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,19 @@ func (c *Client) CreateChatCompletionStream(
330330

331331
if len(requestTools) > 0 {
332332
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))
333+
334+
// Apply tool_choice from agent config
335+
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
336+
switch toolChoice {
337+
case "required":
338+
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}}
339+
case "none":
340+
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}}
341+
default: // "auto" or any other value
342+
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}}
343+
}
344+
slog.Debug("Anthropic request using tool_choice", "tool_choice", toolChoice)
345+
}
333346
}
334347

335348
// Log the request details for debugging

pkg/model/provider/bedrock/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
249249

250250
// Convert and set tools
251251
if len(requestTools) > 0 {
252-
input.ToolConfig = convertToolConfig(requestTools, enableCaching)
252+
input.ToolConfig = convertToolConfig(requestTools, enableCaching, c.ModelOptions.ToolChoice())
253253
}
254254

255255
return input

0 commit comments

Comments
 (0)