Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions v2/gobreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package gobreaker

import (
"context"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -67,32 +68,34 @@ func (s State) String() string {
// Default ReadyToTrip returns true when the number of consecutive failures is more than 5.
//
// OnStateChange is called whenever the state of the CircuitBreaker changes.
// OnStateChangeCtx is like OnStateChange but accepts a context which is propagated from the context-aware methods.
//
// IsSuccessful is called with the error returned from a request.
// If IsSuccessful returns true, the error is counted as a success.
// Otherwise the error is counted as a failure.
// If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors.
type Settings struct {
Name string
MaxRequests uint32
Interval time.Duration
BucketPeriod time.Duration
Timeout time.Duration
ReadyToTrip func(counts Counts) bool
OnStateChange func(name string, from State, to State)
IsSuccessful func(err error) bool
Name string
MaxRequests uint32
Interval time.Duration
BucketPeriod time.Duration
Timeout time.Duration
ReadyToTrip func(counts Counts) bool
OnStateChange func(name string, from State, to State)
OnStateChangeCtx func(ctx context.Context, name string, from State, to State)
IsSuccessful func(err error) bool
}

// CircuitBreaker is a state machine to prevent sending requests that are likely to fail.
type CircuitBreaker[T any] struct {
name string
maxRequests uint32
interval time.Duration
bucketPeriod time.Duration
timeout time.Duration
readyToTrip func(counts Counts) bool
isSuccessful func(err error) bool
onStateChange func(name string, from State, to State)
name string
maxRequests uint32
interval time.Duration
bucketPeriod time.Duration
timeout time.Duration
readyToTrip func(counts Counts) bool
isSuccessful func(err error) bool
onStateChangeCtx func(ctx context.Context, name string, from State, to State)

mutex sync.Mutex
state State
Expand All @@ -107,7 +110,15 @@ func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] {
cb := new(CircuitBreaker[T])

cb.name = st.Name
cb.onStateChange = st.OnStateChange

if st.OnStateChange != nil {
cb.onStateChangeCtx = func(_ context.Context, name string, from State, to State) {
st.OnStateChange(name, from, to)
}
}
if st.OnStateChangeCtx != nil {
cb.onStateChangeCtx = st.OnStateChangeCtx
}

if st.MaxRequests == 0 {
cb.maxRequests = 1
Expand Down Expand Up @@ -173,11 +184,16 @@ func (cb *CircuitBreaker[T]) Name() string {

// State returns the current state of the CircuitBreaker.
func (cb *CircuitBreaker[T]) State() State {
return cb.StateCtx(context.Background())
}

// StateCtx is like State but accepts a context which will be propagated to state change callbacks.
func (cb *CircuitBreaker[T]) StateCtx(ctx context.Context) State {
cb.mutex.Lock()
defer cb.mutex.Unlock()

now := time.Now()
state, _, _ := cb.currentState(now)
state, _, _ := cb.currentState(ctx, now)
return state
}

Expand All @@ -195,7 +211,12 @@ func (cb *CircuitBreaker[T]) Counts() Counts {
// If a panic occurs in the request, the CircuitBreaker handles it as an error
// and causes the same panic again.
func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) {
generation, age, err := cb.beforeRequest()
return cb.ExecuteCtx(context.Background(), req)
}

// ExecuteCtx is like Execute but accepts a context which will be propagated to state change callbacks.
func (cb *CircuitBreaker[T]) ExecuteCtx(ctx context.Context, req func() (T, error)) (T, error) {
generation, age, err := cb.beforeRequest(ctx)
if err != nil {
var defaultValue T
return defaultValue, err
Expand All @@ -204,22 +225,22 @@ func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) {
defer func() {
e := recover()
if e != nil {
cb.afterRequest(generation, age, false)
cb.afterRequest(ctx, generation, age, false)
panic(e)
}
}()

result, err := req()
cb.afterRequest(generation, age, cb.isSuccessful(err))
cb.afterRequest(ctx, generation, age, cb.isSuccessful(err))
return result, err
}

func (cb *CircuitBreaker[T]) beforeRequest() (uint64, uint64, error) {
func (cb *CircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, uint64, error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()

now := time.Now()
state, generation, age := cb.currentState(now)
state, generation, age := cb.currentState(ctx, now)

if state == StateOpen {
return generation, age, ErrOpenState
Expand All @@ -231,48 +252,48 @@ func (cb *CircuitBreaker[T]) beforeRequest() (uint64, uint64, error) {
return generation, age, nil
}

func (cb *CircuitBreaker[T]) afterRequest(previous uint64, age uint64, success bool) {
func (cb *CircuitBreaker[T]) afterRequest(ctx context.Context, previous uint64, age uint64, success bool) {
cb.mutex.Lock()
defer cb.mutex.Unlock()

now := time.Now()
state, generation, _ := cb.currentState(now)
state, generation, _ := cb.currentState(ctx, now)
if generation != previous {
return
}

if success {
cb.onSuccess(state, age, now)
cb.onSuccess(ctx, state, age, now)
} else {
cb.onFailure(state, age, now)
cb.onFailure(ctx, state, age, now)
}
}

func (cb *CircuitBreaker[T]) onSuccess(state State, age uint64, now time.Time) {
func (cb *CircuitBreaker[T]) onSuccess(ctx context.Context, state State, age uint64, now time.Time) {
switch state {
case StateClosed:
cb.counts.onSuccess(age)
case StateHalfOpen:
cb.counts.onSuccess(age)
if cb.counts.ConsecutiveSuccesses >= cb.maxRequests {
cb.setState(StateClosed, now)
cb.setState(ctx, StateClosed, now)
}
}
}

func (cb *CircuitBreaker[T]) onFailure(state State, age uint64, now time.Time) {
func (cb *CircuitBreaker[T]) onFailure(ctx context.Context, state State, age uint64, now time.Time) {
switch state {
case StateClosed:
cb.counts.onFailure(age)
if cb.readyToTrip(cb.counts.Counts) {
cb.setState(StateOpen, now)
cb.setState(ctx, StateOpen, now)
}
case StateHalfOpen:
cb.setState(StateOpen, now)
cb.setState(ctx, StateOpen, now)
}
}

func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64, uint64) {
func (cb *CircuitBreaker[T]) currentState(ctx context.Context, now time.Time) (State, uint64, uint64) {
switch cb.state {
case StateClosed:
if !cb.expiry.IsZero() && cb.expiry.Before(now) {
Expand All @@ -282,7 +303,7 @@ func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64, uint64)
}
case StateOpen:
if cb.expiry.Before(now) {
cb.setState(StateHalfOpen, now)
cb.setState(ctx, StateHalfOpen, now)
}
}
return cb.state, cb.generation, cb.counts.age
Expand All @@ -301,7 +322,7 @@ func (cb *CircuitBreaker[T]) age(now time.Time) uint64 {
return uint64(age)
}

func (cb *CircuitBreaker[T]) setState(state State, now time.Time) {
func (cb *CircuitBreaker[T]) setState(ctx context.Context, state State, now time.Time) {
if cb.state == state {
return
}
Expand All @@ -311,8 +332,8 @@ func (cb *CircuitBreaker[T]) setState(state State, now time.Time) {

cb.toNewGeneration(now)

if cb.onStateChange != nil {
cb.onStateChange(cb.name, prev, state)
if cb.onStateChangeCtx != nil {
cb.onStateChangeCtx(ctx, cb.name, prev, state)
}
}

Expand Down
115 changes: 111 additions & 4 deletions v2/gobreaker_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gobreaker

import (
"context"
"errors"
"runtime"
"sync"
Expand All @@ -22,6 +23,8 @@ type StateChange struct {

var stateChange StateChange

type ctxKey string

func pseudoSleep(cb *CircuitBreaker[bool], period time.Duration) {
cb.start = cb.start.Add(-period)
if !cb.expiry.IsZero() {
Expand Down Expand Up @@ -132,7 +135,7 @@ func TestNewCircuitBreaker(t *testing.T) {
assert.Equal(t, time.Duration(0), defaultCB.interval)
assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout)
assert.NotNil(t, defaultCB.readyToTrip)
assert.Nil(t, defaultCB.onStateChange)
assert.Nil(t, defaultCB.onStateChangeCtx)
assert.Equal(t, StateClosed, defaultCB.state)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.Counts())
assert.True(t, defaultCB.expiry.IsZero())
Expand All @@ -143,7 +146,7 @@ func TestNewCircuitBreaker(t *testing.T) {
assert.Equal(t, time.Duration(30)*time.Second, customCB.interval)
assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout)
assert.NotNil(t, customCB.readyToTrip)
assert.NotNil(t, customCB.onStateChange)
assert.NotNil(t, customCB.onStateChangeCtx)
assert.Equal(t, StateClosed, customCB.state)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.Counts())
assert.False(t, customCB.expiry.IsZero())
Expand All @@ -155,7 +158,7 @@ func TestNewCircuitBreaker(t *testing.T) {
assert.Equal(t, 10, len(rollingWindowCB.counts.buckets))
assert.Equal(t, time.Duration(90)*time.Second, rollingWindowCB.timeout)
assert.NotNil(t, rollingWindowCB.readyToTrip)
assert.NotNil(t, rollingWindowCB.onStateChange)
assert.NotNil(t, rollingWindowCB.onStateChangeCtx)
assert.Equal(t, StateClosed, rollingWindowCB.state)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, rollingWindowCB.Counts())
assert.True(t, rollingWindowCB.expiry.IsZero())
Expand All @@ -166,7 +169,7 @@ func TestNewCircuitBreaker(t *testing.T) {
assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval)
assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout)
assert.NotNil(t, negativeDurationCB.readyToTrip)
assert.Nil(t, negativeDurationCB.onStateChange)
assert.Nil(t, negativeDurationCB.onStateChangeCtx)
assert.Equal(t, StateClosed, negativeDurationCB.state)
assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.Counts())
assert.True(t, negativeDurationCB.expiry.IsZero())
Expand Down Expand Up @@ -472,3 +475,107 @@ func TestRollingWindowCircuitBreakerInParallel(t *testing.T) {

wg.Wait()
}

func TestOnStateChangeCtx_ExecuteCtx(t *testing.T) {
var got struct {
name string
from State
to State
val any
}

st := Settings{
Name: "ctxcb1",
MaxRequests: 3,
Timeout: 2 * time.Second,
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
got.name = name
got.from = from
got.to = to
got.val = ctx.Value(ctxKey("id"))
},
}
cb := NewCircuitBreaker[bool](st)

ctx := context.WithValue(context.Background(), ctxKey("id"), "exec1")
_, err := cb.ExecuteCtx(ctx, func() (bool, error) { return false, assert.AnError })
assert.Error(t, err)

assert.Equal(t, "ctxcb1", got.name)
assert.Equal(t, StateClosed, got.from)
assert.Equal(t, StateOpen, got.to)
assert.Equal(t, "exec1", got.val)
}

func TestOnStateChangeCtx_StateCtx_TimeoutTransition(t *testing.T) {
var got struct {
val any
from, to State
}

st := Settings{
Name: "ctxcb2",
Timeout: time.Second,
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
if name == "ctxcb2" {
got.from = from
got.to = to
got.val = ctx.Value(ctxKey("poll"))
}
},
}
cb := NewCircuitBreaker[bool](st)
// Trip to open
_, _ = cb.ExecuteCtx(context.Background(), func() (bool, error) { return false, assert.AnError })
assert.Equal(t, StateOpen, cb.State())

// Move time and call StateCtx to trigger HalfOpen with provided ctx
pseudoSleep(cb, st.Timeout+time.Millisecond)
pollCtx := context.WithValue(context.Background(), ctxKey("poll"), "state-call")
state := cb.StateCtx(pollCtx)
assert.Equal(t, StateHalfOpen, state)
assert.Equal(t, StateOpen, got.from)
assert.Equal(t, StateHalfOpen, got.to)
assert.Equal(t, "state-call", got.val)
}

func TestTwoStep_AllowCtx_ContextPropagation(t *testing.T) {
var got struct {
from, to State
val any
}
st := Settings{
Name: "twostep",
MaxRequests: 2,
ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 },
OnStateChangeCtx: func(ctx context.Context, name string, from State, to State) {
if name == "twostep" {
got.from = from
got.to = to
got.val = ctx.Value(ctxKey("step"))
}
},
}
tscb := NewTwoStepCircuitBreaker[bool](st)

ctx := context.WithValue(context.Background(), ctxKey("step"), "allow-ctx")
done, err := tscb.AllowCtx(ctx)
assert.NoError(t, err)
done(false) // cause failure to trip to open

assert.Equal(t, StateClosed, got.from)
assert.Equal(t, StateOpen, got.to)
assert.Equal(t, "allow-ctx", got.val)
}

func TestNoCallbacks_NoPanic(t *testing.T) {
// Ensure no callbacks set does not panic on transitions.
cb := NewCircuitBreaker[bool](Settings{ReadyToTrip: func(c Counts) bool { return c.ConsecutiveFailures >= 1 }})
// Trip to open and then to half-open
_, _ = cb.Execute(func() (bool, error) { return false, assert.AnError })
assert.Equal(t, StateOpen, cb.State())
pseudoSleep(cb, time.Second*2)
_ = cb.State() // should transition as needed without panic
}
Loading
Loading