11import pickle
2+ from asyncio import Task as AsyncTask
3+ from asyncio import TaskGroup
24from collections .abc import Mapping
35
46from sqlmodel import col , select
@@ -41,7 +43,7 @@ async def consider_campaigns(session: AsyncSession) -> None:
4143 # table should be to ensure the same *node* is not added twice if
4244 # the same campaign is "considered" by multiple daemons.
4345 for node in graph .processable_graph_nodes (campaign_graph ):
44- logger .info ("Daemon considering node" , id = node .id )
46+ logger .info ("Daemon considering node" , id = str ( node .id ) )
4547 node_task = Task (
4648 namespace = campaign_id ,
4749 node = node .id ,
@@ -71,61 +73,59 @@ async def consider_nodes(session: AsyncSession) -> None:
7173
7274 # Using a TaskGroup context manager means all "tasks" added to the group
7375 # are awaited when the CM exits, giving us concurrency for all the nodes
74- # being considered in the current iteration, but complicating the wrap-up
75- # operations that involve repickling the node's machine for a future
76- # iteration.
77- # . async with TaskGroup() as tg:
78- for task in tasks :
79- node = await session .get_one (Node , task .node )
80- # We expunge the node from *this* session because it will be added to
81- # whatever session the node_machine acquires during its transition
82- session .expunge (node )
83-
84- node_machine : NodeMachine
85- node_machine_pickle : Machine | None
86- if node .machine is None :
87- # create a new machine for the node
88- node_machine = node_machine_factory (node .kind )(o = node )
89- node_machine_pickle = None
90- else :
91- # unpickle the node's machine and rehydrate the Node Stateful Model
92- node_machine_pickle = await session .get_one (Machine , node .machine )
93- node_machine = (pickle .loads (node_machine_pickle .state )).model
94- node_machine .node = node
95- # discard the pickled machine from this session and context
96- session .expunge (node_machine_pickle )
97- del node_machine_pickle
98-
99- # the task's status field is the target status for the node, so the
100- # daemon intends to evolve the node machine to that state.
101- try :
102- assert node .status is task .previous_status
103- except AssertionError :
104- logger .error ("Node status out of sync with Machine" , id = node .id )
105- continue
106-
107- # check possible triggers for state
108- # TODO how to pick the "best" trigger from multiple available?
109- # - Add a caller-backed conditional to the triggers, to identify which
110- # . triggers the daemon is "allowed" to use
111- # - Determine the "desired" trigger from the (source, dest) task state
112- # . tuple
113- if (trigger := trigger_for_transition (task , node_machine .machine .events )) is None :
114- logger .warning (
115- "No trigger available for desired state transition" ,
116- source = task .previous_status ,
117- dest = task .status ,
118- )
119- continue
120-
121- # Add the node transition trigger method to the task group
122- _ = await node_machine .trigger (trigger )
123-
124- # wrap up - the task is removed from the session
125- # TODO if the node has been transitioned to a terminal state, we
126- # should not need to keep its machine around.
127- await session .delete (task )
128- await session .commit ()
76+ # being considered in the current iteration.
77+ async with TaskGroup () as tg :
78+ for task in tasks :
79+ node = await session .get_one (Node , task .node )
80+
81+ # the task's status field is the target status for the node, so the
82+ # daemon intends to evolve the node machine to that state.
83+ try :
84+ assert node .status is task .previous_status
85+ except AssertionError :
86+ logger .error ("Node status out of sync with Machine" , id = str (node .id ))
87+ continue
88+
89+ # Expunge the node from *this* session because it will be added to
90+ # whatever session the node_machine acquires during its transition
91+ session .expunge (node )
92+
93+ node_machine : NodeMachine
94+ node_machine_pickle : Machine | None
95+ if node .machine is None :
96+ # create a new machine for the node
97+ node_machine = node_machine_factory (node .kind )(o = node )
98+ node_machine_pickle = None
99+ else :
100+ # unpickle the node's machine and rehydrate the Stateful Model
101+ node_machine_pickle = await session .get_one (Machine , node .machine )
102+ node_machine = (pickle .loads (node_machine_pickle .state )).model
103+ node_machine .node = node
104+ # discard the pickled machine from this session and context
105+ session .expunge (node_machine_pickle )
106+ del node_machine_pickle
107+
108+ # check possible triggers for state
109+ # TODO how to pick the "best" trigger from multiple available?
110+ # - Add a caller-backed conditional to the triggers, to identify
111+ # . triggers the daemon is "allowed" to use
112+ # - Determine the "desired" trigger from the task (source, dest)
113+ if (trigger := trigger_for_transition (task , node_machine .machine .events )) is None :
114+ logger .warning (
115+ "No trigger available for desired state transition" ,
116+ source = task .previous_status ,
117+ dest = task .status ,
118+ )
119+ continue
120+
121+ # Add the node transition trigger method to the task group
122+ # TODO give this a name and a callback
123+ task_ = tg .create_task (node_machine .trigger (trigger ), name = str (node .id ))
124+ task_ .add_done_callback (task_runner_callback )
125+
126+ # wrap up - the task is removed from the db.
127+ await session .delete (task )
128+ await session .commit ()
129129
130130
131131async def daemon_iteration (session : AsyncSession ) -> None :
@@ -138,6 +138,7 @@ async def daemon_iteration(session: AsyncSession) -> None:
138138 await consider_campaigns (session )
139139 if config .daemon .process_nodes :
140140 await consider_nodes (session )
141+ await session .close ()
141142
142143
143144def trigger_for_transition (task : Task , events : Mapping [str , Event ]) -> str | None :
@@ -156,3 +157,8 @@ def trigger_for_transition(task: Task, events: Mapping[str, Event]) -> str | Non
156157 ):
157158 return trigger
158159 return None
160+
161+
162+ def task_runner_callback (task : AsyncTask ):
163+ """Callback function for `asyncio.TaskGroup` tasks."""
164+ logger .info ("Transition complete" , id = task .get_name ())
0 commit comments