Skip to content

Commit 2d312c1

Browse files
feat: Make CircuitBreaker context aware
1 parent 7bc9a40 commit 2d312c1

File tree

3 files changed

+178
-43
lines changed

3 files changed

+178
-43
lines changed

v2/gobreaker.go

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package gobreaker
44

55
import (
6+
"context"
67
"errors"
78
"fmt"
89
"sync"
@@ -67,32 +68,34 @@ func (s State) String() string {
6768
// Default ReadyToTrip returns true when the number of consecutive failures is more than 5.
6869
//
6970
// OnStateChange is called whenever the state of the CircuitBreaker changes.
71+
// OnStateChangeCtx is like OnStateChange but accepts a context which is propagated from the context-aware methods.
7072
//
7173
// IsSuccessful is called with the error returned from a request.
7274
// If IsSuccessful returns true, the error is counted as a success.
7375
// Otherwise the error is counted as a failure.
7476
// If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors.
7577
type Settings struct {
76-
Name string
77-
MaxRequests uint32
78-
Interval time.Duration
79-
BucketPeriod time.Duration
80-
Timeout time.Duration
81-
ReadyToTrip func(counts Counts) bool
82-
OnStateChange func(name string, from State, to State)
83-
IsSuccessful func(err error) bool
78+
Name string
79+
MaxRequests uint32
80+
Interval time.Duration
81+
BucketPeriod time.Duration
82+
Timeout time.Duration
83+
ReadyToTrip func(counts Counts) bool
84+
OnStateChange func(name string, from State, to State)
85+
OnStateChangeCtx func(ctx context.Context, name string, from State, to State)
86+
IsSuccessful func(err error) bool
8487
}
8588

8689
// CircuitBreaker is a state machine to prevent sending requests that are likely to fail.
8790
type CircuitBreaker[T any] struct {
88-
name string
89-
maxRequests uint32
90-
interval time.Duration
91-
bucketPeriod time.Duration
92-
timeout time.Duration
93-
readyToTrip func(counts Counts) bool
94-
isSuccessful func(err error) bool
95-
onStateChange func(name string, from State, to State)
91+
name string
92+
maxRequests uint32
93+
interval time.Duration
94+
bucketPeriod time.Duration
95+
timeout time.Duration
96+
readyToTrip func(counts Counts) bool
97+
isSuccessful func(err error) bool
98+
onStateChangeCtx func(ctx context.Context, name string, from State, to State)
9699

97100
mutex sync.Mutex
98101
state State
@@ -107,7 +110,15 @@ func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] {
107110
cb := new(CircuitBreaker[T])
108111

109112
cb.name = st.Name
110-
cb.onStateChange = st.OnStateChange
113+
114+
if st.OnStateChange != nil {
115+
cb.onStateChangeCtx = func(_ context.Context, name string, from State, to State) {
116+
st.OnStateChange(name, from, to)
117+
}
118+
}
119+
if st.OnStateChangeCtx != nil {
120+
cb.onStateChangeCtx = st.OnStateChangeCtx
121+
}
111122

112123
if st.MaxRequests == 0 {
113124
cb.maxRequests = 1
@@ -173,11 +184,16 @@ func (cb *CircuitBreaker[T]) Name() string {
173184

174185
// State returns the current state of the CircuitBreaker.
175186
func (cb *CircuitBreaker[T]) State() State {
187+
return cb.StateCtx(context.Background())
188+
}
189+
190+
// StateCtx is like State but accepts a context which will be propagated to state change callbacks.
191+
func (cb *CircuitBreaker[T]) StateCtx(ctx context.Context) State {
176192
cb.mutex.Lock()
177193
defer cb.mutex.Unlock()
178194

179195
now := time.Now()
180-
state, _, _ := cb.currentState(now)
196+
state, _, _ := cb.currentState(ctx, now)
181197
return state
182198
}
183199

@@ -195,7 +211,12 @@ func (cb *CircuitBreaker[T]) Counts() Counts {
195211
// If a panic occurs in the request, the CircuitBreaker handles it as an error
196212
// and causes the same panic again.
197213
func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) {
198-
generation, age, err := cb.beforeRequest()
214+
return cb.ExecuteCtx(context.Background(), req)
215+
}
216+
217+
// ExecuteCtx is like Execute but accepts a context which will be propagated to state change callbacks.
218+
func (cb *CircuitBreaker[T]) ExecuteCtx(ctx context.Context, req func() (T, error)) (T, error) {
219+
generation, age, err := cb.beforeRequest(ctx)
199220
if err != nil {
200221
var defaultValue T
201222
return defaultValue, err
@@ -204,22 +225,22 @@ func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) {
204225
defer func() {
205226
e := recover()
206227
if e != nil {
207-
cb.afterRequest(generation, age, false)
228+
cb.afterRequest(ctx, generation, age, false)
208229
panic(e)
209230
}
210231
}()
211232

212233
result, err := req()
213-
cb.afterRequest(generation, age, cb.isSuccessful(err))
234+
cb.afterRequest(ctx, generation, age, cb.isSuccessful(err))
214235
return result, err
215236
}
216237

217-
func (cb *CircuitBreaker[T]) beforeRequest() (uint64, uint64, error) {
238+
func (cb *CircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, uint64, error) {
218239
cb.mutex.Lock()
219240
defer cb.mutex.Unlock()
220241

221242
now := time.Now()
222-
state, generation, age := cb.currentState(now)
243+
state, generation, age := cb.currentState(ctx, now)
223244

224245
if state == StateOpen {
225246
return generation, age, ErrOpenState
@@ -231,48 +252,48 @@ func (cb *CircuitBreaker[T]) beforeRequest() (uint64, uint64, error) {
231252
return generation, age, nil
232253
}
233254

234-
func (cb *CircuitBreaker[T]) afterRequest(previous uint64, age uint64, success bool) {
255+
func (cb *CircuitBreaker[T]) afterRequest(ctx context.Context, previous uint64, age uint64, success bool) {
235256
cb.mutex.Lock()
236257
defer cb.mutex.Unlock()
237258

238259
now := time.Now()
239-
state, generation, _ := cb.currentState(now)
260+
state, generation, _ := cb.currentState(ctx, now)
240261
if generation != previous {
241262
return
242263
}
243264

244265
if success {
245-
cb.onSuccess(state, age, now)
266+
cb.onSuccess(ctx, state, age, now)
246267
} else {
247-
cb.onFailure(state, age, now)
268+
cb.onFailure(ctx, state, age, now)
248269
}
249270
}
250271

251-
func (cb *CircuitBreaker[T]) onSuccess(state State, age uint64, now time.Time) {
272+
func (cb *CircuitBreaker[T]) onSuccess(ctx context.Context, state State, age uint64, now time.Time) {
252273
switch state {
253274
case StateClosed:
254275
cb.counts.onSuccess(age)
255276
case StateHalfOpen:
256277
cb.counts.onSuccess(age)
257278
if cb.counts.ConsecutiveSuccesses >= cb.maxRequests {
258-
cb.setState(StateClosed, now)
279+
cb.setState(ctx, StateClosed, now)
259280
}
260281
}
261282
}
262283

263-
func (cb *CircuitBreaker[T]) onFailure(state State, age uint64, now time.Time) {
284+
func (cb *CircuitBreaker[T]) onFailure(ctx context.Context, state State, age uint64, now time.Time) {
264285
switch state {
265286
case StateClosed:
266287
cb.counts.onFailure(age)
267288
if cb.readyToTrip(cb.counts.Counts) {
268-
cb.setState(StateOpen, now)
289+
cb.setState(ctx, StateOpen, now)
269290
}
270291
case StateHalfOpen:
271-
cb.setState(StateOpen, now)
292+
cb.setState(ctx, StateOpen, now)
272293
}
273294
}
274295

275-
func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64, uint64) {
296+
func (cb *CircuitBreaker[T]) currentState(ctx context.Context, now time.Time) (State, uint64, uint64) {
276297
switch cb.state {
277298
case StateClosed:
278299
if !cb.expiry.IsZero() && cb.expiry.Before(now) {
@@ -282,7 +303,7 @@ func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64, uint64)
282303
}
283304
case StateOpen:
284305
if cb.expiry.Before(now) {
285-
cb.setState(StateHalfOpen, now)
306+
cb.setState(ctx, StateHalfOpen, now)
286307
}
287308
}
288309
return cb.state, cb.generation, cb.counts.age
@@ -301,7 +322,7 @@ func (cb *CircuitBreaker[T]) age(now time.Time) uint64 {
301322
return uint64(age)
302323
}
303324

304-
func (cb *CircuitBreaker[T]) setState(state State, now time.Time) {
325+
func (cb *CircuitBreaker[T]) setState(ctx context.Context, state State, now time.Time) {
305326
if cb.state == state {
306327
return
307328
}
@@ -311,8 +332,8 @@ func (cb *CircuitBreaker[T]) setState(state State, now time.Time) {
311332

312333
cb.toNewGeneration(now)
313334

314-
if cb.onStateChange != nil {
315-
cb.onStateChange(cb.name, prev, state)
335+
if cb.onStateChangeCtx != nil {
336+
cb.onStateChangeCtx(ctx, cb.name, prev, state)
316337
}
317338
}
318339

v2/gobreaker_test.go

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gobreaker
22

33
import (
4+
"context"
45
"errors"
56
"runtime"
67
"sync"
@@ -22,6 +23,8 @@ type StateChange struct {
2223

2324
var stateChange StateChange
2425

26+
type ctxKey string
27+
2528
func pseudoSleep(cb *CircuitBreaker[bool], period time.Duration) {
2629
cb.start = cb.start.Add(-period)
2730
if !cb.expiry.IsZero() {
@@ -132,7 +135,7 @@ func TestNewCircuitBreaker(t *testing.T) {
132135
assert.Equal(t, time.Duration(0), defaultCB.interval)
133136
assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout)
134137
assert.NotNil(t, defaultCB.readyToTrip)
135-
assert.Nil(t, defaultCB.onStateChange)
138+
assert.Nil(t, defaultCB.onStateChangeCtx)
136139
assert.Equal(t, StateClosed, defaultCB.state)
137140
assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.Counts())
138141
assert.True(t, defaultCB.expiry.IsZero())
@@ -143,7 +146,7 @@ func TestNewCircuitBreaker(t *testing.T) {
143146
assert.Equal(t, time.Duration(30)*time.Second, customCB.interval)
144147
assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout)
145148
assert.NotNil(t, customCB.readyToTrip)
146-
assert.NotNil(t, customCB.onStateChange)
149+
assert.NotNil(t, customCB.onStateChangeCtx)
147150
assert.Equal(t, StateClosed, customCB.state)
148151
assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.Counts())
149152
assert.False(t, customCB.expiry.IsZero())
@@ -155,7 +158,7 @@ func TestNewCircuitBreaker(t *testing.T) {
155158
assert.Equal(t, 10, len(rollingWindowCB.counts.buckets))
156159
assert.Equal(t, time.Duration(90)*time.Second, rollingWindowCB.timeout)
157160
assert.NotNil(t, rollingWindowCB.readyToTrip)
158-
assert.NotNil(t, rollingWindowCB.onStateChange)
161+
assert.NotNil(t, rollingWindowCB.onStateChangeCtx)
159162
assert.Equal(t, StateClosed, rollingWindowCB.state)
160163
assert.Equal(t, Counts{0, 0, 0, 0, 0}, rollingWindowCB.Counts())
161164
assert.True(t, rollingWindowCB.expiry.IsZero())
@@ -166,7 +169,7 @@ func TestNewCircuitBreaker(t *testing.T) {
166169
assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval)
167170
assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout)
168171
assert.NotNil(t, negativeDurationCB.readyToTrip)
169-
assert.Nil(t, negativeDurationCB.onStateChange)
172+
assert.Nil(t, negativeDurationCB.onStateChangeCtx)
170173
assert.Equal(t, StateClosed, negativeDurationCB.state)
171174
assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.Counts())
172175
assert.True(t, negativeDurationCB.expiry.IsZero())
@@ -472,3 +475,107 @@ func TestRollingWindowCircuitBreakerInParallel(t *testing.T) {
472475

473476
wg.Wait()
474477
}
478+
479+
func TestOnStateChangeCtx_ExecuteCtx(t *testing.T) {
480+
var got struct {
481+
name string
482+
from State
483+
to State
484+
val any
485+
}
486+
487+
st := Settings{
488+
Name: "ctxcb1",
489+
MaxRequests: 3,
490+
Timeout: 2 * time.Second,
491+
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
492+
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
493+
got.name = name
494+
got.from = from
495+
got.to = to
496+
got.val = ctx.Value(ctxKey("id"))
497+
},
498+
}
499+
cb := NewCircuitBreaker[bool](st)
500+
501+
ctx := context.WithValue(context.Background(), ctxKey("id"), "exec1")
502+
_, err := cb.ExecuteCtx(ctx, func() (bool, error) { return false, assert.AnError })
503+
assert.Error(t, err)
504+
505+
assert.Equal(t, "ctxcb1", got.name)
506+
assert.Equal(t, StateClosed, got.from)
507+
assert.Equal(t, StateOpen, got.to)
508+
assert.Equal(t, "exec1", got.val)
509+
}
510+
511+
func TestOnStateChangeCtx_StateCtx_TimeoutTransition(t *testing.T) {
512+
var got struct {
513+
val any
514+
from, to State
515+
}
516+
517+
st := Settings{
518+
Name: "ctxcb2",
519+
Timeout: time.Second,
520+
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
521+
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
522+
if name == "ctxcb2" {
523+
got.from = from
524+
got.to = to
525+
got.val = ctx.Value(ctxKey("poll"))
526+
}
527+
},
528+
}
529+
cb := NewCircuitBreaker[bool](st)
530+
// Trip to open
531+
_, _ = cb.ExecuteCtx(context.Background(), func() (bool, error) { return false, assert.AnError })
532+
assert.Equal(t, StateOpen, cb.State())
533+
534+
// Move time and call StateCtx to trigger HalfOpen with provided ctx
535+
pseudoSleep(cb, st.Timeout+time.Millisecond)
536+
pollCtx := context.WithValue(context.Background(), ctxKey("poll"), "state-call")
537+
state := cb.StateCtx(pollCtx)
538+
assert.Equal(t, StateHalfOpen, state)
539+
assert.Equal(t, StateOpen, got.from)
540+
assert.Equal(t, StateHalfOpen, got.to)
541+
assert.Equal(t, "state-call", got.val)
542+
}
543+
544+
func TestTwoStep_AllowCtx_ContextPropagation(t *testing.T) {
545+
var got struct {
546+
from, to State
547+
val any
548+
}
549+
st := Settings{
550+
Name: "twostep",
551+
MaxRequests: 2,
552+
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
553+
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
554+
if name == "twostep" {
555+
got.from = from
556+
got.to = to
557+
got.val = ctx.Value(ctxKey("step"))
558+
}
559+
},
560+
}
561+
tscb := NewTwoStepCircuitBreaker[bool](st)
562+
563+
ctx := context.WithValue(context.Background(), ctxKey("step"), "allow-ctx")
564+
done, err := tscb.AllowCtx(ctx)
565+
assert.NoError(t, err)
566+
done(false) // cause failure to trip to open
567+
568+
assert.Equal(t, StateClosed, got.from)
569+
assert.Equal(t, StateOpen, got.to)
570+
assert.Equal(t, "allow-ctx", got.val)
571+
}
572+
573+
func TestNoCallbacks_NoPanic(t *testing.T) {
574+
// Ensure no callbacks set does not panic on transitions.
575+
cb := NewCircuitBreaker[bool](Settings{ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 }})
576+
// Trip to open and then to half-open
577+
_, _ = cb.Execute(func() (bool, error) { return false, assert.AnError })
578+
assert.Equal(t, StateOpen, cb.State())
579+
pseudoSleep(cb, time.Second*2)
580+
_ = cb.State() // should transition as needed without panic
581+
}

0 commit comments

Comments
 (0)