Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alembic/versions/1da92a1c740f_create_v2_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def upgrade() -> None:
sa.Column("node", postgresql.UUID(), nullable=False),
sa.Column("priority", postgresql.INTEGER(), nullable=True),
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False),
sa.Column("last_processed_at", postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column("submitted_at", postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column("wms_id", postgresql.VARCHAR(), nullable=True),
sa.Column("site_affinity", postgresql.ARRAY(postgresql.VARCHAR()), nullable=True),
Expand All @@ -201,8 +201,10 @@ def upgrade() -> None:
"activity_log_v2",
sa.Column("id", postgresql.UUID(), nullable=False),
sa.Column("namespace", postgresql.UUID(), nullable=False),
sa.Column("node", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=False),
sa.Column("node", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=True),
sa.Column("operator", postgresql.VARCHAR(), nullable=False, default="root"),
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False),
sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True),
sa.Column("from_status", ENUM_COLUMN_AS_VARCHAR, nullable=False),
sa.Column("to_status", ENUM_COLUMN_AS_VARCHAR, nullable=False),
sa.Column(
Expand All @@ -219,6 +221,7 @@ def upgrade() -> None:
default=dict,
server_default=sa.text("'{}'::json"),
),
sa.PrimaryKeyConstraint("id"),
if_not_exists=True,
)

Expand Down
198 changes: 198 additions & 0 deletions src/lsst/cmservice/common/daemon_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import pickle
from asyncio import Task as AsyncTask
from asyncio import TaskGroup, create_task
from collections.abc import Awaitable, Mapping
from typing import TYPE_CHECKING
from uuid import UUID, uuid5

from sqlalchemy.dialects.postgresql import insert
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from transitions import Event

from ..common import graph, timestamp
from ..common.enums import StatusEnum
from ..config import config
from ..db.campaigns_v2 import Campaign, Edge, Machine, Node, Task
from ..db.session import db_session_dependency
from ..machines.node import NodeMachine, node_machine_factory
from .logging import LOGGER

logger = LOGGER.bind(module=__name__)


async def consider_campaigns(session: AsyncSession) -> None:
"""In Phase One, the daemon considers campaigns. Campaigns subject to
consideration have a non-terminal prepared status (ready or running), and
optionally tagged with a priority value lower than the daemon's own
priority.

For any campaigns thus discovered, the daemon then constructs a graph
from the campaign's Edges, and starting at the START node, walks the graph
until a Node is found that requires attention. Each Node found is added to
the Tasks table as a queue item.
"""
c_statement = (
select(Campaign.id)
.where(col(Campaign.status).in_((StatusEnum.ready, StatusEnum.running)))
.with_for_update(key_share=True, skip_locked=True)
)
campaigns = (await session.exec(c_statement)).all()

for campaign_id in campaigns:
logger.info("Daemon considering campaign", id=campaign_id)

# Fetch the Edges for the campaign
e_statement = select(Edge).filter_by(namespace=campaign_id)
edges = (await session.exec(e_statement)).all()
campaign_graph = await graph.graph_from_edge_list_v2(edges=edges, session=session)

for node in graph.processable_graph_nodes(campaign_graph):
logger.info("Daemon considering node", id=str(node.id))
desired_state = node.status.next_status()
node_task = Task(
id=uuid5(node.id, desired_state.name),
namespace=campaign_id,
node=node.id,
status=desired_state,
previous_status=node.status,
)
statement = insert(node_task.__table__).values(**node_task.model_dump()).on_conflict_do_nothing() # type: ignore[attr-defined]
await session.exec(statement) # type: ignore[call-overload]

await session.commit()


async def consider_nodes(session: AsyncSession) -> None:
"""In Phase Two, the daemon considers Nodes. Nodes subject to consideration
are only those Nodes found on the Tasks table that have a priority lower
than the daemon's own priority, and share the daemon's site affinity.

For each node considered by the daemon, the Node's FSM is loaded from the
Machines table, or creates one if needed. The daemon uses methods on the
Node's Stateful Model to evolve the state of the Node.

After handling, the Node's FSM is serialized and the Node is updated with
new values as necessary. The Task is not returned to the Task table.
"""
# Select and lock unsubmitted tasks
statement = select(Task).where(col(Task.submitted_at).is_(None))
# TODO add filter criteria for priority and site affinity
statement = statement.with_for_update(skip_locked=True)

cm_tasks = (await session.exec(statement)).all()

# Using a TaskGroup context manager means all "tasks" added to the group
# are awaited when the CM exits, giving us concurrency for all the nodes
# being considered in the current iteration.
async with TaskGroup() as tg:
for cm_task in cm_tasks:
node = await session.get_one(Node, cm_task.node)

# the task's status field is the target status for the node, so the
# daemon intends to evolve the node machine to that state.
try:
assert node.status is cm_task.previous_status
except AssertionError:
logger.error("Node status out of sync with Machine", id=str(node.id))
continue

# Expunge the node from *this* session because it will be added to
# whatever session the node_machine acquires during its transition
session.expunge(node)

node_machine: NodeMachine
node_machine_pickle: Machine | None
if node.machine is None:
# create a new machine for the node
node_machine = node_machine_factory(node.kind)(o=node)
node_machine_pickle = None
else:
# unpickle the node's machine and rehydrate the Stateful Model
node_machine_pickle = await session.get_one(Machine, node.machine)
node_machine = (pickle.loads(node_machine_pickle.state)).model
node_machine.db_model = node
# discard the pickled machine from this session and context
session.expunge(node_machine_pickle)
del node_machine_pickle

# check possible triggers for state
# TODO how to pick the "best" trigger from multiple available?
# - Add a caller-backed conditional to the triggers, to identify
# . triggers the daemon is "allowed" to use
# - Determine the "desired" trigger from the task (source, dest)
if (trigger := trigger_for_transition(cm_task, node_machine.machine.events)) is None:
logger.warning(
"No trigger available for desired state transition",
source=cm_task.previous_status,
dest=cm_task.status,
)
continue

# Add the node transition trigger method to the task group
task = tg.create_task(node_machine.trigger(trigger), name=str(cm_task.id))
task.add_done_callback(task_runner_callback)

# wrap up - update the task and commit
cm_task.submitted_at = timestamp.now_utc()
await session.commit()


async def daemon_iteration(session: AsyncSession) -> None:
"""A single iteraton of the CM daemon's work loop, which is carried out in
two phases: Campaigns and Nodes.
"""
iteration_start = timestamp.now_utc()
logger.debug("Daemon V2 Iteration: %s", iteration_start)
if config.daemon.process_campaigns:
await consider_campaigns(session)
if config.daemon.process_nodes:
await consider_nodes(session)
await session.close()


def trigger_for_transition(task: Task, events: Mapping[str, Event]) -> str | None:
"""Determine the trigger name for transition that matches the desired state
tuple as indicated on a Task.
"""

for trigger, event in events.items():
for transition_list in event.transitions.values():
for transition in transition_list:
if all(
[
transition.source == task.previous_status.name,
transition.dest == task.status.name,
]
):
return trigger
return None


async def finalize_runner_callback(context: AsyncTask) -> None:
"""Callback function for finalizing the CM Task runner."""

# Using the task name as the ID of a task, get the object and update its
# finished_at column. Alternately, we could delete the task from the table
# now.
if TYPE_CHECKING:
assert db_session_dependency.sessionmaker is not None

logger.info("Finalizing CM Task", id=context.get_name())
async with db_session_dependency.sessionmaker.begin() as session:
cm_task = await session.get_one(Task, UUID(context.get_name()))
cm_task.finished_at = timestamp.now_utc()


def task_runner_callback(context: AsyncTask) -> None:
"""Callback function for `asyncio.TaskGroup` tasks."""
if (exc := context.exception()) is not None:
logger.error(exc)
return

logger.info("Transition complete", id=context.get_name())
callbacks: set[Awaitable] = set()
# TODO: notification callback
finalizer = create_task(finalize_runner_callback(context))
finalizer.add_done_callback(callbacks.discard)
callbacks.add(finalizer)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't something need to await callbacks?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. This is a bit of a rube goldberg sequence of async tasks that starts with the primary task that the daemon is adding to a task group, which is awaited at the end of the TaskGroup context manager. Each of these tasks has a callback coro (task_runner_callback) which is awaited when the original task completes; and that coro itself has a callback coro to be awaited when it finishes, and this last bit is where callbacks comes in. callbacks is a set collection that is holding a strong reference to the callbacks set up by task_runner_callback so they don't get lost in the shuffle. The last callback coro (callbacks.discard) cleans up the collection and discards the strong reference to coros as they complete. A simpler example of this pattern is in the python docs.

15 changes: 15 additions & 0 deletions src/lsst/cmservice/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ def is_processable_script(self) -> bool:
"""Is this a processable state for an elememnt"""
return self.value >= StatusEnum.waiting.value and self.value <= StatusEnum.running.value

def next_status(self) -> StatusEnum:
"""If the status is on the "happy" path, return the next status along
that path, otherwise return the failed status.
"""
happy_path = [StatusEnum.waiting, StatusEnum.ready, StatusEnum.running, StatusEnum.accepted]
if self in happy_path:
i = happy_path.index(self)
return happy_path[i + 1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope nobody calls this on StatusEnum.accepted?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true by fiat but in uncontrolled hands it could be an IndexError waiting to happen. I'll make a note to guard against that.

else:
return StatusEnum.failed


class TaskStatusEnum(enum.Enum):
"""Defines possible outcomes for Pipetask tasks"""
Expand Down Expand Up @@ -295,6 +306,10 @@ class ManifestKind(enum.Enum):
campaign = enum.auto()
node = enum.auto()
edge = enum.auto()
# Node kinds
grouped_step = enum.auto()
step_group = enum.auto()
collect_groups = enum.auto()
# Legacy kinds
specification = enum.auto()
spec_block = enum.auto()
Expand Down
12 changes: 5 additions & 7 deletions src/lsst/cmservice/common/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Iterable, Mapping, MutableSet, Sequence
from typing import Literal
from uuid import UUID

import networkx as nx
from sqlalchemy import select

from ..db import Script, ScriptDependency, Step, StepDependency
from ..db.campaigns_v2 import Edge, Node
Expand Down Expand Up @@ -70,9 +70,7 @@ async def graph_from_edge_list_v2(
# but we want to hydrate the entire Node model for subsequent users of this
# graph to reference without dipping back to the Database.
for node in g.nodes:
s = select(Node).where(Node.id == node)
db_node: Node = (await session.execute(s)).scalars().one()

db_node = await session.get_one(Node, node)
# This Node is going on an adventure where it does not need to drag its
# SQLAlchemy baggage along, so we expunge it from the session before
# adding it to the graph.
Expand All @@ -81,17 +79,17 @@ async def graph_from_edge_list_v2(
# for the simple node view, the goal is to minimize the amount of
# data attached to the node and ensure that this data is json-
# serializable and otherwise appropriate for an API response
g.nodes[node]["id"] = str(db_node.id)
g.nodes[node]["uuid"] = str(db_node.id)
g.nodes[node]["status"] = db_node.status.name
g.nodes[node]["kind"] = db_node.kind.name
g.nodes[node]["version"] = db_node.version
relabel_mapping[node] = db_node.name
else:
g.nodes[node]["model"] = db_node

if relabel_mapping:
g = nx.relabel_nodes(g, mapping=relabel_mapping, copy=False)

# TODO validate graph now raise exception, or leave it to the caller?
return g


Expand All @@ -107,7 +105,7 @@ def graph_to_dict(g: nx.DiGraph) -> Mapping:
return nx.node_link_data(g, edges="edges")


def validate_graph(g: nx.DiGraph, source: str = "START", sink: str = "END") -> bool:
def validate_graph(g: nx.DiGraph, source: UUID | str = "START", sink: UUID | str = "END") -> bool:
"""Validates a graph by asserting by traversal that a complete and correct
path exists between `source` and `sink` nodes.

Expand Down
20 changes: 20 additions & 0 deletions src/lsst/cmservice/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,26 @@ class DaemonConfiguration(BaseModel):
),
)

v1_enabled: bool = Field(
default=True,
description="Whether the v1 daemon is enabled and included in the event loop.",
)

v2_enabled: bool = Field(
default=False,
description="Whether the v2 daemon is enabled and included in the event loop.",
)

process_campaigns: bool = Field(
default=True,
description="Whether the v2 daemon processes Campaigns in the event loop.",
)

process_nodes: bool = Field(
default=True,
description="Whether the v2 daemon processes Nodes in the event loop.",
)


class NotificationConfiguration(BaseModel):
"""Configurations for notifications.
Expand Down
6 changes: 5 additions & 1 deletion src/lsst/cmservice/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import __version__
from .common.butler import BUTLER_FACTORY # noqa: F401
from .common.daemon import daemon_iteration
from .common.daemon_v2 import daemon_iteration as daemon_iteration_v2
from .common.logging import LOGGER
from .common.panda import get_panda_token
from .config import config
Expand Down Expand Up @@ -56,7 +57,10 @@ async def main_loop(app: FastAPI) -> None:
while True:
_iteration_count += 1
logger.info("Daemon starting iteration.")
await daemon_iteration(session)
if config.daemon.v1_enabled:
await daemon_iteration(session)
if config.daemon.v2_enabled:
await daemon_iteration_v2(session)
_iteration_time = current_time()
logger.info(f"Daemon completed {_iteration_count} iterations at {_iteration_time}.")
_next_wakeup = _iteration_time + sleep_time
Expand Down
Loading