Skip to content

Commit c63516a

Browse files
committed
Cleanup
1 parent d4407e0 commit c63516a

File tree

2 files changed

+93
-58
lines changed

2 files changed

+93
-58
lines changed

act/statemachine.go

+76-47
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,30 @@ type StateMachine[D any] struct {
4040
behavior StateMachineBehavior[D]
4141
mailbox gen.ProcessMailbox
4242

43+
// The specification for the StateMachine
4344
spec StateMachineSpec[D]
4445

45-
currentState gen.Atom
46-
data D
46+
// The state the StateMachine is currently in
47+
currentState gen.Atom
48+
49+
// The data associated with the StateMachine
50+
data D
51+
52+
// stateMessageHandlers maps states to the (asynchronous) handlers for the state.
53+
// Key: State (gen.Atom) - The state for which the handler is registered.
54+
// Value: Map of message type to the handler for that message.
55+
// Key: The type of the message received (String).
56+
// Value: The message handler (any). There is a compile-time guarantee
57+
// that the handler is of type StateMessageHandler[D, M].
4758
stateMessageHandlers map[gen.Atom]map[string]any
48-
stateCallHandlers map[gen.Atom]map[string]any
59+
60+
// stateCallHandlers maps states to the (synchronous) handlers for the state.
61+
// Key: State (gen.Atom) - The state for which the handler is registered.
62+
// Value: Map of message type to the handler for that message.
63+
// Key: The type of the message received (String).
64+
// Value: The message handler (any). There is a compile-time guarantee
65+
// that the handler is of type StateCallHandler[D, M, R].
66+
stateCallHandlers map[gen.Atom]map[string]any
4967
}
5068

5169
type StateMessageHandler[D any, M any] func(*StateMachine[D], M) error
@@ -67,8 +85,8 @@ func NewStateMachineSpec[D any](initialState gen.Atom, options ...Option[D]) Sta
6785
stateMessageHandlers: make(map[gen.Atom]map[string]any),
6886
stateCallHandlers: make(map[gen.Atom]map[string]any),
6987
}
70-
for _, cb := range options {
71-
cb(&spec)
88+
for _, opt := range options {
89+
opt(&spec)
7290
}
7391
return spec
7492
}
@@ -79,23 +97,23 @@ func WithData[D any](data D) Option[D] {
7997
}
8098
}
8199

82-
func WithStateMessageHandler[D any, M any](state gen.Atom, callback StateMessageHandler[D, M]) Option[D] {
83-
typeName := reflect.TypeOf((*M)(nil)).Elem().String()
100+
func WithStateMessageHandler[D any, M any](state gen.Atom, handler StateMessageHandler[D, M]) Option[D] {
101+
messageType := reflect.TypeOf((*M)(nil)).Elem().String()
84102
return func(s *StateMachineSpec[D]) {
85103
if _, exists := s.stateMessageHandlers[state]; exists == false {
86104
s.stateMessageHandlers[state] = make(map[string]any)
87105
}
88-
s.stateMessageHandlers[state][typeName] = callback
106+
s.stateMessageHandlers[state][messageType] = handler
89107
}
90108
}
91109

92-
func WithStateCallHandler[D any, M any, R any](state gen.Atom, callback StateCallHandler[D, M, R]) Option[D] {
93-
typeName := reflect.TypeOf((*M)(nil)).Elem().String()
110+
func WithStateCallHandler[D any, M any, R any](state gen.Atom, handler StateCallHandler[D, M, R]) Option[D] {
111+
messageType := reflect.TypeOf((*M)(nil)).Elem().String()
94112
return func(s *StateMachineSpec[D]) {
95113
if _, exists := s.stateCallHandlers[state]; exists == false {
96114
s.stateCallHandlers[state] = make(map[string]any)
97115
}
98-
s.stateCallHandlers[state][typeName] = callback
116+
s.stateCallHandlers[state][messageType] = handler
99117
}
100118
}
101119

@@ -146,9 +164,10 @@ func (sm *StateMachine[D]) ProcessInit(process gen.Process, args ...any) (rr err
146164
return err
147165
}
148166

149-
// set up callbacks
150167
sm.currentState = spec.initialState
168+
sm.data = spec.data
151169
sm.stateMessageHandlers = spec.stateMessageHandlers
170+
sm.stateCallHandlers = spec.stateCallHandlers
152171

153172
return nil
154173
}
@@ -212,46 +231,24 @@ func (sm *StateMachine[D]) ProcessRun() (rr error) {
212231
switch message.Type {
213232
case gen.MailboxMessageTypeRegular:
214233
// check if there is a handler for the message in the current state
215-
typeName := typeName(message)
216-
if callbackInterface, ok := sm.lookupMessageHandler(typeName); ok == true {
217-
callbackValue := reflect.ValueOf(callbackInterface)
218-
smValue := reflect.ValueOf(sm)
219-
msgValue := reflect.ValueOf(message.Message)
220-
221-
results := callbackValue.Call([]reflect.Value{smValue, msgValue})
222-
223-
if len(results) > 0 && !results[0].IsNil() {
224-
return results[0].Interface().(error)
225-
}
226-
return nil
234+
messageType := typeName(message)
235+
handler, ok := sm.lookupMessageHandler(messageType)
236+
if ok == false {
237+
return fmt.Errorf("No handler for message %s in state %s", messageType, sm.currentState)
227238
}
228-
return fmt.Errorf("Unsupported message %s for state %s", typeName, sm.currentState)
239+
return sm.invokeMessageHandler(handler, message)
229240

230241
case gen.MailboxMessageTypeRequest:
231242
var reason error
232243
var result any
233244

234-
sm.Log().Info("got request")
235-
236245
// check if there is a handler for the call in the current state
237-
typeName := typeName(message)
238-
if callbackInterface, ok := sm.lookupMessageHandler(typeName); ok == true {
239-
sm.Log().Info("found handler")
240-
241-
callbackValue := reflect.ValueOf(callbackInterface)
242-
smValue := reflect.ValueOf(sm)
243-
msgValue := reflect.ValueOf(message.Message)
244-
245-
results := callbackValue.Call([]reflect.Value{smValue, msgValue})
246-
if !results[0].IsZero() {
247-
result = results[0].Interface()
248-
}
249-
if !results[1].IsNil() {
250-
reason = results[1].Interface().(error)
251-
}
252-
} else {
253-
reason = fmt.Errorf("Unsupported call %s for state %s", typeName, sm.currentState)
246+
messageType := typeName(message)
247+
handler, ok := sm.lookupCallHandler(messageType)
248+
if ok == false {
249+
return fmt.Errorf("No handler for message %s in state %s", messageType, sm.currentState)
254250
}
251+
result, reason = sm.invokeCallHandler(handler, message)
255252

256253
if reason != nil {
257254
// if reason is "normal" and we got response - send it before termination
@@ -260,9 +257,7 @@ func (sm *StateMachine[D]) ProcessRun() (rr error) {
260257
}
261258
return reason
262259
}
263-
264260
// Note: we do not support async handling of sync request at the moment
265-
266261
sm.SendResponse(message.From, message.Ref, result)
267262

268263
case gen.MailboxMessageTypeEvent:
@@ -343,11 +338,45 @@ func (sm *StateMachine[D]) lookupMessageHandler(messageType string) (any, bool)
343338
return nil, false
344339
}
345340

341+
func (sm *StateMachine[D]) invokeMessageHandler(handler any, message *gen.MailboxMessage) error {
342+
callbackValue := reflect.ValueOf(handler)
343+
smValue := reflect.ValueOf(sm)
344+
msgValue := reflect.ValueOf(message.Message)
345+
346+
results := callbackValue.Call([]reflect.Value{smValue, msgValue})
347+
348+
if len(results) > 0 && !results[0].IsNil() {
349+
return results[0].Interface().(error)
350+
}
351+
return nil
352+
}
353+
346354
func (sm *StateMachine[D]) lookupCallHandler(messageType string) (any, bool) {
347-
if stateCallHandlers, exists := sm.stateMessageHandlers[sm.currentState]; exists == true {
355+
if stateCallHandlers, exists := sm.stateCallHandlers[sm.currentState]; exists == true {
348356
if callback, exists := stateCallHandlers[messageType]; exists == true {
349357
return callback, true
350358
}
351359
}
352360
return nil, false
353361
}
362+
363+
func (sm *StateMachine[D]) invokeCallHandler(handler any, message *gen.MailboxMessage) (any, error) {
364+
callbackValue := reflect.ValueOf(handler)
365+
smValue := reflect.ValueOf(sm)
366+
msgValue := reflect.ValueOf(message.Message)
367+
368+
results := callbackValue.Call([]reflect.Value{smValue, msgValue})
369+
370+
if len(results) != 2 {
371+
sm.Log().Panic("StateMachine terminated. Panic reason: unexpected "+
372+
"error when invoking call handler for %v", typeName(message))
373+
}
374+
375+
if !results[1].IsNil() {
376+
err := results[1].Interface().(error)
377+
return nil, err
378+
}
379+
380+
result := results[0].Interface()
381+
return result, nil
382+
}

tests/001_local/t018_statemachine_test.go

+17-11
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,13 @@ type t18transitionState2toState1 struct {
6565

6666
func (sm *t18statemachine) Init(args ...any) (act.StateMachineSpec[t18data], error) {
6767
spec := act.NewStateMachineSpec(gen.Atom("state1"),
68+
// initial data
6869
act.WithData(t18data{count: 1}),
70+
71+
// set up a message handler for the transition state1 -> state2
6972
act.WithStateMessageHandler(gen.Atom("state1"), state1to2),
73+
74+
// set up a call handler for the transition state2 -> state1
7075
act.WithStateCallHandler(gen.Atom("state2"), state2to1),
7176
)
7277

@@ -110,17 +115,18 @@ func (t *t18) TestStateMachine(input any) {
110115
return
111116
}
112117

113-
// send call to transition from result 2 to 1 (not working yet)
114-
// result, err := t.Call(pid, t18transitionState2toState1{})
115-
// if err != nil {
116-
// t.Log().Error("call to the statemachine process failed: %s", err)
117-
// t.testcase.err <- err
118-
// return
119-
// }
120-
// if result != 3 {
121-
// t.testcase.err <- fmt.Errorf("expected 3, got %v", result)
122-
// return
123-
// }
118+
// send call to transition from result 2 to 1
119+
result, err := t.Call(pid, t18transitionState2toState1{})
120+
if err != nil {
121+
t.Log().Error("call to the statemachine process failed: %s", err)
122+
t.testcase.err <- err
123+
return
124+
}
125+
// initial count was 1, after 2 state transitions we expect the count to be 3
126+
if result != 3 {
127+
t.testcase.err <- fmt.Errorf("expected 3, got %v", result)
128+
return
129+
}
124130

125131
// statemachine process should crash on invalid state transition
126132
err = t.testcase.expectProcessToTerminate(pid, t, func(p gen.Process) error {

0 commit comments

Comments
 (0)