Skip to content

Commit 61a4fb6

Browse files
committed
Support for state enter callbacks
1 parent 4b9758e commit 61a4fb6

File tree

2 files changed

+90
-26
lines changed

2 files changed

+90
-26
lines changed

act/statemachine.go

+40-7
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ type StateMachine[D any] struct {
6464
// Value: The message handler (any). There is a compile-time guarantee
6565
// that the handler is of type StateCallHandler[D, M, R].
6666
stateCallHandlers map[gen.Atom]map[string]any
67+
68+
stateEnterCallback StateEnterCallback[D]
6769
}
6870

6971
// Type alias for MessageHandler callbacks.
@@ -77,11 +79,14 @@ type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.At
7779
// R is the type of the result value.
7880
type StateCallHandler[D any, M any, R any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, R, error)
7981

82+
type StateEnterCallback[D any] func(gen.Atom, gen.Atom, D, gen.Process) (gen.Atom, D, error)
83+
8084
type StateMachineSpec[D any] struct {
8185
initialState gen.Atom
8286
data D
8387
stateMessageHandlers map[gen.Atom]map[string]any
8488
stateCallHandlers map[gen.Atom]map[string]any
89+
stateEnterCallback StateEnterCallback[D]
8590
}
8691

8792
type Option[D any] func(*StateMachineSpec[D])
@@ -124,12 +129,32 @@ func WithStateCallHandler[D any, M any, R any](state gen.Atom, handler StateCall
124129
}
125130
}
126131

132+
func WithStateEnterCallback[D any](callback StateEnterCallback[D]) Option[D] {
133+
return func(s *StateMachineSpec[D]) {
134+
s.stateEnterCallback = callback
135+
}
136+
}
137+
127138
func (s *StateMachine[D]) CurrentState() gen.Atom {
128139
return s.currentState
129140
}
130141

131142
func (s *StateMachine[D]) SetCurrentState(state gen.Atom) {
132-
s.currentState = state
143+
if state != s.currentState {
144+
s.Log().Info("setting current state to %v", state)
145+
oldState := s.currentState
146+
s.currentState = state
147+
148+
// Execute state enter callback until no new transition is triggered.
149+
if s.stateEnterCallback != nil {
150+
newState, newData, err := s.stateEnterCallback(oldState, state, s.data, s)
151+
if err != nil {
152+
panic(fmt.Sprintf("error in StateEnterCallback for state %s", state))
153+
}
154+
s.SetData(newData)
155+
s.SetCurrentState(newState)
156+
}
157+
}
133158
}
134159

135160
func (s *StateMachine[D]) Data() D {
@@ -175,6 +200,7 @@ func (sm *StateMachine[D]) ProcessInit(process gen.Process, args ...any) (rr err
175200
sm.data = spec.data
176201
sm.stateMessageHandlers = spec.stateMessageHandlers
177202
sm.stateCallHandlers = spec.stateCallHandlers
203+
sm.stateEnterCallback = spec.stateEnterCallback
178204

179205
return nil
180206
}
@@ -357,15 +383,18 @@ func (sm *StateMachine[D]) invokeMessageHandler(handler any, message *gen.Mailbo
357383
if len(results) != 3 {
358384
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
359385
"error when invoking call handler for %v", typeName(message))
386+
return gen.TerminateReasonPanic
360387
}
361388
if !results[2].IsNil() {
362389
return results[2].Interface().(error)
363390
}
364-
//TODO: panic if new state or new data is not provided
365-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
366-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
391+
367392
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
368393
setDataMethod.Call([]reflect.Value{results[1]})
394+
// It is important that we set the state last as this can potentially trigger
395+
// a state enter callback
396+
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
397+
setCurrentStateMethod.Call([]reflect.Value{results[0]})
369398

370399
return nil
371400
}
@@ -391,18 +420,22 @@ func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMe
391420
if len(results) != 4 {
392421
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
393422
"error when invoking call handler for %v", typeName(message))
423+
return nil, gen.TerminateReasonPanic
394424
}
395425

396426
if !results[3].IsNil() {
397427
err := results[1].Interface().(error)
398428
return nil, err
399429
}
400-
//TODO: panic if new state or new data is not provided
401-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
402-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
430+
403431
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
404432
setDataMethod.Call([]reflect.Value{results[1]})
433+
// It is important that we set the state last as this can potentially trigger
434+
// a state enter callback
435+
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
436+
setCurrentStateMethod.Call([]reflect.Value{results[0]})
405437

406438
result := results[2].Interface()
439+
407440
return result, nil
408441
}

tests/001_local/t018_statemachine_test.go

+50-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ type t18statemachine struct {
5454
}
5555

5656
type t18data struct {
57-
count int
57+
transitions int
58+
stateEnterCallbacks int
5859
}
5960

6061
type t18transitionState1toState2 struct {
@@ -63,29 +64,50 @@ type t18transitionState1toState2 struct {
6364
type t18transitionState2toState1 struct {
6465
}
6566

67+
type t18query struct {
68+
}
69+
6670
func (sm *t18statemachine) Init(args ...any) (act.StateMachineSpec[t18data], error) {
6771
spec := act.NewStateMachineSpec(gen.Atom("state1"),
6872
// initial data
69-
act.WithData(t18data{count: 1}),
73+
act.WithData(t18data{}),
7074

7175
// set up a message handler for the transition state1 -> state2
7276
act.WithStateMessageHandler(gen.Atom("state1"), state1to2),
7377

74-
// set up a call handler for the transition state2 -> state1
75-
act.WithStateCallHandler(gen.Atom("state2"), state2to1),
78+
// set up a call handler to query the data
79+
act.WithStateCallHandler(gen.Atom("state3"), queryData),
80+
81+
// set up a state enter callback
82+
act.WithStateEnterCallback(stateEnter),
7683
)
7784

7885
return spec, nil
7986
}
8087

8188
func state1to2(state gen.Atom, data t18data, message t18transitionState1toState2, proc gen.Process) (gen.Atom, t18data, error) {
82-
data.count++
89+
data.transitions++
8390
return gen.Atom("state2"), data, nil
8491
}
8592

86-
func state2to1(state gen.Atom, data t18data, message t18transitionState2toState1, proc gen.Process) (gen.Atom, t18data, int, error) {
87-
data.count++
88-
return gen.Atom("state1"), data, data.count, nil
93+
func queryData(state gen.Atom, data t18data, message t18query, proc gen.Process) (gen.Atom, t18data, t18data, error) {
94+
return state, data, data, nil
95+
}
96+
97+
func stateEnter(oldState gen.Atom, newState gen.Atom, data t18data, proc gen.Process) (gen.Atom, t18data, error) {
98+
data.stateEnterCallbacks++
99+
100+
if newState == gen.Atom("state2") {
101+
data.transitions++
102+
return gen.Atom("state3"), data, nil
103+
104+
}
105+
return newState, data, nil
106+
}
107+
108+
func state3Enter(state gen.Atom, data t18data, proc gen.Process) (gen.Atom, t18data, error) {
109+
data.stateEnterCallbacks++
110+
return state, data, nil
89111
}
90112

91113
func (t *t18) TestStateMachine(input any) {
@@ -100,31 +122,40 @@ func (t *t18) TestStateMachine(input any) {
100122
return
101123
}
102124

103-
// send message to transition from state 1 to 2
125+
// Send a message to transition to state 2. The state enter callback should
126+
// automatically transition to state 3 where another state enter callback
127+
// does not trigger any further state transitions.
104128
err = t.Send(pid, t18transitionState1toState2{})
105-
106129
if err != nil {
107-
t.Log().Error("sending to the statemachine process failed: %s", err)
130+
t.Log().Error("send 't18transitionState1toState2' failed: %s", err)
108131
t.testcase.err <- err
109132
return
110133
}
111134

112-
// send call to transition from result 2 to 1
113-
result, err := t.Call(pid, t18transitionState2toState1{})
135+
// Query the data from the state machine (and test StateCallHandler behavior)
136+
result, err := t.Call(pid, t18query{})
114137
if err != nil {
115-
t.Log().Error("call to the statemachine process failed: %s", err)
138+
t.Log().Error("call 't18query' failed: %s", err)
116139
t.testcase.err <- err
117140
return
118141
}
119-
// initial count was 1, after 2 state transitions we expect the count to be 3
120-
if result != 3 {
121-
t.testcase.err <- fmt.Errorf("expected 3, got %v", result)
142+
143+
// We expect 2 state transitions (state1 -> state2 -> state3)
144+
data := result.(t18data)
145+
if data.transitions != 2 {
146+
t.testcase.err <- fmt.Errorf("expected 2 state transitions, got %v", result)
122147
return
123148
}
124149

125-
// statemachine process should crash on invalid state transition
150+
// We expect a chain of 2 state enter callback functions to be called, one
151+
// for state2 and one for state3.
152+
if data.stateEnterCallbacks != 2 {
153+
t.testcase.err <- fmt.Errorf("expected 2 state enter function invocations, got %d", data.stateEnterCallbacks)
154+
}
155+
156+
// Statemachine process should crash on invalid state transition
126157
err = t.testcase.expectProcessToTerminate(pid, t, func(p gen.Process) error {
127-
return p.Send(pid, t18transitionState2toState1{}) // we are in state1
158+
return p.Send(pid, t18transitionState2toState1{}) // we are in state3
128159
})
129160
if err != nil {
130161
t.testcase.err <- err

0 commit comments

Comments
 (0)