Skip to content

Commit ecd7bcd

Browse files
committed
Support for events
1 parent 04b1f9c commit ecd7bcd

File tree

2 files changed

+122
-18
lines changed

2 files changed

+122
-18
lines changed

act/statemachine.go

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ type StateMachine[D any] struct {
6565
// that the handler is of type StateCallHandler[D, M, R].
6666
stateCallHandlers map[gen.Atom]map[string]any
6767

68+
// eventHandlers maps events to the handler for the event.
69+
// Key: Event name (gen.Atom) - The name of the event
70+
// Value: The event handler (any). There is a compile-time guarantee that
71+
// the handler is of type EventHandler[D, E]
72+
eventHandlers map[gen.Event]any
73+
74+
// Callback that is invoked immediately after every state change. If no
75+
// callback is registered stateEnterCallback is nil.
6876
stateEnterCallback StateEnterCallback[D]
6977
}
7078

@@ -79,13 +87,21 @@ type StateMessageHandler[D any, M any] func(gen.Atom, D, M, gen.Process) (gen.At
7987
// R is the type of the result value.
8088
type StateCallHandler[D any, M any, R any] func(gen.Atom, D, M, gen.Process) (gen.Atom, D, R, error)
8189

90+
// Type alias for event handler callbacks.
91+
// D is the type of the data associated with the StateMachine.
92+
// E is the type of the event.
93+
type EventHandler[D any, E any] func(gen.Atom, D, E, gen.Process) (gen.Atom, D, error)
94+
95+
// Type alias for StateEnter callback.
96+
// D is the type of the data associated with the StateMachine.
8297
type StateEnterCallback[D any] func(gen.Atom, gen.Atom, D, gen.Process) (gen.Atom, D, error)
8398

8499
type StateMachineSpec[D any] struct {
85100
initialState gen.Atom
86101
data D
87102
stateMessageHandlers map[gen.Atom]map[string]any
88103
stateCallHandlers map[gen.Atom]map[string]any
104+
eventHandlers map[gen.Event]any
89105
stateEnterCallback StateEnterCallback[D]
90106
}
91107

@@ -96,6 +112,7 @@ func NewStateMachineSpec[D any](initialState gen.Atom, options ...Option[D]) Sta
96112
initialState: initialState,
97113
stateMessageHandlers: make(map[gen.Atom]map[string]any),
98114
stateCallHandlers: make(map[gen.Atom]map[string]any),
115+
eventHandlers: make(map[gen.Event]any),
99116
}
100117
for _, opt := range options {
101118
opt(&spec)
@@ -135,6 +152,12 @@ func WithStateEnterCallback[D any](callback StateEnterCallback[D]) Option[D] {
135152
}
136153
}
137154

155+
func WithEventHandler[D any, E any](event gen.Event, handler EventHandler[D, E]) Option[D] {
156+
return func(s *StateMachineSpec[D]) {
157+
s.eventHandlers[event] = handler
158+
}
159+
}
160+
138161
func (s *StateMachine[D]) CurrentState() gen.Atom {
139162
return s.currentState
140163
}
@@ -165,6 +188,8 @@ func (s *StateMachine[D]) SetData(data D) {
165188
s.data = data
166189
}
167190

191+
type startMonitoringEvents struct{}
192+
168193
//
169194
// ProcessBehavior implementation
170195
//
@@ -200,8 +225,14 @@ func (sm *StateMachine[D]) ProcessInit(process gen.Process, args ...any) (rr err
200225
sm.data = spec.data
201226
sm.stateMessageHandlers = spec.stateMessageHandlers
202227
sm.stateCallHandlers = spec.stateCallHandlers
228+
sm.eventHandlers = spec.eventHandlers
203229
sm.stateEnterCallback = spec.stateEnterCallback
204230

231+
// if we have event handlers we need to start listening for events
232+
if len(sm.eventHandlers) > 0 {
233+
sm.Send(sm.PID(), startMonitoringEvents{})
234+
}
235+
205236
return nil
206237
}
207238

@@ -263,20 +294,33 @@ func (sm *StateMachine[D]) ProcessRun() (rr error) {
263294

264295
switch message.Type {
265296
case gen.MailboxMessageTypeRegular:
266-
// check if there is a handler for the message in the current state
267-
messageType := typeName(message)
268-
handler, ok := sm.lookupMessageHandler(messageType)
269-
if ok == false {
270-
return fmt.Errorf("No handler for message %s in state %s", messageType, sm.currentState)
297+
switch message.Message.(type) {
298+
case startMonitoringEvents:
299+
// start monitoring
300+
for event := range sm.eventHandlers {
301+
if _, err := sm.MonitorEvent(event); err != nil {
302+
panic(fmt.Sprintf("Error monitoring event: %v.", err))
303+
}
304+
}
305+
sm.Log().Info("StateMachine %s is now monitoring events", sm.PID)
306+
return nil
307+
308+
default:
309+
// check if there is a handler for the message in the current state
310+
messageType := reflect.TypeOf(message.Message).String()
311+
handler, ok := sm.lookupMessageHandler(messageType)
312+
if ok == false {
313+
return fmt.Errorf("No handler for message %s in state %s", messageType, sm.currentState)
314+
}
315+
return sm.invokeMessageHandler(handler, message)
271316
}
272-
return sm.invokeMessageHandler(handler, message)
273317

274318
case gen.MailboxMessageTypeRequest:
275319
var reason error
276320
var result any
277321

278322
// check if there is a handler for the call in the current state
279-
messageType := typeName(message)
323+
messageType := reflect.TypeOf(message.Message).String()
280324
handler, ok := sm.lookupCallHandler(messageType)
281325
if ok == false {
282326
return fmt.Errorf("No handler for message %s in state %s", messageType, sm.currentState)
@@ -294,9 +338,12 @@ func (sm *StateMachine[D]) ProcessRun() (rr error) {
294338
sm.SendResponse(message.From, message.Ref, result)
295339

296340
case gen.MailboxMessageTypeEvent:
297-
if reason := sm.behavior.HandleEvent(message.Message.(gen.MessageEvent)); reason != nil {
298-
return reason
341+
event := message.Message.(gen.MessageEvent)
342+
handler, exists := sm.eventHandlers[event.Event]
343+
if exists == false {
344+
return fmt.Errorf("No handler for event %v", event)
299345
}
346+
return sm.invokeEventHandler(handler, &event)
300347

301348
case gen.MailboxMessageTypeExit:
302349
switch exit := message.Message.(type) {
@@ -358,10 +405,6 @@ func (s *StateMachine[D]) Terminate(reason error) {}
358405
// Internals
359406
//
360407

361-
func typeName(message *gen.MailboxMessage) string {
362-
return reflect.TypeOf(message.Message).String()
363-
}
364-
365408
func (sm *StateMachine[D]) lookupMessageHandler(messageType string) (any, bool) {
366409
if stateMessageHandlers, exists := sm.stateMessageHandlers[sm.currentState]; exists == true {
367410
if callback, exists := stateMessageHandlers[messageType]; exists == true {
@@ -382,7 +425,7 @@ func (sm *StateMachine[D]) invokeMessageHandler(handler any, message *gen.Mailbo
382425

383426
if len(results) != 3 {
384427
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
385-
"error when invoking call handler for %v", typeName(message))
428+
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
386429
return gen.TerminateReasonPanic
387430
}
388431
if !results[2].IsNil() {
@@ -419,12 +462,12 @@ func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMe
419462

420463
if len(results) != 4 {
421464
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
422-
"error when invoking call handler for %v", typeName(message))
465+
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
423466
return nil, gen.TerminateReasonPanic
424467
}
425468

426469
if !results[3].IsNil() {
427-
err := results[1].Interface().(error)
470+
err := results[3].Interface().(error)
428471
return nil, err
429472
}
430473

@@ -439,3 +482,33 @@ func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMe
439482

440483
return result, nil
441484
}
485+
486+
func (sm *StateMachine[D]) invokeEventHandler(handler any, message *gen.MessageEvent) error {
487+
callbackValue := reflect.ValueOf(handler)
488+
stateValue := reflect.ValueOf(sm.currentState)
489+
dataValue := reflect.ValueOf(sm.Data())
490+
msgValue := reflect.ValueOf(message.Message)
491+
procValue := reflect.ValueOf(sm)
492+
493+
results := callbackValue.Call([]reflect.Value{stateValue, dataValue, msgValue, procValue})
494+
495+
if len(results) != 3 {
496+
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
497+
"error when invoking call handler for %s", reflect.TypeOf(message.Message))
498+
return gen.TerminateReasonPanic
499+
}
500+
501+
if !results[2].IsNil() {
502+
err := results[2].Interface().(error)
503+
return err
504+
}
505+
506+
setDataMethod := reflect.ValueOf(sm).MethodByName("SetData")
507+
setDataMethod.Call([]reflect.Value{results[1]})
508+
// It is important that we set the state last as this can potentially trigger
509+
// a state enter callback
510+
setCurrentStateMethod := reflect.ValueOf(sm).MethodByName("SetCurrentState")
511+
setCurrentStateMethod.Call([]reflect.Value{results[0]})
512+
513+
return nil
514+
}

tests/001_local/t018_statemachine_test.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"ergo.services/ergo"
99
"ergo.services/ergo/act"
1010
"ergo.services/ergo/gen"
11+
"ergo.services/ergo/lib"
1112
)
1213

1314
//
@@ -56,6 +57,7 @@ type t18statemachine struct {
5657
type t18data struct {
5758
transitions int
5859
stateEnterCallbacks int
60+
testEventReceived bool
5961
}
6062

6163
type t18transitionState1toState2 struct {
@@ -67,6 +69,10 @@ type t18transitionState2toState1 struct {
6769
type t18query struct {
6870
}
6971

72+
type t18event struct {
73+
payload string
74+
}
75+
7076
func (sm *t18statemachine) Init(args ...any) (act.StateMachineSpec[t18data], error) {
7177
spec := act.NewStateMachineSpec(gen.Atom("state1"),
7278
// initial data
@@ -80,6 +86,9 @@ func (sm *t18statemachine) Init(args ...any) (act.StateMachineSpec[t18data], err
8086

8187
// set up a state enter callback
8288
act.WithStateEnterCallback(stateEnter),
89+
90+
// register event handler
91+
act.WithEventHandler(gen.Event{Name: "testEvent", Node: "t18node@localhost"}, handleTestEvent),
8392
)
8493

8594
return spec, nil
@@ -105,8 +114,8 @@ func stateEnter(oldState gen.Atom, newState gen.Atom, data t18data, proc gen.Pro
105114
return newState, data, nil
106115
}
107116

108-
func state3Enter(state gen.Atom, data t18data, proc gen.Process) (gen.Atom, t18data, error) {
109-
data.stateEnterCallbacks++
117+
func handleTestEvent(state gen.Atom, data t18data, event t18event, proc gen.Process) (gen.Atom, t18data, error) {
118+
data.testEventReceived = true
110119
return state, data, nil
111120
}
112121

@@ -115,6 +124,11 @@ func (t *t18) TestStateMachine(input any) {
115124
t.testcase = nil
116125
}()
117126

127+
// Register the event first, otherwise the StateMachine will not be able
128+
// to start monitoring.
129+
testEvent := gen.Atom("testEvent")
130+
token, err := t.RegisterEvent(testEvent, gen.EventOptions{})
131+
118132
pid, err := t.Spawn(factory_t18statemachine, gen.ProcessOptions{})
119133
if err != nil {
120134
t.Log().Error("unable to spawn statemachine process: %s", err)
@@ -151,6 +165,23 @@ func (t *t18) TestStateMachine(input any) {
151165
// for state2 and one for state3.
152166
if data.stateEnterCallbacks != 2 {
153167
t.testcase.err <- fmt.Errorf("expected 2 state enter function invocations, got %d", data.stateEnterCallbacks)
168+
return
169+
}
170+
171+
event := t18event{lib.RandomString(8)}
172+
t.SendEvent(testEvent, token, event)
173+
174+
// Query the data from the state machine
175+
result, err = t.Call(pid, t18query{})
176+
if err != nil {
177+
t.Log().Error("call 't18query' failed: %s", err)
178+
t.testcase.err <- err
179+
return
180+
}
181+
data = result.(t18data)
182+
if data.testEventReceived == false {
183+
t.testcase.err <- fmt.Errorf("expected test event to be received")
184+
return
154185
}
155186

156187
// Statemachine process should crash on invalid state transition

0 commit comments

Comments
 (0)