Skip to content

Commit c0358e0

Browse files
authored
Merge pull request #104 from vimeo/flags_skip_intermediate
flatten mangler: propagate an allow-list of tags
2 parents f5bc6a2 + 5d5d885 commit c0358e0

6 files changed

Lines changed: 123 additions & 16 deletions

File tree

sources/env/env.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ var _ dials.Source = (*Source)(nil)
3131
func (e *Source) Value(_ context.Context, t *dials.Type) (reflect.Value, error) {
3232
// flatten the nested fields
3333
flattenMangler := transform.NewFlattenMangler(common.DialsTagName, caseconversion.EncodeUpperCamelCase, caseconversion.EncodeUpperCamelCase)
34+
flattenMangler.AddPropogateTags(common.DialsEnvTagName)
3435
// reformat the tags so they are SCREAMING_SNAKE_CASE
3536
reformatTagMangler := tagformat.NewTagReformattingMangler(common.DialsTagName, caseconversion.DecodeGoTags, caseconversion.EncodeUpperSnakeCase)
3637
// copy tags from "dials" to "dialsenv" tag

sources/flag/flag.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func Must(s *Set, err error) *Set {
167167

168168
const (
169169
// HelpTextTag is the name of the struct tags for flag descriptions
170-
HelpTextTag = "dialsdesc"
170+
HelpTextTag = common.DialsHelpTextTag
171171
// DefaultFlagHelpText is the default help-text for fields with an
172172
// unset dialsdesc tag.
173173
DefaultFlagHelpText = "unset description (`" + HelpTextTag + "` struct tag)"
@@ -202,6 +202,7 @@ func (s *Set) parse() error {
202202

203203
func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
204204
fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing)
205+
fm.AddPropogateTags(common.DialsFlagTagName, HelpTextTag)
205206
tfmr := transform.NewTransformer(ptyp, transform.NewAliasMangler(common.DialsTagName, common.DialsFlagTagName), fm)
206207
val, TrnslErr := tfmr.Translate()
207208
if TrnslErr != nil {

sources/flag/flag_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func TestDirectBasic(t *testing.T) {
3939

4040
d, err := dials.Config(ctx, &Config{Hello: "nothing"}, src)
4141
if err != nil {
42+
t.Log(buf.String())
43+
fs.PrintDefaults()
4244
t.Fatal(err)
4345
}
4446
src.Flags.Usage()

sources/pflag/pflag.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ func (s *Set) parse() error {
210210

211211
func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error {
212212
fm := transform.NewFlattenMangler(common.DialsTagName, s.NameCfg.FieldNameEncodeCasing, s.NameCfg.TagEncodeCasing)
213+
fm.AddPropogateTags(common.DialsPFlagTag, common.DialsPFlagShortTag, common.DialsHelpTextTag)
213214
tfmr := transform.NewTransformer(ptyp, transform.NewAliasMangler(common.DialsTagName, common.DialsPFlagTag, common.DialsPFlagShortTag), fm)
214215
val, TrnslErr := tfmr.Translate()
215216
if TrnslErr != nil {

transform/flatten_mangler.go

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package transform
33
import (
44
"encoding"
55
"fmt"
6+
"maps"
67
"reflect"
78
"strings"
89

@@ -26,24 +27,36 @@ type FlattenMangler struct {
2627
tag string
2728
nameEncodeCasing caseconversion.EncodeCasingFunc
2829
tagEncodeCasing caseconversion.EncodeCasingFunc
30+
31+
propagateChildTags map[string]struct{}
32+
}
33+
34+
// AddPropogateTags inserts an additional set of tags to propagate down to the
35+
// final leaf fields from intermediate-level struct-fields.
36+
func (f *FlattenMangler) AddPropogateTags(tags ...string) {
37+
for _, tag := range tags {
38+
f.propagateChildTags[tag] = struct{}{}
39+
}
2940
}
3041

3142
// DefaultFlattenMangler returns a FlattenMangler with preset values for tag,
3243
// nameEncodeCasing, and tagEncodeCasing
3344
func DefaultFlattenMangler() *FlattenMangler {
3445
return &FlattenMangler{
35-
tag: common.DialsTagName,
36-
nameEncodeCasing: caseconversion.EncodeUpperCamelCase,
37-
tagEncodeCasing: caseconversion.EncodeCasePreservingSnakeCase,
46+
tag: common.DialsTagName,
47+
nameEncodeCasing: caseconversion.EncodeUpperCamelCase,
48+
tagEncodeCasing: caseconversion.EncodeCasePreservingSnakeCase,
49+
propagateChildTags: map[string]struct{}{},
3850
}
3951
}
4052

4153
// NewFlattenMangler is the constructor for FlattenMangler
4254
func NewFlattenMangler(tag string, nameEnc, tagEnc caseconversion.EncodeCasingFunc) *FlattenMangler {
4355
return &FlattenMangler{
44-
tag: tag,
45-
nameEncodeCasing: nameEnc,
46-
tagEncodeCasing: tagEnc,
56+
tag: tag,
57+
nameEncodeCasing: nameEnc,
58+
tagEncodeCasing: tagEnc,
59+
propagateChildTags: map[string]struct{}{},
4760
}
4861
}
4962

@@ -64,7 +77,9 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
6477
out := []reflect.StructField{}
6578
fieldPath := []string{sf.Name}
6679

67-
tag, prefixTag, tagErr := f.getTag(&sf, nil, fieldPath)
80+
propagatedTags := map[string]string{}
81+
82+
tag, prefixTag, tagErr := f.getTag(&sf, nil, fieldPath, propagatedTags)
6883
if tagErr != nil {
6984
return out, tagErr
7085
}
@@ -79,7 +94,7 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
7994
if !sf.Anonymous {
8095
fieldPrefix = append(fieldPrefix, sf.Name)
8196
}
82-
return f.flattenStruct(fieldPrefix, prefixTag, fieldPath, sf)
97+
return f.flattenStruct(fieldPrefix, prefixTag, fieldPath, maps.Clone(propagatedTags), sf)
8398
default:
8499
}
85100

@@ -99,7 +114,7 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
99114

100115
// flattenStruct takes a struct and flattens all the fields and makes a recursive
101116
// call if the field is a struct too
102-
func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []string, sf reflect.StructField) ([]reflect.StructField, error) {
117+
func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []string, propagatedTags map[string]string, sf reflect.StructField) ([]reflect.StructField, error) {
103118

104119
// get underlying type after removing pointers. Ignoring the kind
105120
_, ft := getUnderlyingKindType(sf.Type)
@@ -122,8 +137,10 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
122137
// embedded fields to the slice so we can iterate through and get the original field
123138
flattenedPath := append(fieldPath[:len(fieldPath):len(fieldPath)], nestedsf.Name)
124139

140+
propagatedFieldTags := maps.Clone(propagatedTags)
141+
125142
// add the tag of the current field to the list of flattened tags
126-
tag, flattenedTags, tagErr := f.getTag(&nestedsf, tagPrefix, flattenedPath)
143+
tag, flattenedTags, tagErr := f.getTag(&nestedsf, tagPrefix, flattenedPath, propagatedFieldTags)
127144
if tagErr != nil {
128145
return out, tagErr
129146
}
@@ -137,7 +154,7 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
137154
if nestedT.Implements(textMReflectType) || reflect.PointerTo(nestedT).Implements(textMReflectType) {
138155
break
139156
}
140-
flattened, err := f.flattenStruct(flattenedNames, flattenedTags, flattenedPath, nestedsf)
157+
flattened, err := f.flattenStruct(flattenedNames, flattenedTags, flattenedPath, propagatedFieldTags, nestedsf)
141158
if err != nil {
142159
return out, err
143160
}
@@ -162,7 +179,7 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
162179
// getTag uses the tag if one already exists or creates one based on the
163180
// configured EncodingCasing function and fieldName. It returns the new parsed
164181
// StructTag, the updated slice of tags, and any error encountered
165-
func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []string) (reflect.StructTag, []string, error) {
182+
func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []string, propagatedTags map[string]string) (reflect.StructTag, []string, error) {
166183
tag, ok := sf.Tag.Lookup(f.tag)
167184

168185
// tag already exists so use the existing tag and append to prefix tags
@@ -199,6 +216,19 @@ func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []s
199216
Name: strings.Join(flattenedPath, ","),
200217
})
201218

219+
for presTagName := range f.propagateChildTags {
220+
if v, ok := sf.Tag.Lookup(presTagName); ok {
221+
propagatedTags[presTagName] = v
222+
continue
223+
}
224+
if val, ok := propagatedTags[presTagName]; ok {
225+
parsedTag.Set(&structtag.Tag{
226+
Key: presTagName,
227+
Name: val,
228+
})
229+
}
230+
}
231+
202232
return reflect.StructTag(parsedTag.String()), tags, nil
203233
}
204234

transform/flatten_mangler_test.go

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func (u tu) UnmarshalText(data []byte) error {
2525

2626
func TestFlattenMangler(t *testing.T) {
2727
type Foo struct {
28-
Location string `dials:"Location"`
28+
Location string `dials:"Location" littlebiddle:"boop"`
2929
Coordinates int `dials:"Coordinates"`
3030
SomeTime time.Duration
3131
}
@@ -48,6 +48,11 @@ func TestFlattenMangler(t *testing.T) {
4848
AnotherField int `dials:"AnotherField"`
4949
}
5050

51+
type intermediateLevel struct {
52+
Name string `dials:"Name"`
53+
FizzleBit bar `dials:"Bit" ooble:"ooops"`
54+
}
55+
5156
b := bar{
5257
Name: "test",
5358
Foobar: &Foo{
@@ -78,8 +83,9 @@ func TestFlattenMangler(t *testing.T) {
7883
}
7984

8085
testCases := []struct {
81-
name string
82-
testStruct any
86+
name string
87+
testStruct any
88+
additionalTags []string
8389
// modify will fill the flatten struct value after Mangling
8490
modify func(t testing.TB, val reflect.Value)
8591
assertion func(t testing.TB, i any)
@@ -135,6 +141,69 @@ func TestFlattenMangler(t *testing.T) {
135141
assert.EqualValues(t, curTime, *i.(*time.Time))
136142
},
137143
},
144+
{
145+
name: "intermediate level struct with an auxillary tag",
146+
testStruct: intermediateLevel{
147+
Name: "foobar",
148+
FizzleBit: b,
149+
},
150+
additionalTags: []string{"ooble"},
151+
modify: func(t testing.TB, val reflect.Value) {
152+
vt := val.Type()
153+
for _, field := range []string{"ConfigFieldFizzleBitName", "ConfigFieldFizzleBitAnotherField",
154+
"ConfigFieldFizzleBitFoobarLocation", "ConfigFieldFizzleBitFoobarCoordinates",
155+
"ConfigFieldFizzleBitFoobarSomeTime"} {
156+
if oobleField, ok := vt.FieldByName(field); ok {
157+
if tagVal, tagOK := oobleField.Tag.Lookup("ooble"); tagOK {
158+
if tagVal != "ooops" {
159+
t.Errorf("unexpected tag value for field %q; got %q; want %q", field, tagVal, "ooops")
160+
}
161+
} else {
162+
t.Errorf("missing ooble tag on %q; only %q", field, oobleField.Tag)
163+
}
164+
} else {
165+
t.Errorf("missing field %q type %s", field, vt)
166+
}
167+
}
168+
},
169+
assertion: func(t testing.TB, i any) {},
170+
},
171+
{
172+
name: "leaf auxilliary tag",
173+
testStruct: Foo{
174+
Location: "boop",
175+
Coordinates: 3333,
176+
},
177+
additionalTags: []string{"littlebiddle"},
178+
modify: func(t testing.TB, val reflect.Value) {
179+
vt := val.Type()
180+
// check the fields we expect to see the littlebiddle tag on
181+
for _, field := range []string{"ConfigFieldLocation"} {
182+
if littlebiddleField, ok := vt.FieldByName(field); ok {
183+
if tagVal, tagOK := littlebiddleField.Tag.Lookup("littlebiddle"); tagOK {
184+
if tagVal != "boop" {
185+
t.Errorf("unexpected tag value for field %q; got %q; want %q", field, tagVal, "ooops")
186+
}
187+
} else {
188+
t.Errorf("missing littlebiddle tag on %q; only %q", field, littlebiddleField.Tag)
189+
}
190+
} else {
191+
t.Errorf("missing field %q type %s", field, vt)
192+
}
193+
}
194+
// check the ones that we don't expect to see the littlebiddle tag on
195+
for _, field := range []string{"ConfigFieldCoordinates", "ConfigFieldSomeTime"} {
196+
if oobleField, ok := vt.FieldByName(field); ok {
197+
if tagVal, tagOK := oobleField.Tag.Lookup("littlebiddle"); tagOK {
198+
t.Errorf("unexpectedly present littlebiddle tag on %q; only %q; with value %q", field, oobleField.Tag, tagVal)
199+
}
200+
} else {
201+
t.Errorf("missing field %q type %s", field, vt)
202+
}
203+
}
204+
},
205+
assertion: func(t testing.TB, i any) {},
206+
},
138207
{
139208
name: "one level nested struct unexposed fields",
140209
testStruct: struct {
@@ -671,6 +740,9 @@ func TestFlattenMangler(t *testing.T) {
671740

672741
ptrifiedConfigType := ptrify.Pointerify(configStructType, reflect.New(configStructType).Elem())
673742
f := DefaultFlattenMangler()
743+
if tc.additionalTags != nil {
744+
f.AddPropogateTags(tc.additionalTags...)
745+
}
674746
tfmr := NewTransformer(ptrifiedConfigType, f)
675747
val, err := tfmr.Translate()
676748
require.NoError(t, err)

0 commit comments

Comments
 (0)