diff --git a/dag.go b/dag.go index 9fe9a75..54aa2c2 100644 --- a/dag.go +++ b/dag.go @@ -80,6 +80,7 @@ func checkDAGRecursive[S any](step Step[S], visited map[string]struct{}) error { visited[ptr] = struct{}{} + // TODO: Handle stepWithErr switch s := step.(type) { case interface{ Unwrap() Step[S] }: return checkDAGRecursive(s.Unwrap(), visited) diff --git a/dag_test.go b/dag_test.go index a043c7b..59ee112 100644 --- a/dag_test.go +++ b/dag_test.go @@ -11,7 +11,6 @@ import ( func TestExecutor_Use(t *testing.T) { type useState struct{ indent int } - type useStep = Step[useState] validateResource := func(ctx context.Context, state useState) error { return nil } createResource := func(ctx context.Context, state useState) error { return nil } @@ -26,9 +25,7 @@ func TestExecutor_Use(t *testing.T) { Result( NewStep(createResource), NewStep(reportSuccess), - func(ctx context.Context, state useState, err error) useStep { - return NewStep(reportFailure) - }, + NewStep(reportFailure), ), ), ) @@ -70,9 +67,7 @@ dagger:seriesStep[useState·1] Result( NewStep(createResource), NewStep(reportSuccess), - func(ctx context.Context, state useState, err error) useStep { - return NewStep(reportFailure) - }, + NewStep(reportFailure), ), Series( If( @@ -125,11 +120,11 @@ dagger:TestExecutor_Use.func1 dagger:TestExecutor_Use.func3 dagger:TestExecutor_Use.func2 dagger:TestExecutor_Use.func3 - dagger:TestExecutor_Use.func6.3 - dagger:TestExecutor_Use.func6.5 - dagger:TestExecutor_Use.func6.7 + dagger:TestExecutor_Use.func6.2 + dagger:TestExecutor_Use.func6.4 + dagger:TestExecutor_Use.func6.6 + dagger:TestExecutor_Use.func6.8 dagger:TestExecutor_Use.func6.9 - dagger:TestExecutor_Use.func6.10 `, buf.String()) }) @@ -141,9 +136,7 @@ dagger:TestExecutor_Use.func3 Result( NewStep(createResource), NewStep(reportSuccess), - func(ctx context.Context, state useState, err error) useStep { - return NewStep(reportFailure) - }, + NewStep(reportFailure), ), ), ) @@ -210,11 +203,9 @@ func Test_buildDAG(t *testing.T) { step4 := NewStep(publishKafka) step5 := NewStep(updateDB) resultStep := &resultStep[dummyState]{ - mainStep: step2, - successStep: step3, - failureHandler: func(ctx context.Context, state dummyState, err error) Step[dummyState] { - return step4 - }, + mainStep: step2, + successStep: step3, + failureHandler: step4, } rootStep := &ifElseStep[dummyState]{ diff --git a/example_test.go b/example_test.go index 3946540..8fad3b8 100644 --- a/example_test.go +++ b/example_test.go @@ -97,7 +97,7 @@ func ExampleResult() { reportFailure := func(ctx context.Context, state exampleState) error { return nil } // Tip: Create a type alias like this to avoid using go generic syntax everywhere. - type exampleStateStep = dagger.Step[exampleState] + // type exampleStateStep = dagger.Step[exampleState] dag, err := dagger.New( dagger.Result( @@ -106,16 +106,7 @@ func ExampleResult() { // It will then run the success Step, if main Step returned no error dagger.NewStep(reportSuccess), // Otherwise, it will run the success Step, if main Step returned an error - func(ctx context.Context, state exampleState, err error) exampleStateStep { - // Note: It is encouraged to test that `failureProcedure` has no cycles - // via unit tests. - // ```go - // _, err := dagger.New(failureProcedure) - // assertNoError(err) - // ``` - failureProcedure := dagger.NewStep(reportFailure) - return failureProcedure - }, + dagger.NewStep(reportFailure), ), ) if err != nil { diff --git a/result_ctx.go b/result_ctx.go new file mode 100644 index 0000000..d091ff0 --- /dev/null +++ b/result_ctx.go @@ -0,0 +1,21 @@ +package dagger + +import ( + "context" +) + +type resultCtxKey int + +const resultErrKey resultCtxKey = iota + +func resultErrToContext(ctx context.Context, err error) context.Context { + return context.WithValue(ctx, resultErrKey, err) +} + +func resultErrFromContext(ctx context.Context) error { + if err, ok := ctx.Value(resultErrKey).(error); ok { + return err + } + + return nil +} diff --git a/result_step.go b/result_step.go new file mode 100644 index 0000000..8b66e8a --- /dev/null +++ b/result_step.go @@ -0,0 +1,194 @@ +package dagger + +import ( + "context" +) + +// FailureSelector is used to define closures that act as +// branch selector for Result's failure Step(s). +type FailureSelector[S any] func(ctx context.Context, err error) bool + +// stepWithErr is used to define closures that act as Step with error. +type stepWithErr[S any] func(ctx context.Context, state S, resErr error) error + +// NewResultErrStep creates a new Step that also has access to the error from +// the Result Step's main Step. +func NewResultErrStep[S any](f func(ctx context.Context, state S, resErr error) error) Step[S] { + return stepWithErr[S](f) +} + +func (f stepWithErr[S]) Exec(ctx context.Context, state S) error { + var s Step[S] = StepFunc[S](f.exec) + + c, ok := ctx.Value(middlewareKey).(MiddlewareChain[S]) + if ok { + si := stepInfo(s) + si.CanSkip = true + s = c.apply(s, si) + } + + return s.Exec(ctx, state) +} + +func (f stepWithErr[S]) exec(ctx context.Context, state S) error { + return f(ctx, state, resultErrFromContext(ctx)) +} + +// ResultFailureHandler is used to define entities that act as +// failure handler for Result Step. +type ResultFailureHandler[S any] interface { + // selectStep is used to select the Step to be executed + // based on the error returned by the mainStep. + selectStep(ctx context.Context, err error) Step[S] +} + +type resultStep[S any] struct { + mainStep Step[S] + successStep Step[S] + failureHandler ResultFailureHandler[S] +} + +var _ middlewareSkipper = (*resultStep[any])(nil) + +func (s *resultStep[S]) canSkip() bool { + return true +} + +func (s *resultStep[S]) Exec(ctx context.Context, state S) error { + if err := execWithContext(ctx, s.mainStep, state); err != nil { + return s.handleErr(ctx, state, err) + } + + return execWithContext(ctx, s.successStep, state) +} + +func (s *resultStep[S]) Unwrap() []Step[S] { + return []Step[S]{ + s.mainStep, + s.successStep, + // TODO: Make failure handler a part of the DAG, update Unwrap to return it. + } +} + +func (s *resultStep[S]) handleErr(ctx context.Context, state S, err error) error { + if s.failureHandler == nil { + return err + } + + if step := s.failureHandler.selectStep(ctx, err); step != nil { + return execWithContext(resultErrToContext(ctx, err), step, state) + } + + return err +} + +var _ ResultFailureHandler[any] = StepFunc[any](nil) + +//nolint:unused +func (f StepFunc[S]) selectStep(_ context.Context, _ error) Step[S] { return f } + +// Result executes the mainStep and uses the returned value to +// - execute successStep, if the returned error is nil +// - execute failureHandler, if the returned error is not nil +// +// Note: The failureHandler is used to define the failure branch, +// if the failureHandler returns a nil Step, Result's Step.Exec +// returns the mainStep's error. +func Result[S any](mainStep, successStep Step[S], failureHandler ResultFailureHandler[S]) Step[S] { + return &resultStep[S]{ + mainStep: mainStep, + successStep: successStep, + failureHandler: failureHandler, + } +} + +type failureBranch[S any] struct { + selector FailureSelector[S] + step Step[S] +} + +var _ FailureBranch[any] = (*failureBranch[any])(nil) + +//nolint:unused +func (s *failureBranch[S]) isFailureBranch(S) {} + +var _ ResultFailureHandler[any] = (*failureBranch[any])(nil) + +//nolint:unused +func (s *failureBranch[S]) selectStep(ctx context.Context, err error) Step[S] { + if s.selector(ctx, err) { + return s.step + } + + return nil +} + +// NewBranch creates a new FailureBranch with the given FailureSelector and Step. +func NewBranch[S any](selector FailureSelector[S], step Step[S]) FailureBranch[S] { + return &failureBranch[S]{selector: selector, step: step} +} + +type defaultBranch[S any] struct { + step Step[S] +} + +var _ FailureBranch[any] = (*defaultBranch[any])(nil) + +//nolint:unused +func (d *defaultBranch[S]) isFailureBranch(S) {} + +var _ ResultFailureHandler[any] = (*defaultBranch[any])(nil) + +//nolint:unused +func (d *defaultBranch[S]) selectStep(_ context.Context, _ error) Step[S] { return d.step } + +func DefaultBranch[S any](step Step[S]) FailureBranch[S] { + return &defaultBranch[S]{step: step} +} + +// FailureBranch is used to define entities that act as branches +// in a ResultFailureHandler. +// +// Note: This is used to prevent misuse of the HandleMultiFailure function. +type FailureBranch[S any] interface{ isFailureBranch(S) } + +type multiFailureHandler[S any] struct { + branches []ResultFailureHandler[S] +} + +var _ ResultFailureHandler[any] = (*multiFailureHandler[any])(nil) + +//nolint:unused +func (m *multiFailureHandler[S]) selectStep(ctx context.Context, err error) Step[S] { + for _, branch := range m.branches { + if step := branch.selectStep(ctx, err); step != nil { + return step + } + } + + return nil +} + +// HandleMultiFailure takes in FailureBranch(s) and returns a ResultFailureHandler. +// It is used to handle multiple failure branches in a Result Step. +// The branches are evaluated in the order they are passed. +// +// The behavior is as follows: +// - if a FailureBranch is eligible, it is executed and the remaining branches +// are ignored. +// - if no FailureBranch is eligible, the DefaultBranch is executed, if provided, +// otherwise the mainStep's error is returned. +func HandleMultiFailure[S any](branches ...FailureBranch[S]) ResultFailureHandler[S] { + res := make([]ResultFailureHandler[S], len(branches)) + + for i, branch := range branches { + switch b := branch.(type) { + case *failureBranch[S]: + res[i] = b + case *defaultBranch[S]: + res[i] = b + } + } + + return &multiFailureHandler[S]{branches: res} +} diff --git a/result_step_test.go b/result_step_test.go new file mode 100644 index 0000000..65f4f1f --- /dev/null +++ b/result_step_test.go @@ -0,0 +1,197 @@ +package dagger + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResult(t *testing.T) { + t.Run("SuccessBranch", func(t *testing.T) { + success, failure := 0, 0 + + ss := NewStep(func(ctx context.Context, state testState) error { success++; return nil }) + fs := NewStep(func(ctx context.Context, state testState) error { failure++; return nil }) + ms := NewStep(func(ctx context.Context, state testState) error { return nil }) + + err := Result(ms, ss, fs).Exec(context.TODO(), testState{}) + assert.NoError(t, err) + assert.Equal(t, 1, success) + assert.Equal(t, 0, failure) + }) + + t.Run("FailureSingleBranch", func(t *testing.T) { + success, failure := 0, 0 + + ss := NewStep(func(ctx context.Context, state testState) error { success++; return nil }) + fs := NewStep(func(ctx context.Context, state testState) error { failure++; return nil }) + ms := NewStep(func(ctx context.Context, state testState) error { return testErrStep }) + + err := Result(ms, ss, fs).Exec(context.TODO(), testState{}) + assert.NoError(t, err) + assert.Equal(t, 0, success) + assert.Equal(t, 1, failure) + }) + + t.Run("FailureMultipleBranch", func(t *testing.T) { + success, failure := 0, 0 + + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + branch1Selected, branch2Selected, defSelected := 0, 0, 0 + branch1Selector := func(ctx context.Context, err error) bool { + if errors.Is(err, err1) { + branch1Selected += 1 + return true + } + return false + } + branch2Selector := func(ctx context.Context, err error) bool { + if errors.Is(err, err2) { + branch2Selected += 1 + return true + } + return false + } + + ss := NewStep(func(ctx context.Context, state testState) error { success++; return nil }) + fs := NewStep(func(ctx context.Context, state testState) error { failure++; return nil }) + ds := NewStep(func(ctx context.Context, state testState) error { defSelected += 1; return nil }) + + t.Run("DefaultBranch", func(t *testing.T) { + ms := NewStep(func(ctx context.Context, state testState) error { return errors.New("error random") }) + + mfh := HandleMultiFailure( + NewBranch(branch1Selector, fs), + NewBranch(branch2Selector, fs), + DefaultBranch(ds), + ) + + err := Result(ms, ss, mfh).Exec(context.TODO(), testState{}) + + assert.NoError(t, err) + assert.Equal(t, 0, success) + assert.Equal(t, 0, failure) + assert.Equal(t, 0, branch1Selected) + assert.Equal(t, 0, branch2Selected) + assert.Equal(t, 1, defSelected) + }) + + t.Run("Branch1", func(t *testing.T) { + ms := NewStep(func(ctx context.Context, state testState) error { return err1 }) + + mfh := HandleMultiFailure( + NewBranch(branch1Selector, fs), + NewBranch(branch2Selector, fs), + DefaultBranch(ds), + ) + + err := Result(ms, ss, mfh).Exec(context.TODO(), testState{}) + + assert.NoError(t, err) + assert.Equal(t, 0, success) + assert.Equal(t, 1, failure) + assert.Equal(t, 1, branch1Selected) + assert.Equal(t, 0, branch2Selected) + assert.Equal(t, 1, defSelected) + }) + + t.Run("Branch2", func(t *testing.T) { + ms := NewStep(func(ctx context.Context, state testState) error { return err2 }) + + mfh := HandleMultiFailure( + NewBranch(branch1Selector, fs), + NewBranch(branch2Selector, fs), + DefaultBranch(ds), + ) + + err := Result(ms, ss, mfh).Exec(context.TODO(), testState{}) + + assert.NoError(t, err) + assert.Equal(t, 0, success) + assert.Equal(t, 2, failure) + assert.Equal(t, 1, branch1Selected) + assert.Equal(t, 1, branch2Selected) + assert.Equal(t, 1, defSelected) + }) + + t.Run("NoBranchMatch", func(t *testing.T) { + ms := NewStep(func(ctx context.Context, state testState) error { return errors.New("random") }) + + mfh := HandleMultiFailure( + NewBranch( + func(ctx context.Context, err error) bool { return false }, // no match + fs, + ), + ) + + err := Result(ms, ss, mfh).Exec(context.TODO(), testState{}) + + assert.Error(t, err) + assert.Equal(t, 0, success) + assert.Equal(t, 2, failure) + assert.Equal(t, 1, branch1Selected) + assert.Equal(t, 1, branch2Selected) + assert.Equal(t, 1, defSelected) + }) + }) +} + +func TestStepWithErr_Exec(t *testing.T) { + t.Run("Success", func(t *testing.T) { + type useState struct{ indent int } + + errCount := 0 + + ss := NewResultErrStep(func(ctx context.Context, state useState, err error) error { + if err != nil { + errCount++ + } + + return err + }) + ms := NewStep(func(ctx context.Context, state useState) error { return errors.New("random") }) + mfh := HandleMultiFailure( + NewBranch[useState]( + func(ctx context.Context, err error) bool { return true }, // no match + ss, + ), + ) + + step := Result( + ms, + ss, + mfh, + ) + + dag, err := New(step) + assert.NoError(t, err) + + buf := new(bytes.Buffer) + dag.Use(func(next Step[useState], info Info) Step[useState] { + return NewStep(func(ctx context.Context, state useState) error { + if info.CanSkip { + return next.Exec(ctx, useState{indent: state.indent + 1}) + } + + buf.WriteString(strings.Repeat("\t", state.indent-1)) + buf.WriteString(info.Name.String()) + buf.WriteString("\n") + + return next.Exec(ctx, useState{indent: state.indent + 1}) + }) + }) + + err = dag.Exec(context.TODO(), useState{}) + assert.Error(t, err) + assert.Equal(t, 1, errCount) + assert.Equal(t, `dagger:TestStepWithErr_Exec.func1.2 +dagger:stepWithErr[useState·2] +`, buf.String()) + }) +} diff --git a/step.go b/step.go index 5e8d4a0..9475d44 100644 --- a/step.go +++ b/step.go @@ -25,8 +25,6 @@ var _ Step[any] = (*StepFunc[any])(nil) // branch selector for Step(s). type Selector[S any] func(state S) bool -type StepErrorHandler[S any] func(ctx context.Context, state S, err error) Step[S] - type ifStep[S any] struct { condition Selector[S] thenStep Step[S] @@ -87,49 +85,6 @@ func IfElse[S any](condition Selector[S], thenStep, elseStep Step[S]) Step[S] { return &ifElseStep[S]{condition: condition, thenStep: thenStep, elseStep: elseStep} } -type resultStep[S any] struct { - mainStep Step[S] - successStep Step[S] - failureHandler StepErrorHandler[S] -} - -var _ middlewareSkipper = (*resultStep[any])(nil) - -func (s *resultStep[S]) canSkip() bool { - return true -} - -func (s *resultStep[S]) Exec(ctx context.Context, state S) error { - if err := execWithContext(ctx, s.mainStep, state); err != nil { - return execWithContext(ctx, s.failureHandler(ctx, state, err), state) - } - - return execWithContext(ctx, s.successStep, state) -} - -func (s *resultStep[S]) Unwrap() []Step[S] { - return []Step[S]{ - s.mainStep, - s.successStep, - // TODO: Make failure handler a part of the DAG, update Unwrap to return it. - } -} - -// Result Step executes the mainStep and uses the returned value to -// - execute successStep, if the returned error is nil -// - call failureHandler to execute returned step, if the returned error is not nil -// -// Note: It is recommended to make sure that the Step returned by -// failureHandler does not contain any cycles, use New on all possible -// return Step(s) to assert it in unit tests. -func Result[S any](mainStep, successStep Step[S], failureHandler StepErrorHandler[S]) Step[S] { - return &resultStep[S]{ - mainStep: mainStep, - successStep: successStep, - failureHandler: failureHandler, - } -} - type seriesStep[S any] struct { steps []Step[S] } diff --git a/step_test.go b/step_test.go index 78bf81f..3fa10e0 100644 --- a/step_test.go +++ b/step_test.go @@ -67,38 +67,6 @@ func TestIfElse(t *testing.T) { assert.Equal(t, 3, count) } -func TestResult(t *testing.T) { - t.Run("SuccessBranch", func(t *testing.T) { - success, failure := 0, 0 - - ss := NewStep(func(ctx context.Context, state testState) error { success++; return nil }) - fs := NewStep(func(ctx context.Context, state testState) error { failure++; return nil }) - ms := NewStep(func(ctx context.Context, state testState) error { return nil }) - - err := Result(ms, ss, func(ctx context.Context, state testState, err error) Step[testState] { - return fs - }).Exec(context.TODO(), testState{}) - assert.NoError(t, err) - assert.Equal(t, 1, success) - assert.Equal(t, 0, failure) - }) - - t.Run("FailureBranch", func(t *testing.T) { - success, failure := 0, 0 - - ss := NewStep(func(ctx context.Context, state testState) error { success++; return nil }) - fs := NewStep(func(ctx context.Context, state testState) error { failure++; return nil }) - ms := NewStep(func(ctx context.Context, state testState) error { return testErrStep }) - - err := Result(ms, ss, func(ctx context.Context, state testState, err error) Step[testState] { - return fs - }).Exec(context.TODO(), testState{}) - assert.NoError(t, err) - assert.Equal(t, 0, success) - assert.Equal(t, 1, failure) - }) -} - func TestSeries(t *testing.T) { appendStepIn := func(res *[]string) func(string) Step[testState] { return func(name string) Step[testState] { @@ -209,9 +177,7 @@ func Test_canSkip(t *testing.T) { step: Result( NewStep(func(context.Context, testState) error { return nil }), NewStep(func(context.Context, testState) error { return nil }), - func(context.Context, testState, error) Step[testState] { - return NewStep(func(context.Context, testState) error { return nil }) - }, + NewStep(func(context.Context, testState) error { return nil }), ), }, {