1
+ //
2
+ // Created by konstantin on 27.07.24.
3
+ //
4
+
5
+ #include < any>
6
+ #include < cassert>
7
+ #include < concepts>
8
+ #include < coroutine>
9
+ #include < iostream>
10
+ #include < unordered_map>
11
+
12
+ #include " generator.h"
13
+
14
+ struct Resumable final {
15
+ struct promise_type {
16
+ using coro_handle = std::coroutine_handle<promise_type>;
17
+
18
+ size_t current_value{};
19
+
20
+ std::suspend_always yield_value (const size_t pos) {
21
+ current_value = pos;
22
+ return {};
23
+ }
24
+
25
+ auto get_return_object () { return coro_handle ::from_promise (*this ); }
26
+
27
+ std::suspend_always initial_suspend () { return {}; }
28
+
29
+ std::suspend_always final_suspend () noexcept { return {}; }
30
+
31
+ auto unhandled_exception () { std::terminate (); }
32
+ };
33
+
34
+ Resumable (promise_type::coro_handle handle) : handle_(handle) {}
35
+
36
+ ~Resumable () {
37
+ if (handle_) handle_.destroy ();
38
+ }
39
+
40
+ promise_type::coro_handle handle () {
41
+ promise_type::coro_handle h = handle_;
42
+ handle_ = nullptr ;
43
+ return h;
44
+ }
45
+
46
+ bool resume () {
47
+ if (handle_) handle_.resume ();
48
+ return !handle_.done ();
49
+ }
50
+
51
+ private:
52
+ promise_type::coro_handle handle_;
53
+ };
54
+
55
+ using coro_t = std::coroutine_handle<>;
56
+
57
+ enum class Sym : char { A, B, Term };
58
+ enum class State { A, B };
59
+
60
+ template <class State , class Sym >
61
+ class StateMachine ;
62
+
63
+ using stm_t = StateMachine<State, Sym>;
64
+
65
+ template <class F >
66
+ concept CanInvokeWithStm = requires(F f, stm_t & stm) {
67
+ { f (stm) } -> std::same_as<Resumable>;
68
+ };
69
+
70
+ Generator<Sym> input_seq (std::string seq) {
71
+ for (char c : seq) {
72
+ switch (c) {
73
+ case ' a' :
74
+ co_yield Sym::A;
75
+ break ;
76
+ case ' b' :
77
+ co_yield Sym::B;
78
+ break ;
79
+ default :
80
+ co_yield Sym::Term;
81
+ break ;
82
+ }
83
+ }
84
+ for (;;) {
85
+ co_yield Sym::Term;
86
+ }
87
+ }
88
+
89
+ template <typename TableTransition, typename SM>
90
+ struct stm_awaiter : public TableTransition {
91
+ SM& stm;
92
+ stm_awaiter (TableTransition transition, SM& stm)
93
+ : TableTransition(transition), stm(stm) {}
94
+
95
+ bool await_ready () const noexcept { return false ; }
96
+ coro_t await_suspend (std::coroutine_handle<>) noexcept {
97
+ stm.gennext ();
98
+ auto sym = stm.genval ();
99
+ auto new_state = TableTransition::operator ()(sym);
100
+ return stm[new_state];
101
+ }
102
+ [[nodiscard]] bool await_resume () const noexcept {
103
+ return stm.genval () == Sym::Term;
104
+ }
105
+ };
106
+
107
+ template <class State , class Sym >
108
+ class StateMachine final {
109
+ public:
110
+ StateMachine (Generator<Sym> gen) : gen(std::move(gen)) {}
111
+
112
+ coro_t operator [](State s) { return states[s]; }
113
+
114
+ template <typename F>
115
+ auto get_awaiter (F transition) {
116
+ return stm_awaiter (transition, *this );
117
+ }
118
+
119
+ template <CanInvokeWithStm F>
120
+ void add_state (State state, F stf) {
121
+ states[state] = stf (*this ).handle ();
122
+ }
123
+
124
+ void run (State initial) {
125
+ current_state = initial;
126
+ states[current_state].resume ();
127
+ }
128
+
129
+ Sym genval () const { return gen.current_value (); }
130
+
131
+ void gennext () { gen.move_next (); }
132
+
133
+ State current () const { return current_state; }
134
+
135
+ private:
136
+ State current_state{};
137
+ std::unordered_map<State, coro_t > states{};
138
+ Generator<Sym> gen;
139
+ };
140
+
141
+ Resumable StateA (stm_t & stm) {
142
+ auto transmission = [](auto sym) {
143
+ if (sym == Sym::B) {
144
+ return State::B;
145
+ }
146
+ return State::A;
147
+ };
148
+ for (;;) {
149
+ std::cout << " State A" << std::endl;
150
+ bool finish = co_await stm.get_awaiter (transmission);
151
+ if (finish) {
152
+ break ;
153
+ }
154
+ }
155
+ }
156
+
157
+ Resumable StateB (stm_t & stm) {
158
+ auto transmission = [](auto sym) {
159
+ if (sym == Sym::A) {
160
+ return State::A;
161
+ }
162
+ return State::B;
163
+ };
164
+ for (;;) {
165
+ std::cout << " State B" << std::endl;
166
+ bool finish = co_await stm.get_awaiter (transmission);
167
+ if (finish) {
168
+ break ;
169
+ }
170
+ }
171
+ }
172
+
173
+ int main () {
174
+ auto gen = input_seq (" aaabbaba" );
175
+ stm_t stm (gen);
176
+ stm.add_state (State::A, StateA);
177
+ stm.add_state (State::B, StateB);
178
+
179
+ stm.run (State::A);
180
+
181
+ auto curr = stm.current ();
182
+ assert (curr == State::A);
183
+ }
0 commit comments