Skip to content

Commit 848875a

Browse files
Add flag type to allow literal or file content values (#187)
* Add flag type to allow String literal or file content values
1 parent c2556e7 commit 848875a

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

flagx/stringfile.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package flagx
2+
3+
import (
4+
"fmt"
5+
"os"
6+
)
7+
8+
// StringFile acts like the native flag.String by storing a string from the
9+
// given argument. Additionally, StringFile may specify a file to read the string value from when
10+
// prefixed with an '@', e.g. -flag=@value.txt
11+
type StringFile struct {
12+
Value string
13+
file string
14+
}
15+
16+
// Set records the string in Value. When the first character of the parameter is
17+
// prefixed with "@", i.e. "@file1", Set reads the file content for the value.
18+
func (fs *StringFile) Set(v string) error {
19+
if len(v) > 0 && v[0] == '@' {
20+
fname := v[1:]
21+
b, err := os.ReadFile(fname)
22+
if err != nil {
23+
return err
24+
}
25+
*fs = StringFile{Value: string(b), file: fname}
26+
} else {
27+
*fs = StringFile{Value: v}
28+
}
29+
return nil
30+
}
31+
32+
// String returns the flags in a form similiar to how they were added from the
33+
// command line.
34+
func (fs *StringFile) String() string {
35+
if fs.file != "" {
36+
return fmt.Sprintf("@%s", fs.file)
37+
}
38+
return fs.Value
39+
}
40+
41+
// Get returns the flag value.
42+
func (fs *StringFile) Get() string {
43+
return fs.Value
44+
}

flagx/stringfile_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package flagx_test
2+
3+
import (
4+
"os"
5+
"path"
6+
"testing"
7+
8+
"github.com/m-lab/go/flagx"
9+
"github.com/m-lab/go/testingx"
10+
)
11+
12+
func TestStringFile(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
value string
16+
useFile bool
17+
wantErr bool
18+
}{
19+
{
20+
name: "success-string",
21+
value: "value12345",
22+
useFile: false,
23+
},
24+
{
25+
name: "success-file",
26+
value: "1234567890abcdef",
27+
useFile: true,
28+
},
29+
{
30+
name: "error-file",
31+
value: "@error-bad-filename",
32+
useFile: false,
33+
wantErr: true,
34+
},
35+
}
36+
for _, tt := range tests {
37+
t.Run(tt.value, func(t *testing.T) {
38+
value := tt.value
39+
40+
if !tt.wantErr && tt.useFile {
41+
// This is a file read - so create a file in a temp directory.
42+
dir := t.TempDir()
43+
name := path.Join(dir, "file.txt")
44+
testingx.Must(t, os.WriteFile(name, []byte(tt.value), 0664), "failed to write test file")
45+
defer os.Remove(name)
46+
value = "@" + name // reset name to include directory.
47+
}
48+
49+
fb := &flagx.StringFile{}
50+
if err := fb.Set(value); (err != nil) != tt.wantErr {
51+
t.Errorf("StringFile.Set() error = %v, wantErr %v", err, tt.wantErr)
52+
}
53+
if tt.wantErr {
54+
return
55+
}
56+
if tt.value != fb.Get() {
57+
t.Errorf("StringFile.Get() want = %q, got %q", tt.value, fb.Get())
58+
}
59+
if fb.String()[0] != '@' && tt.useFile {
60+
t.Errorf("StringFile.String() want = @<file>, got %q", fb.String())
61+
}
62+
})
63+
}
64+
}

0 commit comments

Comments
 (0)