Skip to content

Commit c348c34

Browse files
authored
Merge pull request #13 from peg/staging
Security hardening: token exposure, resource limits, WebSocket deadlines
2 parents 95b585c + 4fd4ef0 commit c348c34

File tree

9 files changed

+163
-13
lines changed

9 files changed

+163
-13
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ rules:
414414
415415
The webhook receives the full tool call context and returns `{"decision": "allow"}` or `{"decision": "deny", "reason": "..."}`. Fail-open by default so a down webhook doesn't break your agent.
416416

417+
**Reference implementation**: See [`rampart-verify`](https://github.com/peg/rampart-verify) — an optional sidecar that uses LLMs (gpt-4o-mini, Claude Haiku, or local Ollama) to classify ambiguous commands. Pattern matching handles 95% of decisions for free; the sidecar reviews the rest at ~$0.0001/call.
418+
417419
---
418420

419421
## Integration

cmd/rampart/cli/daemon.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Example:
7070
}
7171

7272
// Ensure audit directory exists.
73-
if err := os.MkdirAll(auditDir, 0o755); err != nil {
73+
if err := os.MkdirAll(auditDir, 0o700); err != nil {
7474
return fmt.Errorf("daemon: create audit dir: %w", err)
7575
}
7676

cmd/rampart/cli/wrap.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func newWrapCmd(opts *rootOptions, deps *wrapDeps) *cobra.Command {
9595
}
9696
auditDir = filepath.Join(home, ".rampart", "audit")
9797
}
98-
if err := os.MkdirAll(auditDir, 0o755); err != nil {
98+
if err := os.MkdirAll(auditDir, 0o700); err != nil {
9999
return fmt.Errorf("wrap: create audit dir %s: %w", auditDir, err)
100100
}
101101

@@ -167,6 +167,7 @@ func newWrapCmd(opts *rootOptions, deps *wrapDeps) *cobra.Command {
167167
}
168168
defer func() {
169169
_ = os.Remove(shimPath)
170+
_ = os.Remove(shimPath + ".tok")
170171
}()
171172

172173
child := exec.Command(args[0], args[1:]...)
@@ -185,7 +186,7 @@ func newWrapCmd(opts *rootOptions, deps *wrapDeps) *cobra.Command {
185186
childEnv = append(childEnv, e)
186187
}
187188
// Create PATH-based shell wrappers for agents that ignore $SHELL
188-
shimDir, err := createShellWrappers(proxyURL, proxyServer.Token(), mode)
189+
shimDir, err := createShellWrappers(proxyURL, shimPath+".tok", mode)
189190
if err != nil {
190191
return fmt.Errorf("wrap: create shell wrappers: %w", err)
191192
}
@@ -355,11 +356,20 @@ func createShellShim(proxyURL, token, mode, realShell string) (string, error) {
355356
return "", fmt.Errorf("wrap: create shell shim: %w", err)
356357
}
357358

359+
// Write token to a separate file (0600) so it's not visible in the shim
360+
// script or /proc/*/cmdline.
361+
tokenFile := tmp.Name() + ".tok"
362+
if err := os.WriteFile(tokenFile, []byte(token), 0o600); err != nil {
363+
_ = tmp.Close()
364+
_ = os.Remove(tmp.Name())
365+
return "", fmt.Errorf("wrap: write token file: %w", err)
366+
}
367+
358368
script := fmt.Sprintf(`#!/usr/bin/env bash
359369
# Rampart shell shim - auto-generated.
360370
REAL_SHELL=%q
361371
RAMPART_URL=%q
362-
RAMPART_TOKEN=%q
372+
RAMPART_TOKEN=$(cat %q 2>/dev/null)
363373
RAMPART_MODE=%q
364374
365375
# Parse shell flags — collect flags before -c, extract the command after -c.
@@ -420,7 +430,7 @@ if [ "$FOUND_C" = "true" ]; then
420430
fi
421431
422432
exec "$REAL_SHELL" $ORIG_ARGS
423-
`, realShell, proxyURL, token, mode)
433+
`, realShell, proxyURL, tokenFile, mode)
424434

425435
if _, err := io.WriteString(tmp, script); err != nil {
426436
_ = tmp.Close()
@@ -446,11 +456,16 @@ exec "$REAL_SHELL" $ORIG_ARGS
446456
// through Rampart policy before executing. If not set, passes straight through
447457
// to the real shell. This catches agents that hardcode /bin/bash or /bin/zsh
448458
// instead of reading $SHELL.
449-
func createShellWrappers(proxyURL, token, mode string) (string, error) {
459+
func createShellWrappers(proxyURL, tokenFile, mode string) (string, error) {
450460
dir, err := os.MkdirTemp("", "rampart-shells-*")
451461
if err != nil {
452462
return "", fmt.Errorf("create shell wrapper dir: %w", err)
453463
}
464+
// Restrict directory permissions — contains scripts with token file paths.
465+
if err := os.Chmod(dir, 0o700); err != nil {
466+
_ = os.RemoveAll(dir)
467+
return "", fmt.Errorf("chmod shell wrapper dir: %w", err)
468+
}
454469

455470
shells := []struct {
456471
name string
@@ -496,8 +511,9 @@ if [ "$FOUND_C" = "true" ]; then
496511
497512
ENCODED=$(printf '%%s' "$CMD" | base64 | tr -d '\n\r')
498513
PAYLOAD=$(printf '{"agent":"wrapped","session":"wrap","params":{"command_b64":"%%s"}}' "$ENCODED")
514+
RAMPART_TOKEN=$(cat %q 2>/dev/null)
499515
DECISION=$(curl -sfS -X POST "%s/v1/preflight/exec" \
500-
-H "Authorization: Bearer %s" \
516+
-H "Authorization: Bearer ${RAMPART_TOKEN}" \
501517
-H "Content-Type: application/json" \
502518
-d "$PAYLOAD" 2>/dev/null)
503519
@@ -520,7 +536,7 @@ fi
520536
521537
# Interactive or non -c usage — pass through directly
522538
exec "$REAL" $ORIG_ARGS
523-
`, s.name, s.realPath, proxyURL, token, mode)
539+
`, s.name, s.realPath, tokenFile, proxyURL, mode)
524540

525541
wrapperPath := filepath.Join(dir, s.name)
526542
if err := os.WriteFile(wrapperPath, []byte(script), 0o755); err != nil {

internal/approval/store.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ func (s *Store) watchExpiry(req *Request) {
253253
select {
254254
case <-req.done:
255255
return // Already resolved.
256+
case <-s.stop:
257+
return // Store is shutting down.
256258
case <-timer.C:
257259
s.mu.Lock()
258260
if req.Status == StatusPending {

internal/daemon/daemon.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,37 @@ func (d *Daemon) connectAndListen(ctx context.Context) error {
142142
d.seq.Store(0)
143143
d.mu.Unlock()
144144

145+
// Set up ping/pong to detect dead connections.
146+
const pongWait = 90 * time.Second
147+
const pingInterval = 30 * time.Second
148+
conn.SetReadDeadline(time.Now().Add(pongWait))
149+
conn.SetPongHandler(func(string) error {
150+
conn.SetReadDeadline(time.Now().Add(pongWait))
151+
return nil
152+
})
153+
154+
// Start ping ticker in background.
155+
pingDone := make(chan struct{})
156+
go func() {
157+
defer close(pingDone)
158+
ticker := time.NewTicker(pingInterval)
159+
defer ticker.Stop()
160+
for {
161+
select {
162+
case <-ctx.Done():
163+
return
164+
case <-ticker.C:
165+
d.mu.Lock()
166+
err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(10*time.Second))
167+
d.mu.Unlock()
168+
if err != nil {
169+
return
170+
}
171+
}
172+
}
173+
}()
174+
defer func() { <-pingDone }()
175+
145176
// Perform handshake.
146177
if err := d.handshake(ctx); err != nil {
147178
return fmt.Errorf("daemon: handshake: %w", err)
@@ -162,6 +193,8 @@ func (d *Daemon) connectAndListen(ctx context.Context) error {
162193
return fmt.Errorf("daemon: read: %w", err)
163194
}
164195

196+
// Reset deadline on every successful read.
197+
conn.SetReadDeadline(time.Now().Add(pongWait))
165198
d.handleMessage(ctx, message)
166199
}
167200
}

internal/engine/policy.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,13 @@ type WebhookActionConfig struct {
135135
}
136136

137137
// EffectiveTimeout returns the configured timeout or 5s default.
138+
// Capped at 30s to prevent resource exhaustion.
138139
func (c *WebhookActionConfig) EffectiveTimeout() time.Duration {
140+
const maxTimeout = 30 * time.Second
139141
if c.Timeout.Duration > 0 {
142+
if c.Timeout.Duration > maxTimeout {
143+
return maxTimeout
144+
}
140145
return c.Timeout.Duration
141146
}
142147
return 5 * time.Second

internal/mcp/proxy.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ type Proxy struct {
100100
}
101101

102102
type pendingCall struct {
103-
call engine.ToolCall
104-
request map[string]any
103+
call engine.ToolCall
104+
request map[string]any
105+
createdAt time.Time
105106
}
106107

107108
// NewProxy creates a new MCP stdio proxy.
@@ -287,13 +288,41 @@ func (p *Proxy) handleToolsCall(req Request, rawLine []byte) error {
287288
if HasID(req.ID) {
288289
id := NormalizedID(req.ID)
289290
p.pendingMu.Lock()
290-
p.pendingCalls[id] = pendingCall{call: call, request: requestData}
291+
p.pendingCalls[id] = pendingCall{call: call, request: requestData, createdAt: time.Now()}
292+
p.evictStalePendingCalls()
291293
p.pendingMu.Unlock()
292294
}
293295

294296
return p.writeToChild(rawLine)
295297
}
296298

299+
// evictStalePendingCalls removes pending calls older than 5 minutes.
300+
// Must be called with pendingMu held.
301+
func (p *Proxy) evictStalePendingCalls() {
302+
const maxAge = 5 * time.Minute
303+
const maxPending = 1000
304+
now := time.Now()
305+
for id, pc := range p.pendingCalls {
306+
if now.Sub(pc.createdAt) > maxAge {
307+
delete(p.pendingCalls, id)
308+
}
309+
}
310+
// Hard cap: if still too many, evict oldest
311+
if len(p.pendingCalls) > maxPending {
312+
var oldestID string
313+
var oldestTime time.Time
314+
for id, pc := range p.pendingCalls {
315+
if oldestID == "" || pc.createdAt.Before(oldestTime) {
316+
oldestID = id
317+
oldestTime = pc.createdAt
318+
}
319+
}
320+
if oldestID != "" {
321+
delete(p.pendingCalls, oldestID)
322+
}
323+
}
324+
}
325+
297326
func (p *Proxy) handleChildLine(line []byte, parentOut io.Writer) error {
298327
trimmed := bytes.TrimSpace(line)
299328

internal/proxy/server.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ func New(eng *engine.Engine, sink audit.AuditSink, opts ...Option) *Server {
110110
s.token = generateToken(s.logger)
111111
}
112112

113-
if len(s.token) > 8 {
114-
s.logger.Info("proxy: auth token", "token", s.token[:8]+"...")
113+
if len(s.token) > 4 {
114+
s.logger.Info("proxy: auth token", "prefix", s.token[:4]+"")
115115
}
116116
return s
117117
}
@@ -683,6 +683,11 @@ func enrichParams(toolName string, params map[string]any) {
683683
if cmd, ok := decodeBase64Command(params); ok {
684684
params["command"] = cmd
685685
}
686+
// Strip leading shell comment lines (e.g. "# description\nactual command")
687+
// so that command_matches patterns work against the real command.
688+
if cmd, ok := params["command"].(string); ok {
689+
params["command"] = stripLeadingComments(cmd)
690+
}
686691
}
687692

688693
if toolName == "fetch" || toolName == "http" || toolName == "web_fetch" {
@@ -706,12 +711,42 @@ func enrichParams(toolName string, params map[string]any) {
706711
}
707712
}
708713

714+
// stripLeadingComments removes leading lines that start with # (shell comments)
715+
// from multi-line command strings. Agent frameworks often prepend descriptive
716+
// comments (e.g. "# Check disk space\ndf -h") which break command_matches
717+
// patterns that expect the actual command at the start of the string.
718+
func stripLeadingComments(cmd string) string {
719+
lines := strings.Split(cmd, "\n")
720+
start := 0
721+
for start < len(lines) {
722+
trimmed := strings.TrimSpace(lines[start])
723+
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
724+
start++
725+
continue
726+
}
727+
break
728+
}
729+
if start == 0 {
730+
return cmd
731+
}
732+
if start >= len(lines) {
733+
return cmd // all comments — return original
734+
}
735+
return strings.Join(lines[start:], "\n")
736+
}
737+
709738
func decodeBase64Command(params map[string]any) (string, bool) {
710739
encoded, _ := params["command_b64"].(string)
711740
if strings.TrimSpace(encoded) == "" {
712741
return "", false
713742
}
714743

744+
// Cap encoded input at 1MB to prevent memory exhaustion.
745+
const maxBase64Len = 1 << 20
746+
if len(encoded) > maxBase64Len {
747+
return "", false
748+
}
749+
715750
decoded, err := base64.StdEncoding.DecodeString(encoded)
716751
if err != nil {
717752
return "", false

internal/proxy/server_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,31 @@ func TestListenAndServeTimeouts(t *testing.T) {
371371
assert.Equal(t, 30*time.Second, httpSrv.WriteTimeout)
372372
assert.Equal(t, 120*time.Second, httpSrv.IdleTimeout)
373373
}
374+
375+
func TestStripLeadingComments(t *testing.T) {
376+
tests := []struct {
377+
name string
378+
input string
379+
want string
380+
}{
381+
{"no comments", "ls -la", "ls -la"},
382+
{"single comment", "# list files\nls -la", "ls -la"},
383+
{"multiple comments", "# step 1\n# step 2\nls -la", "ls -la"},
384+
{"comment with blank line", "# desc\n\nls -la", "ls -la"},
385+
{"no stripping needed", "git push origin main", "git push origin main"},
386+
{"all comments returns original", "# just a comment\n# another", "# just a comment\n# another"},
387+
{"inline comment preserved", "ls -la # list files", "ls -la # list files"},
388+
{"multiline command", "# build\ndocker build -t app .\ndocker push app", "docker build -t app .\ndocker push app"},
389+
{"empty string", "", ""},
390+
{"whitespace comment", " # padded comment\necho hi", "echo hi"},
391+
}
392+
393+
for _, tt := range tests {
394+
t.Run(tt.name, func(t *testing.T) {
395+
got := stripLeadingComments(tt.input)
396+
if got != tt.want {
397+
t.Errorf("stripLeadingComments(%q) = %q, want %q", tt.input, got, tt.want)
398+
}
399+
})
400+
}
401+
}

0 commit comments

Comments
 (0)