Skip to content

Commit d613e9c

Browse files
authored
transaction retry with backoff strategies (#43)
1 parent 5ccc9a7 commit d613e9c

File tree

5 files changed

+464
-2
lines changed

5 files changed

+464
-2
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- 5432:5432
2020
strategy:
2121
matrix:
22-
go: ['1.20', '1.21']
22+
go: ['1.22', '1.23', '1.24']
2323
name: Go ${{ matrix.go }}
2424
steps:
2525
- uses: actions/checkout@v3

backoff/backoff.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package backoff
2+
3+
import (
4+
"math"
5+
"math/rand/v2"
6+
"time"
7+
8+
"github.com/acoshift/pgsql"
9+
)
10+
11+
// Config contains common configuration for all backoff strategies
12+
type Config struct {
13+
BaseDelay time.Duration // Base delay for backoff
14+
MaxDelay time.Duration // Maximum delay cap
15+
}
16+
17+
// ExponentialConfig contains configuration for exponential backoff
18+
type ExponentialConfig struct {
19+
Config
20+
Multiplier float64 // Multiplier for exponential growth
21+
JitterType JitterType
22+
}
23+
24+
// LinearConfig contains configuration for linear backoff
25+
type LinearConfig struct {
26+
Config
27+
Increment time.Duration // Amount to increase delay each attempt
28+
}
29+
30+
// JitterType defines the type of jitter to apply
31+
type JitterType int
32+
33+
const (
34+
// NoJitter applies no jitter
35+
NoJitter JitterType = iota
36+
// FullJitter applies full jitter (0 to calculated delay)
37+
FullJitter
38+
// EqualJitter applies equal jitter (half fixed + half random)
39+
EqualJitter
40+
)
41+
42+
// NewExponential creates a new exponential backoff function
43+
func NewExponential(config ExponentialConfig) pgsql.BackoffDelayFunc {
44+
return func(attempt int) time.Duration {
45+
baseDelay := time.Duration(float64(config.BaseDelay) * math.Pow(config.Multiplier, float64(attempt)))
46+
if baseDelay > config.MaxDelay {
47+
baseDelay = config.MaxDelay
48+
}
49+
50+
var delay time.Duration
51+
switch config.JitterType {
52+
case FullJitter:
53+
// Full jitter: random delay between 0 and calculated delay
54+
if baseDelay > 0 {
55+
delay = time.Duration(rand.Int64N(int64(baseDelay)))
56+
} else {
57+
delay = baseDelay
58+
}
59+
case EqualJitter:
60+
// Equal jitter: half fixed + half random
61+
half := baseDelay / 2
62+
if half > 0 {
63+
delay = half + time.Duration(rand.Int64N(int64(half)))
64+
} else {
65+
delay = baseDelay
66+
}
67+
default:
68+
delay = baseDelay
69+
}
70+
71+
return delay
72+
}
73+
}
74+
75+
// NewLinear creates a new linear backoff function
76+
func NewLinear(config LinearConfig) pgsql.BackoffDelayFunc {
77+
return func(attempt int) time.Duration {
78+
delay := config.BaseDelay + time.Duration(attempt)*config.Increment
79+
if delay > config.MaxDelay {
80+
delay = config.MaxDelay
81+
}
82+
return delay
83+
}
84+
}
85+
86+
func DefaultExponential() pgsql.BackoffDelayFunc {
87+
return NewExponential(ExponentialConfig{
88+
Config: Config{
89+
BaseDelay: 100 * time.Millisecond,
90+
MaxDelay: 5 * time.Second,
91+
},
92+
Multiplier: 2.0,
93+
JitterType: NoJitter,
94+
})
95+
}
96+
97+
func DefaultExponentialWithFullJitter() pgsql.BackoffDelayFunc {
98+
return NewExponential(ExponentialConfig{
99+
Config: Config{
100+
BaseDelay: 100 * time.Millisecond,
101+
MaxDelay: 5 * time.Second,
102+
},
103+
Multiplier: 2.0,
104+
JitterType: FullJitter,
105+
})
106+
}
107+
108+
func DefaultExponentialWithEqualJitter() pgsql.BackoffDelayFunc {
109+
return NewExponential(ExponentialConfig{
110+
Config: Config{
111+
BaseDelay: 100 * time.Millisecond,
112+
MaxDelay: 5 * time.Second,
113+
},
114+
Multiplier: 2.0,
115+
JitterType: EqualJitter,
116+
})
117+
}
118+
119+
func DefaultLinear() pgsql.BackoffDelayFunc {
120+
return NewLinear(LinearConfig{
121+
Config: Config{
122+
BaseDelay: 100 * time.Millisecond,
123+
MaxDelay: 5 * time.Second,
124+
},
125+
Increment: 100 * time.Millisecond,
126+
})
127+
}

backoff/backoff_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package backoff_test
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/acoshift/pgsql/backoff"
8+
)
9+
10+
func TestExponential(t *testing.T) {
11+
t.Parallel()
12+
13+
config := backoff.ExponentialConfig{
14+
Config: backoff.Config{
15+
BaseDelay: 10 * time.Millisecond,
16+
MaxDelay: 1 * time.Second,
17+
},
18+
Multiplier: 2.0,
19+
}
20+
backoff := backoff.NewExponential(config)
21+
22+
// Test exponential growth
23+
delays := []time.Duration{}
24+
for i := 0; i < 10; i++ {
25+
delay := backoff(i)
26+
delays = append(delays, delay)
27+
}
28+
29+
// Verify exponential growth
30+
for i := 1; i < len(delays); i++ {
31+
if delays[i] < delays[i-1] {
32+
t.Errorf("Expected delay[%d] >= delay[%d], got %v < %v", i, i-1, delays[i], delays[i-1])
33+
}
34+
}
35+
36+
// Verify max delay
37+
for i := 0; i < 10; i++ {
38+
delay := backoff(i)
39+
if delay > config.MaxDelay {
40+
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
41+
}
42+
}
43+
}
44+
45+
func TestExponentialWithFullJitter(t *testing.T) {
46+
t.Parallel()
47+
48+
config := backoff.ExponentialConfig{
49+
Config: backoff.Config{
50+
BaseDelay: 100 * time.Millisecond,
51+
MaxDelay: 1 * time.Second,
52+
},
53+
Multiplier: 2.0,
54+
JitterType: backoff.FullJitter,
55+
}
56+
backoff := backoff.NewExponential(config)
57+
58+
// Test that jitter introduces randomness
59+
var delays []time.Duration
60+
for i := 0; i < 10; i++ {
61+
delay := backoff(3) // Use same attempt number
62+
delays = append(delays, delay)
63+
}
64+
65+
// Check that not all delays are the same (indicating jitter is working)
66+
allSame := true
67+
for i := 1; i < len(delays); i++ {
68+
if delays[i] != delays[0] {
69+
allSame = false
70+
break
71+
}
72+
}
73+
if allSame {
74+
t.Error("Expected jitter to produce different delays, but all delays were the same")
75+
}
76+
77+
// Verify max delay
78+
for i := 0; i < 15; i++ {
79+
delay := backoff(i)
80+
if delay > config.MaxDelay {
81+
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
82+
}
83+
}
84+
}
85+
86+
func TestExponentialWithEqualJitter(t *testing.T) {
87+
t.Parallel()
88+
89+
config := backoff.ExponentialConfig{
90+
Config: backoff.Config{
91+
BaseDelay: 100 * time.Millisecond,
92+
MaxDelay: 1 * time.Second,
93+
},
94+
Multiplier: 2.0,
95+
JitterType: backoff.EqualJitter,
96+
}
97+
backoff := backoff.NewExponential(config)
98+
99+
delay := backoff(2)
100+
101+
// With equal jitter, delay should be at least half of the calculated delay
102+
expectedMin := 200 * time.Millisecond // (100ms * 2^2) / 2 = 200ms
103+
if delay < expectedMin {
104+
t.Errorf("Expected delay >= %v with equal jitter, got %v", expectedMin, delay)
105+
}
106+
107+
// Verify max delay
108+
for i := 0; i < 15; i++ {
109+
delay := backoff(i)
110+
if delay > config.MaxDelay {
111+
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
112+
}
113+
}
114+
}
115+
116+
func TestLinearBackoff(t *testing.T) {
117+
t.Parallel()
118+
119+
config := backoff.LinearConfig{
120+
Config: backoff.Config{
121+
BaseDelay: 100 * time.Millisecond,
122+
MaxDelay: 1 * time.Second,
123+
},
124+
Increment: 100 * time.Millisecond,
125+
}
126+
backoff := backoff.NewLinear(config)
127+
128+
// Test linear growth
129+
delays := []time.Duration{}
130+
for i := 0; i < 5; i++ {
131+
delay := backoff(i)
132+
delays = append(delays, delay)
133+
}
134+
135+
// Verify linear growth
136+
for i := 1; i < len(delays); i++ {
137+
expectedIncrease := 100 * time.Millisecond
138+
actualIncrease := delays[i] - delays[i-1]
139+
140+
if actualIncrease != expectedIncrease {
141+
t.Errorf("Expected linear increase of %v, got %v", expectedIncrease, actualIncrease)
142+
}
143+
}
144+
145+
// Verify max delay
146+
for i := 0; i < 15; i++ {
147+
delay := backoff(i)
148+
if delay > config.MaxDelay {
149+
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
150+
}
151+
}
152+
}

tx.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"errors"
7+
"time"
78
)
89

910
// ErrAbortTx rollbacks transaction and return nil error
@@ -14,10 +15,14 @@ type BeginTxer interface {
1415
BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
1516
}
1617

18+
// BackoffDelayFunc is a function type that defines the delay for backoff
19+
type BackoffDelayFunc func(attempt int) time.Duration
20+
1721
// TxOptions is the transaction options
1822
type TxOptions struct {
1923
sql.TxOptions
20-
MaxAttempts int
24+
MaxAttempts int
25+
BackoffDelayFunc BackoffDelayFunc
2126
}
2227

2328
const (
@@ -54,6 +59,8 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func(
5459
if opts.Isolation == sql.LevelDefault {
5560
option.Isolation = sql.LevelSerializable
5661
}
62+
63+
option.BackoffDelayFunc = opts.BackoffDelayFunc
5764
}
5865

5966
f := func() error {
@@ -80,7 +87,27 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func(
8087
if !IsSerializationFailure(err) {
8188
return err
8289
}
90+
91+
if i < option.MaxAttempts-1 && option.BackoffDelayFunc != nil {
92+
if err = wait(ctx, i, option.BackoffDelayFunc); err != nil {
93+
return err
94+
}
95+
}
8396
}
8497

8598
return err
8699
}
100+
101+
func wait(ctx context.Context, attempt int, backOffDelayFunc BackoffDelayFunc) error {
102+
delay := backOffDelayFunc(attempt)
103+
if delay <= 0 {
104+
return nil
105+
}
106+
107+
select {
108+
case <-ctx.Done():
109+
return ctx.Err()
110+
case <-time.After(delay):
111+
return nil
112+
}
113+
}

0 commit comments

Comments
 (0)