Skip to content

Commit ccfdc65

Browse files
committed
Support for state timeouts
1 parent c428942 commit ccfdc65

5 files changed

+723
-100
lines changed

act/statemachine.go

+145-62
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package act
22

33
import (
4+
"context"
45
"fmt"
56
"reflect"
67
"runtime"
@@ -75,40 +76,44 @@ type StateMachine[D any] struct {
7576
// Callback that is invoked immediately after every state change. If no
7677
// callback is registered stateEnterCallback is nil.
7778
stateEnterCallback StateEnterCallback[D]
79+
80+
// Pointer to the most recently configured state timeout.
81+
activeStateTimeout *ActiveStateTimeout
7882
}
7983

8084
type Action interface {
8185
isAction()
8286
}
8387

84-
type StateTimeout[M any] struct {
88+
type StateTimeout struct {
8589
Duration time.Duration
86-
message M
90+
Message any
8791
}
8892

89-
func (StateTimeout[M]) IsAction() {}
93+
func (StateTimeout) isAction() {}
9094

91-
// state_timeout
92-
// timeout
95+
type ActiveStateTimeout struct {
96+
state gen.Atom
97+
timeout StateTimeout
98+
ctx context.Context
99+
cancel context.CancelFunc
100+
}
93101

94102
// Type alias for MessageHandler callbacks.
95103
// D is the type of the data associated with the StateMachine.
96104
// M is the type of the message this handler accepts.
97-
type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, error)
98-
99-
// new version with actions
100-
//type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, []Action, error)
105+
type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, []Action, error)
101106

102107
// Type alias for CallHandler callbacks.
103108
// D is the type of the data associated with the StateMachine.
104109
// M is the type of the message this handler accepts.
105110
// R is the type of the result value.
106-
type StateCallHandler[D any, M any, R any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, R, error)
111+
type StateCallHandler[D any, M any, R any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, R, []Action, error)
107112

108113
// Type alias for event handler callbacks.
109114
// D is the type of the data associated with the StateMachine.
110115
// E is the type of the event.
111-
type EventHandler[D any, E any] func(gen.Atom, D, E, gen.Process) (gen.Atom, D, error)
116+
type EventHandler[D any, E any] func(gen.Atom, D, E, gen.Process) (gen.Atom, D, []Action, error)
112117

113118
// Type alias for StateEnter callback.
114119
// D is the type of the data associated with the StateMachine.
@@ -137,7 +142,6 @@ func NewStateMachineSpec[D any](initialState gen.Atom, options ...Option[D]) Sta
137142
}
138143
return spec
139144
}
140-
141145
func WithData[D any](data D) Option[D] {
142146
return func(s *StateMachineSpec[D]) {
143147
s.data = data
@@ -182,10 +186,18 @@ func (s *StateMachine[D]) CurrentState() gen.Atom {
182186

183187
func (s *StateMachine[D]) SetCurrentState(state gen.Atom) {
184188
if state != s.currentState {
185-
s.Log().Info("setting current state to %v", state)
189+
s.Log().Info("StateMachine: switching to state %s", state)
186190
oldState := s.currentState
187191
s.currentState = state
188192

193+
// If there is a state timeout set up for the new state then we have
194+
// just registered this timeout in `ProcessActions` and we should not
195+
// touch it. Otherwise we should cancel the active state timeout if there
196+
// is one.
197+
if s.hasActiveStateTimeout() && s.activeStateTimeout.state != state {
198+
s.Log().Info("StateMachine: canceling state timeout for state %s", state)
199+
s.activeStateTimeout.cancel()
200+
}
189201
// Execute state enter callback until no new transition is triggered.
190202
if s.stateEnterCallback != nil {
191203
newState, newData, err := s.stateEnterCallback(oldState, state, s.data, s)
@@ -206,6 +218,10 @@ func (s *StateMachine[D]) SetData(data D) {
206218
s.data = data
207219
}
208220

221+
func (s *StateMachine[D]) hasActiveStateTimeout() bool {
222+
return s.activeStateTimeout != nil && s.activeStateTimeout.ctx.Err() == nil
223+
}
224+
209225
type startMonitoringEvents struct{}
210226

211227
//
@@ -246,10 +262,12 @@ func (sm *StateMachine[D]) ProcessInit(process gen.Process, args ...any) (rr err
246262
sm.eventHandlers = spec.eventHandlers
247263
sm.stateEnterCallback = spec.stateEnterCallback
248264

249-
// if we have event handlers we need to start listening for events
265+
// Send a message to ourselves to start monitoring events if there are
266+
// event handlers registerd.
250267
if len(sm.eventHandlers) > 0 {
251268
sm.Send(sm.PID(), startMonitoringEvents{})
252269
}
270+
sm.Log().Info("StateMachine: started in state %s", sm.currentState)
253271

254272
return nil
255273
}
@@ -320,7 +338,7 @@ func (sm *StateMachine[D]) ProcessRun() (rr error) {
320338
panic(fmt.Sprintf("Error monitoring event: %v.", err))
321339
}
322340
}
323-
sm.Log().Info("StateMachine %s is now monitoring events", sm.PID())
341+
sm.Log().Info("StateMachine: monitoring events")
324342
return nil
325343

326344
default:
@@ -423,6 +441,43 @@ func (s *StateMachine[D]) Terminate(reason error) {}
423441
// Internals
424442
//
425443

444+
func (sm *StateMachine[D]) ProcessActions(actions []Action, state gen.Atom) {
445+
for _, action := range actions {
446+
switch action := action.(type) {
447+
case StateTimeout:
448+
if sm.hasActiveStateTimeout() {
449+
sm.activeStateTimeout.cancel()
450+
}
451+
ctx, cancel := context.WithTimeout(context.Background(), action.Duration)
452+
sm.activeStateTimeout = &ActiveStateTimeout{
453+
state: state,
454+
timeout: action,
455+
ctx: ctx,
456+
cancel: cancel,
457+
}
458+
go startStateTimeout(ctx, state, action.Message, sm)
459+
return
460+
default:
461+
panic("unsupported action")
462+
}
463+
}
464+
}
465+
466+
func startStateTimeout(ctx context.Context, state gen.Atom, message any, proc gen.Process) {
467+
select {
468+
case <-ctx.Done():
469+
switch ctx.Err() {
470+
case context.DeadlineExceeded:
471+
proc.Log().Info("StateMachine: state timeout for state %s timed out", state)
472+
proc.Send(proc.PID(), message)
473+
return
474+
case context.Canceled:
475+
proc.Log().Info("StateMachine: state timeout for state %s canceled", state)
476+
return
477+
}
478+
}
479+
}
480+
426481
func (sm *StateMachine[D]) lookupMessageHandler(messageType string) (any, bool) {
427482
if stateMessageHandlers, exists := sm.stateMessageHandlers[sm.currentState]; exists == true {
428483
if callback, exists := stateMessageHandlers[messageType]; exists == true {
@@ -433,29 +488,20 @@ func (sm *StateMachine[D]) lookupMessageHandler(messageType string) (any, bool)
433488
}
434489

435490
func (sm *StateMachine[D]) invokeMessageHandler(handler any, message *gen.MailboxMessage) error {
491+
stateMachineValue := reflect.ValueOf(sm)
436492
callbackValue := reflect.ValueOf(handler)
437493
stateValue := reflect.ValueOf(sm.currentState)
438494
dataValue := reflect.ValueOf(sm.Data())
439495
msgValue := reflect.ValueOf(message.Message)
440-
procValue := reflect.ValueOf(sm)
496+
messageType := reflect.TypeOf(message).String()
441497

442-
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, procValue})
498+
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, stateMachineValue})
443499

444-
if len(results) != 3 {
445-
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
446-
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
447-
return gen.TerminateReasonPanic
448-
}
449-
if !results[2].IsNil() {
450-
return results[2].Interface().(error)
500+
validateResultSize(results, 4, messageType)
501+
if isError, err := resultIsError(results); isError == true {
502+
return err
451503
}
452-
453-
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
454-
setDataMethod.Call([]reflect.Value{results[1]})
455-
// It is important that we set the state last as this can potentially trigger
456-
// a state enter callback
457-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
458-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
504+
updateStateMachineWithResults(stateMachineValue, results)
459505

460506
return nil
461507
}
@@ -470,63 +516,100 @@ func (sm *StateMachine[D]) lookupCallHandler(messageType string) (any, bool) {
470516
}
471517

472518
func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMessage) (any, error) {
519+
stateMachineValue := reflect.ValueOf(sm)
473520
callbackValue := reflect.ValueOf(handler)
474521
stateValue := reflect.ValueOf(sm.currentState)
475522
dataValue := reflect.ValueOf(sm.Data())
476523
msgValue := reflect.ValueOf(message.Message)
477-
procValue := reflect.ValueOf(sm)
478-
479-
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, procValue})
524+
messageType := reflect.TypeOf(message).String()
480525

481-
if len(results) != 4 {
482-
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
483-
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
484-
return nil, gen.TerminateReasonPanic
485-
}
526+
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, stateMachineValue})
486527

487-
if !results[3].IsNil() {
488-
err := results[3].Interface().(error)
528+
validateResultSize(results, 5, messageType)
529+
if isError, err := resultIsError(results); isError == true {
489530
return nil, err
490531
}
491-
492-
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
493-
setDataMethod.Call([]reflect.Value{results[1]})
494-
// It is important that we set the state last as this can potentially trigger
495-
// a state enter callback
496-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
497-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
498-
532+
updateStateMachineWithResults(stateMachineValue, results)
499533
result := results[2].Interface()
500534

501535
return result, nil
502536
}
503537

504538
func (sm *StateMachine[D]) invokeEventHandler(handler any, message *gen.MessageEvent) error {
539+
stateMachineValue := reflect.ValueOf(sm)
505540
callbackValue := reflect.ValueOf(handler)
506541
stateValue := reflect.ValueOf(sm.currentState)
507542
dataValue := reflect.ValueOf(sm.Data())
508543
msgValue := reflect.ValueOf(message.Message)
509-
procValue := reflect.ValueOf(sm)
544+
messageType := reflect.TypeOf(message).String()
510545

511-
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, procValue})
546+
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, stateMachineValue})
512547

513-
if len(results) != 3 {
514-
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
515-
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
516-
return gen.TerminateReasonPanic
548+
validateResultSize(results, 4, messageType)
549+
if isError, err := resultIsError(results); isError == true {
550+
return err
517551
}
552+
updateStateMachineWithResults(stateMachineValue, results)
518553

519-
if !results[2].IsNil() {
520-
err := results[2].Interface().(error)
521-
return err
554+
return nil
555+
}
556+
557+
func validateResultSize(results []reflect.Value, expectedSize int, messageType string) {
558+
if len(results) != expectedSize {
559+
panic(fmt.Sprintf("StateMachine terminated. Panic reason: unexpected "+
560+
"error when invoking call handler for %s", messageType))
522561
}
562+
}
523563

524-
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
564+
func resultIsError(results []reflect.Value) (bool, error) {
565+
errIndex := len(results) - 1
566+
if !results[errIndex].IsNil() {
567+
err := results[errIndex].Interface().(error)
568+
return true, err
569+
}
570+
return false, nil
571+
}
572+
573+
func updateStateMachineWithResults(sm reflect.Value, results []reflect.Value) {
574+
// Check if any actions were returned. MessageHandler and EventHandler have
575+
// the result tuple (gen.Atom, D, []Action, error) with the actions at index
576+
// 2. CallHandler has the result typle (gen.Atom, D, R, []Action, error)
577+
// with the actions at index 3.
578+
var actionsIndex int
579+
hasResult := len(results) == 5
580+
if hasResult {
581+
actionsIndex = 3
582+
} else {
583+
actionsIndex = 2
584+
}
585+
if !isSliceNilOrEmpty(results[actionsIndex]) {
586+
processActionsMethod := sm.MethodByName("ProcessActions")
587+
if processActionsMethod.IsNil() {
588+
}
589+
processActionsMethod.Call([]reflect.Value{results[actionsIndex], results[0]})
590+
}
591+
592+
// Update the data
593+
setDataMethod := sm.MethodByName("SetData")
525594
setDataMethod.Call([]reflect.Value{results[1]})
595+
526596
// It is important that we set the state last as this can potentially trigger
527-
// a state enter callback
528-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
597+
// a state enter callback. By design state enter callbacks are triggered
598+
// after setting up state timeouts as state timeouts are tied to te state
599+
// they are defined for. A state enter callback could transition to another
600+
// state which then will cancel the state timeout.
601+
setCurrentStateMethod := sm.MethodByName("SetCurrentState")
529602
setCurrentStateMethod.Call([]reflect.Value{results[0]})
603+
}
530604

531-
return nil
605+
func isSliceNilOrEmpty(resultValue reflect.Value) bool {
606+
if resultValue.IsNil() {
607+
return true
608+
}
609+
610+
if resultValue.Len() == 0 {
611+
return true
612+
}
613+
614+
return false
532615
}

tests/001_local/t018_statemachine_test.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ func (t *t18) HandleMessage(from gen.PID, message any) error {
4040
return nil
4141
}
4242

43-
// Test state transitions with messages and calls
4443
func factory_t18_state_transitions() gen.ProcessBehavior {
4544
return &t18_state_transitions{}
4645
}
@@ -74,13 +73,13 @@ func (sm *t18_state_transitions) Init(args ...any) (act.StateMachineSpec[t18_sta
7473
return spec, nil
7574
}
7675

77-
func t18_move_to_state2(state gen.Atom, data t18_state_transitions_data, message t18_state2, proc gen.Process) (gen.Atom, t18_state_transitions_data, error) {
76+
func t18_move_to_state2(state gen.Atom, data t18_state_transitions_data, message t18_state2, proc gen.Process) (gen.Atom, t18_state_transitions_data, []act.Action, error) {
7877
data.transitions++
79-
return gen.Atom("state2"), data, nil
78+
return gen.Atom("state2"), data, nil, nil
8079
}
8180

82-
func t18_total_transitions(state gen.Atom, data t18_state_transitions_data, message t18_get_transitions, proc gen.Process) (gen.Atom, t18_state_transitions_data, int, error) {
83-
return state, data, data.transitions, nil
81+
func t18_total_transitions(state gen.Atom, data t18_state_transitions_data, message t18_get_transitions, proc gen.Process) (gen.Atom, t18_state_transitions_data, int, []act.Action, error) {
82+
return state, data, data.transitions, nil, nil
8483
}
8584

8685
func (t *t18) TestStateMachine(input any) {
@@ -130,9 +129,9 @@ func (t *t18) TestStateMachine(input any) {
130129
t.testcase.err <- nil
131130
}
132131

133-
func TestTt18template(t *testing.T) {
132+
func TestT18StateMachine(t *testing.T) {
134133
nopt := gen.NodeOptions{}
135-
//nopt.Log.DefaultLogger.Disable = true
134+
nopt.Log.DefaultLogger.Disable = true
136135
//nopt.Log.Level = gen.LogLevelTrace
137136
node, err := ergo.StartNode("t18node@localhost", nopt)
138137
if err != nil {

0 commit comments

Comments
 (0)