@@ -82,120 +82,155 @@ func IterateFields(strct any, callback FieldCallback, convert bool, convertToPtr
8282 return err
8383 }
8484
85+ var iterator func (
86+ reflectVal reflect.Value ,
87+ callback FieldCallback ,
88+ convert bool ,
89+ convertToPtr bool ,
90+ ) error
91+
8592 switch {
8693 case chain .equalTo (reflect .Struct ):
8794 strType = fmt .Sprintf ("%T" , reflect .Zero (reflectVal .Type ()).Interface ())
95+ iterator = iterateStruct
8896
89- for i := 0 ; i < reflectVal .Type ().NumField (); i ++ {
90- result := callback (reflectVal .Type ().Field (i ), valueFromField (reflectVal , i ))
97+ case chain .equalTo (reflect .Ptr , reflect .Struct ):
98+ strType = fmt .Sprintf ("%T" , reflect .Zero (reflectVal .Elem ().Type ()).Interface ())
99+ iterator = iteratePtrStruct
91100
92- if result . set {
93- return fmt .Errorf ( "pointer is required to set fields" )
94- }
101+ case chain . equalTo ( reflect . Ptr , reflect . Interface , reflect . Struct ):
102+ strType = fmt .Sprintf ( "%T" , reflect . Zero ( reflectVal . Type ()). Interface () )
103+ iterator = iteratePtrInterfaceStruct
95104
96- if result . stop {
97- return nil
98- }
105+ default :
106+ if err := ptrToNilStructError ( strct ); err != nil {
107+ return err
99108 }
100109
101- case chain .equalTo (reflect .Ptr , reflect .Struct ):
102- strType = fmt .Sprintf ("%T" , reflect .Zero (reflectVal .Elem ().Type ()).Interface ())
103-
104- for i := 0 ; i < reflectVal .Elem ().Type ().NumField (); i ++ {
105- result := callback (reflectVal .Elem ().Type ().Field (i ), valueFromField (reflectVal .Elem (), i ))
110+ return fmt .Errorf ("expected struct or pointer to struct, %T given" , strct )
111+ }
106112
107- if result .set {
108- f := reflectVal .Elem ().Field (i )
109- if ! f .CanSet () {
110- f = reflect .NewAt (f .Type (), unsafe .Pointer (f .UnsafeAddr ())).Elem ()
111- }
113+ if err := iterator (reflectVal , callback , convert , convertToPtr ); err != nil {
114+ return err
115+ }
112116
113- newVal := result .value
117+ return nil
118+ }
114119
115- newRefVal , err := func () (reflect.Value , error ) {
116- if convertToPtr && f .Kind () == reflect .Ptr && (newVal != nil || reflect .ValueOf (newVal ).Kind () != reflect .Ptr ) {
117- val , err := ValueOf (newVal , f .Type ().Elem (), convert )
118- if err != nil {
119- return reflect.Value {}, err
120- }
120+ func valueFromField (strct reflect.Value , i int ) any { //nolint:ireturn
121+ f := strct .Field (i )
121122
122- ptr := reflect .New (val .Type ())
123- ptr .Elem ().Set (val )
123+ if ! f .CanSet () { // handle unexported fields
124+ if ! f .CanAddr () {
125+ tmpReflectVal := reflect .New (strct .Type ()).Elem ()
126+ tmpReflectVal .Set (strct )
127+ f = tmpReflectVal .Field (i )
128+ }
124129
125- return ptr , nil
126- }
130+ f = reflect . NewAt ( f . Type (), unsafe . Pointer ( f . UnsafeAddr ())). Elem ()
131+ }
127132
128- return ValueOf ( newVal , f . Type (), convert )
129- }()
133+ return f . Interface ( )
134+ }
130135
131- if err != nil {
132- return fmt . Errorf ( "field %d %+q: %w" , i , reflectVal .Elem (). Type ().Field ( i ). Name , err )
133- }
136+ func iterateStruct ( reflectVal reflect. Value , callback FieldCallback , convert bool , convertToPtr bool ) error {
137+ for i := 0 ; i < reflectVal .Type ().NumField (); i ++ {
138+ result := callback ( reflectVal . Type (). Field ( i ), valueFromField ( reflectVal , i ))
134139
135- f .Set (newRefVal )
136- }
140+ if result .set {
141+ return fmt .Errorf ("pointer is required to set fields" )
142+ }
137143
138- if result .stop {
139- return nil
140- }
144+ if result .stop {
145+ return nil
141146 }
147+ }
142148
143- case chain .equalTo (reflect .Ptr , reflect .Interface , reflect .Struct ):
144- strType = fmt .Sprintf ("%T" , reflect .Zero (reflectVal .Type ()).Interface ())
145- v := reflectVal .Elem ()
146- tmp := reflect .New (v .Elem ().Type ())
147- tmp .Elem ().Set (v .Elem ())
148-
149- // TODO find a better solution
150- stop := false
151-
152- newCallback := func (f reflect.StructField , value any ) FieldCallbackResult {
153- if stop {
154- return FieldCallbackResult {
155- value : nil ,
156- set : false ,
157- stop : true ,
158- }
159- }
149+ return nil
150+ }
160151
161- result := callback (f , value )
152+ func iteratePtrStruct (reflectVal reflect.Value , callback FieldCallback , convert bool , convertToPtr bool ) error {
153+ for i := 0 ; i < reflectVal .Elem ().Type ().NumField (); i ++ {
154+ result := callback (reflectVal .Elem ().Type ().Field (i ), valueFromField (reflectVal .Elem (), i ))
162155
163- if result .stop {
164- stop = true
156+ if result .set {
157+ f := reflectVal .Elem ().Field (i )
158+ if ! f .CanSet () {
159+ f = reflect .NewAt (f .Type (), unsafe .Pointer (f .UnsafeAddr ())).Elem ()
165160 }
166161
167- return result
168- }
162+ newVal := result .value
169163
170- if err := IterateFields (tmp .Interface (), newCallback , convert , convertToPtr ); err != nil {
171- return err
172- }
164+ newRefVal , err := func () (reflect.Value , error ) {
165+ if convertToPtr && f .Kind () == reflect .Ptr && (newVal != nil || reflect .ValueOf (newVal ).Kind () != reflect .Ptr ) {
166+ val , err := ValueOf (newVal , f .Type ().Elem (), convert )
167+ if err != nil {
168+ return reflect.Value {}, err
169+ }
173170
174- v .Set (tmp .Elem ())
171+ ptr := reflect .New (val .Type ())
172+ ptr .Elem ().Set (val )
175173
176- default :
177- if err := ptrToNilStructError (strct ); err != nil {
178- return err
174+ return ptr , nil
175+ }
176+
177+ return ValueOf (newVal , f .Type (), convert )
178+ }()
179+
180+ if err != nil {
181+ return fmt .Errorf ("field %d %+q: %w" , i , reflectVal .Elem ().Type ().Field (i ).Name , err )
182+ }
183+
184+ f .Set (newRefVal )
179185 }
180186
181- return fmt .Errorf ("expected struct or pointer to struct, %T given" , strct )
187+ if result .stop {
188+ return nil
189+ }
182190 }
183191
184192 return nil
185193}
186194
187- func valueFromField (strct reflect.Value , i int ) any { //nolint:ireturn
188- f := strct .Field (i )
195+ func iteratePtrInterfaceStruct (reflectVal reflect.Value , callback FieldCallback , convert bool , convertToPtr bool ) error {
196+ v := reflectVal .Elem ()
197+ tmp := reflect .New (v .Elem ().Type ())
198+ tmp .Elem ().Set (v .Elem ())
199+
200+ var (
201+ stop = false
202+ set = false
203+ )
204+
205+ newCallback := func (f reflect.StructField , value any ) FieldCallbackResult {
206+ if stop {
207+ return FieldCallbackResult {
208+ value : nil ,
209+ set : false ,
210+ stop : true ,
211+ }
212+ }
189213
190- if ! f .CanSet () { // handle unexported fields
191- if ! f .CanAddr () {
192- tmpReflectVal := reflect .New (strct .Type ()).Elem ()
193- tmpReflectVal .Set (strct )
194- f = tmpReflectVal .Field (i )
214+ result := callback (f , value )
215+
216+ if result .stop {
217+ stop = true
195218 }
196219
197- f = reflect .NewAt (f .Type (), unsafe .Pointer (f .UnsafeAddr ())).Elem ()
220+ if result .set {
221+ set = true
222+ }
223+
224+ return result
198225 }
199226
200- return f .Interface ()
227+ if err := IterateFields (tmp .Interface (), newCallback , convert , convertToPtr ); err != nil {
228+ return err
229+ }
230+
231+ if set {
232+ v .Set (tmp .Elem ())
233+ }
234+
235+ return nil
201236}
0 commit comments