Skip to content

Commit b8dba49

Browse files
authored
Consistently apply Unix socket settings (#277)
Previously, we only supported setting the group for the server-side socket. This change makes it possible to set it on the client side as well. Also fixes a bug where the gRPC broker on the server side would previously not consume the directory/group environment variables.
1 parent c1fefa8 commit b8dba49

8 files changed

+184
-40
lines changed

client.go

+45-5
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ type Client struct {
101101
// forcefully killed.
102102
processKilled bool
103103

104-
hostSocketDir string
104+
unixSocketCfg UnixSocketConfig
105105
}
106106

107107
// NegotiatedVersion returns the protocol version negotiated with the server.
@@ -240,6 +240,28 @@ type ClientConfig struct {
240240
// SkipHostEnv allows plugins to run without inheriting the parent process'
241241
// environment variables.
242242
SkipHostEnv bool
243+
244+
// UnixSocketConfig configures additional options for any Unix sockets
245+
// that are created. Not normally required. Not supported on Windows.
246+
UnixSocketConfig *UnixSocketConfig
247+
}
248+
249+
type UnixSocketConfig struct {
250+
// If set, go-plugin will change the owner of any Unix sockets created to
251+
// this group, and set them as group-writable. Can be a name or gid. The
252+
// client process must be a member of this group or chown will fail.
253+
Group string
254+
255+
// The directory to create Unix sockets in. Internally managed by go-plugin
256+
// and deleted when the plugin is killed.
257+
directory string
258+
}
259+
260+
func unixSocketConfigFromEnv() UnixSocketConfig {
261+
return UnixSocketConfig{
262+
Group: os.Getenv(EnvUnixSocketGroup),
263+
directory: os.Getenv(EnvUnixSocketDir),
264+
}
243265
}
244266

245267
// ReattachConfig is used to configure a client to reattach to an
@@ -445,7 +467,7 @@ func (c *Client) Kill() {
445467
c.l.Lock()
446468
runner := c.runner
447469
addr := c.address
448-
hostSocketDir := c.hostSocketDir
470+
hostSocketDir := c.unixSocketCfg.directory
449471
c.l.Unlock()
450472

451473
// If there is no runner or ID, there is nothing to kill.
@@ -629,15 +651,33 @@ func (c *Client) Start() (addr net.Addr, err error) {
629651
}
630652
}
631653

654+
if c.config.UnixSocketConfig != nil {
655+
c.unixSocketCfg.Group = c.config.UnixSocketConfig.Group
656+
}
657+
658+
if c.unixSocketCfg.Group != "" {
659+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketGroup, c.unixSocketCfg.Group))
660+
}
661+
632662
var runner runner.Runner
633663
switch {
634664
case c.config.RunnerFunc != nil:
635-
c.hostSocketDir, err = os.MkdirTemp("", "")
665+
c.unixSocketCfg.directory, err = os.MkdirTemp("", "plugin-dir")
636666
if err != nil {
637667
return nil, err
638668
}
639-
c.logger.Trace("created temporary directory for unix sockets", "dir", c.hostSocketDir)
640-
runner, err = c.config.RunnerFunc(c.logger, cmd, c.hostSocketDir)
669+
// os.MkdirTemp creates folders with 0o700, so if we have a group
670+
// configured we need to make it group-writable.
671+
if c.unixSocketCfg.Group != "" {
672+
err = setGroupWritable(c.unixSocketCfg.directory, c.unixSocketCfg.Group, 0o770)
673+
if err != nil {
674+
return nil, err
675+
}
676+
}
677+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.directory))
678+
c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.directory)
679+
680+
runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.directory)
641681
if err != nil {
642682
return nil, err
643683
}

client_unix_test.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
//go:build !windows
5+
// +build !windows
6+
7+
package plugin
8+
9+
import (
10+
"fmt"
11+
"os"
12+
"os/exec"
13+
"os/user"
14+
"runtime"
15+
"syscall"
16+
"testing"
17+
18+
"github.com/hashicorp/go-hclog"
19+
"github.com/hashicorp/go-plugin/internal/cmdrunner"
20+
"github.com/hashicorp/go-plugin/runner"
21+
)
22+
23+
func TestSetGroup(t *testing.T) {
24+
if runtime.GOOS == "windows" {
25+
t.Skip("go-plugin doesn't support unix sockets on Windows")
26+
}
27+
28+
group, err := user.LookupGroupId(fmt.Sprintf("%d", os.Getgid()))
29+
if err != nil {
30+
t.Fatal(err)
31+
}
32+
for name, tc := range map[string]struct {
33+
group string
34+
}{
35+
"as integer": {fmt.Sprintf("%d", os.Getgid())},
36+
"as name": {group.Name},
37+
} {
38+
t.Run(name, func(t *testing.T) {
39+
process := helperProcess("mock")
40+
c := NewClient(&ClientConfig{
41+
HandshakeConfig: testHandshake,
42+
Plugins: testPluginMap,
43+
UnixSocketConfig: &UnixSocketConfig{
44+
Group: tc.group,
45+
},
46+
RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) {
47+
// Run tests inside the RunnerFunc to ensure we don't race
48+
// with the code that deletes tmpDir when the client fails
49+
// to start properly.
50+
51+
// Test that it creates a directory with the proper owners and permissions.
52+
info, err := os.Lstat(tmpDir)
53+
if err != nil {
54+
t.Fatal(err)
55+
}
56+
if info.Mode()&os.ModePerm != 0o770 {
57+
t.Fatal(info.Mode())
58+
}
59+
stat, ok := info.Sys().(*syscall.Stat_t)
60+
if !ok {
61+
t.Fatal()
62+
}
63+
if stat.Gid != uint32(os.Getgid()) {
64+
t.Fatalf("Expected %d, but got %d", os.Getgid(), stat.Gid)
65+
}
66+
67+
// Check the correct environment variables were set to forward
68+
// Unix socket config onto the plugin.
69+
var foundUnixSocketDir, foundUnixSocketGroup bool
70+
for _, env := range cmd.Env {
71+
if env == fmt.Sprintf("%s=%s", EnvUnixSocketDir, tmpDir) {
72+
foundUnixSocketDir = true
73+
}
74+
if env == fmt.Sprintf("%s=%s", EnvUnixSocketGroup, tc.group) {
75+
foundUnixSocketGroup = true
76+
}
77+
}
78+
if !foundUnixSocketDir {
79+
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketDir, cmd.Env)
80+
}
81+
if !foundUnixSocketGroup {
82+
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketGroup, cmd.Env)
83+
}
84+
85+
process.Env = append(process.Env, cmd.Env...)
86+
return cmdrunner.NewCmdRunner(l, process)
87+
},
88+
})
89+
defer c.Kill()
90+
91+
_, err := c.Start()
92+
if err != nil {
93+
t.Fatalf("err should be nil, got %s", err)
94+
}
95+
})
96+
}
97+
}

grpc_broker.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ type GRPCBroker struct {
268268
doneCh chan struct{}
269269
o sync.Once
270270

271-
socketDir string
271+
unixSocketCfg UnixSocketConfig
272272
addrTranslator runner.AddrTranslator
273273

274274
sync.Mutex
@@ -279,14 +279,14 @@ type gRPCBrokerPending struct {
279279
doneCh chan struct{}
280280
}
281281

282-
func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator runner.AddrTranslator) *GRPCBroker {
282+
func newGRPCBroker(s streamer, tls *tls.Config, unixSocketCfg UnixSocketConfig, addrTranslator runner.AddrTranslator) *GRPCBroker {
283283
return &GRPCBroker{
284284
streamer: s,
285285
streams: make(map[uint32]*gRPCBrokerPending),
286286
tls: tls,
287287
doneCh: make(chan struct{}),
288288

289-
socketDir: socketDir,
289+
unixSocketCfg: unixSocketCfg,
290290
addrTranslator: addrTranslator,
291291
}
292292
}
@@ -295,7 +295,7 @@ func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator
295295
//
296296
// This should not be called multiple times with the same ID at one time.
297297
func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) {
298-
listener, err := serverListener(b.socketDir)
298+
listener, err := serverListener(b.unixSocketCfg)
299299
if err != nil {
300300
return nil, err
301301
}

grpc_client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) {
6363

6464
// Start the broker.
6565
brokerGRPCClient := newGRPCBrokerClient(conn)
66-
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.hostSocketDir, c.runner)
66+
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.unixSocketCfg, c.runner)
6767
go broker.Run()
6868
go brokerGRPCClient.StartStream()
6969

grpc_server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func (s *GRPCServer) Init() error {
8484
// Register the broker service
8585
brokerServer := newGRPCBrokerServer()
8686
plugin.RegisterGRPCBrokerServer(s.server, brokerServer)
87-
s.broker = newGRPCBroker(brokerServer, s.TLS, "", nil)
87+
s.broker = newGRPCBroker(brokerServer, s.TLS, unixSocketConfigFromEnv(), nil)
8888
go s.broker.Run()
8989

9090
// Register the controller

server.go

+33-24
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func Serve(opts *ServeConfig) {
273273
}
274274

275275
// Register a listener so we can accept a connection
276-
listener, err := serverListener(os.Getenv(EnvUnixSocketDir))
276+
listener, err := serverListener(unixSocketConfigFromEnv())
277277
if err != nil {
278278
logger.Error("plugin init error", "error", err)
279279
return
@@ -496,12 +496,12 @@ func Serve(opts *ServeConfig) {
496496
}
497497
}
498498

499-
func serverListener(dir string) (net.Listener, error) {
499+
func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
500500
if runtime.GOOS == "windows" {
501501
return serverListener_tcp()
502502
}
503503

504-
return serverListener_unix(dir)
504+
return serverListener_unix(unixSocketCfg)
505505
}
506506

507507
func serverListener_tcp() (net.Listener, error) {
@@ -546,8 +546,8 @@ func serverListener_tcp() (net.Listener, error) {
546546
return nil, errors.New("Couldn't bind plugin TCP listener")
547547
}
548548

549-
func serverListener_unix(dir string) (net.Listener, error) {
550-
tf, err := os.CreateTemp(dir, "plugin")
549+
func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
550+
tf, err := os.CreateTemp(unixSocketCfg.directory, "plugin")
551551
if err != nil {
552552
return nil, err
553553
}
@@ -569,25 +569,8 @@ func serverListener_unix(dir string) (net.Listener, error) {
569569

570570
// By default, unix sockets are only writable by the owner. Set up a custom
571571
// group owner and group write permissions if configured.
572-
if groupString := os.Getenv(EnvUnixSocketGroup); groupString != "" {
573-
groupID, err := strconv.Atoi(groupString)
574-
if err != nil {
575-
group, err := user.LookupGroup(groupString)
576-
if err != nil {
577-
return nil, fmt.Errorf("failed to find group ID from %s=%s environment variable: %w", EnvUnixSocketGroup, groupString, err)
578-
}
579-
groupID, err = strconv.Atoi(group.Gid)
580-
if err != nil {
581-
return nil, fmt.Errorf("failed to parse %q group's Gid as an integer: %w", groupString, err)
582-
}
583-
}
584-
585-
err = os.Chown(path, -1, groupID)
586-
if err != nil {
587-
return nil, err
588-
}
589-
590-
err = os.Chmod(path, 0o660)
572+
if unixSocketCfg.Group != "" {
573+
err = setGroupWritable(path, unixSocketCfg.Group, 0o660)
591574
if err != nil {
592575
return nil, err
593576
}
@@ -601,6 +584,32 @@ func serverListener_unix(dir string) (net.Listener, error) {
601584
}, nil
602585
}
603586

587+
func setGroupWritable(path, groupString string, mode os.FileMode) error {
588+
groupID, err := strconv.Atoi(groupString)
589+
if err != nil {
590+
group, err := user.LookupGroup(groupString)
591+
if err != nil {
592+
return fmt.Errorf("failed to find gid from %q: %w", groupString, err)
593+
}
594+
groupID, err = strconv.Atoi(group.Gid)
595+
if err != nil {
596+
return fmt.Errorf("failed to parse %q group's gid as an integer: %w", groupString, err)
597+
}
598+
}
599+
600+
err = os.Chown(path, -1, groupID)
601+
if err != nil {
602+
return err
603+
}
604+
605+
err = os.Chmod(path, mode)
606+
if err != nil {
607+
return err
608+
}
609+
610+
return nil
611+
}
612+
604613
// rmListener is an implementation of net.Listener that forwards most
605614
// calls to the listener but also removes a file as part of the close. We
606615
// use this to cleanup the unix domain socket on close.

server_unix_test.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@ func TestUnixSocketGroupPermissions(t *testing.T) {
2525
t.Fatal(err)
2626
}
2727
for name, tc := range map[string]struct {
28-
gid string
28+
group string
2929
}{
3030
"as integer": {fmt.Sprintf("%d", os.Getgid())},
3131
"as name": {group.Name},
3232
} {
3333
t.Run(name, func(t *testing.T) {
34-
t.Setenv(EnvUnixSocketGroup, tc.gid)
35-
36-
ln, err := serverListener_unix("")
34+
ln, err := serverListener_unix(UnixSocketConfig{Group: tc.group})
3735
if err != nil {
3836
t.Fatal(err)
3937
}

testing.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCSe
166166
}
167167

168168
brokerGRPCClient := newGRPCBrokerClient(conn)
169-
broker := newGRPCBroker(brokerGRPCClient, nil, "", nil)
169+
broker := newGRPCBroker(brokerGRPCClient, nil, UnixSocketConfig{}, nil)
170170
go broker.Run()
171171
go brokerGRPCClient.StartStream()
172172

0 commit comments

Comments
 (0)