@@ -78,7 +78,10 @@ type StateMachine[D any] struct {
78
78
stateEnterCallback StateEnterCallback [D ]
79
79
80
80
// Pointer to the most recently configured state timeout.
81
- activeStateTimeout * ActiveStateTimeout
81
+ stateTimeout * ActiveStateTimeout
82
+
83
+ // Pointer to the most recently configured message timeout.
84
+ messageTimeout * ActiveMessageTimeout
82
85
83
86
// genericTimeouts maps the name of a generic timeout to the timeout
84
87
genericTimeouts map [gen.Atom ]* ActiveGenericTimeout
@@ -116,6 +119,19 @@ type ActiveGenericTimeout struct {
116
119
cancel context.CancelFunc
117
120
}
118
121
122
+ type MessageTimeout struct {
123
+ Duration time.Duration
124
+ Message any
125
+ }
126
+
127
+ func (MessageTimeout ) isAction () {}
128
+
129
+ type ActiveMessageTimeout struct {
130
+ timeout MessageTimeout
131
+ ctx context.Context
132
+ cancel context.CancelFunc
133
+ }
134
+
119
135
// Type alias for MessageHandler callbacks.
120
136
// D is the type of the data associated with the StateMachine.
121
137
// M is the type of the message this handler accepts.
@@ -159,6 +175,7 @@ func NewStateMachineSpec[D any](initialState gen.Atom, options ...Option[D]) Sta
159
175
}
160
176
return spec
161
177
}
178
+
162
179
func WithData [D any ](data D ) Option [D ] {
163
180
return func (s * StateMachineSpec [D ]) {
164
181
s .data = data
@@ -211,9 +228,9 @@ func (s *StateMachine[D]) SetCurrentState(state gen.Atom) {
211
228
// just registered this timeout in `ProcessActions` and we should not
212
229
// touch it. Otherwise we should cancel the active state timeout if there
213
230
// is one.
214
- if s .hasActiveStateTimeout () && s .activeStateTimeout .state != state {
231
+ if s .hasActiveStateTimeout () && s .stateTimeout .state != state {
215
232
s .Log ().Info ("StateMachine: canceling state timeout for state %s" , state )
216
- s .activeStateTimeout .cancel ()
233
+ s .stateTimeout .cancel ()
217
234
}
218
235
// Execute state enter callback until no new transition is triggered.
219
236
if s .stateEnterCallback != nil {
@@ -236,7 +253,11 @@ func (s *StateMachine[D]) SetData(data D) {
236
253
}
237
254
238
255
func (s * StateMachine [D ]) hasActiveStateTimeout () bool {
239
- return s .activeStateTimeout != nil && s .activeStateTimeout .ctx .Err () == nil
256
+ return s .stateTimeout != nil && s .stateTimeout .ctx .Err () == nil
257
+ }
258
+
259
+ func (s * StateMachine [D ]) hasActiveMessageTimeout () bool {
260
+ return s .messageTimeout != nil && s .messageTimeout .ctx .Err () == nil
240
261
}
241
262
242
263
func (s * StateMachine [D ]) hasActiveGenericTimeout (name gen.Atom ) bool {
@@ -353,6 +374,11 @@ func (s *StateMachine[D]) ProcessRun() (rr error) {
353
374
return nil
354
375
}
355
376
377
+ // Any message should cancel the active message timeout
378
+ if s .hasActiveMessageTimeout () {
379
+ s .messageTimeout .cancel ()
380
+ }
381
+
356
382
switch message .Type {
357
383
case gen .MailboxMessageTypeRegular :
358
384
switch message .Message .(type ) {
@@ -471,10 +497,10 @@ func (s *StateMachine[D]) ProcessActions(actions []Action, state gen.Atom) {
471
497
switch action := action .(type ) {
472
498
case StateTimeout :
473
499
if s .hasActiveStateTimeout () {
474
- s .activeStateTimeout .cancel ()
500
+ s .stateTimeout .cancel ()
475
501
}
476
502
ctx , cancel := context .WithTimeout (context .Background (), action .Duration )
477
- s .activeStateTimeout = & ActiveStateTimeout {
503
+ s .stateTimeout = & ActiveStateTimeout {
478
504
state : state ,
479
505
timeout : action ,
480
506
ctx : ctx ,
@@ -492,6 +518,14 @@ func (s *StateMachine[D]) ProcessActions(actions []Action, state gen.Atom) {
492
518
cancel : cancel ,
493
519
}
494
520
go startGenericTimeout (ctx , action .Name , action .Message , s )
521
+ case MessageTimeout :
522
+ ctx , cancel := context .WithTimeout (context .Background (), action .Duration )
523
+ s .messageTimeout = & ActiveMessageTimeout {
524
+ timeout : action ,
525
+ ctx : ctx ,
526
+ cancel : cancel ,
527
+ }
528
+ go startMessageTimeout (ctx , action .Message , s )
495
529
default :
496
530
panic ("unsupported action" )
497
531
}
@@ -528,6 +562,21 @@ func startGenericTimeout(ctx context.Context, name gen.Atom, message any, proc g
528
562
}
529
563
}
530
564
565
+ func startMessageTimeout (ctx context.Context , message any , proc gen.Process ) {
566
+ select {
567
+ case <- ctx .Done ():
568
+ switch ctx .Err () {
569
+ case context .DeadlineExceeded :
570
+ proc .Log ().Info ("StateMachine: message timeout timed out" )
571
+ proc .Send (proc .PID (), message )
572
+ return
573
+ case context .Canceled :
574
+ proc .Log ().Info ("StateMachine: message timeout canceled" )
575
+ return
576
+ }
577
+ }
578
+ }
579
+
531
580
func (s * StateMachine [D ]) lookupMessageHandler (messageType string ) (any , bool ) {
532
581
if stateMessageHandlers , exists := s .stateMessageHandlers [s .currentState ]; exists == true {
533
582
if callback , exists := stateMessageHandlers [messageType ]; exists == true {
0 commit comments