@@ -321,12 +321,20 @@ func (d *dummyStateSpent) IsTerminal() bool {
321
321
return true
322
322
}
323
323
324
- func assertState [Event any , Env Environment ](t * testing.T ,
325
- m * StateMachine [Event , Env ], expectedState State [Event , Env ]) {
324
+ // assertState asserts that the state machine is currently in the expected
325
+ // state type and returns the state cast to that type.
326
+ func assertState [Event any , Env Environment , S State [Event , Env ]](t * testing.T ,
327
+ m * StateMachine [Event , Env ], expectedState S ) S {
326
328
327
329
state , err := m .CurrentState ()
328
330
require .NoError (t , err )
329
331
require .IsType (t , expectedState , state )
332
+
333
+ // Perform the type assertion to return the concrete type.
334
+ concreteState , ok := state .(S )
335
+ require .True (t , ok , "state type assertion failed" )
336
+
337
+ return concreteState
330
338
}
331
339
332
340
func assertStateTransitions [Event any , Env Environment ](
@@ -626,18 +634,15 @@ func TestStateMachineConfMapper(t *testing.T) {
626
634
assertStateTransitions (t , stateSub , expectedStates )
627
635
628
636
// Final state assertion.
629
- finalState , err := stateMachine .CurrentState ()
630
- require .NoError (t , err )
631
- require .IsType (t , & dummyStateConfirmed {}, finalState )
637
+ finalState := assertState (t , & stateMachine , & dummyStateConfirmed {})
632
638
633
639
// Assert that the details from the confirmation event were correctly
634
640
// propagated to the final state.
635
- finalStateDetails := finalState .(* dummyStateConfirmed )
636
641
require .Equal (t ,
637
- * simulatedConf .BlockHash , finalStateDetails .blockHash ,
642
+ * simulatedConf .BlockHash , finalState .blockHash ,
638
643
)
639
644
require .Equal (t ,
640
- simulatedConf .BlockHeight , finalStateDetails .blockHeight ,
645
+ simulatedConf .BlockHeight , finalState .blockHeight ,
641
646
)
642
647
643
648
adapters .AssertExpectations (t )
@@ -706,18 +711,15 @@ func TestStateMachineSpendMapper(t *testing.T) {
706
711
assertStateTransitions (t , stateSub , expectedStates )
707
712
708
713
// Final state assertion.
709
- finalState , err := stateMachine .CurrentState ()
710
- require .NoError (t , err )
711
- require .IsType (t , & dummyStateSpent {}, finalState )
714
+ finalState := assertState (t , & stateMachine , & dummyStateSpent {})
712
715
713
716
// Assert that the details from the spend event were correctly
714
717
// propagated to the final state.
715
- finalStateDetails := finalState .(* dummyStateSpent )
716
718
require .Equal (t ,
717
- * simulatedSpend .SpenderTxHash , finalStateDetails .spenderTxHash ,
719
+ * simulatedSpend .SpenderTxHash , finalState .spenderTxHash ,
718
720
)
719
721
require .Equal (t ,
720
- simulatedSpend .SpendingHeight , finalStateDetails .spendingHeight ,
722
+ simulatedSpend .SpendingHeight , finalState .spendingHeight ,
721
723
)
722
724
723
725
adapters .AssertExpectations (t )
0 commit comments