diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..98137af --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,122 @@ +package errors + +import ( + "context" + "fmt" + "runtime" + + bugsnag_errors "github.com/bugsnag/bugsnag-go/v2/errors" + "github.com/sirupsen/logrus" +) + +// Fields can be attached to errors like this: +// errors.Wrap(err, "invalid value", errors.Fields{"value": value}) +type Fields map[string]interface{} + +type ErrorWithDetails interface { + ErrorDetails() map[string]interface{} +} + +type baseError struct { + err error + message string + fields Fields + stack []uintptr +} + +func (e *baseError) Error() string { + return e.message +} + +func (e *baseError) Unwrap() error { + return e.err +} + +func (e *baseError) Callers() []uintptr { + return findStackFromError(e) // lazy stack lookup in wrapped errors +} + +func (e *baseError) LogFields() logrus.Fields { + return logrus.Fields(e.fields) +} + +var _ bugsnag_errors.ErrorWithCallers = &baseError{} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &baseError{ + message: fmt.Sprintf(format, args...), + stack: captureStack(), + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string, fields ...Fields) error { + if err == nil { + return nil + } + + return &baseError{ + err: err, + message: fmt.Sprintf("%s: %s", message, err.Error()), + fields: mergeFieldsCtx(nil, err, fields...), + stack: captureStack(), + } +} + +// WrapCtx returns an error annotating err with a stack trace and log fields. +// The log fields are captured from context.Context and arguments. +// If err is nil, WrapCtx returns nil. +func WrapCtx(ctx context.Context, err error, message string, fields ...Fields) error { + if err == nil { + return nil + } + + return &baseError{ + err: err, + message: message + ": " + err.Error(), + fields: mergeFieldsCtx(ctx, err, fields...), + stack: captureStack(), + } +} + +// With returns an error annotating err with a stack trace and log fields. +// If err is nil, With returns nil. +func With(err error, fields ...Fields) error { + if err == nil { + return nil + } + + return &baseError{ + err: err, + message: err.Error(), + fields: mergeFieldsCtx(nil, err, fields...), + stack: captureStack(), + } +} + +// WithCtx returns an error annotating err with a stack trace and log fields. +// The log fields are captured from context.Context and arguments. +// If err is nil, WithCtx returns nil. +func WithCtx(ctx context.Context, err error, fields ...Fields) error { + if err == nil { + return nil + } + + return &baseError{ + err: err, + message: err.Error(), + fields: mergeFieldsCtx(ctx, err, fields...), + stack: captureStack(), + } +} + +func captureStack() []uintptr { + var pcs [32]uintptr + n := runtime.Callers(3, pcs[:]) // Wrap/WrapCtx -> findStack -> runtime.Callers + return pcs[0:n] +} diff --git a/errors/errors_test.go b/errors/errors_test.go new file mode 100644 index 0000000..86ef320 --- /dev/null +++ b/errors/errors_test.go @@ -0,0 +1,94 @@ +package errors + +import ( + "context" + stderrors "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + err := New("MSG") + require.Equal(t, "MSG", err.Error()) +} + +func TestErrorf(t *testing.T) { + err := Errorf("MSG %s", "ME") + require.Equal(t, "MSG ME", err.Error()) +} + +func TestWrap_Nil(t *testing.T) { + err := Wrap(nil, "") + require.Nil(t, err) +} + +func TestWrap(t *testing.T) { + err := Wrap(stderrors.New("inner"), "outer") + require.NotNil(t, err) + require.Equal(t, "outer: inner", err.Error()) +} + +func TestWrapJoined(t *testing.T) { + err := Wrap(stderrors.Join(New("inner 1"), New("inner 2")), "outer") + require.NotNil(t, err) + require.Equal(t, "outer: inner 1\ninner 2", err.Error()) +} + +func TestJoinWrapped(t *testing.T) { + err := stderrors.Join(New("first"), Wrap(New("inner"), "outer")) + require.NotNil(t, err) + require.Equal(t, "first\nouter: inner", err.Error()) +} + +func TestWrapCtx_Nil(t *testing.T) { + ctx := context.Background() + err := WrapCtx(ctx, nil, "") + require.Nil(t, err) +} + +func TestWrapCtx(t *testing.T) { + ctx := context.Background() + err := WrapCtx(ctx, stderrors.New("inner"), "outer") + require.NotNil(t, err) + require.Equal(t, "outer: inner", err.Error()) +} + +func TestWrapCtxJoined(t *testing.T) { + ctx := context.Background() + err := WrapCtx(ctx, stderrors.Join(New("first"), New("second")), "outer") + require.NotNil(t, err) + require.Equal(t, "outer: first\nsecond", err.Error()) +} + +func TestWithCtx(t *testing.T) { + ctx := context.Background() + err := WithCtx(ctx, stderrors.New("inner"), Fields{"key": "val"}) + require.NotNil(t, err) + require.Equal(t, "inner", err.Error()) + + require.Equal(t, Fields{"key": "val"}, FieldsFromError(err)) +} + +func TestWithCtxJoined(t *testing.T) { + ctx := context.Background() + err := WithCtx(ctx, stderrors.Join(New("first"), New("second")), Fields{"key": "val"}) + require.NotNil(t, err) + require.Equal(t, "first\nsecond", err.Error()) + require.Equal(t, Fields{"key": "val"}, FieldsFromError(err)) +} + +func TestWith(t *testing.T) { + err := With(stderrors.New("inner"), Fields{"key": "val"}) + require.NotNil(t, err) + require.Equal(t, "inner", err.Error()) + + require.Equal(t, Fields{"key": "val"}, FieldsFromError(err)) +} + +func TestWithJoined(t *testing.T) { + err := With(stderrors.Join(New("first"), New("second")), Fields{"key": "val"}) + require.NotNil(t, err) + require.Equal(t, "first\nsecond", err.Error()) + require.Equal(t, Fields{"key": "val"}, FieldsFromError(err)) +} diff --git a/errors/logfields.go b/errors/logfields.go new file mode 100644 index 0000000..fac79c2 --- /dev/null +++ b/errors/logfields.go @@ -0,0 +1,46 @@ +package errors + +import ( + "github.com/pkg/errors" + + "github.com/Shopify/goose/v2/logger" +) + +type LoggableError interface { + error + logger.Loggable +} + +func FieldsFromError(err error) Fields { + var loggable LoggableError + if joined, ok := err.(interface{ Unwrap() []error }); ok { + fs := []Fields{} + for _, e := range joined.Unwrap() { + if errors.As(e, &loggable) { + fs = append(fs, Fields(loggable.LogFields())) + } + } + return mergeFields(fs) + } + if errors.As(err, &loggable) { + return Fields(loggable.LogFields()) + } + + return Fields{} +} + +func mergeFieldsCtx(ctx logger.Valuer, err error, fieldsList ...Fields) Fields { + fieldsList = append([]Fields{FieldsFromError(err)}, fieldsList...) + fieldsList = append(fieldsList, Fields(logger.GetLoggableValues(ctx))) + return mergeFields(fieldsList) +} + +func mergeFields(fieldsList []Fields) Fields { + fields := Fields{} + for _, fs := range fieldsList { + for k, v := range fs { + fields[k] = v + } + } + return fields +} diff --git a/errors/logfields_test.go b/errors/logfields_test.go new file mode 100644 index 0000000..74add83 --- /dev/null +++ b/errors/logfields_test.go @@ -0,0 +1,76 @@ +package errors + +import ( + "context" + stderrors "errors" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/Shopify/goose/v2/logger" +) + +func Test_FieldsFromError_NoFields(t *testing.T) { + require.Empty(t, FieldsFromError(nil)) + require.Empty(t, FieldsFromError(New(""))) +} + +type testError struct { + error + fields logrus.Fields +} + +func (e *testError) Unwrap() error { + return e.error +} + +func (e *testError) LogFields() logrus.Fields { + return e.fields +} + +func Test_FieldsFromError(t *testing.T) { + err1 := &testError{error: New("foo"), fields: logrus.Fields{"KEY1": "VAL1", "KEY2": "VAL2"}} + err2 := Wrap(err1, "bar", Fields{"KEY2": "VAL3", "KEY3": "VAL3"}) + + // Fields from outer error have precedence + require.Equal(t, Fields{"KEY1": "VAL1", "KEY2": "VAL3", "KEY3": "VAL3"}, FieldsFromError(err2)) +} + +func Test_FieldsFromError_From_Context(t *testing.T) { + originalErr := Wrap(New(""), "", Fields{"KEY1": "VAL1"}) + + ctx := context.Background() + ctx = logger.WithFields(ctx, logrus.Fields{"KEY1": "VAL1", "KEY2": "VAL2"}) + + err := WrapCtx(ctx, originalErr, "", Fields{"KEY2": "EXTRA", "EXTRA": "EXTRA"}) // KEY2 overlap + require.Equal(t, Fields{ + "KEY1": "VAL1", + "KEY2": "VAL2", // fields from inner error have precedence. + "EXTRA": "EXTRA", + }, FieldsFromError(err)) +} + +func Test_FieldsFromJoinedError(t *testing.T) { + err1 := Wrap(New(""), "", Fields{"FOO": "BAR"}) + err2 := stderrors.Join(Wrap(err1, "", Fields{"BAZ": "BOO"}), New("second")) + + extracted := FieldsFromError(err2) + require.Equal(t, Fields{"FOO": "BAR", "BAZ": "BOO"}, extracted) + + err3 := stderrors.Join(Wrap(err1, "", Fields{"BAZ": "BOO"}), Wrap(New(""), "", Fields{"FOO": "BAR"})) + + extracted = FieldsFromError(err3) + require.Equal(t, Fields{"FOO": "BAR", "BAZ": "BOO"}, extracted) + + err4 := stderrors.Join(Wrap(New(""), "", Fields{"FOO": "BAR"}), Wrap(New(""), "", Fields{"BAZ": "BOO"})) + + extracted = FieldsFromError(err4) + require.Equal(t, Fields{"FOO": "BAR", "BAZ": "BOO"}, extracted) + + err5 := stderrors.Join(Wrap(New(""), "", Fields{"FRUIT": "BANANA"}), New("")) + err6 := Wrap(stderrors.Join(Wrap(err5, "", Fields{"BAZ": "BOO"}), Wrap(New(""), "", Fields{"FOO": "BAR"})), "", Fields{"JOINED": "YES"}) + + extracted = FieldsFromError(err6) + require.Equal(t, Fields{"FOO": "BAR", "BAZ": "BOO", "JOINED": "YES", "FRUIT": "BANANA"}, extracted) +} diff --git a/errors/stack.go b/errors/stack.go new file mode 100644 index 0000000..9b86309 --- /dev/null +++ b/errors/stack.go @@ -0,0 +1,78 @@ +package errors + +import ( + stderrors "errors" + "fmt" + "runtime" + "strings" + + pkgerrors "github.com/pkg/errors" +) + +// findStackFromError returns the deepest stacktrace found. Support Courier errors and pkg/errors. +func findStackFromError(err error) []uintptr { + if joined, ok := err.(interface{ Unwrap() []error }); ok { + for _, e := range joined.Unwrap() { + stack := findStackFromError(e) // recursion + if stack != nil { + return stack // an inner error has a stacktrace, escape the recursion + } + } + } else if wrappedErr := stderrors.Unwrap(err); wrappedErr != nil { + stack := findStackFromError(wrappedErr) // recursion + if stack != nil { + return stack // an inner error has a stacktrace, escape the recursion + } + } + + // starting from the inner error, look for stacktrace + if stack := stackFromBaseError(err); stack != nil { + return stack + } + return stackFromPkgError(err) +} + +type pkgErrorWithStacktrace interface { + StackTrace() pkgerrors.StackTrace +} + +func stackFromPkgError(err error) []uintptr { + var errWithStacktrace pkgErrorWithStacktrace + if !stderrors.As(err, &errWithStacktrace) { + return nil + } + + stacktrace := errWithStacktrace.StackTrace() + callers := make([]uintptr, len(stacktrace)) + for i, pc := range stacktrace { + callers[i] = uintptr(pc) // de-alias from pkgerrors.Frame + } + return callers +} + +func stackFromBaseError(err error) []uintptr { + var errWithStacktrace *baseError + if !stderrors.As(err, &errWithStacktrace) { + return nil + } + + return errWithStacktrace.stack +} + +func formatStack(s []uintptr) string { // useful for debugging and writing tests. + var builder strings.Builder + + frames := runtime.CallersFrames(s) + for { + frame, more := frames.Next() + + builder.WriteString(fmt.Sprintf("%s\n\t%s:%d", frame.Function, frame.File, frame.Line)) + builder.WriteString("\n") + + if !more { + break + } + } + + return builder.String() +} diff --git a/errors/stack_test.go b/errors/stack_test.go new file mode 100644 index 0000000..233b5c3 --- /dev/null +++ b/errors/stack_test.go @@ -0,0 +1,73 @@ +package errors + +import ( + stderrors "errors" + "testing" + + bugsnagerrors "github.com/bugsnag/bugsnag-go/v2/errors" + pkgErrors "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func rawStdError() error { + return New("") +} + +func nestedPkgError() error { + return pkgErrors.Wrap(New(""), "") +} + +func nestedBaseError() error { + return Wrap(New(""), "") +} + +func joinedError() error { + return stderrors.Join(New("first"), New("second")) +} + +func nestedJoinedError() error { + return Wrap(stderrors.Join(New("second"), New("third")), "first") +} + +func Test_baseError_Callers(t *testing.T) { + tests := []struct { + test string + wrappedErr func() error + stackLen int + }{ + { + test: "baseError-stdError", + wrappedErr: rawStdError, + stackLen: 3, // asm_amd64.s - testing.go - Wrap(tt.wrappedErr(), "") + }, + { + test: "baseError-baseError-stdError", + wrappedErr: nestedBaseError, + stackLen: 4, // asm_amd64.s - testing.go - Wrap(tt.wrappedErr(), "") - nestedBaseError + }, + { + test: "baseError-pkgError-stdError", + wrappedErr: nestedPkgError, + stackLen: 4, // asm_amd64.s - testing.go - Wrap(tt.wrappedErr(), "") - nestedPkgError + }, + { + test: "baseError-stdJoinedError", + wrappedErr: joinedError, + stackLen: 3, + }, + { + test: "baseError-stdNestedJoinedError", + wrappedErr: nestedJoinedError, + stackLen: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.test, func(t *testing.T) { + err := Wrap(tt.wrappedErr(), "") + + stack := err.(bugsnagerrors.ErrorWithCallers).Callers() //nolint:errorlint + require.Len(t, stack, tt.stackLen, formatStack(stack)) + }) + } +} diff --git a/errors/stdlib.go b/errors/stdlib.go new file mode 100644 index 0000000..4965b8c --- /dev/null +++ b/errors/stdlib.go @@ -0,0 +1,56 @@ +package errors + +import stderrors "errors" + +// New returns an error that formats as the given text. +// Each call to New returns a distinct error value even if the text is identical. +func New(text string) error { + return stderrors.New(text) +} + +// Is reports whether any error in err's chain matches target. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error is considered to match a target if it is equal to that target or if +// it implements a method Is(error) bool such that Is(target) returns true. +// +// An error type might provide an Is method so it can be treated as equivalent +// to an existing error. For example, if MyError defines +// +// func (m MyError) Is(target error) bool { return target == os.ErrExist } +// +// then Is(MyError{}, os.ErrExist) returns true. See syscall.Errno.Is for +// an example in the standard library. +func Is(err, target error) bool { + return stderrors.Is(err, target) +} + +// As finds the first error in err's chain that matches target, and if so, sets +// target to that error value and returns true. Otherwise, it returns false. +// +// The chain consists of err itself followed by the sequence of errors obtained by +// repeatedly calling Unwrap. +// +// An error matches target if the error's concrete value is assignable to the value +// pointed to by target, or if the error has a method As(interface{}) bool such that +// As(target) returns true. In the latter case, the As method is responsible for +// setting target. +// +// An error type might provide an As method so it can be treated as if it were a +// a different error type. +// +// As panics if target is not a non-nil pointer to either a type that implements +// error, or to any interface type. +func As(err error, target interface{}) bool { + return stderrors.As(err, target) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's +// type contains an Unwrap method returning error. +// Otherwise, Unwrap returns nil. +// Unwrap returns nil if the Unwrap method returns []error. +func Unwrap(err error) error { + return stderrors.Unwrap(err) +} diff --git a/errors/stdlib_test.go b/errors/stdlib_test.go new file mode 100644 index 0000000..ebb7ca4 --- /dev/null +++ b/errors/stdlib_test.go @@ -0,0 +1,15 @@ +package errors + +import ( + stderrors "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnwrapJoinedErrors(t *testing.T) { + joined := stderrors.Join(New("first"), New("second")) + + unwrapped := Unwrap(joined) + assert.Nil(t, unwrapped) +}