33package gobreaker
44
55import (
6+ "context"
67 "errors"
78 "fmt"
89 "sync"
@@ -67,32 +68,34 @@ func (s State) String() string {
6768// Default ReadyToTrip returns true when the number of consecutive failures is more than 5.
6869//
6970// OnStateChange is called whenever the state of the CircuitBreaker changes.
71+ // OnStateChangeCtx is like OnStateChange but accepts a context which is propagated from the context-aware methods.
7072//
7173// IsSuccessful is called with the error returned from a request.
7274// If IsSuccessful returns true, the error is counted as a success.
7375// Otherwise the error is counted as a failure.
7476// If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors.
7577type Settings struct {
76- Name string
77- MaxRequests uint32
78- Interval time.Duration
79- BucketPeriod time.Duration
80- Timeout time.Duration
81- ReadyToTrip func (counts Counts ) bool
82- OnStateChange func (name string , from State , to State )
83- IsSuccessful func (err error ) bool
78+ Name string
79+ MaxRequests uint32
80+ Interval time.Duration
81+ BucketPeriod time.Duration
82+ Timeout time.Duration
83+ ReadyToTrip func (counts Counts ) bool
84+ OnStateChange func (name string , from State , to State )
85+ OnStateChangeCtx func (ctx context.Context , name string , from State , to State )
86+ IsSuccessful func (err error ) bool
8487}
8588
8689// CircuitBreaker is a state machine to prevent sending requests that are likely to fail.
8790type CircuitBreaker [T any ] struct {
88- name string
89- maxRequests uint32
90- interval time.Duration
91- bucketPeriod time.Duration
92- timeout time.Duration
93- readyToTrip func (counts Counts ) bool
94- isSuccessful func (err error ) bool
95- onStateChange func (name string , from State , to State )
91+ name string
92+ maxRequests uint32
93+ interval time.Duration
94+ bucketPeriod time.Duration
95+ timeout time.Duration
96+ readyToTrip func (counts Counts ) bool
97+ isSuccessful func (err error ) bool
98+ onStateChangeCtx func (ctx context. Context , name string , from State , to State )
9699
97100 mutex sync.Mutex
98101 state State
@@ -107,7 +110,15 @@ func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] {
107110 cb := new (CircuitBreaker [T ])
108111
109112 cb .name = st .Name
110- cb .onStateChange = st .OnStateChange
113+
114+ if st .OnStateChange != nil {
115+ cb .onStateChangeCtx = func (_ context.Context , name string , from State , to State ) {
116+ st .OnStateChange (name , from , to )
117+ }
118+ }
119+ if st .OnStateChangeCtx != nil {
120+ cb .onStateChangeCtx = st .OnStateChangeCtx
121+ }
111122
112123 if st .MaxRequests == 0 {
113124 cb .maxRequests = 1
@@ -173,11 +184,16 @@ func (cb *CircuitBreaker[T]) Name() string {
173184
174185// State returns the current state of the CircuitBreaker.
175186func (cb * CircuitBreaker [T ]) State () State {
187+ return cb .StateCtx (context .Background ())
188+ }
189+
190+ // StateCtx is like State but accepts a context which will be propagated to state change callbacks.
191+ func (cb * CircuitBreaker [T ]) StateCtx (ctx context.Context ) State {
176192 cb .mutex .Lock ()
177193 defer cb .mutex .Unlock ()
178194
179195 now := time .Now ()
180- state , _ , _ := cb .currentState (now )
196+ state , _ , _ := cb .currentState (ctx , now )
181197 return state
182198}
183199
@@ -195,7 +211,12 @@ func (cb *CircuitBreaker[T]) Counts() Counts {
195211// If a panic occurs in the request, the CircuitBreaker handles it as an error
196212// and causes the same panic again.
197213func (cb * CircuitBreaker [T ]) Execute (req func () (T , error )) (T , error ) {
198- generation , age , err := cb .beforeRequest ()
214+ return cb .ExecuteCtx (context .Background (), req )
215+ }
216+
217+ // ExecuteCtx is like Execute but accepts a context which will be propagated to state change callbacks.
218+ func (cb * CircuitBreaker [T ]) ExecuteCtx (ctx context.Context , req func () (T , error )) (T , error ) {
219+ generation , age , err := cb .beforeRequest (ctx )
199220 if err != nil {
200221 var defaultValue T
201222 return defaultValue , err
@@ -204,22 +225,22 @@ func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) {
204225 defer func () {
205226 e := recover ()
206227 if e != nil {
207- cb .afterRequest (generation , age , false )
228+ cb .afterRequest (ctx , generation , age , false )
208229 panic (e )
209230 }
210231 }()
211232
212233 result , err := req ()
213- cb .afterRequest (generation , age , cb .isSuccessful (err ))
234+ cb .afterRequest (ctx , generation , age , cb .isSuccessful (err ))
214235 return result , err
215236}
216237
217- func (cb * CircuitBreaker [T ]) beforeRequest () (uint64 , uint64 , error ) {
238+ func (cb * CircuitBreaker [T ]) beforeRequest (ctx context. Context ) (uint64 , uint64 , error ) {
218239 cb .mutex .Lock ()
219240 defer cb .mutex .Unlock ()
220241
221242 now := time .Now ()
222- state , generation , age := cb .currentState (now )
243+ state , generation , age := cb .currentState (ctx , now )
223244
224245 if state == StateOpen {
225246 return generation , age , ErrOpenState
@@ -231,48 +252,48 @@ func (cb *CircuitBreaker[T]) beforeRequest() (uint64, uint64, error) {
231252 return generation , age , nil
232253}
233254
234- func (cb * CircuitBreaker [T ]) afterRequest (previous uint64 , age uint64 , success bool ) {
255+ func (cb * CircuitBreaker [T ]) afterRequest (ctx context. Context , previous uint64 , age uint64 , success bool ) {
235256 cb .mutex .Lock ()
236257 defer cb .mutex .Unlock ()
237258
238259 now := time .Now ()
239- state , generation , _ := cb .currentState (now )
260+ state , generation , _ := cb .currentState (ctx , now )
240261 if generation != previous {
241262 return
242263 }
243264
244265 if success {
245- cb .onSuccess (state , age , now )
266+ cb .onSuccess (ctx , state , age , now )
246267 } else {
247- cb .onFailure (state , age , now )
268+ cb .onFailure (ctx , state , age , now )
248269 }
249270}
250271
251- func (cb * CircuitBreaker [T ]) onSuccess (state State , age uint64 , now time.Time ) {
272+ func (cb * CircuitBreaker [T ]) onSuccess (ctx context. Context , state State , age uint64 , now time.Time ) {
252273 switch state {
253274 case StateClosed :
254275 cb .counts .onSuccess (age )
255276 case StateHalfOpen :
256277 cb .counts .onSuccess (age )
257278 if cb .counts .ConsecutiveSuccesses >= cb .maxRequests {
258- cb .setState (StateClosed , now )
279+ cb .setState (ctx , StateClosed , now )
259280 }
260281 }
261282}
262283
263- func (cb * CircuitBreaker [T ]) onFailure (state State , age uint64 , now time.Time ) {
284+ func (cb * CircuitBreaker [T ]) onFailure (ctx context. Context , state State , age uint64 , now time.Time ) {
264285 switch state {
265286 case StateClosed :
266287 cb .counts .onFailure (age )
267288 if cb .readyToTrip (cb .counts .Counts ) {
268- cb .setState (StateOpen , now )
289+ cb .setState (ctx , StateOpen , now )
269290 }
270291 case StateHalfOpen :
271- cb .setState (StateOpen , now )
292+ cb .setState (ctx , StateOpen , now )
272293 }
273294}
274295
275- func (cb * CircuitBreaker [T ]) currentState (now time.Time ) (State , uint64 , uint64 ) {
296+ func (cb * CircuitBreaker [T ]) currentState (ctx context. Context , now time.Time ) (State , uint64 , uint64 ) {
276297 switch cb .state {
277298 case StateClosed :
278299 if ! cb .expiry .IsZero () && cb .expiry .Before (now ) {
@@ -282,7 +303,7 @@ func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64, uint64)
282303 }
283304 case StateOpen :
284305 if cb .expiry .Before (now ) {
285- cb .setState (StateHalfOpen , now )
306+ cb .setState (ctx , StateHalfOpen , now )
286307 }
287308 }
288309 return cb .state , cb .generation , cb .counts .age
@@ -301,7 +322,7 @@ func (cb *CircuitBreaker[T]) age(now time.Time) uint64 {
301322 return uint64 (age )
302323}
303324
304- func (cb * CircuitBreaker [T ]) setState (state State , now time.Time ) {
325+ func (cb * CircuitBreaker [T ]) setState (ctx context. Context , state State , now time.Time ) {
305326 if cb .state == state {
306327 return
307328 }
@@ -311,8 +332,8 @@ func (cb *CircuitBreaker[T]) setState(state State, now time.Time) {
311332
312333 cb .toNewGeneration (now )
313334
314- if cb .onStateChange != nil {
315- cb .onStateChange ( cb .name , prev , state )
335+ if cb .onStateChangeCtx != nil {
336+ cb .onStateChangeCtx ( ctx , cb .name , prev , state )
316337 }
317338}
318339
0 commit comments