Skip to content

Commit ffd5249

Browse files
authored
Search for config.yaml/yml in both service and cli mode (slackhq#1717)
1 parent 625f58b commit ffd5249

5 files changed

Lines changed: 110 additions & 21 deletions

File tree

cmd/nebula-service/main.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@ func main() {
6161
}
6262

6363
if *configPath == "" {
64-
fmt.Println("-config flag must be set")
65-
flag.Usage()
66-
os.Exit(1)
64+
p, err := config.DefaultPath()
65+
if err != nil {
66+
fmt.Println(err)
67+
os.Exit(1)
68+
}
69+
*configPath = p
6770
}
6871

6972
c := config.NewC(l)

cmd/nebula-service/service.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package main
33
import (
44
"fmt"
55
"log"
6-
"os"
7-
"path/filepath"
86

97
"github.com/kardianos/service"
108
"github.com/slackhq/nebula"
@@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error {
5755
return nil
5856
}
5957

60-
func fileExists(filename string) bool {
61-
_, err := os.Stat(filename)
62-
if os.IsNotExist(err) {
63-
return false
64-
}
65-
return true
66-
}
67-
6858
func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error {
6959
if *configPath == "" {
70-
ex, err := os.Executable()
60+
p, err := config.DefaultPath()
7161
if err != nil {
7262
return err
7363
}
74-
*configPath = filepath.Dir(ex) + "/config.yaml"
75-
if !fileExists(*configPath) {
76-
*configPath = filepath.Dir(ex) + "/config.yml"
77-
}
64+
*configPath = p
7865
}
7966

8067
svcConfig := &service.Config{

cmd/nebula/main.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ func main() {
5050
}
5151

5252
if *configPath == "" {
53-
fmt.Println("-config flag must be set")
54-
flag.Usage()
55-
os.Exit(1)
53+
p, err := config.DefaultPath()
54+
if err != nil {
55+
fmt.Println(err)
56+
os.Exit(1)
57+
}
58+
*configPath = p
5659
}
5760

5861
l := logging.NewLogger(os.Stdout)

config/default.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package config
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"path/filepath"
7+
)
8+
9+
// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml.
10+
// If neither file exists an error is returned that names both paths checked.
11+
func DefaultPath() (string, error) {
12+
ex, err := os.Executable()
13+
if err != nil {
14+
return "", err
15+
}
16+
return defaultPathInDir(filepath.Dir(ex))
17+
}
18+
19+
func defaultPathInDir(dir string) (string, error) {
20+
yamlPath := filepath.Join(dir, "config.yaml")
21+
if _, err := os.Stat(yamlPath); err == nil {
22+
return yamlPath, nil
23+
}
24+
ymlPath := filepath.Join(dir, "config.yml")
25+
if _, err := os.Stat(ymlPath); err == nil {
26+
return ymlPath, nil
27+
}
28+
return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath)
29+
}

config/default_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package config
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestDefaultPathInDir(t *testing.T) {
13+
t.Run("prefers config.yaml when both exist", func(t *testing.T) {
14+
dir := t.TempDir()
15+
want := filepath.Join(dir, "config.yaml")
16+
other := filepath.Join(dir, "config.yml")
17+
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
18+
require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644))
19+
20+
got, err := defaultPathInDir(dir)
21+
require.NoError(t, err)
22+
assert.Equal(t, want, got)
23+
})
24+
25+
t.Run("returns config.yaml when only it exists", func(t *testing.T) {
26+
dir := t.TempDir()
27+
want := filepath.Join(dir, "config.yaml")
28+
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
29+
30+
got, err := defaultPathInDir(dir)
31+
require.NoError(t, err)
32+
assert.Equal(t, want, got)
33+
})
34+
35+
t.Run("falls back to config.yml when only it exists", func(t *testing.T) {
36+
dir := t.TempDir()
37+
want := filepath.Join(dir, "config.yml")
38+
require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644))
39+
40+
got, err := defaultPathInDir(dir)
41+
require.NoError(t, err)
42+
assert.Equal(t, want, got)
43+
})
44+
45+
t.Run("errors when neither exists and names both paths", func(t *testing.T) {
46+
dir := t.TempDir()
47+
got, err := defaultPathInDir(dir)
48+
assert.Empty(t, got)
49+
require.Error(t, err)
50+
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml"))
51+
assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml"))
52+
})
53+
}
54+
55+
func TestDefaultPath(t *testing.T) {
56+
got, err := DefaultPath()
57+
if err != nil {
58+
ex, exErr := os.Executable()
59+
require.NoError(t, exErr)
60+
assert.Contains(t, err.Error(), filepath.Dir(ex))
61+
return
62+
}
63+
ex, err := os.Executable()
64+
require.NoError(t, err)
65+
assert.Equal(t, filepath.Dir(ex), filepath.Dir(got))
66+
assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got))
67+
}

0 commit comments

Comments
 (0)