Skip to content

Commit 350a937

Browse files
authored
Merge pull request #910 from rumpl/wd
Working directory
2 parents abf25e1 + b9c3611 commit 350a937

File tree

9 files changed

+46
-41
lines changed

9 files changed

+46
-41
lines changed

examples/golibrary/builtintool/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"syscall"
1010

1111
"github.com/docker/cagent/pkg/agent"
12+
"github.com/docker/cagent/pkg/config"
1213
"github.com/docker/cagent/pkg/config/latest"
1314
"github.com/docker/cagent/pkg/environment"
1415
"github.com/docker/cagent/pkg/model/provider/openai"
@@ -46,7 +47,7 @@ func run(ctx context.Context) error {
4647
"root",
4748
"You are an expert hacker",
4849
agent.WithModel(llm),
49-
agent.WithToolSets(builtin.NewShellTool(os.Environ())),
50+
agent.WithToolSets(builtin.NewShellTool(os.Environ(), &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})),
5051
),
5152
),
5253
)

pkg/creator/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func CreateAgent(ctx context.Context, baseDir, prompt string, runConfig *config.
104104
agentBuilderInstructions,
105105
agent.WithModel(llm),
106106
agent.WithToolSets(
107-
builtin.NewShellTool(os.Environ()),
107+
builtin.NewShellTool(os.Environ(), runConfig),
108108
&fsToolset,
109109
),
110110
)))

pkg/server/server.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[
101101

102102
group := e.Group("/api")
103103

104-
// Health check endpoint
105-
group.GET("/ping", s.ping)
106104
// List all available agents
107105
group.GET("/agents", s.getAgents)
108106
// Get an agent by id
@@ -127,6 +125,9 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[
127125
group.POST("/agents/push", s.pushAgent)
128126
// Delete an agent by file path
129127
group.DELETE("/agents", s.deleteAgent)
128+
129+
// SESSIONS
130+
130131
// List all sessions
131132
group.GET("/sessions", s.getSessions)
132133
// Get sessions by agent filename
@@ -135,18 +136,20 @@ func New(sessionStore session.Store, runConfig *config.RuntimeConfig, teams map[
135136
group.GET("/sessions/:id", s.getSession)
136137
// Resume a session by id
137138
group.POST("/sessions/:id/resume", s.resumeSession)
138-
// Create a new session and run an agent loop
139+
// Create a new session
139140
group.POST("/sessions", s.createSession)
140141
// Delete a session
141142
group.DELETE("/sessions/:id", s.deleteSession)
142-
143143
// Run an agent loop
144144
group.POST("/sessions/:id/agent/:agent", s.runAgent)
145145
group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent)
146-
147146
group.POST("/sessions/:id/elicitation", s.elicitation)
148147

148+
// MISC
149+
149150
group.GET("/desktop/token", s.getDesktopToken)
151+
// Health check endpoint
152+
group.GET("/ping", s.ping)
150153

151154
return s, nil
152155
}

pkg/teamloader/registry.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir str
8080
var memoryPath string
8181
if filepath.IsAbs(toolset.Path) {
8282
memoryPath = ""
83-
} else if wd, err := os.Getwd(); err == nil {
83+
} else if wd := runtimeConfig.WorkingDir; wd != "" {
8484
memoryPath = wd
8585
} else {
8686
memoryPath = parentDir
@@ -112,7 +112,7 @@ func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir stri
112112
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
113113
}
114114
env = append(env, os.Environ()...)
115-
return builtin.NewShellTool(env), nil
115+
return builtin.NewShellTool(env, runtimeConfig), nil
116116
}
117117

118118
func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, runtimeConfig *config.RuntimeConfig) (tools.ToolSet, error) {
@@ -193,10 +193,10 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string
193193

194194
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
195195
if serverSpec.Type == "remote" {
196-
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil), nil
196+
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.WorkingDir), nil
197197
}
198198

199-
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, runtimeConfig.EnvProvider())
199+
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, runtimeConfig.EnvProvider(), runtimeConfig.WorkingDir)
200200
}
201201

202202
if toolset.Command != "" {
@@ -205,7 +205,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string
205205
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
206206
}
207207
env = append(env, os.Environ()...)
208-
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
208+
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env, runtimeConfig.WorkingDir), nil
209209
}
210210

211211
if toolset.Remote.URL != "" {
@@ -219,7 +219,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string
219219
headers[k] = expanded
220220
}
221221

222-
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers), nil
222+
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.WorkingDir), nil
223223
}
224224

225225
return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration")

pkg/tools/builtin/shell.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"time"
1515

1616
"github.com/docker/cagent/pkg/concurrent"
17+
"github.com/docker/cagent/pkg/config"
1718
"github.com/docker/cagent/pkg/tools"
1819
)
1920

@@ -38,6 +39,7 @@ type shellHandler struct {
3839
shellArgsPrefix []string
3940
env []string
4041
timeout time.Duration
42+
workingDir string
4143
jobs *concurrent.Map[string, *backgroundJob]
4244
jobCounter atomic.Int64
4345
}
@@ -149,12 +151,9 @@ func (h *shellHandler) RunShell(ctx context.Context, toolCall tools.ToolCall) (*
149151

150152
cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
151153
cmd.Env = h.env
154+
cmd.Dir = h.workingDir
152155
if params.Cwd != "" {
153156
cmd.Dir = params.Cwd
154-
} else {
155-
if wd, err := os.Getwd(); err == nil {
156-
cmd.Dir = wd
157-
}
158157
}
159158

160159
cmd.SysProcAttr = platformSpecificSysProcAttr()
@@ -232,12 +231,9 @@ func (h *shellHandler) RunShellBackground(ctx context.Context, toolCall tools.To
232231
// Setup command (no context - background jobs run independently)
233232
cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
234233
cmd.Env = h.env
234+
cmd.Dir = h.workingDir
235235
if params.Cwd != "" {
236236
cmd.Dir = params.Cwd
237-
} else {
238-
if wd, err := os.Getwd(); err == nil {
239-
cmd.Dir = wd
240-
}
241237
}
242238

243239
cmd.SysProcAttr = platformSpecificSysProcAttr()
@@ -421,7 +417,7 @@ func (h *shellHandler) StopBackgroundJob(_ context.Context, toolCall tools.ToolC
421417
}, nil
422418
}
423419

424-
func NewShellTool(env []string) *ShellTool {
420+
func NewShellTool(env []string, runtimeConfig *config.RuntimeConfig) *ShellTool {
425421
var shell string
426422
var argsPrefix []string
427423

@@ -458,6 +454,7 @@ func NewShellTool(env []string) *ShellTool {
458454
env: env,
459455
timeout: 30 * time.Second,
460456
jobs: concurrent.NewMap[string, *backgroundJob](),
457+
workingDir: runtimeConfig.WorkingDir,
461458
},
462459
}
463460
}

pkg/tools/builtin/shell_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,28 @@ import (
88
"github.com/stretchr/testify/assert"
99
"github.com/stretchr/testify/require"
1010

11+
"github.com/docker/cagent/pkg/config"
1112
"github.com/docker/cagent/pkg/tools"
1213
)
1314

1415
func TestNewShellTool(t *testing.T) {
1516
t.Setenv("SHELL", "/bin/bash")
16-
tool := NewShellTool(nil)
17+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
1718

1819
assert.NotNil(t, tool)
1920
assert.NotNil(t, tool.handler)
2021
assert.Equal(t, "/bin/bash", tool.handler.shell)
2122

2223
t.Setenv("SHELL", "")
23-
tool = NewShellTool(nil)
24+
tool = NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
2425

2526
assert.NotNil(t, tool)
2627
assert.NotNil(t, tool.handler)
2728
assert.Equal(t, "/bin/sh", tool.handler.shell, "Should default to /bin/sh when SHELL is not set")
2829
}
2930

3031
func TestShellTool_Tools(t *testing.T) {
31-
tool := NewShellTool(nil)
32+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
3233

3334
allTools, err := tool.Tools(t.Context())
3435

@@ -68,7 +69,7 @@ func TestShellTool_Tools(t *testing.T) {
6869
}
6970

7071
func TestShellTool_DisplayNames(t *testing.T) {
71-
tool := NewShellTool(nil)
72+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
7273

7374
all, err := tool.Tools(t.Context())
7475
require.NoError(t, err)
@@ -81,7 +82,7 @@ func TestShellTool_DisplayNames(t *testing.T) {
8182

8283
func TestShellTool_HandlerEcho(t *testing.T) {
8384
// This is a simple test that should work on most systems
84-
tool := NewShellTool(nil)
85+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
8586

8687
tls, err := tool.Tools(t.Context())
8788
require.NoError(t, err)
@@ -111,7 +112,7 @@ func TestShellTool_HandlerEcho(t *testing.T) {
111112

112113
func TestShellTool_HandlerWithCwd(t *testing.T) {
113114
// This test verifies the cwd parameter works
114-
tool := NewShellTool(nil)
115+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
115116

116117
tls, err := tool.Tools(t.Context())
117118
require.NoError(t, err)
@@ -145,7 +146,7 @@ func TestShellTool_HandlerWithCwd(t *testing.T) {
145146

146147
func TestShellTool_HandlerError(t *testing.T) {
147148
// This test verifies error handling
148-
tool := NewShellTool(nil)
149+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
149150

150151
tls, err := tool.Tools(t.Context())
151152
require.NoError(t, err)
@@ -174,7 +175,7 @@ func TestShellTool_HandlerError(t *testing.T) {
174175
}
175176

176177
func TestShellTool_InvalidArguments(t *testing.T) {
177-
tool := NewShellTool(nil)
178+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
178179

179180
tls, err := tool.Tools(t.Context())
180181
require.NoError(t, err)
@@ -196,7 +197,7 @@ func TestShellTool_InvalidArguments(t *testing.T) {
196197
}
197198

198199
func TestShellTool_StartStop(t *testing.T) {
199-
tool := NewShellTool(nil)
200+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
200201

201202
err := tool.Start(t.Context())
202203
require.NoError(t, err)
@@ -206,7 +207,7 @@ func TestShellTool_StartStop(t *testing.T) {
206207
}
207208

208209
func TestShellTool_OutputSchema(t *testing.T) {
209-
tool := NewShellTool(nil)
210+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
210211

211212
allTools, err := tool.Tools(t.Context())
212213
require.NoError(t, err)
@@ -218,7 +219,7 @@ func TestShellTool_OutputSchema(t *testing.T) {
218219
}
219220

220221
func TestShellTool_ParametersAreObjects(t *testing.T) {
221-
tool := NewShellTool(nil)
222+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
222223

223224
allTools, err := tool.Tools(t.Context())
224225
require.NoError(t, err)
@@ -234,7 +235,7 @@ func TestShellTool_ParametersAreObjects(t *testing.T) {
234235

235236
// Minimal tests for background job features
236237
func TestShellTool_RunBackgroundJob(t *testing.T) {
237-
tool := NewShellTool(nil)
238+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
238239
err := tool.Start(t.Context())
239240
require.NoError(t, err)
240241
defer func() { _ = tool.Stop(t.Context()) }()
@@ -273,7 +274,7 @@ func TestShellTool_RunBackgroundJob(t *testing.T) {
273274
}
274275

275276
func TestShellTool_ListBackgroundJobs(t *testing.T) {
276-
tool := NewShellTool(nil)
277+
tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})
277278
err := tool.Start(t.Context())
278279
require.NoError(t, err)
279280
defer func() { _ = tool.Stop(t.Context()) }()

pkg/tools/mcp/gateway.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type GatewayToolset struct {
2222

2323
var _ tools.ToolSet = (*GatewayToolset)(nil)
2424

25-
func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, envProvider environment.Provider) (*GatewayToolset, error) {
25+
func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, envProvider environment.Provider, cwd string) (*GatewayToolset, error) {
2626
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName)
2727

2828
// Check which secrets (env vars) are required by the MCP server.
@@ -55,7 +55,7 @@ func NewGatewayToolset(ctx context.Context, mcpServerName string, config any, en
5555
}
5656

5757
return &GatewayToolset{
58-
cmdToolset: NewToolsetCommand("docker", args, nil),
58+
cmdToolset: NewToolsetCommand("docker", args, nil, cwd),
5959
cleanUp: func() error {
6060
return errors.Join(os.Remove(fileSecrets), os.Remove(fileConfig))
6161
},

pkg/tools/mcp/mcp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ type Toolset struct {
3737
var _ tools.ToolSet = (*Toolset)(nil)
3838

3939
// NewToolsetCommand creates a new MCP toolset from a command.
40-
func NewToolsetCommand(command string, args, env []string) *Toolset {
40+
func NewToolsetCommand(command string, args, env []string, cwd string) *Toolset {
4141
slog.Debug("Creating Stdio MCP toolset", "command", command, "args", args)
4242

4343
return &Toolset{
44-
mcpClient: newStdioCmdClient(command, args, env),
44+
mcpClient: newStdioCmdClient(command, args, env, cwd),
4545
logID: command,
4646
}
4747
}
4848

4949
// NewRemoteToolset creates a new MCP toolset from a remote MCP Server.
50-
func NewRemoteToolset(url, transport string, headers map[string]string) *Toolset {
50+
func NewRemoteToolset(url, transport string, headers map[string]string, cwd string) *Toolset {
5151
slog.Debug("Creating Remote MCP toolset", "url", url, "transport", transport, "headers", headers)
5252

5353
return &Toolset{

pkg/tools/mcp/stdio.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ type stdioMCPClient struct {
1818
args []string
1919
env []string
2020
session *mcp.ClientSession
21+
cwd string
2122
}
2223

23-
func newStdioCmdClient(command string, args, env []string) *stdioMCPClient {
24+
func newStdioCmdClient(command string, args, env []string, cwd string) *stdioMCPClient {
2425
return &stdioMCPClient{
2526
command: command,
2627
args: args,
2728
env: env,
29+
cwd: cwd,
2830
}
2931
}
3032

@@ -41,6 +43,7 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReques
4143

4244
cmd := exec.CommandContext(ctx, c.command, c.args...)
4345
cmd.Env = c.env
46+
cmd.Dir = c.cwd
4447
session, err := client.Connect(ctx, &mcp.CommandTransport{
4548
Command: cmd,
4649
}, nil)

0 commit comments

Comments
 (0)