Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/golibrary/builtintool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"syscall"

"github.com/docker/cagent/pkg/agent"
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/environment"
"github.com/docker/cagent/pkg/model/provider/openai"
Expand Down Expand Up @@ -46,7 +47,7 @@ func run(ctx context.Context) error {
"root",
"You are an expert hacker",
agent.WithModel(llm),
agent.WithToolSets(builtin.NewShellTool(os.Environ())),
agent.WithToolSets(builtin.NewShellTool(os.Environ(), &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})),
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion pkg/creator/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func CreateAgent(ctx context.Context, baseDir, prompt string, runConfig *config.
agentBuilderInstructions,
agent.WithModel(llm),
agent.WithToolSets(
builtin.NewShellTool(os.Environ()),
builtin.NewShellTool(os.Environ(), runConfig),
&fsToolset,
),
)))
Expand Down
13 changes: 8 additions & 5 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[

group := e.Group("/api")

// Health check endpoint
group.GET("/ping", s.ping)
// List all available agents
group.GET("/agents", s.getAgents)
// Get an agent by id
Expand All @@ -127,6 +125,9 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[
group.POST("/agents/push", s.pushAgent)
// Delete an agent by file path
group.DELETE("/agents", s.deleteAgent)

// SESSIONS

// List all sessions
group.GET("/sessions", s.getSessions)
// Get sessions by agent filename
Expand All @@ -135,18 +136,20 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[
group.GET("/sessions/:id", s.getSession)
// Resume a session by id
group.POST("/sessions/:id/resume", s.resumeSession)
// Create a new session and run an agent loop
// Create a new session
group.POST("/sessions", s.createSession)
// Delete a session
group.DELETE("/sessions/:id", s.deleteSession)

// Run an agent loop
group.POST("/sessions/:id/agent/:agent", s.runAgent)
group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent)

group.POST("/sessions/:id/elicitation", s.elicitation)

// MISC

group.GET("/desktop/token", s.getDesktopToken)
// Health check endpoint
group.GET("/ping", s.ping)

return s, nil
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/teamloader/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir str
var memoryPath string
if filepath.IsAbs(toolset.Path) {
memoryPath = ""
} else if wd, err := os.Getwd(); err == nil {
} else if wd := runtimeConfig.WorkingDir; wd != "" {
memoryPath = wd
} else {
memoryPath = parentDir
Expand Down Expand Up @@ -112,7 +112,7 @@ func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir stri
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
}
env = append(env, os.Environ()...)
return builtin.NewShellTool(env), nil
return builtin.NewShellTool(env, runtimeConfig), nil
}

func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, runtimeConfig *config.RuntimeConfig) (tools.ToolSet, error) {
Expand Down Expand Up @@ -193,10 +193,10 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string

// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
if serverSpec.Type == "remote" {
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil), nil
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.WorkingDir), nil
}

return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, runtimeConfig.EnvProvider())
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, runtimeConfig.EnvProvider(), runtimeConfig.WorkingDir)
}

if toolset.Command != "" {
Expand All @@ -205,7 +205,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
}
env = append(env, os.Environ()...)
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env, runtimeConfig.WorkingDir), nil
}

if toolset.Remote.URL != "" {
Expand All @@ -219,7 +219,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string
headers[k] = expanded
}

return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers), nil
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.WorkingDir), nil
}

return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration")
Expand Down
15 changes: 6 additions & 9 deletions pkg/tools/builtin/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/docker/cagent/pkg/concurrent"
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/tools"
)

Expand All @@ -38,6 +39,7 @@ type shellHandler struct {
shellArgsPrefix []string
env []string
timeout time.Duration
workingDir string
jobs *concurrent.Map[string, *backgroundJob]
jobCounter atomic.Int64
}
Expand Down Expand Up @@ -149,12 +151,9 @@ func (h *shellHandler) RunShell(ctx context.Context, toolCall tools.ToolCall) (*

cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
cmd.Env = h.env
cmd.Dir = h.workingDir
if params.Cwd != "" {
cmd.Dir = params.Cwd
} else {
if wd, err := os.Getwd(); err == nil {
cmd.Dir = wd
}
}

cmd.SysProcAttr = platformSpecificSysProcAttr()
Expand Down Expand Up @@ -232,12 +231,9 @@ func (h *shellHandler) RunShellBackground(ctx context.Context, toolCall tools.To
// Setup command (no context - background jobs run independently)
cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
cmd.Env = h.env
cmd.Dir = h.workingDir
if params.Cwd != "" {
cmd.Dir = params.Cwd
} else {
if wd, err := os.Getwd(); err == nil {
cmd.Dir = wd
}
}

cmd.SysProcAttr = platformSpecificSysProcAttr()
Expand Down Expand Up @@ -421,7 +417,7 @@ func (h *shellHandler) StopBackgroundJob(_ context.Context, toolCall tools.ToolC
}, nil
}

func NewShellTool(env []string) *ShellTool {
func NewShellTool(env []string, runtimeConfig *config.RuntimeConfig) *ShellTool {
var shell string
var argsPrefix []string

Expand Down Expand Up @@ -458,6 +454,7 @@ func NewShellTool(env []string) *ShellTool {
env: env,
timeout: 30 * time.Second,
jobs: concurrent.NewMap[string, *backgroundJob](),
workingDir: runtimeConfig.WorkingDir,
},
}
}
Expand Down
27 changes: 14 additions & 13 deletions pkg/tools/builtin/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/tools"
)

func TestNewShellTool(t *testing.T) {
t.Setenv("SHELL", "/bin/bash")
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

assert.NotNil(t, tool)
assert.NotNil(t, tool.handler)
assert.Equal(t, "/bin/bash", tool.handler.shell)

t.Setenv("SHELL", "")
tool = NewShellTool(nil)
tool = NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

assert.NotNil(t, tool)
assert.NotNil(t, tool.handler)
assert.Equal(t, "/bin/sh", tool.handler.shell, "Should default to /bin/sh when SHELL is not set")
}

func TestShellTool_Tools(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

allTools, err := tool.Tools(t.Context())

Expand Down Expand Up @@ -68,7 +69,7 @@ func TestShellTool_Tools(t *testing.T) {
}

func TestShellTool_DisplayNames(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

all, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand All @@ -81,7 +82,7 @@ func TestShellTool_DisplayNames(t *testing.T) {

func TestShellTool_HandlerEcho(t *testing.T) {
// This is a simple test that should work on most systems
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

tls, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand Down Expand Up @@ -111,7 +112,7 @@ func TestShellTool_HandlerEcho(t *testing.T) {

func TestShellTool_HandlerWithCwd(t *testing.T) {
// This test verifies the cwd parameter works
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

tls, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand Down Expand Up @@ -145,7 +146,7 @@ func TestShellTool_HandlerWithCwd(t *testing.T) {

func TestShellTool_HandlerError(t *testing.T) {
// This test verifies error handling
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

tls, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand Down Expand Up @@ -174,7 +175,7 @@ func TestShellTool_HandlerError(t *testing.T) {
}

func TestShellTool_InvalidArguments(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

tls, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand All @@ -196,7 +197,7 @@ func TestShellTool_InvalidArguments(t *testing.T) {
}

func TestShellTool_StartStop(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

err := tool.Start(t.Context())
require.NoError(t, err)
Expand All @@ -206,7 +207,7 @@ func TestShellTool_StartStop(t *testing.T) {
}

func TestShellTool_OutputSchema(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand All @@ -218,7 +219,7 @@ func TestShellTool_OutputSchema(t *testing.T) {
}

func TestShellTool_ParametersAreObjects(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})

allTools, err := tool.Tools(t.Context())
require.NoError(t, err)
Expand All @@ -234,7 +235,7 @@ func TestShellTool_ParametersAreObjects(t *testing.T) {

// Minimal tests for background job features
func TestShellTool_RunBackgroundJob(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
err := tool.Start(t.Context())
require.NoError(t, err)
defer func() { _ = tool.Stop(t.Context()) }()
Expand Down Expand Up @@ -273,7 +274,7 @@ func TestShellTool_RunBackgroundJob(t *testing.T) {
}

func TestShellTool_ListBackgroundJobs(t *testing.T) {
tool := NewShellTool(nil)
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
err := tool.Start(t.Context())
require.NoError(t, err)
defer func() { _ = tool.Stop(t.Context()) }()
Expand Down
4 changes: 2 additions & 2 deletions pkg/tools/mcp/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type GatewayToolset struct {

var _ tools.ToolSet = (*GatewayToolset)(nil)

func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, envProvider environment.Provider) (*GatewayToolset, error) {
func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName)

// Check which secrets (env vars) are required by the MCP server.
Expand Down Expand Up @@ -55,7 +55,7 @@ func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, en
}

return &GatewayToolset{
cmdToolset: NewToolsetCommand("docker", args, nil),
cmdToolset: NewToolsetCommand("docker", args, nil, cwd),
cleanUp: func() error {
return errors.Join(os.Remove(fileSecrets), os.Remove(fileConfig))
},
Expand Down
6 changes: 3 additions & 3 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ type Toolset struct {
var _ tools.ToolSet = (*Toolset)(nil)

// NewToolsetCommand creates a new MCP toolset from a command.
func NewToolsetCommand(command string, args, env []string) *Toolset {
func NewToolsetCommand(command string, args, env []string, cwd string) *Toolset {
slog.Debug("Creating Stdio MCP toolset", "command", command, "args", args)

return &Toolset{
mcpClient: newStdioCmdClient(command, args, env),
mcpClient: newStdioCmdClient(command, args, env, cwd),
logID: command,
}
}

// NewRemoteToolset creates a new MCP toolset from a remote MCP Server.
func NewRemoteToolset(url, transport string, headers map[string]string) *Toolset {
func NewRemoteToolset(url, transport string, headers map[string]string, cwd string) *Toolset {
slog.Debug("Creating Remote MCP toolset", "url", url, "transport", transport, "headers", headers)

return &Toolset{
Expand Down
5 changes: 4 additions & 1 deletion pkg/tools/mcp/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ type stdioMCPClient struct {
args []string
env []string
session *mcp.ClientSession
cwd string
}

func newStdioCmdClient(command string, args, env []string) *stdioMCPClient {
func newStdioCmdClient(command string, args, env []string, cwd string) *stdioMCPClient {
return &stdioMCPClient{
command: command,
args: args,
env: env,
cwd: cwd,
}
}

Expand All @@ -41,6 +43,7 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReques

cmd := exec.CommandContext(ctx, c.command, c.args...)
cmd.Env = c.env
cmd.Dir = c.cwd
session, err := client.Connect(ctx, &mcp.CommandTransport{
Command: cmd,
}, nil)
Expand Down
Loading