Skip to content

Commit 1ad05b1

Browse files
committed
fix(root): validate rootful execution with the actual sudo program
1 parent d2a6e50 commit 1ad05b1

4 files changed

Lines changed: 72 additions & 12 deletions

File tree

internal/cli/assemble.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func assembleAction(ctx context.Context, cmd *cli.Command, cfg *config.Values, d
102102
// if at least one item in the manifest requires root, validate sudo before proceeding
103103
for _, item := range manifest {
104104
if item.Root {
105-
if err := rootful.Validate(ctx); err != nil {
105+
if err := rootful.Validate(ctx, cmd.String("sudo-command")); err != nil {
106106
return fmt.Errorf("cannot run in root mode: %w", err)
107107
}
108108
break

internal/cli/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func withRoot(_ *config.Values, cmd *cli.Command) *cli.Command {
146146
}
147147
}
148148
if c.Bool("root") {
149-
if err := rootful.Validate(ctx); err != nil {
149+
if err := rootful.Validate(ctx, c.String("sudo-command")); err != nil {
150150
return nil, fmt.Errorf("cannot run in root mode: %w", err)
151151
}
152152
}

internal/rootful/rootful.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Package rootful provides utilities for running operations that require
2-
// root privileges via sudo.
2+
// root privileges via a configurable sudo command.
33
package rootful
44

55
import (
@@ -12,22 +12,28 @@ import (
1212

1313
//nolint:gochecknoglobals // singleton: process-wide memoization is the intent
1414
var (
15-
validateOnce sync.Once
16-
errValidate error
15+
validateOnce sync.Once
16+
errValidate error
17+
cachedSudoCommand string
1718
)
1819

19-
// Validate ensures that sudo is available and the user can elevate.
20-
// It runs `sudo -v` at most once per process: the first call performs the
21-
// check and caches the result; subsequent calls return the cached result
22-
// without re-running the command.
23-
func Validate(ctx context.Context) error {
20+
// Validate ensures that sudoCommand is available and the user can elevate.
21+
// It runs `<sudoCommand> -v` at most once per process: the first call performs
22+
// the check and caches the result; subsequent calls return the cached result
23+
// without re-running the command. Passing a different sudoCommand after the
24+
// first call is a programming error and returns an explicit error.
25+
func Validate(ctx context.Context, sudoCommand string) error {
26+
if cachedSudoCommand != "" && cachedSudoCommand != sudoCommand {
27+
return fmt.Errorf("sudoCommand mismatch: already validated with %q, got %q", cachedSudoCommand, sudoCommand)
28+
}
2429
validateOnce.Do(func() {
25-
cmd := exec.CommandContext(ctx, "sudo", "-v")
30+
cachedSudoCommand = sudoCommand
31+
cmd := exec.CommandContext(ctx, sudoCommand, "-v")
2632
cmd.Stdin = os.Stdin
2733
cmd.Stdout = os.Stdout
2834
cmd.Stderr = os.Stderr
2935
if err := cmd.Run(); err != nil {
30-
errValidate = fmt.Errorf("failed to validate sudo: %w", err)
36+
errValidate = fmt.Errorf("failed to validate %q: %w", sudoCommand, err)
3137
}
3238
})
3339
return errValidate
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package rootful
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func resetState() {
13+
validateOnce = sync.Once{}
14+
errValidate = nil
15+
cachedSudoCommand = ""
16+
}
17+
18+
func TestValidate_Success(t *testing.T) {
19+
resetState()
20+
t.Cleanup(resetState)
21+
22+
err := Validate(context.Background(), "true")
23+
require.NoError(t, err)
24+
}
25+
26+
func TestValidate_Failure(t *testing.T) {
27+
resetState()
28+
t.Cleanup(resetState)
29+
30+
err := Validate(context.Background(), "false")
31+
require.Error(t, err)
32+
assert.Contains(t, err.Error(), `"false"`)
33+
}
34+
35+
func TestValidate_CachesResult(t *testing.T) {
36+
resetState()
37+
t.Cleanup(resetState)
38+
39+
require.NoError(t, Validate(context.Background(), "true"))
40+
require.NoError(t, Validate(context.Background(), "true"))
41+
}
42+
43+
func TestValidate_MismatchedCommand(t *testing.T) {
44+
resetState()
45+
t.Cleanup(resetState)
46+
47+
require.NoError(t, Validate(context.Background(), "true"))
48+
49+
err := Validate(context.Background(), "false")
50+
require.Error(t, err)
51+
assert.Contains(t, err.Error(), "mismatch")
52+
assert.Contains(t, err.Error(), `"true"`)
53+
assert.Contains(t, err.Error(), `"false"`)
54+
}

0 commit comments

Comments
 (0)