Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 44 additions & 41 deletions core/conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func LoadConfig(file string, v any, opts ...Option) error {

// LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v any) error {
info, err := buildFieldsInfo(reflect.TypeOf(v), "")
info, err := buildFieldsInfo(reflect.TypeOf(v), "", make(map[reflect.Type]*fieldInfo))
if err != nil {
return err
}
Expand Down Expand Up @@ -152,10 +152,11 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil
}

func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
fullName string, visited map[reflect.Type]*fieldInfo) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName)
fields, err := buildFieldsInfo(ft, fullName, visited)
if err != nil {
return err
}
Expand All @@ -166,7 +167,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName, visited)
if err != nil {
return err
}
Expand All @@ -192,14 +193,44 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
return nil
}

func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildFieldsInfo(tp reflect.Type, fullName string,
visited map[reflect.Type]*fieldInfo) (*fieldInfo, error) {
tp = mapping.Deref(tp)
if finfo, ok := visited[tp]; ok {
return finfo, nil
}

switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp, fullName)
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
visited[tp] = info

for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
if !field.IsExported() {
continue
}

name := getTagName(field)
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName), visited); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName), visited); err != nil {
return nil, err
}
}

return info, nil
case reflect.Array, reflect.Slice, reflect.Map:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName, visited)
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName)
default:
Expand All @@ -209,23 +240,24 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
}
}

func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
fullName string, visited map[reflect.Type]*fieldInfo) error {
var finfo *fieldInfo
var err error

switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft, fullName, visited)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName)
finfo, err = buildFieldsInfo(ft.Elem(), fullName, visited)
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName, visited)
if err != nil {
return err
}
Expand All @@ -235,7 +267,7 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft, fullName, visited)
if err != nil {
return err
}
Expand All @@ -244,35 +276,6 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
return addOrMergeFields(info, lowerCaseName, finfo, fullName)
}

func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}

for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
if !field.IsExported() {
continue
}

name := getTagName(field)
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err
}
}

return info, nil
}

// getTagName get the tag name of the given field, if no tag name, use file.Name.
// field.Name is returned on tags like `json:""` and `json:",optional"`.
func getTagName(field reflect.StructField) string {
Expand Down
28 changes: 27 additions & 1 deletion core/conf/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ func Test_buildFieldsInfo(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := buildFieldsInfo(tt.t, "")
_, err := buildFieldsInfo(tt.t, "", make(map[reflect.Type]*fieldInfo))
if tt.ok {
assert.NoError(t, err)
} else {
Expand All @@ -1339,6 +1339,32 @@ func Test_buildFieldsInfo(t *testing.T) {
}
}

func TestLoadWithCycleReference(t *testing.T) {
type Node struct {
Name string `json:"name"`
Children []*Node `json:"children,optional"`
}

var c Node
input := []byte(`
name: root
children:
- name: child1
children:
- name: grandchild1
- name: child2
`)
err := LoadFromYamlBytes(input, &c)
assert.NoError(t, err)

assert.Equal(t, "root", c.Name)
assert.Len(t, c.Children, 2)
assert.Equal(t, "child1", c.Children[0].Name)
assert.Equal(t, "child2", c.Children[1].Name)
assert.Len(t, c.Children[0].Children, 1)
assert.Equal(t, "grandchild1", c.Children[0].Children[0].Name)
}

func createTempFile(t *testing.T, ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
Expand Down
Loading