Skip to content

Commit 8436c96

Browse files
feat: retry with context (#63)
## Description Retry with context
1 parent 0c8dc51 commit 8436c96

File tree

2 files changed

+154
-36
lines changed

2 files changed

+154
-36
lines changed

helpers/misc.go

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package helpers
55

66
import (
7+
"context"
8+
"errors"
79
"fmt"
810
"maps"
911
"math"
@@ -18,26 +20,64 @@ func BoolPtr(b bool) *bool {
1820
return &b
1921
}
2022

21-
// Retry will retry a function until it succeeds or the timeout is reached. timeout == 2^attempt * delay.
22-
func Retry(fn func() error, retries int, delay time.Duration, logger func(format string, args ...any)) error {
23+
// RetryWithContext retries a function until it succeeds, the timeout is reached, or the context is done.
24+
// The delay between attempts increases exponentially as (2^(attempt-1)) * delay.
25+
// For example, with a delay of one second and three attempts, the timing would be:
26+
// - First attempt: immediate
27+
// - Second attempt: after one second
28+
// - Third attempt: after two seconds
29+
func RetryWithContext(ctx context.Context, fn func() error, attempts int, delay time.Duration, logger func(format string, args ...any)) error {
30+
if attempts < 1 {
31+
return errors.New("invalid number of attempts, must be at least 1")
32+
}
2333
var err error
24-
for r := 0; r < retries; r++ {
25-
err = fn()
26-
if err == nil {
27-
break
28-
}
34+
for r := 0; r < attempts; r++ {
35+
select {
36+
case <-ctx.Done():
37+
return ctx.Err()
38+
default:
39+
err = fn()
40+
if err == nil {
41+
return nil
42+
}
43+
44+
logger("Attempt (%d/%d) failed with: %s", r+1, attempts, err.Error())
45+
46+
// No reason to wait when we aren't going to retry again
47+
if r+1 == attempts {
48+
return err
49+
}
2950

30-
pow := math.Pow(2, float64(r))
31-
backoff := delay * time.Duration(pow)
51+
pow := math.Pow(2, float64(r))
52+
backoff := delay * time.Duration(pow)
3253

33-
logger("Retrying (%d/%d) in %s: %s", r+1, retries, backoff, err.Error())
54+
logger("Retrying in %s", backoff)
3455

35-
time.Sleep(backoff)
56+
timer := time.NewTimer(backoff)
57+
select {
58+
case <-timer.C:
59+
case <-ctx.Done():
60+
if !timer.Stop() {
61+
<-timer.C
62+
}
63+
return ctx.Err()
64+
}
65+
}
3666
}
3767

3868
return err
3969
}
4070

71+
// Retry retries a function until it succeeds, the timeout is reached, or the context is done.
72+
// The delay between attempts increases exponentially as (2^(attempt-1)) * delay.
73+
// For example, with a delay of one second and three attempts, the timing would be:
74+
// - First attempt: immediate
75+
// - Second attempt: after one second
76+
// - Third attempt: after two seconds
77+
func Retry(fn func() error, attempts int, delay time.Duration, logger func(format string, args ...any)) error {
78+
return RetryWithContext(context.Background(), fn, attempts, delay, logger)
79+
}
80+
4181
// TransformMapKeys takes a map and transforms its keys using the provided function.
4282
func TransformMapKeys[T any](m map[string]T, transform func(string) string) (r map[string]T) {
4383
r = map[string]T{}

helpers/misc_test.go

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
package helpers
55

66
import (
7+
"context"
78
"errors"
89
"reflect"
910
"strings"
1011
"testing"
12+
"time"
1113

1214
"github.com/stretchr/testify/require"
1315
"github.com/stretchr/testify/suite"
@@ -46,33 +48,109 @@ func (suite *TestMiscSuite) SetupSuite() {
4648
}
4749
}
4850

49-
func (suite *TestMiscSuite) TestRetry() {
50-
var count int
51-
countFn := func() error {
52-
count++
53-
if count < 4 {
54-
return errors.New("count exceeded")
51+
func TestRetry(t *testing.T) {
52+
t.Run("RetriesWhenThereAreFailures", func(t *testing.T) {
53+
count := 0
54+
logCount := 0
55+
returnedErr := errors.New("always fail")
56+
countFn := func() error {
57+
count++
58+
return returnedErr
5559
}
56-
return nil
57-
}
58-
var logCount int
59-
loggerFn := func(_ string, _ ...any) {
60-
logCount++
61-
}
60+
loggerFn := func(_ string, _ ...any) {
61+
logCount++
62+
}
63+
64+
err := Retry(countFn, 3, 0, loggerFn)
65+
require.ErrorIs(t, err, returnedErr)
66+
require.Equal(t, 3, count)
67+
require.Equal(t, 5, logCount)
68+
})
69+
70+
t.Run("ContextCancellationBeforeStart", func(t *testing.T) {
71+
ctx, cancel := context.WithCancel(context.Background())
72+
cancel()
73+
count := 0
74+
fn := func() error {
75+
count++
76+
return errors.New("Never here since context got cancelled")
77+
}
78+
logger := func(_ string, _ ...any) {}
79+
80+
waitThatsNotCalled := 1000000 * time.Minute
81+
err := RetryWithContext(ctx, fn, 5, waitThatsNotCalled, logger)
82+
require.Equal(t, 0, count)
83+
require.ErrorIs(t, err, context.Canceled)
84+
})
85+
86+
t.Run("ContextCancellationDuringExecution", func(t *testing.T) {
87+
ctx, cancel := context.WithCancel(context.Background())
88+
89+
count := 0
90+
fn := func() error {
91+
count++
92+
if count < 2 {
93+
return errors.New("fail")
94+
}
95+
cancel()
96+
return errors.New("don't care about this error since we've cancelled and there is still another retry")
97+
}
98+
99+
logger := func(_ string, _ ...any) {}
100+
101+
err := RetryWithContext(ctx, fn, 3, 0, logger)
102+
require.Equal(t, 2, count)
103+
require.ErrorIs(t, err, context.Canceled)
104+
})
105+
106+
t.Run("NoErr", func(t *testing.T) {
107+
count := 0
108+
fn := func() error {
109+
count++
110+
return nil
111+
}
112+
113+
logger := func(_ string, _ ...any) {}
114+
115+
err := RetryWithContext(context.TODO(), fn, 3, 0, logger)
116+
require.ErrorIs(t, err, nil)
117+
require.Equal(t, 1, count)
118+
})
119+
120+
t.Run("InvalidAttempts", func(t *testing.T) {
121+
count := 0
122+
fn := func() error {
123+
count++
124+
return nil
125+
}
126+
127+
logger := func(_ string, _ ...any) {}
128+
129+
err := RetryWithContext(context.TODO(), fn, 0, 0, logger)
130+
require.Error(t, err)
131+
require.Equal(t, 0, count)
132+
})
133+
134+
t.Run("ContextCancellationDeadline", func(t *testing.T) {
135+
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(2*time.Second))
136+
defer cancel()
137+
138+
count := 0
139+
fn := func() error {
140+
count++
141+
return errors.New("Always fail")
142+
}
143+
144+
logger := func(_ string, _ ...any) {}
145+
146+
err := RetryWithContext(ctx, fn, 3, 1*time.Second, logger)
147+
// fn should be called twice, it will wait one second after the first attempt
148+
// and tries to wait two seconds after the second attempt
149+
// but the context will cancel before the third attempt is called
150+
require.Equal(t, 2, count)
151+
require.ErrorIs(t, err, context.DeadlineExceeded)
152+
})
62153

63-
count = 0
64-
logCount = 0
65-
err := Retry(countFn, 3, 0, loggerFn)
66-
suite.Error(err)
67-
suite.Equal(3, count)
68-
suite.Equal(3, logCount)
69-
70-
count = 0
71-
logCount = 0
72-
err = Retry(countFn, 4, 0, loggerFn)
73-
suite.NoError(err)
74-
suite.Equal(4, count)
75-
suite.Equal(3, logCount)
76154
}
77155

78156
func (suite *TestMiscSuite) TestTransformMapKeys() {

0 commit comments

Comments
 (0)