-
Notifications
You must be signed in to change notification settings - Fork 279
/
Copy pathnesting.py
62 lines (49 loc) · 2.08 KB
/
nesting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import asyncio
import sys
import traceback
from typing import Literal, TypeAlias
from pydantic import BaseModel
from beeai_framework.errors import FrameworkError
from beeai_framework.workflows import Workflow, WorkflowReservedStepName
WorkflowStep: TypeAlias = Literal["pre_process", "add_loop", "post_process"]
async def main() -> None:
# State
class State(BaseModel):
x: int
y: int
abs_repetitions: int | None = None
result: int | None = None
def pre_process(state: State) -> WorkflowStep:
print("pre_process")
state.abs_repetitions = abs(state.y)
return "add_loop"
def add_loop(state: State) -> WorkflowStep | WorkflowReservedStepName:
if state.abs_repetitions and state.abs_repetitions > 0:
result = (state.result if state.result is not None else 0) + state.x
abs_repetitions = (state.abs_repetitions if state.abs_repetitions is not None else 0) - 1
print(f"add_loop: intermediate result {result}")
state.abs_repetitions = abs_repetitions
state.result = result
return Workflow.SELF
else:
return "post_process"
def post_process(state: State) -> WorkflowReservedStepName:
print("post_process")
if state.y < 0:
result = -(state.result if state.result is not None else 0)
state.result = result
return Workflow.END
multiplication_workflow = Workflow[State, WorkflowStep](name="MultiplicationWorkflow", schema=State)
multiplication_workflow.add_step("pre_process", pre_process)
multiplication_workflow.add_step("add_loop", add_loop)
multiplication_workflow.add_step("post_process", post_process)
response = await multiplication_workflow.run(State(x=8, y=5))
print(f"result: {response.state.result}")
response = await multiplication_workflow.run(State(x=8, y=-5))
print(f"result: {response.state.result}")
if __name__ == "__main__":
try:
asyncio.run(main())
except FrameworkError as e:
traceback.print_exc()
sys.exit(e.explain())