@@ -3,6 +3,7 @@ package transform
33import (
44 "encoding"
55 "fmt"
6+ "maps"
67 "reflect"
78 "strings"
89
@@ -26,24 +27,36 @@ type FlattenMangler struct {
2627 tag string
2728 nameEncodeCasing caseconversion.EncodeCasingFunc
2829 tagEncodeCasing caseconversion.EncodeCasingFunc
30+
31+ propagateChildTags map [string ]struct {}
32+ }
33+
34+ // AddPropogateTags inserts an additional set of tags to propagate down to the
35+ // final leaf fields from intermediate-level struct-fields.
36+ func (f * FlattenMangler ) AddPropogateTags (tags ... string ) {
37+ for _ , tag := range tags {
38+ f .propagateChildTags [tag ] = struct {}{}
39+ }
2940}
3041
3142// DefaultFlattenMangler returns a FlattenMangler with preset values for tag,
3243// nameEncodeCasing, and tagEncodeCasing
3344func DefaultFlattenMangler () * FlattenMangler {
3445 return & FlattenMangler {
35- tag : common .DialsTagName ,
36- nameEncodeCasing : caseconversion .EncodeUpperCamelCase ,
37- tagEncodeCasing : caseconversion .EncodeCasePreservingSnakeCase ,
46+ tag : common .DialsTagName ,
47+ nameEncodeCasing : caseconversion .EncodeUpperCamelCase ,
48+ tagEncodeCasing : caseconversion .EncodeCasePreservingSnakeCase ,
49+ propagateChildTags : map [string ]struct {}{},
3850 }
3951}
4052
4153// NewFlattenMangler is the constructor for FlattenMangler
4254func NewFlattenMangler (tag string , nameEnc , tagEnc caseconversion.EncodeCasingFunc ) * FlattenMangler {
4355 return & FlattenMangler {
44- tag : tag ,
45- nameEncodeCasing : nameEnc ,
46- tagEncodeCasing : tagEnc ,
56+ tag : tag ,
57+ nameEncodeCasing : nameEnc ,
58+ tagEncodeCasing : tagEnc ,
59+ propagateChildTags : map [string ]struct {}{},
4760 }
4861}
4962
@@ -64,7 +77,9 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
6477 out := []reflect.StructField {}
6578 fieldPath := []string {sf .Name }
6679
67- tag , prefixTag , tagErr := f .getTag (& sf , nil , fieldPath )
80+ propagatedTags := map [string ]string {}
81+
82+ tag , prefixTag , tagErr := f .getTag (& sf , nil , fieldPath , propagatedTags )
6883 if tagErr != nil {
6984 return out , tagErr
7085 }
@@ -79,7 +94,7 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
7994 if ! sf .Anonymous {
8095 fieldPrefix = append (fieldPrefix , sf .Name )
8196 }
82- return f .flattenStruct (fieldPrefix , prefixTag , fieldPath , sf )
97+ return f .flattenStruct (fieldPrefix , prefixTag , fieldPath , maps . Clone ( propagatedTags ), sf )
8398 default :
8499 }
85100
@@ -99,7 +114,7 @@ func (f *FlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField,
99114
100115// flattenStruct takes a struct and flattens all the fields and makes a recursive
101116// call if the field is a struct too
102- func (f * FlattenMangler ) flattenStruct (fieldPrefix , tagPrefix , fieldPath []string , sf reflect.StructField ) ([]reflect.StructField , error ) {
117+ func (f * FlattenMangler ) flattenStruct (fieldPrefix , tagPrefix , fieldPath []string , propagatedTags map [ string ] string , sf reflect.StructField ) ([]reflect.StructField , error ) {
103118
104119 // get underlying type after removing pointers. Ignoring the kind
105120 _ , ft := getUnderlyingKindType (sf .Type )
@@ -122,8 +137,10 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
122137 // embedded fields to the slice so we can iterate through and get the original field
123138 flattenedPath := append (fieldPath [:len (fieldPath ):len (fieldPath )], nestedsf .Name )
124139
140+ propagatedFieldTags := maps .Clone (propagatedTags )
141+
125142 // add the tag of the current field to the list of flattened tags
126- tag , flattenedTags , tagErr := f .getTag (& nestedsf , tagPrefix , flattenedPath )
143+ tag , flattenedTags , tagErr := f .getTag (& nestedsf , tagPrefix , flattenedPath , propagatedFieldTags )
127144 if tagErr != nil {
128145 return out , tagErr
129146 }
@@ -137,7 +154,7 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
137154 if nestedT .Implements (textMReflectType ) || reflect .PointerTo (nestedT ).Implements (textMReflectType ) {
138155 break
139156 }
140- flattened , err := f .flattenStruct (flattenedNames , flattenedTags , flattenedPath , nestedsf )
157+ flattened , err := f .flattenStruct (flattenedNames , flattenedTags , flattenedPath , propagatedFieldTags , nestedsf )
141158 if err != nil {
142159 return out , err
143160 }
@@ -162,7 +179,7 @@ func (f *FlattenMangler) flattenStruct(fieldPrefix, tagPrefix, fieldPath []strin
162179// getTag uses the tag if one already exists or creates one based on the
163180// configured EncodingCasing function and fieldName. It returns the new parsed
164181// StructTag, the updated slice of tags, and any error encountered
165- func (f * FlattenMangler ) getTag (sf * reflect.StructField , tags , flattenedPath []string ) (reflect.StructTag , []string , error ) {
182+ func (f * FlattenMangler ) getTag (sf * reflect.StructField , tags , flattenedPath []string , propagatedTags map [ string ] string ) (reflect.StructTag , []string , error ) {
166183 tag , ok := sf .Tag .Lookup (f .tag )
167184
168185 // tag already exists so use the existing tag and append to prefix tags
@@ -199,6 +216,19 @@ func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []s
199216 Name : strings .Join (flattenedPath , "," ),
200217 })
201218
219+ for presTagName := range f .propagateChildTags {
220+ if v , ok := sf .Tag .Lookup (presTagName ); ok {
221+ propagatedTags [presTagName ] = v
222+ continue
223+ }
224+ if val , ok := propagatedTags [presTagName ]; ok {
225+ parsedTag .Set (& structtag.Tag {
226+ Key : presTagName ,
227+ Name : val ,
228+ })
229+ }
230+ }
231+
202232 return reflect .StructTag (parsedTag .String ()), tags , nil
203233}
204234
0 commit comments