Skip to content

Commit d0fa47e

Browse files
committed
feat(fields): read/set all the fields from the given struct
refactor
1 parent aa03e94 commit d0fa47e

File tree

2 files changed

+114
-78
lines changed

2 files changed

+114
-78
lines changed

fields/iterate.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ func iterate(strct any, cfg *config, path []reflect.StructField) error {
126126
return intReflect.FieldCallbackResultStop()
127127
}
128128

129+
// TODO do not use DeepEqual, return in the result whether the value has changed
129130
valueHasChanged = valueHasChanged || !reflect.DeepEqual(original, value)
130131
}
131132

internal/reflect/iterate.go

Lines changed: 113 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -82,120 +82,155 @@ func IterateFields(strct any, callback FieldCallback, convert bool, convertToPtr
8282
return err
8383
}
8484

85+
var iterator func(
86+
reflectVal reflect.Value,
87+
callback FieldCallback,
88+
convert bool,
89+
convertToPtr bool,
90+
) error
91+
8592
switch {
8693
case chain.equalTo(reflect.Struct):
8794
strType = fmt.Sprintf("%T", reflect.Zero(reflectVal.Type()).Interface())
95+
iterator = iterateStruct
8896

89-
for i := 0; i < reflectVal.Type().NumField(); i++ {
90-
result := callback(reflectVal.Type().Field(i), valueFromField(reflectVal, i))
97+
case chain.equalTo(reflect.Ptr, reflect.Struct):
98+
strType = fmt.Sprintf("%T", reflect.Zero(reflectVal.Elem().Type()).Interface())
99+
iterator = iteratePtrStruct
91100

92-
if result.set {
93-
return fmt.Errorf("pointer is required to set fields")
94-
}
101+
case chain.equalTo(reflect.Ptr, reflect.Interface, reflect.Struct):
102+
strType = fmt.Sprintf("%T", reflect.Zero(reflectVal.Type()).Interface())
103+
iterator = iteratePtrInterfaceStruct
95104

96-
if result.stop {
97-
return nil
98-
}
105+
default:
106+
if err := ptrToNilStructError(strct); err != nil {
107+
return err
99108
}
100109

101-
case chain.equalTo(reflect.Ptr, reflect.Struct):
102-
strType = fmt.Sprintf("%T", reflect.Zero(reflectVal.Elem().Type()).Interface())
103-
104-
for i := 0; i < reflectVal.Elem().Type().NumField(); i++ {
105-
result := callback(reflectVal.Elem().Type().Field(i), valueFromField(reflectVal.Elem(), i))
110+
return fmt.Errorf("expected struct or pointer to struct, %T given", strct)
111+
}
106112

107-
if result.set {
108-
f := reflectVal.Elem().Field(i)
109-
if !f.CanSet() {
110-
f = reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem()
111-
}
113+
if err := iterator(reflectVal, callback, convert, convertToPtr); err != nil {
114+
return err
115+
}
112116

113-
newVal := result.value
117+
return nil
118+
}
114119

115-
newRefVal, err := func() (reflect.Value, error) {
116-
if convertToPtr && f.Kind() == reflect.Ptr && (newVal != nil || reflect.ValueOf(newVal).Kind() != reflect.Ptr) {
117-
val, err := ValueOf(newVal, f.Type().Elem(), convert)
118-
if err != nil {
119-
return reflect.Value{}, err
120-
}
120+
func valueFromField(strct reflect.Value, i int) any { //nolint:ireturn
121+
f := strct.Field(i)
121122

122-
ptr := reflect.New(val.Type())
123-
ptr.Elem().Set(val)
123+
if !f.CanSet() { // handle unexported fields
124+
if !f.CanAddr() {
125+
tmpReflectVal := reflect.New(strct.Type()).Elem()
126+
tmpReflectVal.Set(strct)
127+
f = tmpReflectVal.Field(i)
128+
}
124129

125-
return ptr, nil
126-
}
130+
f = reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem()
131+
}
127132

128-
return ValueOf(newVal, f.Type(), convert)
129-
}()
133+
return f.Interface()
134+
}
130135

131-
if err != nil {
132-
return fmt.Errorf("field %d %+q: %w", i, reflectVal.Elem().Type().Field(i).Name, err)
133-
}
136+
func iterateStruct(reflectVal reflect.Value, callback FieldCallback, convert bool, convertToPtr bool) error {
137+
for i := 0; i < reflectVal.Type().NumField(); i++ {
138+
result := callback(reflectVal.Type().Field(i), valueFromField(reflectVal, i))
134139

135-
f.Set(newRefVal)
136-
}
140+
if result.set {
141+
return fmt.Errorf("pointer is required to set fields")
142+
}
137143

138-
if result.stop {
139-
return nil
140-
}
144+
if result.stop {
145+
return nil
141146
}
147+
}
142148

143-
case chain.equalTo(reflect.Ptr, reflect.Interface, reflect.Struct):
144-
strType = fmt.Sprintf("%T", reflect.Zero(reflectVal.Type()).Interface())
145-
v := reflectVal.Elem()
146-
tmp := reflect.New(v.Elem().Type())
147-
tmp.Elem().Set(v.Elem())
148-
149-
// TODO find a better solution
150-
stop := false
151-
152-
newCallback := func(f reflect.StructField, value any) FieldCallbackResult {
153-
if stop {
154-
return FieldCallbackResult{
155-
value: nil,
156-
set: false,
157-
stop: true,
158-
}
159-
}
149+
return nil
150+
}
160151

161-
result := callback(f, value)
152+
func iteratePtrStruct(reflectVal reflect.Value, callback FieldCallback, convert bool, convertToPtr bool) error {
153+
for i := 0; i < reflectVal.Elem().Type().NumField(); i++ {
154+
result := callback(reflectVal.Elem().Type().Field(i), valueFromField(reflectVal.Elem(), i))
162155

163-
if result.stop {
164-
stop = true
156+
if result.set {
157+
f := reflectVal.Elem().Field(i)
158+
if !f.CanSet() {
159+
f = reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem()
165160
}
166161

167-
return result
168-
}
162+
newVal := result.value
169163

170-
if err := IterateFields(tmp.Interface(), newCallback, convert, convertToPtr); err != nil {
171-
return err
172-
}
164+
newRefVal, err := func() (reflect.Value, error) {
165+
if convertToPtr && f.Kind() == reflect.Ptr && (newVal != nil || reflect.ValueOf(newVal).Kind() != reflect.Ptr) {
166+
val, err := ValueOf(newVal, f.Type().Elem(), convert)
167+
if err != nil {
168+
return reflect.Value{}, err
169+
}
173170

174-
v.Set(tmp.Elem())
171+
ptr := reflect.New(val.Type())
172+
ptr.Elem().Set(val)
175173

176-
default:
177-
if err := ptrToNilStructError(strct); err != nil {
178-
return err
174+
return ptr, nil
175+
}
176+
177+
return ValueOf(newVal, f.Type(), convert)
178+
}()
179+
180+
if err != nil {
181+
return fmt.Errorf("field %d %+q: %w", i, reflectVal.Elem().Type().Field(i).Name, err)
182+
}
183+
184+
f.Set(newRefVal)
179185
}
180186

181-
return fmt.Errorf("expected struct or pointer to struct, %T given", strct)
187+
if result.stop {
188+
return nil
189+
}
182190
}
183191

184192
return nil
185193
}
186194

187-
func valueFromField(strct reflect.Value, i int) any { //nolint:ireturn
188-
f := strct.Field(i)
195+
func iteratePtrInterfaceStruct(reflectVal reflect.Value, callback FieldCallback, convert bool, convertToPtr bool) error {
196+
v := reflectVal.Elem()
197+
tmp := reflect.New(v.Elem().Type())
198+
tmp.Elem().Set(v.Elem())
199+
200+
var (
201+
stop = false
202+
set = false
203+
)
204+
205+
newCallback := func(f reflect.StructField, value any) FieldCallbackResult {
206+
if stop {
207+
return FieldCallbackResult{
208+
value: nil,
209+
set: false,
210+
stop: true,
211+
}
212+
}
189213

190-
if !f.CanSet() { // handle unexported fields
191-
if !f.CanAddr() {
192-
tmpReflectVal := reflect.New(strct.Type()).Elem()
193-
tmpReflectVal.Set(strct)
194-
f = tmpReflectVal.Field(i)
214+
result := callback(f, value)
215+
216+
if result.stop {
217+
stop = true
195218
}
196219

197-
f = reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem()
220+
if result.set {
221+
set = true
222+
}
223+
224+
return result
198225
}
199226

200-
return f.Interface()
227+
if err := IterateFields(tmp.Interface(), newCallback, convert, convertToPtr); err != nil {
228+
return err
229+
}
230+
231+
if set {
232+
v.Set(tmp.Elem())
233+
}
234+
235+
return nil
201236
}

0 commit comments

Comments
 (0)