Skip to content

Commit 466d846

Browse files
authored
Merge pull request #1937 from dgageot/selector
Add model_picker toolset for dynamic model switching
2 parents ffe68e4 + 0e5f347 commit 466d846

File tree

8 files changed

+425
-34
lines changed

8 files changed

+425
-34
lines changed

agent-schema.json

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,8 @@
704704
"a2a",
705705
"lsp",
706706
"user_prompt",
707-
"openapi"
707+
"openapi",
708+
"model_picker"
708709
]
709710
},
710711
"instruction": {
@@ -840,6 +841,13 @@
840841
"items": {
841842
"type": "string"
842843
}
844+
},
845+
"models": {
846+
"type": "array",
847+
"description": "List of allowed models for the model_picker tool.",
848+
"items": {
849+
"type": "string"
850+
}
843851
}
844852
},
845853
"additionalProperties": false,
@@ -890,7 +898,8 @@
890898
"api",
891899
"a2a",
892900
"lsp",
893-
"user_prompt"
901+
"user_prompt",
902+
"model_picker"
894903
]
895904
}
896905
}
@@ -958,6 +967,22 @@
958967
]
959968
}
960969
]
970+
},
971+
{
972+
"allOf": [
973+
{
974+
"properties": {
975+
"type": {
976+
"const": "model_picker"
977+
}
978+
}
979+
},
980+
{
981+
"required": [
982+
"models"
983+
]
984+
}
985+
]
961986
}
962987
]
963988
},

examples/model_picker.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env docker agent run
2+
3+
# This example demonstrates the model_picker toolset, which lets the agent
4+
# dynamically switch between models mid-conversation. The agent can pick the
5+
# best model for each sub-task (e.g. a fast model for simple questions, a
6+
# powerful one for complex reasoning) and revert back when done.
7+
8+
agents:
9+
root:
10+
model: google/gemini-2.5-flash-lite
11+
description: A versatile assistant that picks the best model for each task
12+
instruction: |
13+
You are a helpful assistant with access to multiple AI models.
14+
toolsets:
15+
- type: filesystem
16+
- type: shell
17+
- type: model_picker
18+
instruction: |
19+
{ORIGINAL_INSTRUCTIONS}
20+
21+
## Model selection policy
22+
23+
Your default model (`gemini-2.5-flash-lite`) is fast and cheap but
24+
limited. You MUST follow this policy for every user message:
25+
26+
1. **Classify first.** Decide whether the request is *trivial*
27+
(greetings, single-fact lookups, yes/no answers, short
28+
clarifications) or *non-trivial* (anything else: writing, coding,
29+
analysis, planning, multi-step reasoning, tool use, etc.).
30+
31+
2. **Trivial → stay on the default model.** Answer directly.
32+
33+
3. **Non-trivial → switch before you do any work.**
34+
Call `change_model` to `claude-haiku-4-5` as the very first action,
35+
*before* reasoning, planning, or calling any other tool.
36+
Then carry out the task.
37+
38+
4. **ALWAYS revert when done.** After completing a non-trivial task,
39+
you MUST call `revert_model` as your very last action so the next
40+
turn starts on the cheap default again. This is mandatory—treat
41+
it as the final step of every non-trivial request. Never end your
42+
turn on a non-default model.
43+
44+
**Important:** never start working on a non-trivial task while still
45+
on the default model. When in doubt, switch.
46+
models:
47+
- google/gemini-2.5-flash-lite
48+
- anthropic/claude-haiku-4-5

pkg/config/latest/types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ type Toolset struct {
575575

576576
// For the `fetch` tool
577577
Timeout int `json:"timeout,omitempty"`
578+
579+
// For the `model_picker` tool
580+
Models []string `json:"models,omitempty"`
578581
}
579582

580583
func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error {

pkg/config/latest/validate.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ func (t *Toolset) validate() error {
7676
if len(t.FileTypes) > 0 && t.Type != "lsp" {
7777
return errors.New("file_types can only be used with type 'lsp'")
7878
}
79+
if len(t.Models) > 0 && t.Type != "model_picker" {
80+
return errors.New("models can only be used with type 'model_picker'")
81+
}
7982
if t.Sandbox != nil && t.Type != "shell" {
8083
return errors.New("sandbox can only be used with type 'shell'")
8184
}
@@ -154,6 +157,10 @@ func (t *Toolset) validate() error {
154157
if t.URL == "" {
155158
return errors.New("openapi toolset requires a url to be set")
156159
}
160+
case "model_picker":
161+
if len(t.Models) == 0 {
162+
return errors.New("model_picker toolset requires at least one model in the 'models' list")
163+
}
157164
}
158165

159166
return nil

pkg/runtime/runtime.go

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ func ResumeReject(reason string) ResumeRequest {
9696
// ToolHandlerFunc is a function type for handling tool calls
9797
type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error)
9898

99-
type ToolHandler struct {
100-
handler ToolHandlerFunc
101-
tool tools.Tool
102-
}
103-
10499
// ElicitationRequestHandler is a function type for handling elicitation requests
105100
type ElicitationRequestHandler func(ctx context.Context, message string, schema map[string]any) (map[string]any, error)
106101

@@ -196,7 +191,7 @@ type ToolsChangeSubscriber interface {
196191

197192
// LocalRuntime manages the execution of agents
198193
type LocalRuntime struct {
199-
toolMap map[string]ToolHandler
194+
toolMap map[string]ToolHandlerFunc
200195
team *team.Team
201196
currentAgent string
202197
resumeChan chan ResumeRequest
@@ -297,7 +292,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
297292
}
298293

299294
r := &LocalRuntime{
300-
toolMap: make(map[string]ToolHandler),
295+
toolMap: make(map[string]ToolHandlerFunc),
301296
team: agents,
302297
currentAgent: defaultAgent.Name(),
303298
resumeChan: make(chan ResumeRequest),
@@ -909,30 +904,14 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen
909904
send(ToolsetInfo(totalTools, false, r.currentAgent))
910905
}
911906

912-
// registerDefaultTools registers the default tool handlers
907+
// registerDefaultTools registers the runtime-managed tool handlers.
908+
// The tool definitions themselves come from the agent's toolsets; this only
909+
// maps tool names to the runtime handler functions that implement them.
913910
func (r *LocalRuntime) registerDefaultTools() {
914-
slog.Debug("Registering default tools")
915-
916-
tt := builtin.NewTransferTaskTool()
917-
ht := builtin.NewHandoffTool()
918-
ttTools, _ := tt.Tools(context.TODO())
919-
htTools, _ := ht.Tools(context.TODO())
920-
allTools := append(ttTools, htTools...)
921-
922-
handlers := map[string]ToolHandlerFunc{
923-
builtin.ToolNameTransferTask: r.handleTaskTransfer,
924-
builtin.ToolNameHandoff: r.handleHandoff,
925-
}
926-
927-
for _, t := range allTools {
928-
if h, exists := handlers[t.Name]; exists {
929-
r.toolMap[t.Name] = ToolHandler{handler: h, tool: t}
930-
} else {
931-
slog.Warn("No handler found for default tool", "tool", t.Name)
932-
}
933-
}
934-
935-
slog.Debug("Registered default tools", "count", len(r.toolMap))
911+
r.toolMap[builtin.ToolNameTransferTask] = r.handleTaskTransfer
912+
r.toolMap[builtin.ToolNameHandoff] = r.handleHandoff
913+
r.toolMap[builtin.ToolNameChangeModel] = r.handleChangeModel
914+
r.toolMap[builtin.ToolNameRevertModel] = r.handleRevertModel
936915
}
937916

938917
func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.Session, events chan Event) {
@@ -1579,8 +1558,8 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi
15791558
// Pick the handler: runtime-managed tools (transfer_task, handoff)
15801559
// have dedicated handlers; everything else goes through the toolset.
15811560
var runTool func()
1582-
if def, exists := r.toolMap[toolCall.Function.Name]; exists {
1583-
runTool = func() { r.runAgentTool(callCtx, def.handler, sess, toolCall, tool, events, a) }
1561+
if handler, exists := r.toolMap[toolCall.Function.Name]; exists {
1562+
runTool = func() { r.runAgentTool(callCtx, handler, sess, toolCall, tool, events, a) }
15841563
} else {
15851564
runTool = func() { r.runTool(callCtx, tool, toolCall, events, sess, a) }
15861565
}
@@ -2089,6 +2068,75 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool
20892068
return tools.ResultSuccess(handoffMessage), nil
20902069
}
20912070

2071+
// findModelPickerTool returns the ModelPickerTool from the current agent's
2072+
// toolsets, or nil if the agent has no model_picker configured.
2073+
func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool {
2074+
a, err := r.team.Agent(r.currentAgent)
2075+
if err != nil {
2076+
return nil
2077+
}
2078+
for _, ts := range a.ToolSets() {
2079+
if mpt, ok := tools.As[*builtin.ModelPickerTool](ts); ok {
2080+
return mpt
2081+
}
2082+
}
2083+
return nil
2084+
}
2085+
2086+
// handleChangeModel handles the change_model tool call by switching the current agent's model.
2087+
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
2088+
var params builtin.ChangeModelArgs
2089+
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &params); err != nil {
2090+
return nil, fmt.Errorf("invalid arguments: %w", err)
2091+
}
2092+
2093+
if params.Model == "" {
2094+
return tools.ResultError("model parameter is required"), nil
2095+
}
2096+
2097+
// Validate the requested model against the allowed list
2098+
mpt := r.findModelPickerTool()
2099+
if mpt == nil {
2100+
return tools.ResultError("model_picker is not configured for this agent"), nil
2101+
}
2102+
allowed := mpt.AllowedModels()
2103+
if !slices.Contains(allowed, params.Model) {
2104+
return tools.ResultError(fmt.Sprintf(
2105+
"model %q is not in the allowed list. Available models: %s",
2106+
params.Model, strings.Join(allowed, ", "),
2107+
)), nil
2108+
}
2109+
2110+
return r.setModelAndEmitInfo(ctx, params.Model, events)
2111+
}
2112+
2113+
// handleRevertModel handles the revert_model tool call by reverting the current agent to its default model.
2114+
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
2115+
return r.setModelAndEmitInfo(ctx, "", events)
2116+
}
2117+
2118+
// setModelAndEmitInfo sets the model for the current agent and emits an updated
2119+
// AgentInfo event so the UI reflects the change. An empty modelRef reverts to
2120+
// the agent's default model.
2121+
func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) {
2122+
if err := r.SetAgentModel(ctx, r.currentAgent, modelRef); err != nil {
2123+
return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil
2124+
}
2125+
2126+
if a, err := r.team.Agent(r.currentAgent); err == nil {
2127+
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
2128+
} else {
2129+
slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", r.currentAgent, "error", err)
2130+
}
2131+
2132+
if modelRef == "" {
2133+
slog.Info("Model reverted via model_picker tool", "agent", r.currentAgent)
2134+
return tools.ResultSuccess("Model reverted to the agent's default model"), nil
2135+
}
2136+
slog.Info("Model changed via model_picker tool", "agent", r.currentAgent, "model", modelRef)
2137+
return tools.ResultSuccess(fmt.Sprintf("Model changed to %s", modelRef)), nil
2138+
}
2139+
20922140
// Summarize generates a summary for the session based on the conversation history.
20932141
// The additionalPrompt parameter allows users to provide additional instructions
20942142
// for the summarization (e.g., "focus on code changes" or "include action items").

pkg/teamloader/registry.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry {
7272
r.Register("lsp", createLSPTool)
7373
r.Register("user_prompt", createUserPromptTool)
7474
r.Register("openapi", createOpenAPITool)
75+
r.Register("model_picker", createModelPickerTool)
7576
return r
7677
}
7778

@@ -327,3 +328,10 @@ func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, ru
327328

328329
return builtin.NewOpenAPITool(specURL, headers), nil
329330
}
331+
332+
func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) {
333+
if len(toolset.Models) == 0 {
334+
return nil, fmt.Errorf("model_picker toolset requires at least one model")
335+
}
336+
return builtin.NewModelPickerTool(toolset.Models), nil
337+
}

pkg/tools/builtin/model_picker.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package builtin
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/docker/cagent/pkg/tools"
9+
)
10+
11+
const (
12+
ToolNameChangeModel = "change_model"
13+
ToolNameRevertModel = "revert_model"
14+
)
15+
16+
// ModelPickerTool provides tools for dynamically switching the agent's model mid-conversation.
17+
type ModelPickerTool struct {
18+
models []string // list of available model references
19+
}
20+
21+
// Verify interface compliance
22+
var (
23+
_ tools.ToolSet = (*ModelPickerTool)(nil)
24+
_ tools.Instructable = (*ModelPickerTool)(nil)
25+
)
26+
27+
// ChangeModelArgs are the arguments for the change_model tool.
28+
type ChangeModelArgs struct {
29+
Model string `json:"model" jsonschema:"The model to switch to. Must be one of the available models."`
30+
}
31+
32+
// NewModelPickerTool creates a new ModelPickerTool with the given list of allowed models.
33+
func NewModelPickerTool(models []string) *ModelPickerTool {
34+
return &ModelPickerTool{models: models}
35+
}
36+
37+
// Instructions returns guidance for the LLM on when and how to use the model picker tools.
38+
func (t *ModelPickerTool) Instructions() string {
39+
return "## Model Switching\n\n" +
40+
"You have access to multiple models and can switch between them mid-conversation " +
41+
"using the `" + ToolNameChangeModel + "` and `" + ToolNameRevertModel + "` tools.\n\n" +
42+
"Available models: " + strings.Join(t.models, ", ") + ".\n\n" +
43+
"Use `" + ToolNameChangeModel + "` when the current task would benefit from a different model's strengths " +
44+
"(e.g., switching to a faster model for simple tasks or a more capable model for complex reasoning).\n" +
45+
"Use `" + ToolNameRevertModel + "` to return to the original model after the specialized task is complete."
46+
}
47+
48+
// AllowedModels returns the list of models this tool allows switching to.
49+
func (t *ModelPickerTool) AllowedModels() []string {
50+
return t.models
51+
}
52+
53+
// Tools returns the change_model and revert_model tool definitions.
54+
func (t *ModelPickerTool) Tools(context.Context) ([]tools.Tool, error) {
55+
return []tools.Tool{
56+
{
57+
Name: ToolNameChangeModel,
58+
Category: "model",
59+
Description: fmt.Sprintf(
60+
"Change the current model to one of the available models: %s. "+
61+
"Use this when you need a different model for the current task.",
62+
strings.Join(t.models, ", "),
63+
),
64+
Parameters: tools.MustSchemaFor[ChangeModelArgs](),
65+
Annotations: tools.ToolAnnotations{
66+
ReadOnlyHint: true,
67+
Title: "Change Model",
68+
},
69+
},
70+
{
71+
Name: ToolNameRevertModel,
72+
Category: "model",
73+
Description: "Revert to the agent's original/default model. " +
74+
"Use this after completing a task that required a different model.",
75+
Annotations: tools.ToolAnnotations{
76+
ReadOnlyHint: true,
77+
Title: "Revert Model",
78+
},
79+
},
80+
}, nil
81+
}

0 commit comments

Comments
 (0)