|
8 | 8 | from .error import EffectError |
9 | 9 |
|
10 | 10 |
|
11 | | -_TInner = TypeVar("_TInner") |
12 | | -_TOuter = TypeVar("_TOuter") |
| 11 | +_T = TypeVar("_T") # for value type |
| 12 | +_M = TypeVar("_M") # for monadic type |
13 | 13 | _P = ParamSpec("_P") |
14 | 14 |
|
15 | 15 |
|
16 | | -class Builder(Generic[_TInner, _TOuter], ABC): |
| 16 | +class BuilderState(Generic[_T]): |
| 17 | + """Encapsulates the state of a builder computation.""" |
| 18 | + |
| 19 | + def __init__(self): |
| 20 | + self.is_done = False |
| 21 | + |
| 22 | + |
| 23 | +class Builder(Generic[_T, _M], ABC): # Corrected Generic definition |
17 | 24 | """Effect builder.""" |
18 | 25 |
|
19 | | - def bind(self, xs: _TOuter, fn: Callable[[Any], _TOuter]) -> _TOuter: |
20 | | - raise NotImplementedError("Builder does not implement a bind method") |
| 26 | + # Required methods |
| 27 | + def bind(self, xs: _M, fn: Callable[[_T], _M]) -> _M: # Use concrete types for Callable input and output |
| 28 | + raise NotImplementedError("Builder does not implement a `bind` method") |
21 | 29 |
|
22 | | - def return_(self, x: _TInner) -> _TOuter: |
23 | | - raise NotImplementedError("Builder does not implement a return method") |
| 30 | + def return_(self, x: _T) -> _M: |
| 31 | + raise NotImplementedError("Builder does not implement a `return` method") |
24 | 32 |
|
25 | | - def return_from(self, xs: _TOuter) -> _TOuter: |
26 | | - raise NotImplementedError("Builder does not implement a return from method") |
| 33 | + def return_from(self, xs: _M) -> _M: |
| 34 | + raise NotImplementedError("Builder does not implement a `return` from method") |
27 | 35 |
|
28 | | - def combine(self, xs: _TOuter, ys: _TOuter) -> _TOuter: |
| 36 | + def combine(self, xs: _M, ys: _M) -> _M: |
29 | 37 | """Used for combining multiple statements in the effect.""" |
30 | | - raise NotImplementedError("Builder does not implement a combine method") |
| 38 | + raise NotImplementedError("Builder does not implement a `combine` method") |
31 | 39 |
|
32 | | - def zero(self) -> _TOuter: |
| 40 | + def zero(self) -> _M: |
33 | 41 | """Zero effect. |
34 | 42 |
|
35 | 43 | Called if the effect raises StopIteration without a value, i.e |
36 | 44 | returns None. |
37 | 45 | """ |
38 | | - raise NotImplementedError("Builder does not implement a zero method") |
| 46 | + raise NotImplementedError("Builder does not implement a `zero` method") |
39 | 47 |
|
40 | | - def delay(self, fn: Callable[[], _TOuter]) -> _TOuter: |
| 48 | + # Optional methods for control flow |
| 49 | + def delay(self, fn: Callable[[], _M]) -> _M: |
41 | 50 | """Delay the computation. |
42 | 51 |
|
43 | | - In F# computation expressions, delay wraps the entire computation to ensure |
44 | | - it is not evaluated until run. This enables proper sequencing of effects |
45 | | - and lazy evaluation. |
46 | | -
|
47 | | - Args: |
48 | | - fn: The computation to delay |
49 | | -
|
50 | | - Returns: |
51 | | - The delayed computation |
| 52 | + Default implementation is to return the result of the function. |
52 | 53 | """ |
53 | 54 | return fn() |
54 | 55 |
|
55 | | - def run(self, computation: _TOuter) -> _TOuter: |
| 56 | + def run(self, computation: _M) -> _M: |
56 | 57 | """Run a computation. |
57 | 58 |
|
58 | | - Forces evaluation of a delayed computation. In F# computation expressions, |
59 | | - run is called at the end to evaluate the entire computation that was |
60 | | - wrapped in delay. |
61 | | -
|
62 | | - Args: |
63 | | - computation: The computation to run |
64 | | -
|
65 | | - Returns: |
66 | | - The evaluated result |
| 59 | + Default implementation is to return the computation as is. |
67 | 60 | """ |
68 | 61 | return computation |
69 | 62 |
|
| 63 | + # Internal implementation |
70 | 64 | def _send( |
71 | 65 | self, |
72 | 66 | gen: Generator[Any, Any, Any], |
73 | | - done: list[bool], |
74 | | - value: _TInner | None = None, |
75 | | - ) -> _TOuter: |
| 67 | + state: BuilderState[_T], # Use BuilderState |
| 68 | + value: _T, |
| 69 | + ) -> _M: |
76 | 70 | try: |
77 | 71 | yielded = gen.send(value) |
78 | 72 | return self.return_(yielded) |
79 | 73 | except EffectError as error: |
80 | | - # Effect errors (Nothing, Error, etc) short circuits the processing so we |
81 | | - # set `done` to `True` here. |
82 | | - done.append(True) |
83 | | - # get value from exception |
84 | | - value = error.args[0] |
85 | | - return self.return_from(cast("_TOuter", value)) |
| 74 | + # Effect errors (Nothing, Error, etc) short circuits |
| 75 | + state.is_done = True |
| 76 | + return self.return_from(cast("_M", error.args[0])) |
86 | 77 | except StopIteration as ex: |
87 | | - done.append(True) |
| 78 | + state.is_done = True |
| 79 | + |
88 | 80 | # Return of a value in the generator produces StopIteration with a value |
89 | 81 | if ex.value is not None: |
90 | 82 | return self.return_(ex.value) |
91 | | - raise |
| 83 | + |
| 84 | + raise # Raise StopIteration with no value |
| 85 | + |
92 | 86 | except RuntimeError: |
93 | | - done.append(True) |
94 | | - raise StopIteration |
| 87 | + state.is_done = True |
| 88 | + return self.zero() # Return zero() to handle generator runtime errors instead of raising StopIteration |
95 | 89 |
|
96 | 90 | def __call__( |
97 | 91 | self, |
98 | 92 | fn: Callable[ |
99 | 93 | _P, |
100 | | - Generator[_TInner | None, _TInner, _TInner | None] | Generator[_TInner | None, None, _TInner | None], |
| 94 | + Generator[_T | None, _T, _T | None] | Generator[_T | None, None, _T | None], |
101 | 95 | ], |
102 | | - ) -> Callable[_P, _TOuter]: |
103 | | - """Option builder. |
104 | | -
|
105 | | - Enables the use of computational expressions using coroutines. |
106 | | - Thus inside the coroutine the keywords `yield` and `yield from` |
107 | | - reassembles `yield` and `yield!` from F#. |
108 | | -
|
109 | | - Args: |
110 | | - fn: A function that contains a computational expression and |
111 | | - returns either a coroutine, generator or an option. |
112 | | -
|
113 | | - Returns: |
114 | | - A `builder` function that can wrap coroutines into builders. |
115 | | - """ |
| 96 | + ) -> Callable[_P, _M]: |
| 97 | + """The builder decorator.""" |
116 | 98 |
|
117 | 99 | @wraps(fn) |
118 | | - def wrapper(*args: _P.args, **kw: _P.kwargs) -> _TOuter: |
| 100 | + def wrapper(*args: _P.args, **kw: _P.kwargs) -> _M: |
119 | 101 | gen = fn(*args, **kw) |
120 | | - done: list[bool] = [] |
121 | | - |
122 | | - result: _TOuter | None = None |
| 102 | + state = BuilderState[_T]() # Initialize BuilderState |
| 103 | + result: _M = self.zero() # Initialize result |
| 104 | + value: _M |
123 | 105 |
|
124 | | - def binder(value: Any) -> _TOuter: |
125 | | - ret = self._send(gen, done, value) |
126 | | - |
127 | | - # Delay every result except the first |
128 | | - if result is not None: |
129 | | - return self.delay(lambda: ret) |
130 | | - return ret |
| 106 | + def binder(value: Any) -> _M: |
| 107 | + ret = self._send(gen, state, value) # Pass state to _send |
| 108 | + return self.delay(lambda: ret) # Delay every bind call |
131 | 109 |
|
132 | 110 | try: |
133 | | - result = self._send(gen, done) |
| 111 | + # Initialize co-routine with None to start the generator and get the |
| 112 | + # first value |
| 113 | + result = value = binder(None) |
134 | 114 |
|
135 | | - while not done: |
136 | | - cont = self.bind(result, binder) |
| 115 | + while not state.is_done: # Loop until coroutine is exhausted |
| 116 | + value: _M = self.bind(value, binder) # Send value to coroutine |
| 117 | + result = self.combine(result, value) # Combine previous result with new value |
137 | 118 |
|
138 | | - # Combine every result except the first |
139 | | - if result is None: |
140 | | - result = cont |
141 | | - else: |
142 | | - result = self.combine(result, cont) |
143 | 119 | except StopIteration: |
| 120 | + # This will happens if the generator exits by returning None |
144 | 121 | pass |
145 | 122 |
|
146 | | - # If anything returns `None` (i.e raises StopIteration without a value) then |
147 | | - # we expect the effect to have a zero method implemented. |
148 | | - if result is None: |
149 | | - result = self.zero() |
150 | | - |
151 | | - # Run the computation at the end |
152 | | - return self.run(result) |
| 123 | + return self.run(result) # Run the result |
153 | 124 |
|
154 | 125 | return wrapper |
0 commit comments