|
1 | 1 | package codec
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "errors" |
4 | 5 | "fmt"
|
5 | 6 | "reflect"
|
6 | 7 | "strings"
|
@@ -376,7 +377,18 @@ func extractElement(src any, field string) (reflect.Value, error) {
|
376 | 377 | return reflect.Value{}, err
|
377 | 378 | }
|
378 | 379 |
|
379 |
| - if len(extractMaps) != 1 { |
| 380 | + // if extract maps is empty, check if the underlying field is an uninitialised slice, if so initialise it and return extracted elem. |
| 381 | + if len(extractMaps) == 0 { |
| 382 | + typ, err := initSliceForFieldPath(reflect.TypeOf(src), field) |
| 383 | + if errors.Is(err, &NoSliceUnderFieldPathError{}) { |
| 384 | + return reflect.Value{}, fmt.Errorf("%w: cannot find %q in type: %q for extraction", types.ErrInvalidType, field, reflect.TypeOf(src).String()) |
| 385 | + } else if err != nil { |
| 386 | + return reflect.Value{}, fmt.Errorf("%w: cannot find %q in type: %q for extraction, tried to check if path leads to an uninitialised slice, but failed with %w", types.ErrInvalidType, field, reflect.TypeOf(src).String(), err) |
| 387 | + } |
| 388 | + return typ, nil |
| 389 | + } |
| 390 | + |
| 391 | + if len(extractMaps) > 1 { |
380 | 392 | var sliceValue reflect.Value
|
381 | 393 | var sliceInitialized bool
|
382 | 394 | for _, fields := range extractMaps {
|
@@ -407,7 +419,7 @@ func extractElement(src any, field string) (reflect.Value, error) {
|
407 | 419 |
|
408 | 420 | item, ok := em[name]
|
409 | 421 | if !ok {
|
410 |
| - return reflect.Value{}, fmt.Errorf("%w: cannot find %s", types.ErrInvalidType, field) |
| 422 | + return reflect.Value{}, fmt.Errorf("%w: cannot find %q", types.ErrInvalidType, field) |
411 | 423 | }
|
412 | 424 |
|
413 | 425 | return reflect.ValueOf(item), nil
|
@@ -452,3 +464,68 @@ func pathAndName(field string) ([]string, string) {
|
452 | 464 |
|
453 | 465 | return path, name
|
454 | 466 | }
|
| 467 | + |
| 468 | +type NoSliceUnderFieldPathError struct { |
| 469 | + Err error |
| 470 | +} |
| 471 | + |
| 472 | +func (e *NoSliceUnderFieldPathError) Error() string { |
| 473 | + return fmt.Sprintf("field path did not resolve to a slice") |
| 474 | +} |
| 475 | + |
| 476 | +func initSliceForFieldPath(rootType reflect.Type, fieldPath string) (reflect.Value, error) { |
| 477 | + parts := strings.Split(fieldPath, ".") |
| 478 | + var prevIsSlice bool |
| 479 | + |
| 480 | + if rootType == nil { |
| 481 | + return reflect.Value{}, fmt.Errorf("root type is nil") |
| 482 | + } |
| 483 | + |
| 484 | + typ := derefTypePtr(rootType) |
| 485 | + |
| 486 | + for i, p := range parts { |
| 487 | + if typ.Kind() != reflect.Struct { |
| 488 | + return reflect.Value{}, fmt.Errorf("expected a struct when processing field %q, got %s", p, typ.Kind()) |
| 489 | + } |
| 490 | + |
| 491 | + fieldByName, ok := typ.FieldByName(p) |
| 492 | + if !ok { |
| 493 | + return reflect.Value{}, fmt.Errorf("field %q not found in type %s", p, typ.Name()) |
| 494 | + } |
| 495 | + |
| 496 | + fieldType := derefTypePtr(fieldByName.Type) |
| 497 | + |
| 498 | + // at end of path return a slice or a slice of slice if parent was a slice |
| 499 | + if i == len(parts)-1 { |
| 500 | + if prevIsSlice { |
| 501 | + fieldType = reflect.SliceOf(fieldType) |
| 502 | + } |
| 503 | + |
| 504 | + if fieldType.Kind() != reflect.Slice { |
| 505 | + return reflect.Value{}, &NoSliceUnderFieldPathError{} |
| 506 | + } |
| 507 | + |
| 508 | + return reflect.MakeSlice(fieldType, 0, 0), nil |
| 509 | + } |
| 510 | + |
| 511 | + if fieldType.Kind() == reflect.Slice { |
| 512 | + if prevIsSlice { |
| 513 | + return reflect.Value{}, fmt.Errorf("multiple nested slices are not allowed: found a slice at field %q, but parent in path is already a slice", p) |
| 514 | + } |
| 515 | + prevIsSlice = true |
| 516 | + |
| 517 | + newTyp := fieldType.Elem() |
| 518 | + newTyp = derefTypePtr(newTyp) |
| 519 | + |
| 520 | + if newTyp.Kind() == reflect.Slice { |
| 521 | + return reflect.Value{}, fmt.Errorf("multiple nested slices are not allowed: field %q in path contains a nested slice", p) |
| 522 | + } |
| 523 | + |
| 524 | + typ = newTyp |
| 525 | + } else { |
| 526 | + typ = derefTypePtr(fieldByName.Type) |
| 527 | + } |
| 528 | + } |
| 529 | + |
| 530 | + return reflect.Value{}, &NoSliceUnderFieldPathError{} |
| 531 | +} |
0 commit comments