Skip to content

[WIP] feat: make Result's failure branch compile-time aware #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func checkDAGRecursive[S any](step Step[S], visited map[string]struct{}) error {

visited[ptr] = struct{}{}

// TODO: Handle stepWithErr
switch s := step.(type) {
case interface{ Unwrap() Step[S] }:
return checkDAGRecursive(s.Unwrap(), visited)
Expand Down
29 changes: 10 additions & 19 deletions dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

func TestExecutor_Use(t *testing.T) {
type useState struct{ indent int }
type useStep = Step[useState]

validateResource := func(ctx context.Context, state useState) error { return nil }
createResource := func(ctx context.Context, state useState) error { return nil }
Expand All @@ -26,9 +25,7 @@ func TestExecutor_Use(t *testing.T) {
Result(
NewStep(createResource),
NewStep(reportSuccess),
func(ctx context.Context, state useState, err error) useStep {
return NewStep(reportFailure)
},
NewStep(reportFailure),
),
),
)
Expand Down Expand Up @@ -70,9 +67,7 @@ dagger:seriesStep[useState·1]
Result(
NewStep(createResource),
NewStep(reportSuccess),
func(ctx context.Context, state useState, err error) useStep {
return NewStep(reportFailure)
},
NewStep(reportFailure),
),
Series(
If(
Expand Down Expand Up @@ -125,11 +120,11 @@ dagger:TestExecutor_Use.func1
dagger:TestExecutor_Use.func3
dagger:TestExecutor_Use.func2
dagger:TestExecutor_Use.func3
dagger:TestExecutor_Use.func6.3
dagger:TestExecutor_Use.func6.5
dagger:TestExecutor_Use.func6.7
dagger:TestExecutor_Use.func6.2
dagger:TestExecutor_Use.func6.4
dagger:TestExecutor_Use.func6.6
dagger:TestExecutor_Use.func6.8
dagger:TestExecutor_Use.func6.9
dagger:TestExecutor_Use.func6.10
`, buf.String())
})

Expand All @@ -141,9 +136,7 @@ dagger:TestExecutor_Use.func3
Result(
NewStep(createResource),
NewStep(reportSuccess),
func(ctx context.Context, state useState, err error) useStep {
return NewStep(reportFailure)
},
NewStep(reportFailure),
),
),
)
Expand Down Expand Up @@ -210,11 +203,9 @@ func Test_buildDAG(t *testing.T) {
step4 := NewStep(publishKafka)
step5 := NewStep(updateDB)
resultStep := &resultStep[dummyState]{
mainStep: step2,
successStep: step3,
failureHandler: func(ctx context.Context, state dummyState, err error) Step[dummyState] {
return step4
},
mainStep: step2,
successStep: step3,
failureHandler: step4,
}

rootStep := &ifElseStep[dummyState]{
Expand Down
13 changes: 2 additions & 11 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func ExampleResult() {
reportFailure := func(ctx context.Context, state exampleState) error { return nil }

// Tip: Create a type alias like this to avoid using go generic syntax everywhere.
type exampleStateStep = dagger.Step[exampleState]
// type exampleStateStep = dagger.Step[exampleState]

dag, err := dagger.New(
dagger.Result(
Expand All @@ -106,16 +106,7 @@ func ExampleResult() {
// It will then run the success Step, if main Step returned no error
dagger.NewStep(reportSuccess),
// Otherwise, it will run the success Step, if main Step returned an error
func(ctx context.Context, state exampleState, err error) exampleStateStep {
// Note: It is encouraged to test that `failureProcedure` has no cycles
// via unit tests.
// ```go
// _, err := dagger.New(failureProcedure)
// assertNoError(err)
// ```
failureProcedure := dagger.NewStep(reportFailure)
return failureProcedure
},
dagger.NewStep(reportFailure),
),
)
if err != nil {
Expand Down
21 changes: 21 additions & 0 deletions result_ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package dagger

import (
"context"
)

type resultCtxKey int

const resultErrKey resultCtxKey = iota

func resultErrToContext(ctx context.Context, err error) context.Context {
return context.WithValue(ctx, resultErrKey, err)
}

func resultErrFromContext(ctx context.Context) error {
if err, ok := ctx.Value(resultErrKey).(error); ok {
return err
}

return nil

Check warning on line 20 in result_ctx.go

View check run for this annotation

Codecov / codecov/patch

result_ctx.go#L20

Added line #L20 was not covered by tests
}
194 changes: 194 additions & 0 deletions result_step.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package dagger

import (
"context"
)

// FailureSelector is used to define closures that act as
// branch selector for Result's failure Step(s).
type FailureSelector[S any] func(ctx context.Context, err error) bool

// stepWithErr is used to define closures that act as Step with error.
type stepWithErr[S any] func(ctx context.Context, state S, resErr error) error

// NewResultErrStep creates a new Step that also has access to the error from
// the Result Step's main Step.
func NewResultErrStep[S any](f func(ctx context.Context, state S, resErr error) error) Step[S] {
return stepWithErr[S](f)
}

func (f stepWithErr[S]) Exec(ctx context.Context, state S) error {
var s Step[S] = StepFunc[S](f.exec)

c, ok := ctx.Value(middlewareKey).(MiddlewareChain[S])
if ok {
si := stepInfo(s)
si.CanSkip = true
s = c.apply(s, si)
}

return s.Exec(ctx, state)
}

func (f stepWithErr[S]) exec(ctx context.Context, state S) error {
return f(ctx, state, resultErrFromContext(ctx))
}

// ResultFailureHandler is used to define entities that act as
// failure handler for Result Step.
type ResultFailureHandler[S any] interface {
// selectStep is used to select the Step to be executed
// based on the error returned by the mainStep.
selectStep(ctx context.Context, err error) Step[S]
}

type resultStep[S any] struct {
mainStep Step[S]
successStep Step[S]
failureHandler ResultFailureHandler[S]
}

var _ middlewareSkipper = (*resultStep[any])(nil)

func (s *resultStep[S]) canSkip() bool {
return true
}

func (s *resultStep[S]) Exec(ctx context.Context, state S) error {
if err := execWithContext(ctx, s.mainStep, state); err != nil {
return s.handleErr(ctx, state, err)
}

return execWithContext(ctx, s.successStep, state)
}

func (s *resultStep[S]) Unwrap() []Step[S] {
return []Step[S]{
s.mainStep,
s.successStep,
// TODO: Make failure handler a part of the DAG, update Unwrap to return it.
}
}

func (s *resultStep[S]) handleErr(ctx context.Context, state S, err error) error {
if s.failureHandler == nil {
return err
}

Check warning on line 76 in result_step.go

View check run for this annotation

Codecov / codecov/patch

result_step.go#L75-L76

Added lines #L75 - L76 were not covered by tests

if step := s.failureHandler.selectStep(ctx, err); step != nil {
return execWithContext(resultErrToContext(ctx, err), step, state)
}

return err
}

var _ ResultFailureHandler[any] = StepFunc[any](nil)

//nolint:unused
func (f StepFunc[S]) selectStep(_ context.Context, _ error) Step[S] { return f }

// Result executes the mainStep and uses the returned value to
// - execute successStep, if the returned error is nil
// - execute failureHandler, if the returned error is not nil
//
// Note: The failureHandler is used to define the failure branch,
// if the failureHandler returns a nil Step, Result's Step.Exec
// returns the mainStep's error.
func Result[S any](mainStep, successStep Step[S], failureHandler ResultFailureHandler[S]) Step[S] {
return &resultStep[S]{
mainStep: mainStep,
successStep: successStep,
failureHandler: failureHandler,
}
}

type failureBranch[S any] struct {
selector FailureSelector[S]
step Step[S]
}

var _ FailureBranch[any] = (*failureBranch[any])(nil)

//nolint:unused
func (s *failureBranch[S]) isFailureBranch(S) {}

Check warning on line 113 in result_step.go

View check run for this annotation

Codecov / codecov/patch

result_step.go#L113

Added line #L113 was not covered by tests

var _ ResultFailureHandler[any] = (*failureBranch[any])(nil)

//nolint:unused
func (s *failureBranch[S]) selectStep(ctx context.Context, err error) Step[S] {
if s.selector(ctx, err) {
return s.step
}

return nil
}

// NewBranch creates a new FailureBranch with the given FailureSelector and Step.
func NewBranch[S any](selector FailureSelector[S], step Step[S]) FailureBranch[S] {
return &failureBranch[S]{selector: selector, step: step}
}

type defaultBranch[S any] struct {
step Step[S]
}

var _ FailureBranch[any] = (*defaultBranch[any])(nil)

//nolint:unused
func (d *defaultBranch[S]) isFailureBranch(S) {}

Check warning on line 138 in result_step.go

View check run for this annotation

Codecov / codecov/patch

result_step.go#L138

Added line #L138 was not covered by tests

var _ ResultFailureHandler[any] = (*defaultBranch[any])(nil)

//nolint:unused
func (d *defaultBranch[S]) selectStep(_ context.Context, _ error) Step[S] { return d.step }

func DefaultBranch[S any](step Step[S]) FailureBranch[S] {
return &defaultBranch[S]{step: step}
}

// FailureBranch is used to define entities that act as branches
// in a ResultFailureHandler.
//
// Note: This is used to prevent misuse of the HandleMultiFailure function.
type FailureBranch[S any] interface{ isFailureBranch(S) }

type multiFailureHandler[S any] struct {
branches []ResultFailureHandler[S]
}

var _ ResultFailureHandler[any] = (*multiFailureHandler[any])(nil)

//nolint:unused
func (m *multiFailureHandler[S]) selectStep(ctx context.Context, err error) Step[S] {
for _, branch := range m.branches {
if step := branch.selectStep(ctx, err); step != nil {
return step
}
}

return nil
}

// HandleMultiFailure takes in FailureBranch(s) and returns a ResultFailureHandler.
// It is used to handle multiple failure branches in a Result Step.
// The branches are evaluated in the order they are passed.
//
// The behavior is as follows:
// - if a FailureBranch is eligible, it is executed and the remaining branches
// are ignored.
// - if no FailureBranch is eligible, the DefaultBranch is executed, if provided,
// otherwise the mainStep's error is returned.
func HandleMultiFailure[S any](branches ...FailureBranch[S]) ResultFailureHandler[S] {
res := make([]ResultFailureHandler[S], len(branches))

for i, branch := range branches {
switch b := branch.(type) {
case *failureBranch[S]:
res[i] = b
case *defaultBranch[S]:
res[i] = b
}
}

return &multiFailureHandler[S]{branches: res}
}
Loading
Loading