Skip to content

Commit ed02bf2

Browse files
committed
Support for state enter callbacks
1 parent 4b9758e commit ed02bf2

File tree

2 files changed

+87
-26
lines changed

2 files changed

+87
-26
lines changed

act/statemachine.go

+39-7
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,14 @@ type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.At
7777
// R is the type of the result value.
7878
type StateCallHandler[D any, M any, R any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, R, error)
7979

80+
type StateEnterCallback[D any] func(gen.Atom, D, gen.Process) (gen.Atom, D, error)
81+
8082
type StateMachineSpec[D any] struct {
8183
initialState gen.Atom
8284
data D
8385
stateMessageHandlers map[gen.Atom]map[string]any
8486
stateCallHandlers map[gen.Atom]map[string]any
87+
stateEnterCallbacks map[gen.Atom]StateEnterCallback[D]
8588
}
8689

8790
type Option[D any] func(*StateMachineSpec[D])
@@ -91,6 +94,7 @@ func NewStateMachineSpec[D any](initialState gen.Atom, options ...Option[D]) Sta
9194
initialState: initialState,
9295
stateMessageHandlers: make(map[gen.Atom]map[string]any),
9396
stateCallHandlers: make(map[gen.Atom]map[string]any),
97+
stateEnterCallbacks: make(map[gen.Atom]StateEnterCallback[D]),
9498
}
9599
for _, opt := range options {
96100
opt(&spec)
@@ -124,12 +128,32 @@ func WithStateCallHandler[D any, M any, R any](state gen.Atom, handler StateCall
124128
}
125129
}
126130

131+
func WithStateEnterCallback[D any](state gen.Atom, callback StateEnterCallback[D]) Option[D] {
132+
return func(s *StateMachineSpec[D]) {
133+
s.stateEnterCallbacks[state] = callback
134+
}
135+
}
136+
127137
func (s *StateMachine[D]) CurrentState() gen.Atom {
128138
return s.currentState
129139
}
130140

131141
func (s *StateMachine[D]) SetCurrentState(state gen.Atom) {
132-
s.currentState = state
142+
if state != s.currentState {
143+
s.Log().Info("setting current state to %v", state)
144+
s.currentState = state
145+
146+
// Execute state enter callback if one exist for the state and continue
147+
// this chain until no new transition is triggered.
148+
if callback, exists := s.spec.stateEnterCallbacks[state]; exists == true {
149+
newState, newData, err := callback(state, s.data, s)
150+
if err != nil {
151+
panic(fmt.Sprintf("error in StateEnterCallback for state %s", state))
152+
}
153+
s.SetData(newData)
154+
s.SetCurrentState(newState)
155+
}
156+
}
133157
}
134158

135159
func (s *StateMachine[D]) Data() D {
@@ -175,6 +199,7 @@ func (sm *StateMachine[D]) ProcessInit(process gen.Process, args ...any) (rr err
175199
sm.data = spec.data
176200
sm.stateMessageHandlers = spec.stateMessageHandlers
177201
sm.stateCallHandlers = spec.stateCallHandlers
202+
sm.spec.stateEnterCallbacks = spec.stateEnterCallbacks
178203

179204
return nil
180205
}
@@ -357,15 +382,18 @@ func (sm *StateMachine[D]) invokeMessageHandler(handler any, message *gen.Mailbo
357382
if len(results) != 3 {
358383
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
359384
"error when invoking call handler for %v", typeName(message))
385+
return gen.TerminateReasonPanic
360386
}
361387
if !results[2].IsNil() {
362388
return results[2].Interface().(error)
363389
}
364-
//TODO: panic if new state or new data is not provided
365-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
366-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
390+
367391
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
368392
setDataMethod.Call([]reflect.Value{results[1]})
393+
// It is important that we set the state last as this can potentially trigger
394+
// a state enter callback
395+
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
396+
setCurrentStateMethod.Call([]reflect.Value{results[0]})
369397

370398
return nil
371399
}
@@ -391,18 +419,22 @@ func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMe
391419
if len(results) != 4 {
392420
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
393421
"error when invoking call handler for %v", typeName(message))
422+
return nil, gen.TerminateReasonPanic
394423
}
395424

396425
if !results[3].IsNil() {
397426
err := results[1].Interface().(error)
398427
return nil, err
399428
}
400-
//TODO: panic if new state or new data is not provided
401-
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
402-
setCurrentStateMethod.Call([]reflect.Value{results[0]})
429+
403430
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
404431
setDataMethod.Call([]reflect.Value{results[1]})
432+
// It is important that we set the state last as this can potentially trigger
433+
// a state enter callback
434+
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
435+
setCurrentStateMethod.Call([]reflect.Value{results[0]})
405436

406437
result := results[2].Interface()
438+
407439
return result, nil
408440
}

tests/001_local/t018_statemachine_test.go

+48-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ type t18statemachine struct {
5454
}
5555

5656
type t18data struct {
57-
count int
57+
transitions int
58+
stateEnterCallbacks int
5859
}
5960

6061
type t18transitionState1toState2 struct {
@@ -63,29 +64,48 @@ type t18transitionState1toState2 struct {
6364
type t18transitionState2toState1 struct {
6465
}
6566

67+
type t18query struct {
68+
}
69+
6670
func (sm *t18statemachine) Init(args ...any) (act.StateMachineSpec[t18data], error) {
6771
spec := act.NewStateMachineSpec(gen.Atom("state1"),
6872
// initial data
69-
act.WithData(t18data{count: 1}),
73+
act.WithData(t18data{}),
7074

7175
// set up a message handler for the transition state1 -> state2
7276
act.WithStateMessageHandler(gen.Atom("state1"), state1to2),
7377

74-
// set up a call handler for the transition state2 -> state1
75-
act.WithStateCallHandler(gen.Atom("state2"), state2to1),
78+
// set up a call handler to query the data
79+
act.WithStateCallHandler(gen.Atom("state3"), queryData),
80+
81+
// set up a state enter callback for state2
82+
act.WithStateEnterCallback(gen.Atom("state2"), state2Enter),
83+
84+
// set up a state enter callback for state3 to test chaining state enter callbacks
85+
act.WithStateEnterCallback(gen.Atom("state3"), state3Enter),
7686
)
7787

7888
return spec, nil
7989
}
8090

8191
func state1to2(state gen.Atom, data t18data, message t18transitionState1toState2, proc gen.Process) (gen.Atom, t18data, error) {
82-
data.count++
92+
data.transitions++
8393
return gen.Atom("state2"), data, nil
8494
}
8595

86-
func state2to1(state gen.Atom, data t18data, message t18transitionState2toState1, proc gen.Process) (gen.Atom, t18data, int, error) {
87-
data.count++
88-
return gen.Atom("state1"), data, data.count, nil
96+
func queryData(state gen.Atom, data t18data, message t18query, proc gen.Process) (gen.Atom, t18data, t18data, error) {
97+
return state, data, data, nil
98+
}
99+
100+
func state2Enter(state gen.Atom, data t18data, proc gen.Process) (gen.Atom, t18data, error) {
101+
data.stateEnterCallbacks++
102+
data.transitions++
103+
return gen.Atom("state3"), data, nil
104+
}
105+
106+
func state3Enter(state gen.Atom, data t18data, proc gen.Process) (gen.Atom, t18data, error) {
107+
data.stateEnterCallbacks++
108+
return state, data, nil
89109
}
90110

91111
func (t *t18) TestStateMachine(input any) {
@@ -100,31 +120,40 @@ func (t *t18) TestStateMachine(input any) {
100120
return
101121
}
102122

103-
// send message to transition from state 1 to 2
123+
// Send a message to transition to state 2. The state enter callback should
124+
// automatically transition to state 3 where another state enter callback
125+
// does not trigger any further state transitions.
104126
err = t.Send(pid, t18transitionState1toState2{})
105-
106127
if err != nil {
107-
t.Log().Error("sending to the statemachine process failed: %s", err)
128+
t.Log().Error("send 't18transitionState1toState2' failed: %s", err)
108129
t.testcase.err <- err
109130
return
110131
}
111132

112-
// send call to transition from result 2 to 1
113-
result, err := t.Call(pid, t18transitionState2toState1{})
133+
// Query the data from the state machine (and test StateCallHandler behavior)
134+
result, err := t.Call(pid, t18query{})
114135
if err != nil {
115-
t.Log().Error("call to the statemachine process failed: %s", err)
136+
t.Log().Error("call 't18query' failed: %s", err)
116137
t.testcase.err <- err
117138
return
118139
}
119-
// initial count was 1, after 2 state transitions we expect the count to be 3
120-
if result != 3 {
121-
t.testcase.err <- fmt.Errorf("expected 3, got %v", result)
140+
141+
// We expect 2 state transitions (state1 -> state2 -> state3)
142+
data := result.(t18data)
143+
if data.transitions != 2 {
144+
t.testcase.err <- fmt.Errorf("expected 2 state transitions, got %v", result)
122145
return
123146
}
124147

125-
// statemachine process should crash on invalid state transition
148+
// We expect a chain of 2 state enter callback functions to be called, one
149+
// for state2 and one for state3.
150+
if data.stateEnterCallbacks != 2 {
151+
t.testcase.err <- fmt.Errorf("expected 2 state enter function invocations, got %d", data.stateEnterCallbacks)
152+
}
153+
154+
// Statemachine process should crash on invalid state transition
126155
err = t.testcase.expectProcessToTerminate(pid, t, func(p gen.Process) error {
127-
return p.Send(pid, t18transitionState2toState1{}) // we are in state1
156+
return p.Send(pid, t18transitionState2toState1{}) // we are in state3
128157
})
129158
if err != nil {
130159
t.testcase.err <- err

0 commit comments

Comments
 (0)