diff --git a/decode.go b/decode.go index 9c168fe1..57a40ce6 100644 --- a/decode.go +++ b/decode.go @@ -757,7 +757,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr return err } if err := unmarshaler(ptrValue.Interface(), b); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -769,7 +769,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr return err } if err := unmarshaler.UnmarshalYAML(ctx, b); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -780,7 +780,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr return err } if err := unmarshaler.UnmarshalYAML(b); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -796,7 +796,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr } return nil }); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -812,14 +812,14 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr } return nil }); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } if unmarshaler, ok := iface.(NodeUnmarshaler); ok { if err := unmarshaler.UnmarshalYAML(src); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil @@ -827,7 +827,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr if unmarshaler, ok := iface.(NodeUnmarshalerContext); ok { if err := unmarshaler.UnmarshalYAML(ctx, src); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil @@ -845,7 +845,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr b, ok := d.unmarshalableText(src) if ok { if err := unmarshaler.UnmarshalText(b); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -863,7 +863,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr } jsonBytes = bytes.TrimRight(jsonBytes, "\n") if err := unmarshaler.UnmarshalJSON(jsonBytes); err != nil { - return err + return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken()) } return nil } @@ -872,9 +872,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr return errors.New("does not implemented Unmarshaler") } -var ( - astNodeType = reflect.TypeOf((*ast.Node)(nil)).Elem() -) +var astNodeType = reflect.TypeOf((*ast.Node)(nil)).Elem() func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.Node) error { d.stepIn() @@ -1381,8 +1379,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if foundErr != nil { continue } - var te *errors.TypeError - if errors.As(err, &te) { + if te, ok := err.(*errors.TypeError); ok { if te.StructFieldName != nil { fieldName := fmt.Sprintf("%s.%s", structType.Name(), *te.StructFieldName) te.StructFieldName = &fieldName @@ -1417,8 +1414,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N if foundErr != nil { continue } - var te *errors.TypeError - if errors.As(err, &te) { + if te, ok := err.(*errors.TypeError); ok { fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name) te.StructFieldName = &fieldName foundErr = te diff --git a/decode_test.go b/decode_test.go index 2623e397..d48fc70e 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2531,6 +2531,9 @@ func TestUnmarshalablePtrString(t *testing.T) { type unmarshalableIntValue int func (v *unmarshalableIntValue) UnmarshalYAML(raw []byte) error { + if string(raw) == "yamlerr" { + return yaml.Unmarshal(raw, (*int)(v)) + } i, err := strconv.Atoi(string(raw)) if err != nil { return err @@ -2604,6 +2607,38 @@ func TestUnmarshalablePtrInt(t *testing.T) { }) } +func TestUnmarshalableErrors(t *testing.T) { + t.Run("wrapped error", func(t *testing.T) { + t.Parallel() + var container unmarshalableIntContainer + err := yaml.Unmarshal([]byte(`value: atoierr`), &container) + if err == nil { + t.Fatal("expected to error") + } + var unmarshalerErr *errors.UnmarshalerError + if !errors.As(err, &unmarshalerErr) { + t.Fatalf("expected UnmarshalerError but got: %s", err) + } + expectedErr := `yaml_test.unmarshalableIntValue: strconv.Atoi: parsing "atoierr": invalid syntax` + if !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), expectedErr) + } + }) + + t.Run("nested yaml decode error", func(t *testing.T) { + t.Parallel() + var container unmarshalableIntContainer + err := yaml.Unmarshal([]byte(`value: yamlerr`), &container) + if err == nil { + t.Fatal("expected to error") + } + expectedErr := `cannot unmarshal into Go value of type yaml_test.unmarshalableIntValue: cannot unmarshal string into Go value of type int` + if !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("expected error message: %s to contain: %s", err.Error(), expectedErr) + } + }) +} + type literalContainer struct { v string } @@ -3206,6 +3241,9 @@ type unmarshableMapKey struct { } func (mk *unmarshableMapKey) UnmarshalYAML(b []byte) error { + if string(b) == "errkey" { + return errors.New("invalid map key") + } mk.Key = string(b) return nil } @@ -3401,6 +3439,15 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) { if val != "value" { t.Fatalf("expected to have value \"value\", but got %q", val) } + + expectErr := "invalid map key" + err := yaml.Unmarshal([]byte(`errkey: value`), &m) + if err == nil { + t.Fatal("expected error but got nil") + } + if !strings.Contains(err.Error(), expectErr) { + t.Fatalf("error message %q should contain %q", err.Error(), expectErr) + } } type bytesUnmershalerWithMapAlias struct{} diff --git a/error.go b/error.go index 52d3e7e6..6db2d785 100644 --- a/error.go +++ b/error.go @@ -26,6 +26,7 @@ type ( DuplicateKeyError = errors.DuplicateKeyError UnknownFieldError = errors.UnknownFieldError UnexpectedNodeTypeError = errors.UnexpectedNodeTypeError + UnmarshalerError = errors.UnmarshalerError Error = errors.Error ) diff --git a/internal/errors/error.go b/internal/errors/error.go index b08a3fc6..cdfe4a70 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -35,6 +35,7 @@ var ( _ Error = new(DuplicateKeyError) _ Error = new(UnknownFieldError) _ Error = new(UnexpectedNodeTypeError) + _ Error = new(UnmarshalerError) ) type SyntaxError struct { @@ -71,6 +72,12 @@ type UnexpectedNodeTypeError struct { Token *token.Token } +type UnmarshalerError struct { + Wrapped error + DstType reflect.Type + Token *token.Token +} + // ErrSyntax create syntax error instance with message and token func ErrSyntax(msg string, tk *token.Token) *SyntaxError { return &SyntaxError{ @@ -121,6 +128,24 @@ func ErrUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) *Unex } } +func ErrUnmarshaler(wrapped error, dstType reflect.Type, tk *token.Token) error { + if wrapped == nil { + return nil + } + if yamlErr, ok := wrapped.(Error); ok { + return &UnmarshalerError{ + Wrapped: errors.New(yamlErr.GetMessage()), + DstType: dstType, + Token: tk, + } + } + return &UnmarshalerError{ + Wrapped: wrapped, + DstType: dstType, + Token: tk, + } +} + func (e *SyntaxError) GetMessage() string { return e.Message } @@ -232,6 +257,30 @@ func (e *UnexpectedNodeTypeError) msg() string { return fmt.Sprintf("%s was used where %s is expected", e.Actual.YAMLName(), e.Expected.YAMLName()) } +func (e *UnmarshalerError) GetMessage() string { + return e.msg() +} + +func (e *UnmarshalerError) GetToken() *token.Token { + return e.Token +} + +func (e *UnmarshalerError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) +} + +func (e *UnmarshalerError) FormatError(colored, inclSource bool) string { + return FormatError(e.msg(), e.Token, colored, inclSource) +} + +func (e *UnmarshalerError) Unwrap() error { + return e.Wrapped +} + +func (e *UnmarshalerError) msg() string { + return fmt.Sprintf("cannot unmarshal into Go value of type %s: %s", e.DstType, e.Wrapped.Error()) +} + func FormatError(errMsg string, token *token.Token, colored, inclSource bool) string { var pp printer.Printer if token == nil {