Skip to content

Commit e372cc3

Browse files
committed
fix #171: add breaker to context converter to prevent critical bug with http package integration
1 parent ba0db78 commit e372cc3

File tree

6 files changed

+153
-75
lines changed

6 files changed

+153
-75
lines changed

context.go

+11-40
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,18 @@ package retry
22

33
import "context"
44

5-
type lite struct {
6-
context.Context
7-
signal <-chan struct{}
8-
}
9-
10-
func (ctx lite) Done() <-chan struct{} {
11-
return ctx.signal
12-
}
13-
14-
func (ctx lite) Err() error {
15-
select {
16-
case <-ctx.signal:
17-
return context.Canceled
18-
default:
19-
return nil
5+
func convert(breaker Breaker) context.Context {
6+
ctx, is := breaker.(context.Context)
7+
if !is {
8+
ctx = lite{context.Background(), breaker}
209
}
10+
return ctx
2111
}
2212

23-
// equal to go.octolab.org/errors.Unwrap
24-
func unwrap(err error) error {
25-
// compatible with github.com/pkg/errors
26-
type causer interface {
27-
Cause() error
28-
}
29-
// compatible with built-in errors since 1.13
30-
type wrapper interface {
31-
Unwrap() error
32-
}
33-
34-
for err != nil {
35-
layer, is := err.(wrapper)
36-
if is {
37-
err = layer.Unwrap()
38-
continue
39-
}
40-
cause, is := err.(causer)
41-
if is {
42-
err = cause.Cause()
43-
continue
44-
}
45-
break
46-
}
47-
return err
13+
type lite struct {
14+
context.Context
15+
breaker Breaker
4816
}
17+
18+
func (ctx lite) Done() <-chan struct{} { return ctx.breaker.Done() }
19+
func (ctx lite) Err() error { return ctx.breaker.Err() }

context_test.go

+41-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package retry
22

33
import (
44
"context"
5+
"reflect"
56
"runtime"
67
"testing"
78
"time"
@@ -13,7 +14,7 @@ func TestContext(t *testing.T) {
1314

1415
var (
1516
sig = make(chan struct{})
16-
ctx = context.Context(lite{context.TODO(), sig})
17+
ctx = context.Context(lite{context.TODO(), breaker(sig)})
1718
)
1819
if ctx.Err() != nil {
1920
t.Error("invalid state")
@@ -32,7 +33,7 @@ func TestContext(t *testing.T) {
3233

3334
var (
3435
sig = make(chan struct{})
35-
ctx = context.Context(lite{context.TODO(), sig})
36+
ctx = context.Context(lite{context.TODO(), breaker(sig)})
3637
)
3738
if ctx.Err() != nil {
3839
t.Error("invalid state")
@@ -51,7 +52,7 @@ func TestContext(t *testing.T) {
5152

5253
var (
5354
sig = make(chan struct{})
54-
ctx = context.Context(lite{context.TODO(), sig})
55+
ctx = context.Context(lite{context.TODO(), breaker(sig)})
5556
)
5657
if ctx.Err() != nil {
5758
t.Error("invalid state")
@@ -70,7 +71,7 @@ func TestContext(t *testing.T) {
7071

7172
var (
7273
sig = make(chan struct{})
73-
ctx = context.Context(lite{context.TODO(), sig})
74+
ctx = context.Context(lite{context.TODO(), breaker(sig)})
7475
)
7576
if ctx.Err() != nil {
7677
t.Error("invalid state")
@@ -85,6 +86,30 @@ func TestContext(t *testing.T) {
8586
})
8687
}
8788

89+
func TestConvert(t *testing.T) {
90+
t.Run("breaker", func(t *testing.T) {
91+
br := make(breaker)
92+
93+
ctx := convert(br)
94+
if ctx.Err() != nil {
95+
t.Error("invalid state")
96+
}
97+
98+
close(br)
99+
if ctx.Err() == nil {
100+
t.Error("invalid state")
101+
}
102+
})
103+
104+
t.Run("context", func(t *testing.T) {
105+
ctx := context.TODO()
106+
107+
if !reflect.DeepEqual(convert(ctx), ctx) {
108+
t.Error("unexpected behavior")
109+
}
110+
})
111+
}
112+
88113
// helpers
89114

90115
func stop(timer *time.Timer) {
@@ -116,4 +141,16 @@ func verify(t *testing.T, ctx context.Context, cancel context.CancelFunc, sig ch
116141

117142
type key struct{}
118143

144+
type breaker chan struct{}
145+
146+
func (br breaker) Done() <-chan struct{} { return br }
147+
func (br breaker) Err() error {
148+
select {
149+
case <-br:
150+
return context.Canceled
151+
default:
152+
return nil
153+
}
154+
}
155+
119156
var schedule = 10 * time.Duration(runtime.NumCPU()) * time.Millisecond

errors.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package retry
2+
3+
const internal Error = "have no any try"
4+
5+
// Error defines a string-based error without a different root cause.
6+
type Error string
7+
8+
// Error returns a string representation of an error.
9+
func (err Error) Error() string { return string(err) }
10+
11+
// Unwrap always returns nil means that an error doesn't have other root cause.
12+
func (err Error) Unwrap() error { return nil }
13+
14+
// equal to go.octolab.org/errors.Unwrap
15+
func unwrap(err error) error {
16+
// compatible with github.com/pkg/errors
17+
type causer interface {
18+
Cause() error
19+
}
20+
// compatible with built-in errors since 1.13
21+
type wrapper interface {
22+
Unwrap() error
23+
}
24+
25+
for err != nil {
26+
layer, is := err.(wrapper)
27+
if is {
28+
err = layer.Unwrap()
29+
continue
30+
}
31+
cause, is := err.(causer)
32+
if is {
33+
err = cause.Cause()
34+
continue
35+
}
36+
break
37+
}
38+
return err
39+
}

errors_test.go

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package retry
2+
3+
import (
4+
"errors"
5+
"reflect"
6+
"testing"
7+
)
8+
9+
func TestError(t *testing.T) {
10+
if internal.Error() != string(internal) {
11+
t.Error("unexpected behavior")
12+
}
13+
14+
if internal.Unwrap() != nil {
15+
t.Error("unexpected behavior")
16+
}
17+
}
18+
19+
func TestUnwrap(t *testing.T) {
20+
cause := errors.New("root")
21+
core := unwrap(causer{layer{cause}})
22+
if !reflect.DeepEqual(core, cause) {
23+
t.Error("unexpected behavior")
24+
}
25+
}
26+
27+
// helpers
28+
29+
type causer struct{ error }
30+
31+
func (causer causer) Cause() error { return causer.error }
32+
33+
type layer struct{ error }
34+
35+
func (layer layer) Unwrap() error { return layer.error }

retry.go

+21-24
Original file line numberDiff line numberDiff line change
@@ -8,53 +8,50 @@ package retry
88
import (
99
"context"
1010
"fmt"
11-
12-
"github.com/kamilsk/retry/v5/strategy"
1311
)
1412

1513
// Action defines a callable function that package retry can handle.
1614
type Action = func(context.Context) error
1715

18-
// Error defines a string-based error without a different root cause.
19-
type Error string
20-
21-
// Error returns a string representation of an error.
22-
func (err Error) Error() string { return string(err) }
23-
24-
// Unwrap always returns nil means that an error doesn't have other root cause.
25-
func (err Error) Unwrap() error { return nil }
16+
// A Breaker carries a cancellation signal to interrupt an action execution.
17+
//
18+
// It is a subset of the built-in context and github.com/kamilsk/breaker interfaces.
19+
type Breaker = interface {
20+
// Done returns a channel that's closed when a cancellation signal occurred.
21+
Done() <-chan struct{}
22+
// If Done is not yet closed, Err returns nil.
23+
// If Done is closed, Err returns a non-nil error.
24+
// After Err returns a non-nil error, successive calls to Err return the same error.
25+
Err() error
26+
}
2627

2728
// How is an alias for batch of Strategies.
2829
//
2930
// how := retry.How{
3031
// strategy.Limit(3),
3132
// }
3233
//
33-
type How = []func(strategy.Breaker, uint, error) bool
34+
type How = []func(Breaker, uint, error) bool
3435

3536
// Do takes the action and performs it, repetitively, until successful.
3637
//
3738
// Optionally, strategies may be passed that assess whether or not an attempt
3839
// should be made.
3940
func Do(
40-
breaker strategy.Breaker,
41+
breaker Breaker,
4142
action func(context.Context) error,
42-
strategies ...func(strategy.Breaker, uint, error) bool,
43+
strategies ...func(Breaker, uint, error) bool,
4344
) error {
4445
var (
45-
err error = Error("have no any try")
46-
clean error
46+
ctx = convert(breaker)
47+
err error = internal
48+
core error
4749
)
4850

49-
ctx, is := breaker.(context.Context)
50-
if !is {
51-
ctx = lite{context.Background(), breaker.Done()}
52-
}
53-
5451
for attempt, should := uint(0), true; should; attempt++ {
55-
clean = unwrap(err)
52+
core = unwrap(err)
5653
for i, repeat := 0, len(strategies); should && i < repeat; i++ {
57-
should = should && strategies[i](breaker, attempt, clean)
54+
should = should && strategies[i](breaker, attempt, core)
5855
}
5956

6057
select {
@@ -78,9 +75,9 @@ func Do(
7875
// Optionally, strategies may be passed that assess whether or not an attempt
7976
// should be made.
8077
func Go(
81-
breaker strategy.Breaker,
78+
breaker Breaker,
8279
action func(context.Context) error,
83-
strategies ...func(strategy.Breaker, uint, error) bool,
80+
strategies ...func(Breaker, uint, error) bool,
8481
) error {
8582
done := make(chan error, 1)
8683

retry_test.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package retry_test
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"reflect"
87
"testing"
@@ -100,11 +99,11 @@ func TestGo(t *testing.T) {
10099

101100
// helpers
102101

103-
func breaker() strategy.Breaker {
102+
func breaker() Breaker {
104103
return context.TODO()
105104
}
106105

107-
func interrupted() strategy.Breaker {
106+
func interrupted() Breaker {
108107
ctx, cancel := context.WithCancel(context.TODO())
109108
cancel()
110109
return ctx
@@ -127,7 +126,7 @@ func (layer layer) Unwrap() error { return layer.error }
127126

128127
type testCase struct {
129128
name string
130-
breaker strategy.Breaker
129+
breaker Breaker
131130
strategies How
132131
action func(context.Context) error
133132
expected expected
@@ -145,8 +144,8 @@ var testCases = []testCase{
145144
"failed action call",
146145
breaker(),
147146
How{strategy.Limit(10)},
148-
func(context.Context) error { return layer{causer{errors.New("failure")}} },
149-
expected{10, layer{causer{errors.New("failure")}}},
147+
func(context.Context) error { return layer{causer{Error("failure")}} },
148+
expected{10, layer{causer{Error("failure")}}},
150149
},
151150
{
152151
"action call with interrupted breaker",
@@ -159,7 +158,7 @@ var testCases = []testCase{
159158
"have no action call",
160159
breaker(),
161160
How{strategy.Limit(0)},
162-
func(context.Context) error { return layer{causer{errors.New("failure")}} },
161+
func(context.Context) error { return layer{causer{Error("failure")}} },
163162
expected{0, Error("have no any try")},
164163
},
165164
}

0 commit comments

Comments
 (0)