Skip to content

Commit 578b1c9

Browse files
committed
Add onEnter and onExit events to states
1 parent d8c86b9 commit 578b1c9

File tree

4 files changed

+139
-22
lines changed

4 files changed

+139
-22
lines changed

Swift/Sources/StateMachine/StateMachine.swift

+57-8
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,33 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
5454
private let states: States
5555
private var observers: [Observer] = []
5656

57+
private typealias EnterExitAction = (State) throws -> Void
58+
59+
private var onEnterActions: [State.HashableIdentifier: EnterExitAction]
60+
private var onExitActions: [State.HashableIdentifier: EnterExitAction]
61+
5762
private var isNotifying: Bool = false
5863

5964
public init(@DefinitionBuilder build: () -> Definition) {
6065
let definition: Definition = build()
6166
state = definition.initialState.state
62-
states = definition.states.reduce(into: States()) {
63-
$0[$1.state] = $1.events.reduce(into: Events()) {
64-
$0[$1.event] = $1.action
67+
var enterActions: [State.HashableIdentifier: EnterExitAction] = [:]
68+
var exitActions: [State.HashableIdentifier: EnterExitAction] = [:]
69+
states = definition.states.reduce(into: States()) { result, tuple in
70+
let (state, events) = tuple
71+
result[state] = events.reduce(into: Events()) {
72+
switch $1.eventType {
73+
case .onEnter(let action):
74+
enterActions[state] = action
75+
case .onExit(let action):
76+
exitActions[state] = action
77+
case .normal(let event, let action):
78+
$0[event] = action
79+
}
6580
}
6681
}
82+
onEnterActions = enterActions
83+
onExitActions = exitActions
6784
observers = definition.callbacks.map {
6885
Observer(object: self, callback: $0)
6986
}
@@ -104,10 +121,18 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
104121
event: event,
105122
toState: action.toState ?? state,
106123
sideEffects: action.sideEffects)
124+
let fromState = state
107125
if let toState: State = action.toState {
108126
state = toState
109127
}
128+
110129
result = .success(transition)
130+
131+
// if not `dontTransition`
132+
if action.toState != nil {
133+
try? onExitActions[stateIdentifier]?(fromState)
134+
try? onEnterActions[state.hashableIdentifier]?(state)
135+
}
111136
} else {
112137
result = .failure(Transition.Invalid())
113138
}
@@ -172,25 +197,41 @@ extension StateMachineBuilder {
172197
.state(state: state, events: build())
173198
}
174199

200+
public static func onEnter(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
201+
[EventHandler(eventType: .onEnter(perform))]
202+
}
203+
204+
public static func onExit(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
205+
[EventHandler(eventType: .onExit(perform))]
206+
}
207+
208+
public static func onEnter(_ perform: @escaping () throws -> Void) -> [EventHandler] {
209+
[EventHandler(eventType: .onEnter({ _ in try perform() }))]
210+
}
211+
212+
public static func onExit(_ perform: @escaping () throws -> Void) -> [EventHandler] {
213+
[EventHandler(eventType: .onExit({ _ in try perform() }))]
214+
}
215+
175216
public static func on(
176217
_ event: Event.HashableIdentifier,
177218
perform: @escaping (State, Event) throws -> Action
178219
) -> [EventHandler] {
179-
[EventHandler(event: event, action: perform)]
220+
[EventHandler(eventType: .normal(event, perform))]
180221
}
181222

182223
public static func on(
183224
_ event: Event.HashableIdentifier,
184225
perform: @escaping (State) throws -> Action
185226
) -> [EventHandler] {
186-
[EventHandler(event: event) { state, _ in try perform(state) }]
227+
[EventHandler(eventType: .normal(event, { state, _ in try perform(state) }))]
187228
}
188229

189230
public static func on(
190231
_ event: Event.HashableIdentifier,
191232
perform: @escaping () throws -> Action
192233
) -> [EventHandler] {
193-
[EventHandler(event: event) { _, _ in try perform() }]
234+
[EventHandler(eventType: .normal(event, { _, _ in try perform() }))]
194235
}
195236

196237
public static func transition(
@@ -277,8 +318,16 @@ public enum StateMachineTypes {
277318

278319
public struct EventHandler<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {
279320

280-
fileprivate let event: Event.HashableIdentifier
281-
fileprivate let action: Action<State, Event, SideEffect>.Factory
321+
fileprivate var eventType: EventType<State, Event, SideEffect>
322+
323+
fileprivate enum EventType<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {
324+
325+
fileprivate typealias EnterExitAction = (State) throws -> Void
326+
327+
case normal(Event.HashableIdentifier, Action<State, Event, SideEffect>.Factory)
328+
case onEnter(EnterExitAction)
329+
case onExit(EnterExitAction)
330+
}
282331
}
283332

284333
public struct Action<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

Swift/Tests/StateMachineTests/StateMachineTests.swift

+10
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,13 @@ func log(_ expectedMessages: String...) -> Predicate<Logger> {
224224
return PredicateResult(bool: actualMessages == expectedMessages, message: message)
225225
}
226226
}
227+
228+
func noLog() -> Predicate<Logger> {
229+
return Predicate {
230+
let actualMessages: [String]? = try $0.evaluate()?.messages
231+
let actualString: String = stringify(actualMessages?.joined(separator: "\\n"))
232+
let message: ExpectationMessage = .expectedCustomValueTo("no logs",
233+
actual: "<\(actualString)>")
234+
return PredicateResult(bool: actualString.count == 0, message: message)
235+
}
236+
}

Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift

+38-14
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,41 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
2828
typealias ValidTransition = MatterStateMachine.Transition.Valid
2929
typealias InvalidTransition = MatterStateMachine.Transition.Invalid
3030

31-
enum Message {
32-
33-
static let melted: String = "I melted"
34-
static let frozen: String = "I froze"
35-
static let vaporized: String = "I vaporized"
36-
static let condensed: String = "I condensed"
31+
enum Message: String {
32+
33+
case melted = "I melted"
34+
case frozen = "I froze"
35+
case vaporized = "I vaporized"
36+
case condensed = "I condensed"
37+
case enteredSolid
38+
case exitedSolid
39+
case enteredLiquid
40+
case exitedLiquid
41+
case enteredGas
42+
case exitedGas
3743
}
3844

3945
static func matterStateMachine(withInitialState _state: State, logger: Logger) -> MatterStateMachine {
4046
MatterStateMachine {
4147
initialState(_state)
4248
state(.solid) {
49+
onEnter { _ in
50+
logger.log(Message.enteredSolid.rawValue)
51+
}
52+
onExit { _ in
53+
logger.log(Message.exitedSolid.rawValue)
54+
}
4355
on(.melt) {
4456
transition(to: .liquid, emit: .logMelted)
4557
}
4658
}
4759
state(.liquid) {
60+
onEnter { _ in
61+
logger.log(Message.enteredLiquid.rawValue)
62+
}
63+
onExit { _ in
64+
logger.log(Message.exitedLiquid.rawValue)
65+
}
4866
on(.freeze) {
4967
transition(to: .solid, emit: .logFrozen)
5068
}
@@ -53,6 +71,12 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
5371
}
5472
}
5573
state(.gas) {
74+
onEnter { _ in
75+
logger.log(Message.enteredGas.rawValue)
76+
}
77+
onExit { _ in
78+
logger.log(Message.exitedGas.rawValue)
79+
}
5680
on(.condense) {
5781
transition(to: .liquid, emit: .logCondensed)
5882
}
@@ -61,10 +85,10 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
6185
guard case let .success(transition) = $0 else { return }
6286
transition.sideEffects.forEach { sideEffect in
6387
switch sideEffect {
64-
case .logMelted: logger.log(Message.melted)
65-
case .logFrozen: logger.log(Message.frozen)
66-
case .logVaporized: logger.log(Message.vaporized)
67-
case .logCondensed: logger.log(Message.condensed)
88+
case .logMelted: logger.log(Message.melted.rawValue)
89+
case .logFrozen: logger.log(Message.frozen.rawValue)
90+
case .logVaporized: logger.log(Message.vaporized.rawValue)
91+
case .logCondensed: logger.log(Message.condensed.rawValue)
6892
}
6993
}
7094
}
@@ -103,7 +127,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
103127
event: .melt,
104128
toState: .liquid,
105129
sideEffects: [.logMelted])))
106-
expect(self.logger).to(log(Message.melted))
130+
expect(self.logger).to(log(Message.exitedSolid.rawValue, Message.enteredLiquid.rawValue, Message.melted.rawValue))
107131
}
108132

109133
func test_givenStateIsSolid_whenFrozen_shouldThrowInvalidTransitionError() throws {
@@ -136,7 +160,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
136160
event: .freeze,
137161
toState: .solid,
138162
sideEffects: [.logFrozen])))
139-
expect(self.logger).to(log(Message.frozen))
163+
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredSolid.rawValue, Message.frozen.rawValue))
140164
}
141165

142166
func test_givenStateIsLiquid_whenVaporized_shouldTransitionToGasState() throws {
@@ -153,7 +177,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
153177
event: .vaporize,
154178
toState: .gas,
155179
sideEffects: [.logVaporized])))
156-
expect(self.logger).to(log(Message.vaporized))
180+
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredGas.rawValue, Message.vaporized.rawValue))
157181
}
158182

159183
func test_givenStateIsGas_whenCondensed_shouldTransitionToLiquidState() throws {
@@ -170,6 +194,6 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
170194
event: .condense,
171195
toState: .liquid,
172196
sideEffects: [.logCondensed])))
173-
expect(self.logger).to(log(Message.condensed))
197+
expect(self.logger).to(log(Message.exitedGas.rawValue, Message.enteredLiquid.rawValue, Message.condensed.rawValue))
174198
}
175199
}

Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift

+34
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,25 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
3232
typealias TurnstileStateMachine = StateMachine<State, Event, SideEffect>
3333
typealias ValidTransition = TurnstileStateMachine.Transition.Valid
3434

35+
enum Message: String {
36+
case enteredLocked
37+
case exitedLocked
38+
case enteredUnlocked
39+
case exitedUnlocked
40+
case enteredBroken
41+
case exitedBroken
42+
}
43+
3544
static func turnstileStateMachine(withInitialState _state: State, logger: Logger) -> TurnstileStateMachine {
3645
TurnstileStateMachine {
3746
initialState(_state)
3847
state(.locked) {
48+
onEnter { state in
49+
logger.log("\(Message.enteredLocked.rawValue) \(try state.credit() as Int)")
50+
}
51+
onExit {
52+
logger.log(Message.exitedLocked.rawValue)
53+
}
3954
on(.insertCoin) { locked, insertCoin in
4055
let newCredit: Int = try locked.credit() + insertCoin.value()
4156
if newCredit >= Constant.farePrice {
@@ -52,11 +67,23 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
5267
}
5368
}
5469
state(.unlocked) {
70+
onEnter {
71+
logger.log(Message.enteredUnlocked.rawValue)
72+
}
73+
onExit {
74+
logger.log(Message.exitedUnlocked.rawValue)
75+
}
5576
on(.admitPerson) {
5677
transition(to: .locked(credit: 0), emit: .closeDoors)
5778
}
5879
}
5980
state(.broken) {
81+
onEnter {
82+
logger.log(Message.enteredBroken.rawValue)
83+
}
84+
onExit {
85+
logger.log(Message.exitedBroken.rawValue)
86+
}
6087
on(.machineRepairDidComplete) { broken in
6188
transition(to: try broken.oldState())
6289
}
@@ -96,6 +123,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
96123
event: .insertCoin(10),
97124
toState: .locked(credit: 10),
98125
sideEffects: [])))
126+
expect(self.logger).to(log(Message.exitedLocked.rawValue, "\(Message.enteredLocked.rawValue) 10"))
99127
}
100128

101129
func test_givenStateIsLocked_whenInsertCoin_andCreditEqualsFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
@@ -112,6 +140,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
112140
event: .insertCoin(15),
113141
toState: .unlocked,
114142
sideEffects: [.openDoors])))
143+
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
115144
}
116145

117146
func test_givenStateIsLocked_whenInsertCoin_andCreditMoreThanFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
@@ -128,6 +157,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
128157
event: .insertCoin(20),
129158
toState: .unlocked,
130159
sideEffects: [.openDoors])))
160+
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
131161
}
132162

133163
func test_givenStateIsLocked_whenAdmitPerson_shouldTransitionToLockedStateAndSoundAlarm() throws {
@@ -144,6 +174,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
144174
event: .admitPerson,
145175
toState: .locked(credit: 35),
146176
sideEffects: [.soundAlarm])))
177+
expect(self.logger).to(noLog())
147178
}
148179

149180
func test_givenStateIsLocked_whenMachineDidFail_shouldTransitionToBrokenStateAndOrderRepair() throws {
@@ -160,6 +191,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
160191
event: .machineDidFail,
161192
toState: .broken(oldState: .locked(credit: 15)),
162193
sideEffects: [.orderRepair])))
194+
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredBroken.rawValue))
163195
}
164196

165197
func test_givenStateIsUnlocked_whenAdmitPerson_shouldTransitionToLockedStateAndCloseDoors() throws {
@@ -176,6 +208,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
176208
event: .admitPerson,
177209
toState: .locked(credit: 0),
178210
sideEffects: [.closeDoors])))
211+
expect(self.logger).to(log(Message.exitedUnlocked.rawValue, "\(Message.enteredLocked.rawValue) 0"))
179212
}
180213

181214
func test_givenStateIsBroken_whenMachineRepairDidComplete_shouldTransitionToLockedState() throws {
@@ -192,6 +225,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
192225
event: .machineRepairDidComplete,
193226
toState: .locked(credit: 15),
194227
sideEffects: [])))
228+
expect(self.logger).to(log(Message.exitedBroken.rawValue, "\(Message.enteredLocked.rawValue) 15"))
195229
}
196230
}
197231

0 commit comments

Comments
 (0)