Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
return err
}
if err := unmarshaler(ptrValue.Interface(), b); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -769,7 +769,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
return err
}
if err := unmarshaler.UnmarshalYAML(ctx, b); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -780,7 +780,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
return err
}
if err := unmarshaler.UnmarshalYAML(b); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -796,7 +796,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}
return nil
}); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -812,22 +812,22 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}
return nil
}); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}

if unmarshaler, ok := iface.(NodeUnmarshaler); ok {
if err := unmarshaler.UnmarshalYAML(src); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}

return nil
}

if unmarshaler, ok := iface.(NodeUnmarshalerContext); ok {
if err := unmarshaler.UnmarshalYAML(ctx, src); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}

return nil
Expand All @@ -845,7 +845,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
b, ok := d.unmarshalableText(src)
if ok {
if err := unmarshaler.UnmarshalText(b); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -863,7 +863,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}
jsonBytes = bytes.TrimRight(jsonBytes, "\n")
if err := unmarshaler.UnmarshalJSON(jsonBytes); err != nil {
return err
return errors.ErrUnmarshaler(err, dst.Type(), src.GetToken())
}
return nil
}
Expand All @@ -872,9 +872,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
return errors.New("does not implemented Unmarshaler")
}

var (
astNodeType = reflect.TypeOf((*ast.Node)(nil)).Elem()
)
var astNodeType = reflect.TypeOf((*ast.Node)(nil)).Elem()

func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.Node) error {
d.stepIn()
Expand Down Expand Up @@ -1381,8 +1379,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
if foundErr != nil {
continue
}
var te *errors.TypeError
if errors.As(err, &te) {
if te, ok := err.(*errors.TypeError); ok {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will stop working if the err is wrapped somewhere. Why did you change the sequence using As ?

if te.StructFieldName != nil {
fieldName := fmt.Sprintf("%s.%s", structType.Name(), *te.StructFieldName)
te.StructFieldName = &fieldName
Expand Down Expand Up @@ -1417,8 +1414,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
if foundErr != nil {
continue
}
var te *errors.TypeError
if errors.As(err, &te) {
if te, ok := err.(*errors.TypeError); ok {
fieldName := fmt.Sprintf("%s.%s", structType.Name(), field.Name)
te.StructFieldName = &fieldName
foundErr = te
Expand Down
47 changes: 47 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,9 @@ func TestUnmarshalablePtrString(t *testing.T) {
type unmarshalableIntValue int

func (v *unmarshalableIntValue) UnmarshalYAML(raw []byte) error {
if string(raw) == "yamlerr" {
return yaml.Unmarshal(raw, (*int)(v))
}
i, err := strconv.Atoi(string(raw))
if err != nil {
return err
Expand Down Expand Up @@ -2604,6 +2607,38 @@ func TestUnmarshalablePtrInt(t *testing.T) {
})
}

func TestUnmarshalableErrors(t *testing.T) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only a test case for UnmarshalYAML([]byte) error, so please add tests for other cases as well. In the tests, ensure that you can obtain an UnmarshalerError and verify that the content of its Error() matches the expected message.

t.Run("wrapped error", func(t *testing.T) {
t.Parallel()
var container unmarshalableIntContainer
err := yaml.Unmarshal([]byte(`value: atoierr`), &container)
if err == nil {
t.Fatal("expected to error")
}
var unmarshalerErr *errors.UnmarshalerError
if !errors.As(err, &unmarshalerErr) {
t.Fatalf("expected UnmarshalerError but got: %s", err)
}
expectedErr := `yaml_test.unmarshalableIntValue: strconv.Atoi: parsing "atoierr": invalid syntax`
if !strings.Contains(err.Error(), expectedErr) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), expectedErr)
}
})

t.Run("nested yaml decode error", func(t *testing.T) {
t.Parallel()
var container unmarshalableIntContainer
err := yaml.Unmarshal([]byte(`value: yamlerr`), &container)
if err == nil {
t.Fatal("expected to error")
}
expectedErr := `cannot unmarshal into Go value of type yaml_test.unmarshalableIntValue: cannot unmarshal string into Go value of type int`
if !strings.Contains(err.Error(), expectedErr) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), expectedErr)
}
})
}

type literalContainer struct {
v string
}
Expand Down Expand Up @@ -3206,6 +3241,9 @@ type unmarshableMapKey struct {
}

func (mk *unmarshableMapKey) UnmarshalYAML(b []byte) error {
if string(b) == "errkey" {
return errors.New("invalid map key")
}
mk.Key = string(b)
return nil
}
Expand Down Expand Up @@ -3401,6 +3439,15 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) {
if val != "value" {
t.Fatalf("expected to have value \"value\", but got %q", val)
}

expectErr := "invalid map key"
err := yaml.Unmarshal([]byte(`errkey: value`), &m)
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), expectErr) {
t.Fatalf("error message %q should contain %q", err.Error(), expectErr)
}
}

type bytesUnmershalerWithMapAlias struct{}
Expand Down
1 change: 1 addition & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type (
DuplicateKeyError = errors.DuplicateKeyError
UnknownFieldError = errors.UnknownFieldError
UnexpectedNodeTypeError = errors.UnexpectedNodeTypeError
UnmarshalerError = errors.UnmarshalerError
Error = errors.Error
)

Expand Down
49 changes: 49 additions & 0 deletions internal/errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
_ Error = new(DuplicateKeyError)
_ Error = new(UnknownFieldError)
_ Error = new(UnexpectedNodeTypeError)
_ Error = new(UnmarshalerError)
)

type SyntaxError struct {
Expand Down Expand Up @@ -71,6 +72,12 @@ type UnexpectedNodeTypeError struct {
Token *token.Token
}

type UnmarshalerError struct {
Wrapped error
DstType reflect.Type
Token *token.Token
}

// ErrSyntax create syntax error instance with message and token
func ErrSyntax(msg string, tk *token.Token) *SyntaxError {
return &SyntaxError{
Expand Down Expand Up @@ -121,6 +128,24 @@ func ErrUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) *Unex
}
}

func ErrUnmarshaler(wrapped error, dstType reflect.Type, tk *token.Token) error {
if wrapped == nil {
return nil
}
if yamlErr, ok := wrapped.(Error); ok {
return &UnmarshalerError{
Wrapped: errors.New(yamlErr.GetMessage()),
DstType: dstType,
Token: tk,
}
}
return &UnmarshalerError{
Wrapped: wrapped,
DstType: dstType,
Token: tk,
}
}

func (e *SyntaxError) GetMessage() string {
return e.Message
}
Expand Down Expand Up @@ -232,6 +257,30 @@ func (e *UnexpectedNodeTypeError) msg() string {
return fmt.Sprintf("%s was used where %s is expected", e.Actual.YAMLName(), e.Expected.YAMLName())
}

func (e *UnmarshalerError) GetMessage() string {
return e.msg()
}

func (e *UnmarshalerError) GetToken() *token.Token {
return e.Token
}

func (e *UnmarshalerError) Error() string {
return e.FormatError(defaultFormatColor, defaultIncludeSource)
}

func (e *UnmarshalerError) FormatError(colored, inclSource bool) string {
return FormatError(e.msg(), e.Token, colored, inclSource)
}

func (e *UnmarshalerError) Unwrap() error {
return e.Wrapped
}

func (e *UnmarshalerError) msg() string {
return fmt.Sprintf("cannot unmarshal into Go value of type %s: %s", e.DstType, e.Wrapped.Error())
}

func FormatError(errMsg string, token *token.Token, colored, inclSource bool) string {
var pp printer.Printer
if token == nil {
Expand Down