diff --git a/alembic/versions/1da92a1c740f_create_v2_tables.py b/alembic/versions/1da92a1c740f_create_v2_tables.py index 127ee5604..27fac9815 100644 --- a/alembic/versions/1da92a1c740f_create_v2_tables.py +++ b/alembic/versions/1da92a1c740f_create_v2_tables.py @@ -178,7 +178,7 @@ def upgrade() -> None: sa.Column("node", postgresql.UUID(), nullable=False), sa.Column("priority", postgresql.INTEGER(), nullable=True), sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False), - sa.Column("last_processed_at", postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.Column("submitted_at", postgresql.TIMESTAMP(timezone=True), nullable=True), sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True), sa.Column("wms_id", postgresql.VARCHAR(), nullable=True), sa.Column("site_affinity", postgresql.ARRAY(postgresql.VARCHAR()), nullable=True), @@ -201,8 +201,10 @@ def upgrade() -> None: "activity_log_v2", sa.Column("id", postgresql.UUID(), nullable=False), sa.Column("namespace", postgresql.UUID(), nullable=False), - sa.Column("node", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=False), + sa.Column("node", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=True), sa.Column("operator", postgresql.VARCHAR(), nullable=False, default="root"), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True), sa.Column("from_status", ENUM_COLUMN_AS_VARCHAR, nullable=False), sa.Column("to_status", ENUM_COLUMN_AS_VARCHAR, nullable=False), sa.Column( @@ -219,6 +221,7 @@ def upgrade() -> None: default=dict, server_default=sa.text("'{}'::json"), ), + sa.PrimaryKeyConstraint("id"), if_not_exists=True, ) diff --git a/src/lsst/cmservice/common/daemon_v2.py b/src/lsst/cmservice/common/daemon_v2.py new file mode 100644 index 000000000..2612d98f2 --- /dev/null +++ b/src/lsst/cmservice/common/daemon_v2.py @@ -0,0 +1,198 @@ +import pickle +from asyncio import Task as AsyncTask +from asyncio import TaskGroup, create_task +from collections.abc import Awaitable, Mapping +from typing import TYPE_CHECKING +from uuid import UUID, uuid5 + +from sqlalchemy.dialects.postgresql import insert +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession +from transitions import Event + +from ..common import graph, timestamp +from ..common.enums import StatusEnum +from ..config import config +from ..db.campaigns_v2 import Campaign, Edge, Machine, Node, Task +from ..db.session import db_session_dependency +from ..machines.node import NodeMachine, node_machine_factory +from .logging import LOGGER + +logger = LOGGER.bind(module=__name__) + + +async def consider_campaigns(session: AsyncSession) -> None: + """In Phase One, the daemon considers campaigns. Campaigns subject to + consideration have a non-terminal prepared status (ready or running), and + optionally tagged with a priority value lower than the daemon's own + priority. + + For any campaigns thus discovered, the daemon then constructs a graph + from the campaign's Edges, and starting at the START node, walks the graph + until a Node is found that requires attention. Each Node found is added to + the Tasks table as a queue item. + """ + c_statement = ( + select(Campaign.id) + .where(col(Campaign.status).in_((StatusEnum.ready, StatusEnum.running))) + .with_for_update(key_share=True, skip_locked=True) + ) + campaigns = (await session.exec(c_statement)).all() + + for campaign_id in campaigns: + logger.info("Daemon considering campaign", id=campaign_id) + + # Fetch the Edges for the campaign + e_statement = select(Edge).filter_by(namespace=campaign_id) + edges = (await session.exec(e_statement)).all() + campaign_graph = await graph.graph_from_edge_list_v2(edges=edges, session=session) + + for node in graph.processable_graph_nodes(campaign_graph): + logger.info("Daemon considering node", id=str(node.id)) + desired_state = node.status.next_status() + node_task = Task( + id=uuid5(node.id, desired_state.name), + namespace=campaign_id, + node=node.id, + status=desired_state, + previous_status=node.status, + ) + statement = insert(node_task.__table__).values(**node_task.model_dump()).on_conflict_do_nothing() # type: ignore[attr-defined] + await session.exec(statement) # type: ignore[call-overload] + + await session.commit() + + +async def consider_nodes(session: AsyncSession) -> None: + """In Phase Two, the daemon considers Nodes. Nodes subject to consideration + are only those Nodes found on the Tasks table that have a priority lower + than the daemon's own priority, and share the daemon's site affinity. + + For each node considered by the daemon, the Node's FSM is loaded from the + Machines table, or creates one if needed. The daemon uses methods on the + Node's Stateful Model to evolve the state of the Node. + + After handling, the Node's FSM is serialized and the Node is updated with + new values as necessary. The Task is not returned to the Task table. + """ + # Select and lock unsubmitted tasks + statement = select(Task).where(col(Task.submitted_at).is_(None)) + # TODO add filter criteria for priority and site affinity + statement = statement.with_for_update(skip_locked=True) + + cm_tasks = (await session.exec(statement)).all() + + # Using a TaskGroup context manager means all "tasks" added to the group + # are awaited when the CM exits, giving us concurrency for all the nodes + # being considered in the current iteration. + async with TaskGroup() as tg: + for cm_task in cm_tasks: + node = await session.get_one(Node, cm_task.node) + + # the task's status field is the target status for the node, so the + # daemon intends to evolve the node machine to that state. + try: + assert node.status is cm_task.previous_status + except AssertionError: + logger.error("Node status out of sync with Machine", id=str(node.id)) + continue + + # Expunge the node from *this* session because it will be added to + # whatever session the node_machine acquires during its transition + session.expunge(node) + + node_machine: NodeMachine + node_machine_pickle: Machine | None + if node.machine is None: + # create a new machine for the node + node_machine = node_machine_factory(node.kind)(o=node) + node_machine_pickle = None + else: + # unpickle the node's machine and rehydrate the Stateful Model + node_machine_pickle = await session.get_one(Machine, node.machine) + node_machine = (pickle.loads(node_machine_pickle.state)).model + node_machine.db_model = node + # discard the pickled machine from this session and context + session.expunge(node_machine_pickle) + del node_machine_pickle + + # check possible triggers for state + # TODO how to pick the "best" trigger from multiple available? + # - Add a caller-backed conditional to the triggers, to identify + # . triggers the daemon is "allowed" to use + # - Determine the "desired" trigger from the task (source, dest) + if (trigger := trigger_for_transition(cm_task, node_machine.machine.events)) is None: + logger.warning( + "No trigger available for desired state transition", + source=cm_task.previous_status, + dest=cm_task.status, + ) + continue + + # Add the node transition trigger method to the task group + task = tg.create_task(node_machine.trigger(trigger), name=str(cm_task.id)) + task.add_done_callback(task_runner_callback) + + # wrap up - update the task and commit + cm_task.submitted_at = timestamp.now_utc() + await session.commit() + + +async def daemon_iteration(session: AsyncSession) -> None: + """A single iteraton of the CM daemon's work loop, which is carried out in + two phases: Campaigns and Nodes. + """ + iteration_start = timestamp.now_utc() + logger.debug("Daemon V2 Iteration: %s", iteration_start) + if config.daemon.process_campaigns: + await consider_campaigns(session) + if config.daemon.process_nodes: + await consider_nodes(session) + await session.close() + + +def trigger_for_transition(task: Task, events: Mapping[str, Event]) -> str | None: + """Determine the trigger name for transition that matches the desired state + tuple as indicated on a Task. + """ + + for trigger, event in events.items(): + for transition_list in event.transitions.values(): + for transition in transition_list: + if all( + [ + transition.source == task.previous_status.name, + transition.dest == task.status.name, + ] + ): + return trigger + return None + + +async def finalize_runner_callback(context: AsyncTask) -> None: + """Callback function for finalizing the CM Task runner.""" + + # Using the task name as the ID of a task, get the object and update its + # finished_at column. Alternately, we could delete the task from the table + # now. + if TYPE_CHECKING: + assert db_session_dependency.sessionmaker is not None + + logger.info("Finalizing CM Task", id=context.get_name()) + async with db_session_dependency.sessionmaker.begin() as session: + cm_task = await session.get_one(Task, UUID(context.get_name())) + cm_task.finished_at = timestamp.now_utc() + + +def task_runner_callback(context: AsyncTask) -> None: + """Callback function for `asyncio.TaskGroup` tasks.""" + if (exc := context.exception()) is not None: + logger.error(exc) + return + + logger.info("Transition complete", id=context.get_name()) + callbacks: set[Awaitable] = set() + # TODO: notification callback + finalizer = create_task(finalize_runner_callback(context)) + finalizer.add_done_callback(callbacks.discard) + callbacks.add(finalizer) diff --git a/src/lsst/cmservice/common/enums.py b/src/lsst/cmservice/common/enums.py index 50d07f331..8371b3f02 100644 --- a/src/lsst/cmservice/common/enums.py +++ b/src/lsst/cmservice/common/enums.py @@ -178,6 +178,17 @@ def is_processable_script(self) -> bool: """Is this a processable state for an elememnt""" return self.value >= StatusEnum.waiting.value and self.value <= StatusEnum.running.value + def next_status(self) -> StatusEnum: + """If the status is on the "happy" path, return the next status along + that path, otherwise return the failed status. + """ + happy_path = [StatusEnum.waiting, StatusEnum.ready, StatusEnum.running, StatusEnum.accepted] + if self in happy_path: + i = happy_path.index(self) + return happy_path[i + 1] + else: + return StatusEnum.failed + class TaskStatusEnum(enum.Enum): """Defines possible outcomes for Pipetask tasks""" @@ -295,6 +306,10 @@ class ManifestKind(enum.Enum): campaign = enum.auto() node = enum.auto() edge = enum.auto() + # Node kinds + grouped_step = enum.auto() + step_group = enum.auto() + collect_groups = enum.auto() # Legacy kinds specification = enum.auto() spec_block = enum.auto() diff --git a/src/lsst/cmservice/common/graph.py b/src/lsst/cmservice/common/graph.py index 184c20917..b206034c6 100644 --- a/src/lsst/cmservice/common/graph.py +++ b/src/lsst/cmservice/common/graph.py @@ -1,8 +1,8 @@ from collections.abc import Iterable, Mapping, MutableSet, Sequence from typing import Literal +from uuid import UUID import networkx as nx -from sqlalchemy import select from ..db import Script, ScriptDependency, Step, StepDependency from ..db.campaigns_v2 import Edge, Node @@ -70,9 +70,7 @@ async def graph_from_edge_list_v2( # but we want to hydrate the entire Node model for subsequent users of this # graph to reference without dipping back to the Database. for node in g.nodes: - s = select(Node).where(Node.id == node) - db_node: Node = (await session.execute(s)).scalars().one() - + db_node = await session.get_one(Node, node) # This Node is going on an adventure where it does not need to drag its # SQLAlchemy baggage along, so we expunge it from the session before # adding it to the graph. @@ -81,9 +79,10 @@ async def graph_from_edge_list_v2( # for the simple node view, the goal is to minimize the amount of # data attached to the node and ensure that this data is json- # serializable and otherwise appropriate for an API response - g.nodes[node]["id"] = str(db_node.id) + g.nodes[node]["uuid"] = str(db_node.id) g.nodes[node]["status"] = db_node.status.name g.nodes[node]["kind"] = db_node.kind.name + g.nodes[node]["version"] = db_node.version relabel_mapping[node] = db_node.name else: g.nodes[node]["model"] = db_node @@ -91,7 +90,6 @@ async def graph_from_edge_list_v2( if relabel_mapping: g = nx.relabel_nodes(g, mapping=relabel_mapping, copy=False) - # TODO validate graph now raise exception, or leave it to the caller? return g @@ -107,7 +105,7 @@ def graph_to_dict(g: nx.DiGraph) -> Mapping: return nx.node_link_data(g, edges="edges") -def validate_graph(g: nx.DiGraph, source: str = "START", sink: str = "END") -> bool: +def validate_graph(g: nx.DiGraph, source: UUID | str = "START", sink: UUID | str = "END") -> bool: """Validates a graph by asserting by traversal that a complete and correct path exists between `source` and `sink` nodes. diff --git a/src/lsst/cmservice/config.py b/src/lsst/cmservice/config.py index 099bc8810..73ed94c9c 100644 --- a/src/lsst/cmservice/config.py +++ b/src/lsst/cmservice/config.py @@ -522,6 +522,26 @@ class DaemonConfiguration(BaseModel): ), ) + v1_enabled: bool = Field( + default=True, + description="Whether the v1 daemon is enabled and included in the event loop.", + ) + + v2_enabled: bool = Field( + default=False, + description="Whether the v2 daemon is enabled and included in the event loop.", + ) + + process_campaigns: bool = Field( + default=True, + description="Whether the v2 daemon processes Campaigns in the event loop.", + ) + + process_nodes: bool = Field( + default=True, + description="Whether the v2 daemon processes Nodes in the event loop.", + ) + class NotificationConfiguration(BaseModel): """Configurations for notifications. diff --git a/src/lsst/cmservice/daemon.py b/src/lsst/cmservice/daemon.py index 4d226ad97..d7e407896 100644 --- a/src/lsst/cmservice/daemon.py +++ b/src/lsst/cmservice/daemon.py @@ -11,6 +11,7 @@ from . import __version__ from .common.butler import BUTLER_FACTORY # noqa: F401 from .common.daemon import daemon_iteration +from .common.daemon_v2 import daemon_iteration as daemon_iteration_v2 from .common.logging import LOGGER from .common.panda import get_panda_token from .config import config @@ -56,7 +57,10 @@ async def main_loop(app: FastAPI) -> None: while True: _iteration_count += 1 logger.info("Daemon starting iteration.") - await daemon_iteration(session) + if config.daemon.v1_enabled: + await daemon_iteration(session) + if config.daemon.v2_enabled: + await daemon_iteration_v2(session) _iteration_time = current_time() logger.info(f"Daemon completed {_iteration_count} iterations at {_iteration_time}.") _next_wakeup = _iteration_time + sleep_time diff --git a/src/lsst/cmservice/db/campaigns_v2.py b/src/lsst/cmservice/db/campaigns_v2.py index 8350f43aa..b2968f06e 100644 --- a/src/lsst/cmservice/db/campaigns_v2.py +++ b/src/lsst/cmservice/db/campaigns_v2.py @@ -1,16 +1,17 @@ """ORM Models for v2 tables and objects.""" -from datetime import datetime +from collections.abc import MutableSequence from typing import Any -from uuid import NAMESPACE_DNS, UUID, uuid5 +from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 -from pydantic import AliasChoices, ValidationInfo, model_validator +from pydantic import AliasChoices, AwareDatetime, ValidationInfo, model_validator from sqlalchemy.dialects import postgresql from sqlalchemy.ext.mutable import MutableDict, MutableList from sqlalchemy.types import PickleType -from sqlmodel import Column, Enum, Field, MetaData, SQLModel, String +from sqlmodel import Column, DateTime, Enum, Field, MetaData, SQLModel, String from ..common.enums import ManifestKind, StatusEnum +from ..common.timestamp import now_utc from ..common.types import KindField, StatusField from ..config import config @@ -46,6 +47,8 @@ def jsonb_column(name: str, aliases: list[str] | None = None) -> Any: class BaseSQLModel(SQLModel): + """Shared base SQL model for all tables.""" + __table_args__ = {"schema": config.db.table_schema} metadata = metadata @@ -57,12 +60,13 @@ class CampaignBase(BaseSQLModel): name: str namespace: UUID owner: str | None = Field(default=None) - status: StatusField | None = Field( + status: StatusField = Field( default=StatusEnum.waiting, sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), ) metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"]) + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") @model_validator(mode="before") @classmethod @@ -85,8 +89,6 @@ class Campaign(CampaignBase, table=True): __tablename__: str = "campaigns_v2" # type: ignore[misc] - machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") - class CampaignUpdate(BaseSQLModel): """Model representing updatable fields for a PATCH operation on a Campaign @@ -97,6 +99,20 @@ class CampaignUpdate(BaseSQLModel): status: StatusField | None = None +class CampaignSummary(CampaignBase): + """Model for the response of a Campaign Summary route.""" + + node_summary: MutableSequence["NodeStatusSummary"] + + +class NodeStatusSummary(BaseSQLModel): + """Model for a Node Status Summary.""" + + status: StatusField = Field(description="A state name") + count: int = Field(description="Count of nodes in this state") + mtime: AwareDatetime | None = Field(description="The most recent update time for nodes in this state") + + class NodeBase(BaseSQLModel): """nodes_v2 db table""" @@ -120,6 +136,7 @@ def __hash__(self) -> int: ) metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"]) + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") @model_validator(mode="before") @classmethod @@ -142,8 +159,6 @@ def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any: class Node(NodeBase, table=True): __tablename__: str = "nodes_v2" # type: ignore[misc] - machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") - class EdgeBase(BaseSQLModel): """edges_v2 db table""" @@ -169,8 +184,8 @@ class Edge(EdgeBase, table=True): class MachineBase(BaseSQLModel): """machines_v2 db table.""" - id: UUID = Field(primary_key=True) - state: Any | None = Field(sa_column=Column("state", PickleType)) + id: UUID = Field(primary_key=True, default_factory=uuid4) + state: Any = Field(sa_column=Column("state", PickleType)) class Machine(MachineBase, table=True): @@ -203,21 +218,42 @@ class Task(BaseSQLModel, table=True): __tablename__: str = "tasks_v2" # type: ignore[misc] - id: UUID = Field(primary_key=True) - namespace: UUID = Field(foreign_key="campaigns_v2.id") - node: UUID = Field(foreign_key="nodes_v2.id") - priority: int - created_at: datetime - last_processed_at: datetime - finished_at: datetime - wms_id: str - site_affinity: list[str] = Field( - sa_column=Column("site_affinity", MutableList.as_mutable(postgresql.ARRAY(String()))) + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + description="A hash of the related Node ID and target status, as a UUID5.", + ) + namespace: UUID = Field(foreign_key="campaigns_v2.id", description="The ID of a Campaign") + node: UUID = Field(foreign_key="nodes_v2.id", description="The ID of the target node") + priority: int | None = Field(default=None) + created_at: AwareDatetime = Field( + description="The `datetime` (UTC) at which this Task was first added to the queue", + default_factory=now_utc, + sa_column=Column(DateTime(timezone=True)), + ) + submitted_at: AwareDatetime | None = Field( + description="The `datetime` (UTC) at which this Task was first submitted as work to the event loop", + default=None, + sa_column=Column(DateTime(timezone=True)), + ) + finished_at: AwareDatetime | None = Field( + description=( + "The `datetime` (UTC) at which this Task successfully finalized. " + "A Task whose `finished_at` is not `None` is tombstoned and is subject to deletion." + ), + default=None, + sa_column=Column(DateTime(timezone=True)), + ) + wms_id: str | None = Field(default=None) + site_affinity: list[str] | None = Field( + default=None, sa_column=Column("site_affinity", MutableList.as_mutable(postgresql.ARRAY(String()))) ) status: StatusField = Field( + description="The 'target' status to which this Task will attempt to transition the Node", sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), ) previous_status: StatusField = Field( + description="The 'original' status from which this Task will attempt to transition the Node", sa_column=Column( "previous_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) ), @@ -225,21 +261,38 @@ class Task(BaseSQLModel, table=True): class ActivityLogBase(BaseSQLModel): - id: UUID = Field(primary_key=True) - namespace: UUID = Field(foreign_key="campaigns_v2.id") - node: UUID = Field(foreign_key="nodes_v2.id") - operator: str + id: UUID = Field(primary_key=True, default_factory=uuid4) + namespace: UUID = Field(foreign_key="campaigns_v2.id", description="The ID of a Campaign") + node: UUID | None = Field(default=None, foreign_key="nodes_v2.id", description="The ID of a Node") + operator: str = Field(description="The name of the operator or pilot who triggered the activity") + created_at: AwareDatetime = Field( + description="The `datetime` in UTC at which this log entry was created.", + default_factory=now_utc, + sa_column=Column(DateTime(timezone=True)), + ) + finished_at: AwareDatetime | None = Field( + description="The `datetime` in UTC at which this log entry was finalized.", + default=None, + sa_column=Column(DateTime(timezone=True), nullable=True), + ) to_status: StatusField = Field( + description=( + "The `target` state to which this activity tried to transition. " + "This may be the same as `from_status` in cases where no transition was attempted " + "(such as for a conditional check)." + ), sa_column=Column( "to_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) ), ) from_status: StatusField = Field( + description="The `original` state from which this activity tried to transition", sa_column=Column( "from_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) ), ) detail: dict = jsonb_column("detail") + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) class ActivityLog(ActivityLogBase, table=True): diff --git a/src/lsst/cmservice/db/manifests_v2.py b/src/lsst/cmservice/db/manifests_v2.py index af3cfde28..22947dd2e 100644 --- a/src/lsst/cmservice/db/manifests_v2.py +++ b/src/lsst/cmservice/db/manifests_v2.py @@ -10,6 +10,7 @@ from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationInfo, model_validator from ..common.enums import DEFAULT_NAMESPACE, ManifestKind +from ..common.timestamp import element_time from ..common.types import KindField @@ -30,17 +31,6 @@ class Manifest[MetadataT, SpecT](BaseModel): ) -class ManifestMetadata(BaseModel): - """Generic metadata model for Manifests. - - Conventionally denormalized fields are excluded from the model_dump when - serialized for ORM use. - """ - - name: str - namespace: str - - class ManifestSpec(BaseModel): """Generic spec model for Manifests. @@ -55,18 +45,30 @@ class ManifestSpec(BaseModel): model_config = ConfigDict(extra="allow") +class ManifestMetadata(BaseModel): + """Generic metadata model for Manifests. + + Conventionally denormalized fields are excluded from the model_dump when + serialized for ORM use. + """ + + name: str = Field(exclude=True) + namespace: str = Field(exclude=True) + crtime: int = Field(default_factory=element_time) + + class VersionedMetadata(ManifestMetadata): """Metadata model for versioned Manifests.""" - version: int = 0 + version: int = Field(exclude=True, default=0) class ManifestModelMetadata(VersionedMetadata): """Manifest model for general Manifests. These manifests are versioned but - a namespace is optional. + a namespace is optional (defaultable). """ - namespace: str = Field(default=str(DEFAULT_NAMESPACE)) + namespace: str = Field(default=str(DEFAULT_NAMESPACE), exclude=True) class ManifestModel(Manifest[ManifestModelMetadata, ManifestSpec]): @@ -81,16 +83,7 @@ def custom_model_validator(self, info: ValidationInfo) -> Self: return self -class CampaignMetadata(BaseModel): - """Metadata model for a Campaign Manifest. - - Campaign metadata does not require a namespace field. - """ - - name: str - - -class CampaignManifest(Manifest[CampaignMetadata, ManifestSpec]): +class CampaignManifest(Manifest[ManifestModelMetadata, ManifestSpec]): """validating model for campaigns""" @model_validator(mode="after") @@ -108,14 +101,15 @@ class EdgeMetadata(ManifestMetadata): A default random alphanumeric 8-byte name is generated if no name provided. """ - name: str = Field(default_factory=lambda: uuid4().hex[:8]) + name: str = Field(default_factory=lambda: uuid4().hex[:8], exclude=True) + crtime: int = Field(default_factory=element_time) class EdgeSpec(ManifestSpec): """Spec model for an Edge Manifest.""" - source: str - target: str + source: str = Field(exclude=True) + target: str = Field(exclude=True) class EdgeManifest(Manifest[EdgeMetadata, EdgeSpec]): diff --git a/src/lsst/cmservice/db/session.py b/src/lsst/cmservice/db/session.py index 9a1f12c04..e912e24ad 100644 --- a/src/lsst/cmservice/db/session.py +++ b/src/lsst/cmservice/db/session.py @@ -7,14 +7,17 @@ from sqlalchemy.pool import AsyncAdaptedQueuePool, Pool from sqlmodel.ext.asyncio.session import AsyncSession +from ..common.logging import LOGGER from ..config import config +logger = LOGGER.bind(module=__name__) -class DatabaseSessionDependency: + +class DatabaseManager: """A database session manager class designed to manage an async sqlalchemy engine and produce sessions. - A module-level instance of this class is created, and when called, a new + A module-level instance of this class is created, and when called a new async session is yielded. """ @@ -32,7 +35,7 @@ async def initialize( *, use_async: bool = True, ) -> None: - """Initialize the session dependency. + """Initialize the database manager. Parameters ---------- @@ -61,7 +64,8 @@ async def initialize( self.sessionmaker = async_sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) async def __call__(self) -> AsyncGenerator[AsyncSession]: - """Yields a database session. + """Yields a database session, rolls it back on error and closes it on + completion. Yields ------- @@ -72,7 +76,14 @@ async def __call__(self) -> AsyncGenerator[AsyncSession]: raise RuntimeError("Async sessionmaker is not initialized") async with self.sessionmaker() as session: - yield session + try: + yield session + except Exception: + logger.exception() + await session.rollback() + raise + finally: + await session.close() async def aclose(self) -> None: """Shut down the database engine.""" @@ -82,10 +93,5 @@ async def aclose(self) -> None: self.engine = None -db_session_dependency = DatabaseSessionDependency() -"""A module-level instance of the session manager""" - - -# FIXME not sure why this pattern -async def get_async_session() -> AsyncSession: - return await anext(db_session_dependency()) +db_session_dependency = DatabaseManager() +"""A module-level instance of the database manager""" diff --git a/src/lsst/cmservice/handlers/functions.py b/src/lsst/cmservice/handlers/functions.py index ff6c93e51..d3451bd58 100644 --- a/src/lsst/cmservice/handlers/functions.py +++ b/src/lsst/cmservice/handlers/functions.py @@ -24,7 +24,7 @@ from ..db.pipetask_error import PipetaskError from ..db.pipetask_error_type import PipetaskErrorType from ..db.product_set import ProductSet -from ..db.session import get_async_session +from ..db.session import db_session_dependency from ..db.spec_block import SpecBlock from ..db.specification import Specification from ..db.step import Step @@ -393,7 +393,8 @@ async def force_accept_node( """ local_session = False if session is None: - session = await get_async_session() + assert db_session_dependency.sessionmaker is not None + session = db_session_dependency.sessionmaker() local_session = True the_node = await db_class.get_row(session, node) @@ -444,7 +445,8 @@ async def render_campaign_steps( """ local_session = False if session is None: - session = await get_async_session() + assert db_session_dependency.sessionmaker is not None + session = db_session_dependency.sessionmaker() local_session = True if isinstance(campaign, int): campaign = await Campaign.get_row(session, campaign) diff --git a/src/lsst/cmservice/machines/__init__.py b/src/lsst/cmservice/machines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lsst/cmservice/machines/abc.py b/src/lsst/cmservice/machines/abc.py new file mode 100644 index 000000000..8d64454e7 --- /dev/null +++ b/src/lsst/cmservice/machines/abc.py @@ -0,0 +1,255 @@ +"""Abstract Base Classes used by Stateful Model and/or Machine classes. + +These primarily exist and are used to satisfy static type checkers that are +otherwise unaware of any dynamic methods added to Stateful Model classes by +a Machine instance. + +Notes +----- +These ABCs were generated automatically by `transitions.experimental.utils. +generate_base_model and simplified and/or modified for use by the application. + +These ABCs do not use abstractclasses because the implmentations will not be +available to static type checkers (i.e., they only exist at runtime). + +These ABCs may implement methods that are not used by application, i.e., that +involve states that are not referenced by any transition. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from sqlmodel.ext.asyncio.session import AsyncSession +from transitions import EventData, Machine +from transitions.extensions.asyncio import AsyncMachine + +from ..common.enums import ManifestKind, StatusEnum +from ..db.campaigns_v2 import ActivityLog, Campaign, Node + +type AnyStatefulObject = Campaign | Node +type AnyMachine = Machine | AsyncMachine + + +class StatefulModel(ABC): + """Base ABC for a Stateful Model, where the Machine will override abstract + methods and properties when it is created. + """ + + __kind__ = [ManifestKind.other] + activity_log_entry: ActivityLog | None = None + db_model: AnyStatefulObject | None + machine: AnyMachine + state: StatusEnum + session: AsyncSession | None = None + + @abstractmethod + def __init__( + self, *args: Any, o: AnyStatefulObject, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: ... + + @abstractmethod + async def error_handler(self, event: EventData) -> None: ... + + @abstractmethod + async def prepare_activity_log(self, event: EventData) -> None: ... + + @abstractmethod + async def update_persistent_status(self, event: EventData) -> None: ... + + @abstractmethod + async def finalize(self, event: EventData) -> None: ... + + async def may_trigger(self, trigger_name: str) -> bool: + raise NotImplementedError("Must be overridden by a Machine") + + async def trigger(self, trigger_name: str, **kwargs: Any) -> bool: + raise NotImplementedError("Must be overridden by a Machine") + + async def resume(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_resume(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def force(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_force(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def pause(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_pause(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def start(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_start(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def unblock(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_unblock(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def unprepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_unprepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def stop(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_stop(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def retry(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_retry(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def finish(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_finish(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def block(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_block(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def prepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_prepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def fail(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_fail(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") diff --git a/src/lsst/cmservice/machines/campaign.py b/src/lsst/cmservice/machines/campaign.py new file mode 100644 index 000000000..55cff74b6 --- /dev/null +++ b/src/lsst/cmservice/machines/campaign.py @@ -0,0 +1,190 @@ +"""Module for state machine implementations related to Campaigns. + +A Campaign state machine should be a simple one, since a Campaign itself does +not need to implement much in the way of Actions or Triggers. A campaign's +status should generally reflect the "worst-case" status of any of Nodes active +in its namespace. + +Since a campaign is mostly a container, the critical path of its state machine +should focus on validity and completeness of its graph, while providing useful +information about the overall campaign progress to pilots and other users. +""" + +from typing import TYPE_CHECKING, Any +from uuid import uuid5 + +from sqlmodel import select +from transitions import EventData +from transitions.extensions.asyncio import AsyncMachine + +from ..common import timestamp +from ..common.enums import ManifestKind, StatusEnum +from ..common.graph import graph_from_edge_list_v2, validate_graph +from ..common.logging import LOGGER +from ..db.campaigns_v2 import ActivityLog, Campaign, Edge, Node +from .node import NodeMachine + +logger = LOGGER.bind(module=__name__) + + +TRANSITIONS = [ + # The critical/happy path of state evolution from waiting to accepted + { + "trigger": "start", + "source": StatusEnum.waiting, + "dest": StatusEnum.running, + "conditions": "has_valid_graph", + }, + { + "trigger": "finish", + "source": StatusEnum.running, + "dest": StatusEnum.accepted, + "conditions": "is_successful", + }, + # User-initiated transitions + {"trigger": "pause", "source": StatusEnum.running, "dest": StatusEnum.paused}, + { + "trigger": "resume", + "source": StatusEnum.paused, + "dest": StatusEnum.running, + "conditions": "has_valid_graph", + }, +] +"""Transitions available to a Campaign, expressed as source-destination pairs +with a named trigger-verb. +""" + + +class InvalidCampaignGraphError(Exception): ... + + +class CampaignMachine(NodeMachine): + """Class representing the stateful structure of a Campaign State Machine, + including callbacks and actions to be executed during transitions. + """ + + __kind__ = [ManifestKind.campaign] + + def __init__( + self, *args: Any, o: Campaign, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + self.db_model = o + self.machine = AsyncMachine( + model=self, + states=StatusEnum, + transitions=TRANSITIONS, + initial=initial_state, + auto_transitions=False, + prepare_event=["prepare_session", "prepare_activity_log"], + after_state_change="update_persistent_status", + finalize_event="finalize", + on_exception="error_handler", + send_event=True, + model_override=True, + ) + + async def error_handler(self, event: EventData) -> None: + """Error handler function for the Stateful Model, called by the Machine + if any exception is raised in a callback function. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + if event.error is None: + return + + logger.exception(event.error, id=self.db_model.id) + if self.activity_log_entry is not None: + self.activity_log_entry.detail["trigger"] = event.event.name + self.activity_log_entry.detail["error"] = str(event.error) + self.activity_log_entry.finished_at = timestamp.now_utc() + + async def prepare_activity_log(self, event: EventData) -> None: + """Callback method invoked by the Machine before every state-change.""" + + if TYPE_CHECKING: + assert self.db_model is not None + + if self.activity_log_entry is not None: + return None + + from_state = StatusEnum[event.transition.source] if event.transition else self.state + to_state = ( + StatusEnum[event.transition.dest] if event.transition and event.transition.dest else self.state + ) + + self.activity_log_entry = ActivityLog( + namespace=self.db_model.id, + operator=event.kwargs.get("operator", "daemon"), + from_status=from_state, + to_status=to_state, + detail={}, + metadata_={"request_id": event.kwargs.get("request_id")}, + ) + + async def finalize(self, event: EventData) -> None: + """Callback method invoked by the Machine unconditionally at the end + of every callback chain. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + # The activity log entry is added to the db. For failed transitions it + # may include error detail. For other transitions it is not necessary + # to log every attempt, so if no callback has registered any detail + # for the log entry it is not persisted. + if self.activity_log_entry is None: + return + elif self.activity_log_entry.finished_at is None: + return + + try: + self.session.add(self.activity_log_entry) + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(self.activity_log_entry) + self.activity_log_entry = None + + await self.session.close() + self.session = None + self.activity_log_entry = None + + async def is_successful(self, event: EventData) -> bool: + """A conditional method associated with a transition. + + This callback should assert that the campaign is in a complete and + accepted state by the virtue of all its Nodes also being in a complete + and accepted state. The campaign's "END" node is used as a proxy + for this assertion, because by the rules of the campaign's graph, the + "END" node may only be reached if all other nodes have been success- + fully evolved by an executor. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + end_node = await self.session.get_one(Node, uuid5(self.db_model.id, "END.1")) + logger.info(f"Checking whether campaign {self.db_model.name} is finished.", end_node=end_node.status) + return end_node.status is StatusEnum.accepted + + async def has_valid_graph(self, event: EventData) -> bool: + """A conditional method associated with a transition. + + This callback asserts that the campaign graph is valid as a condition + that must be met before the campaign may transition to a "ready" state. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + edges = await self.session.exec(select(Edge).where(Edge.namespace == self.db_model.id)) + graph = await graph_from_edge_list_v2(edges.all(), self.session) + source = uuid5(self.db_model.id, "START.1") + sink = uuid5(self.db_model.id, "END.1") + graph_is_valid = validate_graph(graph, source, sink) + if not graph_is_valid: + raise InvalidCampaignGraphError("Invalid campaign graph") + return graph_is_valid diff --git a/src/lsst/cmservice/machines/node.py b/src/lsst/cmservice/machines/node.py new file mode 100644 index 000000000..cecf0a47a --- /dev/null +++ b/src/lsst/cmservice/machines/node.py @@ -0,0 +1,452 @@ +"""Module for state machine implementations related to Nodes.""" + +import inspect +import pickle +import shutil +import sys +from functools import cache +from os.path import expandvars +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from anyio import Path +from fastapi.concurrency import run_in_threadpool +from transitions import EventData +from transitions.extensions.asyncio import AsyncEvent, AsyncMachine + +from ..common import timestamp +from ..common.enums import ManifestKind, StatusEnum +from ..common.logging import LOGGER +from ..common.timestamp import element_time +from ..config import config +from ..db.campaigns_v2 import ActivityLog, Machine, Node +from ..db.session import db_session_dependency +from .abc import StatefulModel + +logger = LOGGER.bind(module=__name__) + + +TRANSITIONS = [ + # The critical/happy path of state evolution from waiting to accepted + { + "trigger": "prepare", + "source": StatusEnum.waiting, + "dest": StatusEnum.ready, + }, + { + "trigger": "start", + "source": StatusEnum.ready, + "dest": StatusEnum.running, + "conditions": "is_startable", + }, + { + "trigger": "finish", + "source": StatusEnum.running, + "dest": StatusEnum.accepted, + "conditions": "is_done_running", + }, + # The bad transitions + {"trigger": "block", "source": StatusEnum.running, "dest": StatusEnum.blocked}, + {"trigger": "fail", "source": StatusEnum.running, "dest": StatusEnum.failed}, + # User-initiated transitions + {"trigger": "pause", "source": StatusEnum.running, "dest": StatusEnum.paused}, + {"trigger": "unblock", "source": StatusEnum.blocked, "dest": StatusEnum.running}, + {"trigger": "resume", "source": StatusEnum.paused, "dest": StatusEnum.running}, + {"trigger": "force", "source": StatusEnum.failed, "dest": StatusEnum.accepted}, + # Inverse transitions, i.e., rollbacks + {"trigger": "unprepare", "source": StatusEnum.ready, "dest": StatusEnum.waiting}, + {"trigger": "stop", "source": StatusEnum.paused, "dest": StatusEnum.ready}, + {"trigger": "retry", "source": StatusEnum.failed, "dest": StatusEnum.ready}, +] +"""Transitions available to a Node, expressed as source-destination pairs +with a named trigger-verb. +""" + + +class NodeMachine(StatefulModel): + """General state model for a Node in a Campaign Graph.""" + + __kind__ = [ManifestKind.node] + + def __init__( + self, *args: Any, o: Node, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + self.db_model = o + self.machine = AsyncMachine( + model=self, + states=StatusEnum, + transitions=TRANSITIONS, + initial=initial_state, + auto_transitions=False, + prepare_event=["prepare_session", "prepare_activity_log"], + after_state_change="update_persistent_status", + finalize_event="finalize", + on_exception="error_handler", + send_event=True, + model_override=True, + ) + self.post_init() + + def post_init(self) -> None: + """Additional initialization method called at the end of ``__init__``, + as a convenience to child classes. + """ + pass + + def __getstate__(self) -> dict: + """Prepares the stateful model for serialization, as with pickle.""" + # Remove members that are not picklable or should not be included + # in the pickle + state = self.__dict__.copy() + del state["session"] + del state["db_model"] + del state["activity_log_entry"] + return state + + async def error_handler(self, event: EventData) -> None: + """Error handler function for the Stateful Model, called by the Machine + if any exception is raised in a callback function. + """ + if event.error is None: + return + + logger.exception(event.error) + if self.activity_log_entry is not None: + self.activity_log_entry.detail["trigger"] = event.event.name + self.activity_log_entry.detail["error"] = str(event.error) + self.activity_log_entry.finished_at = timestamp.now_utc() + + # Auto-transition on error + match event.event: + case AsyncEvent(name="finish"): + # TODO if we need to distinguish between types of failures, + # e.g., fail vs block, we'd have to inspect the error here + await self.trigger("fail") + case _: + ... + + async def prepare_session(self, event: EventData) -> None: + """Prepares the machine by acquiring a database session.""" + # This positive assertion concerning the ORM member will prevent + # any callback from proceeding if no such member is defined, but type + # checkers don't know this, which is why it repeated in a TYPE_CHECKING + # guard in each method that accesses the ORM member. + assert self.db_model is not None, "Stateful Model must have a Node member." + + logger.debug("Preparing session for transition", id=str(self.db_model.id)) + if self.session is not None: + await self.session.close() + else: + assert db_session_dependency.sessionmaker is not None + self.session = db_session_dependency.sessionmaker() + + async def prepare_activity_log(self, event: EventData) -> None: + """Callback method invoked by the Machine before every state-change.""" + if TYPE_CHECKING: + assert self.db_model is not None + + if self.activity_log_entry is not None: + return None + + logger.debug("Preparing activity log for transition", id=str(self.db_model.id)) + + from_state = StatusEnum[event.transition.source] if event.transition else self.state + to_state = ( + StatusEnum[event.transition.dest] if event.transition and event.transition.dest else self.state + ) + + self.activity_log_entry = ActivityLog( + namespace=self.db_model.namespace, + node=self.db_model.id, + operator="daemon", + from_status=from_state, + to_status=to_state, + detail={}, + metadata_={}, + ) + + async def update_persistent_status(self, event: EventData) -> None: + """Callback method invoked by the Machine after every state-change.""" + # Update activity log entry with new state and timestamp + if TYPE_CHECKING: + assert self.db_model is not None, "Stateful Model must have a Node member." + assert self.session is not None + logger.debug("Updating the ORM instance after transition.", id=str(self.db_model.id)) + + if self.activity_log_entry is not None: + self.activity_log_entry.to_status = self.state + self.activity_log_entry.finished_at = timestamp.now_utc() + + # Ensure database record for transitioned object is updated + self.db_model = await self.session.merge(self.db_model, load=False) + self.db_model.status = self.state + self.db_model.metadata_["mtime"] = element_time() + await self.session.commit() + + async def finalize(self, event: EventData) -> None: + """Callback method invoked by the Machine unconditionally at the end + of every callback chain. During this callback, if the activity log + indicates that change has occurred, it is written to the db and the + machine is serialized to the Machines table for later use. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + # The activity log entry is added to the db. For failed transitions it + # may include error detail. For other transitions it is not necessary + # to log every attempt. + if self.activity_log_entry is None: + return + elif self.activity_log_entry.finished_at is None: + return + + # ensure the orm instance is in the session + if self.db_model not in self.session: + self.db_model = await self.session.merge(self.db_model, load=False) + + # flush the activity log entry to the db + try: + logger.debug("Finalizing the activity log after transition.", id=str(self.db_model.id)) + self.session.add(self.activity_log_entry) + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(self.activity_log_entry) + self.activity_log_entry = None + + # create or update a machine entry in the db + new_machine = Machine.model_validate( + dict(id=self.db_model.machine or uuid4(), state=pickle.dumps(self.machine)) + ) + try: + logger.debug("Serializing the state machine after transition.", id=str(self.db_model.id)) + await self.session.merge(new_machine) + self.db_model.machine = new_machine.id + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(new_machine) + + await self.session.close() + self.session = None + + async def is_startable(self, event: EventData) -> bool: + """Conditional method called to check whether a ``start`` trigger may + be called. + """ + return True + + async def is_done_running(self, event: EventData) -> bool: + """Conditional method called to check whether a ``finish`` trigger may + be called. + """ + return True + + +class StartMachine(NodeMachine): + """Conceptually, a campaign's START node may participate in activities like + any other kind of node, even though its purpose is to provide a solid well- + known root to the campaign graph. Some activities assigned to the Campaign + Machine could also be modeled as belonging to the START node instead. The + END node could serve a similar purpose. + """ + + __kind__ = [ManifestKind.node] + + def post_init(self) -> None: + """Post init, set class-specific callback triggers.""" + self.machine.before_prepare("do_prepare") + self.machine.before_unprepare("do_unprepare") + self.machine.before_start("do_start") + + async def do_prepare(self, event: EventData) -> None: + """Action method invoked when executing the "prepare" transition. + + For a Campaign to enter the ready state, the machine must consider: + + Conditions + ---------- + - the campaign's graph is valid. + + Callbacks + --------- + - artifact directory is created and writable. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + logger.info("Preparing START node", id=str(self.db_model.id)) + + artifact_location = Path(expandvars(config.bps.artifact_path)) / str(self.db_model.namespace) + await artifact_location.mkdir(parents=False, exist_ok=True) + + async def do_unprepare(self, event: EventData) -> None: + if TYPE_CHECKING: + assert self.db_model is not None + + logger.info("Unpreparing START node", id=str(self.db_model.id)) + artifact_location = Path(expandvars(config.bps.artifact_path)) / str(self.db_model.namespace) + await run_in_threadpool(shutil.rmtree, artifact_location) + + async def do_start(self, event: EventData) -> None: + """Callback invoked when entering the "running" state. + + There is no particular work performed when a campaign enters a running + state other than to update the record's entry in the database which + acts as a flag to an executor to signal that a campaign's graph Nodes + may now be evolved. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + logger.debug("Starting START Node for Campaign", id=str(self.db_model.id)) + return None + + +class StepMachine(NodeMachine): + """Specific state model for a Node of kind GroupedStep. + + The Step-Nodes may be the most involved state models, as the logic that + must execute during each transition is complex. The behaviors are generally + the same as the "scripts" associated with a Step/Group/Job in the legacy + CM implementation. + + A summary of the logic at each transition: + + - prepare + - determine number of groups and group membership + - create new Manifest for each Group + - start + - create new StepGroup Nodes (reading prepared Manifests) + - create new StepCollect Node + - create edges + - finish + - (condition) campaign graph is valid + - unprepare (rollback) + - no action taken, but know that on the next use of "prepare" + new versions of the group manifests may be created. + + Failure modes may include + - Butler errors (can't query for group membership) + - Bad inputs (group membership rules don't make sense) + """ + + __kind__ = [ManifestKind.grouped_step] + + def __init__( + self, *args: Any, o: Node, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + super().__init__(*args, o, initial_state, **kwargs) + self.machine.before_prepare("do_prepare") + self.machine.before_start("do_start") + self.machine.before_unprepare("do_unprepare") + self.machine.before_finish("do_finish") + + async def do_prepare(self, event: EventData) -> None: ... + + async def do_unprepare(self, event: EventData) -> None: ... + + async def do_start(self, event: EventData) -> None: ... + + async def do_finish(self, event: EventData) -> None: ... + + async def is_successful(self, event: EventData) -> bool: + """Checks whether the WMS job is finished or not based on the result of + a bps-report or similar. Returns a True value if the batch is done and + good, a False value if it is still running. Raises an exception in any + other terminal WMS state (HELD or FAILED). + + ``` + bps_report: WmsStatusReport = get_wms_status_from_bps(...) + + match bps_report: + case WmsStatusReport(wms_status="FINISHED"): + return True + case WmsStatusReport(wms_status="HELD"): + raise WmsBlockedError() + case WmsStatusReport(wms_status="FAILED"): + raise WmsFailedError() + case WmsStatusReport(wms_status="RUNNING"): + return False + ``` + """ + return True + + +class GroupMachine(NodeMachine): + """Specific state model for a Node of kind StepGroup. + + A summary of the logic at each transition: + + - prepare + - create artifact output directory + - collect all relevant configuration Manifests + - render bps workflow artifacts + - create butler in collection(s) + + - start + - bps submit + - (after_start) determine bps submit directory + + - finish + - (condition) bps report == done + - create butler out collection(s) + + - fail + - read/parse bps output logs + + - stop (rollback) + - bps cancel + + - unprepare (rollback) + - remove artifact output directory + - Butler collections are not modified (paint-over pattern) + + Failure modes may include: + - Unwritable artifact output directory + - Manifests insufficient to render bps workflow artifacts + - Butler errors + - BPS or other middleware errors + """ + + __kind__ = [ManifestKind.step_group] + + ... + + +class StepCollectMachine(NodeMachine): + """Specific state model for a Node of kind StepCollect. + + - prepare + - create step output chained butler collection + + - start + - (condition) ancestor output collections exist in butler? + - add each ancestor output collection to step output chain + + - finish + - (condition) all ancestor output collections in chain + """ + + __kind__ = [ManifestKind.collect_groups] + + ... + + +@cache +def node_machine_factory(kind: ManifestKind) -> type[NodeMachine]: + """Returns the Stateful Model for a node based on its kind, by matching + the ``__kind__`` attribute of available classes in this module. + + TODO: May "construct" new classes from multiple matches, but this is not + yet necessary. + """ + for _, o in inspect.getmembers(sys.modules[__name__], inspect.isclass): + if issubclass(o, NodeMachine) and kind in o.__kind__: + return o + return NodeMachine diff --git a/src/lsst/cmservice/machines/tasks.py b/src/lsst/cmservice/machines/tasks.py new file mode 100644 index 000000000..93c74604d --- /dev/null +++ b/src/lsst/cmservice/machines/tasks.py @@ -0,0 +1,47 @@ +"""Background task implementations for FSM-related operations performed via +API routes. +""" + +from uuid import UUID + +from ..common.enums import StatusEnum +from ..common.logging import LOGGER +from ..db.campaigns_v2 import Campaign +from .campaign import CampaignMachine + +logger = LOGGER.bind(module=__name__) + + +async def change_campaign_state(campaign: Campaign, desired_state: StatusEnum, request_id: UUID) -> None: + """A Background Task to affect a state change in a Campaign, using an + FSM by triggering based on one of a handful of possible user-initiated + state changes, as by PATCHing a campaign using the REST API. + """ + + logger.info( + "Updating campaign state", + campaign=str(campaign.id), + request_id=str(request_id), + dest=desired_state.name, + ) + # Establish an FSM for the Campaign initialized to the current status + campaign_machine = CampaignMachine(o=campaign, initial_state=campaign.status) + + trigger: str + match (campaign.status, desired_state): + case (StatusEnum.waiting, StatusEnum.running): + trigger = "start" + case (StatusEnum.running, StatusEnum.paused): + trigger = "pause" + case (StatusEnum.paused, StatusEnum.running): + trigger = "resume" + case _: + logger.warning( + "Invalid campaign transition requested", + id=str(campaign.id), + source=campaign.status, + dest=desired_state, + ) + return None + + await campaign_machine.trigger(trigger, request_id=str(request_id)) diff --git a/src/lsst/cmservice/routers/v2/__init__.py b/src/lsst/cmservice/routers/v2/__init__.py index 938b69d99..4a0d14c12 100644 --- a/src/lsst/cmservice/routers/v2/__init__.py +++ b/src/lsst/cmservice/routers/v2/__init__.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from . import ( + activity_log, campaigns, edges, manifests, @@ -11,6 +12,7 @@ prefix="/v2", ) +router.include_router(activity_log.router) router.include_router(campaigns.router) router.include_router(edges.router) router.include_router(manifests.router) diff --git a/src/lsst/cmservice/routers/v2/activity_log.py b/src/lsst/cmservice/routers/v2/activity_log.py new file mode 100644 index 000000000..e9650efcb --- /dev/null +++ b/src/lsst/cmservice/routers/v2/activity_log.py @@ -0,0 +1,88 @@ +"""http routers for accessing the activity log.""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from ...common.logging import LOGGER +from ...db.campaigns_v2 import ActivityLog +from ...db.session import db_session_dependency + +# TODO should probably bind a logger to the fastapi app or something +logger = LOGGER.bind(module=__name__) + + +# Build the router +router = APIRouter( + prefix="/logs", + tags=["logs", "v2"], +) + + +@router.get( + "/", + summary="Get a list of Activity Log entries", +) +async def read_activity_collection( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + limit: Annotated[int, Query(le=100)] = 10, + offset: Annotated[int, Query()] = 0, + node: Annotated[UUID | None, Query(description="ID of a Node")] = None, + campaign: Annotated[UUID | None, Query(description="ID of a Campaign")] = None, + pilot: Annotated[str | None, Query(description="String name of a pilot")] = None, + since: Annotated[ + datetime | None, Query(description="Datetime of earliest log, in any format pydantic can validate") + ] = None, +) -> Sequence[ActivityLog]: + """A paginated API returning a list of all Activity Logs known to the + application, with query parameters allowing some constraints. + """ + statement = select(ActivityLog) + + # Add predicates for query parameters + if node is not None: + statement = statement.where(ActivityLog.node == node) + if campaign is not None: + statement = statement.where(ActivityLog.namespace == campaign) + if pilot is not None: + statement = statement.where(ActivityLog.operator == pilot) + if since is not None: + statement = statement.where(ActivityLog.finished_at >= since) # type: ignore[operator] + + statement = statement.order_by(col(ActivityLog.created_at).desc()).offset(offset).limit(limit) + activity_logs = (await session.exec(statement)).all() + return activity_logs + + +@router.get( + "/{activity_log_id}", + summary="Get single activity log detail", +) +async def read_activity_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + activity_log_id: UUID, +) -> ActivityLog: + """Fetch a single activity log from the database given the log's ID""" + + activity_log = await session.get(ActivityLog, activity_log_id) + # set the response headers + if activity_log is not None: + response.headers["Campaign"] = str( + request.url_for("read_campaign_resource", campaign_name_or_id=activity_log.namespace) + ) + response.headers["Node"] = str(request.url_for("read_node_resource", node_name=activity_log.node)) + response.headers["Self"] = str( + request.url_for("read_activity_resource", activity_log_id=activity_log_id) + ) + return activity_log + else: + raise HTTPException(status_code=404) diff --git a/src/lsst/cmservice/routers/v2/campaigns.py b/src/lsst/cmservice/routers/v2/campaigns.py index 71ab74966..86352a14c 100644 --- a/src/lsst/cmservice/routers/v2/campaigns.py +++ b/src/lsst/cmservice/routers/v2/campaigns.py @@ -6,9 +6,11 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Annotated -from uuid import UUID, uuid5 +from uuid import UUID, uuid4, uuid5 -from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, Response +from pydantic import UUID5 +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import aliased from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,9 +18,10 @@ from ...common.graph import graph_from_edge_list_v2, graph_to_dict from ...common.logging import LOGGER from ...common.timestamp import element_time -from ...db.campaigns_v2 import Campaign, CampaignUpdate, Edge, Manifest, Node +from ...db.campaigns_v2 import ActivityLog, Campaign, CampaignUpdate, Edge, Manifest, Node from ...db.manifests_v2 import CampaignManifest from ...db.session import db_session_dependency +from ...machines.tasks import change_campaign_state # TODO should probably bind a logger to the fastapi app or something logger = LOGGER.bind(module=__name__) @@ -72,14 +75,14 @@ async def read_campaign_collection( @router.get( - "/{campaign_name}", + "/{campaign_name_or_id}", summary="Get campaign detail", ) async def read_campaign_resource( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], - campaign_name: str, + campaign_name_or_id: str, ) -> Campaign: """Fetch a single campaign from the database given either the campaign id or its name. @@ -87,23 +90,25 @@ async def read_campaign_resource( s = select(Campaign) # The input could be a campaign UUID or it could be a literal name. try: - if campaign_id := UUID(campaign_name): + if campaign_id := UUID(campaign_name_or_id): s = s.where(Campaign.id == campaign_id) except ValueError: - s = s.where(Campaign.name == campaign_name) + s = s.where(Campaign.name == campaign_name_or_id) campaign = (await session.exec(s)).one_or_none() # set the response headers if campaign is not None: - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) + response.headers["Self"] = str( + request.url_for("read_campaign_resource", campaign_name_or_id=campaign.id) + ) response.headers["Nodes"] = str( - request.url_for("read_campaign_node_collection", campaign_name=campaign.id) + request.url_for("read_campaign_node_collection", campaign_id=campaign.id) ) response.headers["Edges"] = str( - request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + request.url_for("read_campaign_edge_collection", campaign_id=campaign.id) ) response.headers["Manifests"] = str( - request.url_for("read_campaign_manifest_collection", campaign_name=campaign.id) + request.url_for("read_campaign_manifest_collection", campaign_id=campaign.id) ) return campaign else: @@ -119,28 +124,27 @@ async def update_campaign_resource( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], + background_tasks: BackgroundTasks, campaign_name: str, patch_data: CampaignUpdate, ) -> Campaign: """Partial update method for campaigns. Should primarily be used to set the status of a campaign, e.g., from - waiting->ready, in order to trigger any validation rules contained in that - transition. - - Another common use case would be to set status to "paused". - - This could be used to update a campaign's metadata, but otherwise the - status is the only field available for modification, and even then there is - not an imperative "change the status" command, rather a request to evolve - the state of a campaign from A to B, which may or may not be successful. - - Rather than manipulating the campaign's record, a change to status should - instead create a work item for the task processing queue for an executor - to discover and attempt to act upon. Barring that, the work should be - delegated to a Background Task. This is why the method returns a 202; the - user needs to check back "later" to see if the requested state change has - occurred. + waiting->running or running->paused. + + Rather than directly manipulating the campaign's record, a change to status + uses a Background Task, which may or may not perform the requested update. + This is why the method returns a 202; the user needs to check back "later" + to see if the requested state change has occurred. + + Note + ---- + For patching a Campaign status, this API accepts only RFC7396 "Merge-Patch" + updates with the appropriate request header set. + + This route returns the Campaign subject to the PATCH, which may or may not + reflect all the requested updates (subject to background task resolution). """ use_rfc7396 = False use_rfc6902 = False @@ -170,33 +174,46 @@ async def update_campaign_resource( if campaign is None: raise HTTPException(status_code=404, detail="No such campaign") - # update the campaign with the patch data - update_data = patch_data.model_dump(exclude_unset=True) + # update the campaign with the patch data as a Merge operation + update_data = patch_data.model_dump(exclude={"status"}, exclude_unset=True) campaign.sqlmodel_update(update_data) - session.add(campaign) await session.commit() - await session.refresh(campaign) + session.expunge(campaign) + + # If the patch data is requesting a status change, we will not affect that + # directly, but defer it to a background task + if patch_data.status is not None: + # TODO implement middleware to assign a request_id to every request + request_id = uuid4() + background_tasks.add_task(change_campaign_state, campaign, patch_data.status, request_id) + response.headers["StatusUpdate"] = ( + f"""{request.url_for("read_campaign_activity_log", campaign_name=campaign.id)}""" + f"""?request-id={request_id}""" + ).strip() + # set the response headers if campaign is not None: - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) + response.headers["Self"] = str( + request.url_for("read_campaign_resource", campaign_name_or_id=campaign.id) + ) response.headers["Nodes"] = str( - request.url_for("read_campaign_node_collection", campaign_name=campaign.id) + request.url_for("read_campaign_node_collection", campaign_id=campaign.id) ) response.headers["Edges"] = str( - request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + request.url_for("read_campaign_edge_collection", campaign_id=campaign.id) ) return campaign @router.get( - "/{campaign_name}/nodes", + "/{campaign_id}/nodes", summary="Get campaign Nodes", ) async def read_campaign_node_collection( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], - campaign_name: str, + campaign_id: UUID5, limit: Annotated[int, Query(le=100)] = 10, offset: Annotated[int, Query()] = 0, ) -> Sequence[Node]: @@ -204,37 +221,34 @@ async def read_campaign_node_collection( single Campaign. """ - # The input could be a campaign UUID or it could be a literal name. - # TODO this could just as well be a campaign query with a join to nodes - statement = select(Node).order_by(Node.metadata_["crtime"].asc().nulls_last()) + statement = ( + select(Node) + .where(Node.namespace == campaign_id) + .order_by(Node.metadata_["crtime"].asc().nulls_last()) + .offset(offset) + .limit(limit) + ) - try: - if campaign_id := UUID(campaign_name): - statement = statement.where(Node.namespace == campaign_id) - except ValueError: - # FIXME get an id from a name - raise HTTPException(status_code=422, detail="campaign_name must be a uuid") - statement = statement.offset(offset).limit(limit) nodes = await session.exec(statement) response.headers["Next"] = str( request.url_for( "read_campaign_node_collection", - campaign_name=campaign_id, + campaign_id=campaign_id, ).include_query_params(offset=(offset + limit), limit=limit), ) - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign_id)) + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name_or_id=campaign_id)) return nodes.all() @router.get( - "/{campaign_name}/manifests", + "/{campaign_id}/manifests", summary="Get campaign Manifests", ) async def read_campaign_manifest_collection( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], - campaign_name: str, + campaign_id: UUID5, limit: Annotated[int, Query(le=100)] = 10, offset: Annotated[int, Query()] = 0, ) -> Sequence[Manifest]: @@ -242,36 +256,34 @@ async def read_campaign_manifest_collection( single Campaign. """ - # The input could be a campaign UUID or it could be a literal name. - statement = select(Manifest).order_by(Manifest.metadata_["crtime"].asc().nulls_last()) + statement = ( + select(Manifest) + .where(Manifest.namespace == campaign_id) + .order_by(Manifest.metadata_["crtime"].asc().nulls_last()) + .offset(offset) + .limit(limit) + ) - try: - if campaign_id := UUID(campaign_name): - statement = statement.where(Manifest.namespace == campaign_id) - except ValueError: - # FIXME get an id from a name - raise HTTPException(status_code=422, detail="campaign_name must be a uuid") - statement = statement.offset(offset).limit(limit) nodes = await session.exec(statement) response.headers["Next"] = str( request.url_for( "read_campaign_manifest_collection", - campaign_name=campaign_id, + campaign_id=campaign_id, ).include_query_params(offset=(offset + limit), limit=limit), ) - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign_id)) + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name_or_id=campaign_id)) return nodes.all() @router.get( - "/{campaign_name}/edges", + "/{campaign_id}/edges", summary="Get campaign Edges", ) async def read_campaign_edge_collection( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], - campaign_name: str, + campaign_id: UUID5, *, resolve_names: bool = False, ) -> Sequence[Edge]: @@ -280,41 +292,33 @@ async def read_campaign_edge_collection( graph. """ - # The input could be a campaign UUID or it could be a literal name. - # This is why raw SQL is better than ORMs - # This is probably better off as two queries instead of a "complicated" - # join. if resolve_names: source_nodes = aliased(Node, name="source") target_nodes = aliased(Node, name="target") - s = ( - select( + statement = ( + select( # type: ignore[call-overload] col(Edge.id).label("id"), col(Edge.name).label("name"), col(Edge.namespace).label("namespace"), col(source_nodes.name).label("source"), col(target_nodes.name).label("target"), col(Edge.configuration).label("configuration"), - ) # type: ignore + ) .join_from(Edge, source_nodes, Edge.source == source_nodes.id) .join_from(Edge, target_nodes, Edge.target == target_nodes.id) ) else: - s = select(Edge).order_by(col(Edge.name).asc().nulls_last()) - try: - if campaign_id := UUID(campaign_name): - s = s.where(Edge.namespace == campaign_id) - except ValueError: - # FIXME get an id from a name - raise HTTPException(status_code=422, detail="campaign_name must be a uuid") - edges = await session.exec(s) + statement = select(Edge).order_by(col(Edge.name).asc().nulls_last()) - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign_id)) + statement = statement.where(Edge.namespace == campaign_id) + edges = await session.exec(statement) + + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name_or_id=campaign_id)) return edges.all() @router.delete( - "/{campaign_name}/edges/{edge_name}", + "/{campaign_id}/edges/{edge_name}", summary="Delete campaign edge", status_code=204, ) @@ -322,16 +326,10 @@ async def delete_campaign_edge_resource( request: Request, response: Response, session: Annotated[AsyncSession, Depends(db_session_dependency)], - campaign_name: str, + campaign_id: UUID5, edge_name: str, ) -> None: - """Delete an edge resource from the campaign, using either name or id.""" - # If the campaign name is not a uuid, find the appropriate id - try: - campaign_id = UUID(campaign_name) - except ValueError: - # FIXME get an id from a name - raise HTTPException(status_code=422, detail="campaign_name must be a uuid") + """Delete an edge resource from the campaign.""" try: edge_id = UUID(edge_name) @@ -359,52 +357,66 @@ async def create_campaign_resource( session: Annotated[AsyncSession, Depends(db_session_dependency)], manifest: CampaignManifest, ) -> Campaign: - """An API to create a Campaign from an appropriate Manifest.""" - # Create a campaign spec from the manifest, delegating the creation of new - # dynamic fields to the model validation method, -OR- create new dynamic - # fields here. - campaign_metadata = manifest.metadata_.model_dump() - campaign_metadata |= {"crtime": element_time()} + """An API to create a Campaign from an appropriate Manifest. + + If a duplicate campaign is created, the route returns the original campaign + from the database with a 409 (conflict) status code. + """ + campaign = Campaign.model_validate( dict( - name=campaign_metadata.pop("name"), - metadata_=campaign_metadata, + name=manifest.metadata_.name, + metadata_=manifest.metadata_.model_dump(), # owner = ... # TODO Get username from gafaelfawr # noqa: ERA001 ) ) # A new campaign comes with a START and END node - start_node = Node.model_validate(dict(name="START", namespace=campaign.id)) - end_node = Node.model_validate(dict(name="END", namespace=campaign.id)) + start_node = Node.model_validate( + dict(name="START", namespace=campaign.id, metadata_={"crtime": element_time()}) + ) + end_node = Node.model_validate( + dict(name="END", namespace=campaign.id, metadata_={"crtime": element_time()}) + ) - # Put the campaign in the database - session.add(campaign) - session.add(start_node) - session.add(end_node) - await session.commit() - await session.refresh(campaign) + try: + # Put the campaign in the database + session.add(campaign) + session.add(start_node) + session.add(end_node) + await session.commit() + except IntegrityError: + # campaign already exists in the database, set the conflict status + # response but allow the response to proceed + logger.exception() + await session.rollback() + campaign = await session.get_one(Campaign, campaign.id) + response.status_code = 409 + except Exception as e: + logger.exception() + raise HTTPException(status_code=500, detail=str(e)) # set the response headers - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) - response.headers["Nodes"] = str( - request.url_for("read_campaign_node_collection", campaign_name=campaign.id) - ) - response.headers["Edges"] = str( - request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name_or_id=campaign.id)) + response.headers["Nodes"] = str(request.url_for("read_campaign_node_collection", campaign_id=campaign.id)) + response.headers["Edges"] = str(request.url_for("read_campaign_edge_collection", campaign_id=campaign.id)) + response.headers["Graph"] = str(request.url_for("read_campaign_graph", campaign_name=campaign.id)) + response.headers["Activity"] = str( + request.url_for("read_campaign_activity_log", campaign_name=campaign.id) ) return campaign @router.get( - "/{campaign_name_or_id}/graph", + "/{campaign_name}/graph", status_code=200, summary="Construct and return a Campaign's graph of nodes", ) async def read_campaign_graph( request: Request, response: Response, - campaign_name_or_id: str, + campaign_name: str, session: Annotated[AsyncSession, Depends(db_session_dependency)], ) -> Mapping: """Reads the graph resource for a campaign and returns its JSON represent- @@ -415,9 +427,9 @@ async def read_campaign_graph( # The input could be a campaign UUID or it could be a literal name. campaign_id: UUID | None try: - campaign_id = UUID(campaign_name_or_id) + campaign_id = UUID(campaign_name) except ValueError: - s = select(Campaign.id).where(Campaign.name == campaign_name_or_id) + s = select(Campaign.id).where(Campaign.name == campaign_name) campaign_id = (await session.exec(s)).one_or_none() if campaign_id is None: @@ -431,5 +443,42 @@ async def read_campaign_graph( # current database attributes according to the "simple" node view. graph = await graph_from_edge_list_v2(edges=edges, node_type=Node, session=session, node_view="simple") - response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign_id)) + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name_or_id=campaign_id)) return graph_to_dict(graph) + + +@router.get( + "/{campaign_name}/logs", + status_code=200, + summary="Obtain a collection of Activity Log records for a Campaign.", +) +async def read_campaign_activity_log( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, + request_id: Annotated[str | None, Query(validation_alias="request_id", alias="request-id")] = None, +) -> Sequence[ActivityLog]: + """Returns the collection of Activity Log resources associated with a + Campaign by its namespace. Optionally, a ``?request-id=...`` query param + may constrain entries to specific client requests. + """ + + # The input could be a campaign UUID or it could be a literal name. + campaign_id: UUID | None + try: + campaign_id = UUID(campaign_name) + except ValueError: + s = select(Campaign.id).where(Campaign.name == campaign_name) + campaign_id = (await session.exec(s)).one_or_none() + + if campaign_id is None: + raise HTTPException(status_code=404, detail="No such campaign found.") + + # Fetch the Activity Log entries for the campaign + statement = select(ActivityLog).where(ActivityLog.namespace == campaign_id) + if request_id is not None: + statement = statement.filter(ActivityLog.metadata_["request_id"].astext == request_id) + logs = (await session.exec(statement)).all() + + return logs diff --git a/src/lsst/cmservice/routers/v2/edges.py b/src/lsst/cmservice/routers/v2/edges.py index d96857966..80bfccdf3 100644 --- a/src/lsst/cmservice/routers/v2/edges.py +++ b/src/lsst/cmservice/routers/v2/edges.py @@ -95,10 +95,10 @@ async def read_edge_resource( response.headers["Source"] = request.url_for("read_node_resource", node_name=edge.source).__str__() response.headers["Target"] = request.url_for("read_node_resource", node_name=edge.target).__str__() response.headers["Campaign"] = request.url_for( - "read_campaign_resource", campaign_name=edge.namespace + "read_campaign_resource", campaign_name_or_id=edge.namespace ).__str__() response.headers["Graph"] = request.url_for( - "read_campaign_edge_collection", campaign_name=edge.namespace + "read_campaign_edge_collection", campaign_id=edge.namespace ).__str__() return edge @@ -135,13 +135,12 @@ async def create_edge_resource( target_node = manifest.spec.target # A edge must exist in the namespace of an existing campaign - edge_namespace: str = manifest.metadata_.namespace try: - edge_namespace_uuid: UUID | None = UUID(edge_namespace) + edge_namespace_uuid: UUID | None = UUID(manifest.metadata_.namespace) except ValueError: # get the campaign ID by its name to use as a namespace edge_namespace_uuid = ( - await session.exec(select(Campaign.id).where(Campaign.name == edge_namespace)) + await session.exec(select(Campaign.id).where(Campaign.name == manifest.metadata_.namespace)) ).one_or_none() # it is an error if the provided namespace (campaign) does not exist @@ -168,7 +167,8 @@ async def create_edge_resource( namespace=edge_namespace_uuid, source=uuid5(edge_namespace_uuid, source_node), target=uuid5(edge_namespace_uuid, target_node), - configuration=manifest.spec.model_dump(), + metadata_=manifest.metadata_.model_dump(exclude_none=True), + configuration=manifest.spec.model_dump(exclude_none=True), ) # The merge operation is effectively an upsert should an edge matching the @@ -180,10 +180,10 @@ async def create_edge_resource( response.headers["Source"] = request.url_for("read_node_resource", node_name=edge.source).__str__() response.headers["Target"] = request.url_for("read_node_resource", node_name=edge.target).__str__() response.headers["Campaign"] = request.url_for( - "read_campaign_resource", campaign_name=edge.namespace + "read_campaign_resource", campaign_name_or_id=edge.namespace ).__str__() response.headers["Graph"] = request.url_for( - "read_campaign_edge_collection", campaign_name=edge.namespace + "read_campaign_edge_collection", campaign_id=edge.namespace ).__str__() return edge diff --git a/src/lsst/cmservice/routers/v2/manifests.py b/src/lsst/cmservice/routers/v2/manifests.py index 32be3106b..b83a9b1df 100644 --- a/src/lsst/cmservice/routers/v2/manifests.py +++ b/src/lsst/cmservice/routers/v2/manifests.py @@ -15,7 +15,7 @@ from ...common.jsonpatch import JSONPatch, JSONPatchError, apply_json_patch from ...common.logging import LOGGER from ...common.timestamp import element_time -from ...db.campaigns_v2 import Campaign, Manifest, _default_campaign_namespace +from ...db.campaigns_v2 import Campaign, Manifest from ...db.manifests_v2 import ManifestModel from ...db.session import db_session_dependency @@ -125,23 +125,21 @@ async def create_one_or_more_manifests( # A manifest must exist in the namespace of an existing campaign # or the default namespace - _namespace: str | None = manifest.metadata_.namespace - if _namespace is None: - _namespace_uuid = _default_campaign_namespace - else: - try: - _namespace_uuid = UUID(_namespace) - except ValueError: - # get the campaign ID by its name to use as a namespace - # it is an error if the namespace/campaign does not exist - # FIXME but this could also be handled by FK constraints - if ( - _campaign_id := ( - await session.exec(select(Campaign.id).where(Campaign.name == _namespace)) - ).one_or_none() - ) is None: - raise HTTPException(status_code=422, detail="Requested namespace does not exist.") - _namespace_uuid = _campaign_id + _namespace = manifest.metadata_.namespace + + try: + _namespace_uuid = UUID(_namespace) + except ValueError: + # get the campaign ID by its name to use as a namespace + # it is an error if the namespace/campaign does not exist + # FIXME but this could also be handled by FK constraints + if ( + _campaign_id := ( + await session.exec(select(Campaign.id).where(Campaign.name == _namespace)) + ).one_or_none() + ) is None: + raise HTTPException(status_code=422, detail="Requested namespace does not exist.") + _namespace_uuid = _campaign_id # A manifest must be a new version if name+namespace already exists # check db for manifest as name+namespace, get version and increment @@ -158,16 +156,14 @@ async def create_one_or_more_manifests( _version = _previous.version if _previous else manifest.metadata_.version _version += 1 - _manifest_metadata = manifest.metadata_.model_dump() - _manifest_metadata |= {"crtime": element_time()} _manifest = Manifest( id=uuid5(_namespace_uuid, f"{_name}.{_version}"), - name=_manifest_metadata.pop("name"), + name=manifest.metadata_.name, namespace=_namespace_uuid, kind=manifest.kind, version=_version, - metadata_=_manifest_metadata, - spec=manifest.spec.model_dump(), + metadata_=manifest.metadata_.model_dump(exclude_none=True), + spec=manifest.spec.model_dump(exclude_none=True), ) # Put the node in the database diff --git a/src/lsst/cmservice/routers/v2/nodes.py b/src/lsst/cmservice/routers/v2/nodes.py index 59b95619e..67e5a91be 100644 --- a/src/lsst/cmservice/routers/v2/nodes.py +++ b/src/lsst/cmservice/routers/v2/nodes.py @@ -9,6 +9,7 @@ from uuid import UUID, uuid5 from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from pydantic import UUID5 from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -77,6 +78,7 @@ async def read_node_resource( response: Response, node_name: str, session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_id: Annotated[UUID5 | None, Query(validation_alias="campaign_id", alias="campaign-id")] = None, ) -> Node: """Fetch a single node from the database given either the node id or its name. @@ -87,14 +89,21 @@ async def read_node_resource( if node_id := UUID(node_name): s = s.where(Node.id == node_id) except ValueError: - s = s.where(Node.name == node_name) + # node name by itself is not sufficient to identity a single node in + # the database, so we must also constrain the request with the campaign + # namespace or raise an error. + if campaign_id is None: + raise HTTPException( + status_code=400, detail="Cannot locate Node by name alone. Try including `?campaign-id=...`" + ) + s = s.where(Node.name == node_name).where(Node.namespace == campaign_id) node = (await session.exec(s)).one_or_none() if node is None: raise HTTPException(status_code=404) response.headers["Self"] = request.url_for("read_node_resource", node_name=node.id).__str__() response.headers["Campaign"] = request.url_for( - "read_campaign_resource", campaign_name=node.namespace + "read_campaign_resource", campaign_name_or_id=node.namespace ).__str__() return node @@ -140,15 +149,13 @@ async def create_node_resource( node_version = previous_node.version if previous_node else node_version node_version += 1 - node_metadata = manifest.metadata_.model_dump() - node_metadata |= {"crtime": element_time()} node = Node( id=uuid5(node_namespace_uuid, f"{node_name}.{node_version}"), - name=node_metadata.pop("name"), + name=node_name, namespace=node_namespace_uuid, version=node_version, - configuration=manifest.spec.model_dump(), - metadata_=node_metadata, + configuration=manifest.spec.model_dump(exclude_none=True), + metadata_=manifest.metadata_.model_dump(exclude_none=True), ) # Put the node in the database @@ -157,7 +164,7 @@ async def create_node_resource( await session.refresh(node) response.headers["Self"] = request.url_for("read_node_resource", node_name=node.id).__str__() response.headers["Campaign"] = request.url_for( - "read_campaign_resource", campaign_name=node.namespace + "read_campaign_resource", campaign_name_or_id=node.namespace ).__str__() return node @@ -242,7 +249,7 @@ async def update_node_resource( response.headers["Self"] = request.url_for("read_node_resource", node_name=new_manifest_db.id).__str__() response.headers["Campaign"] = request.url_for( - "read_campaign_resource", campaign_name=new_manifest_db.namespace + "read_campaign_resource", campaign_name_or_id=new_manifest_db.namespace ).__str__() return new_manifest_db diff --git a/tests/v2/conftest.py b/tests/v2/conftest.py index ccd19505c..44905d3ea 100644 --- a/tests/v2/conftest.py +++ b/tests/v2/conftest.py @@ -2,13 +2,12 @@ import importlib import os -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Callable, Generator from typing import TYPE_CHECKING from uuid import NAMESPACE_DNS, uuid4 import pytest import pytest_asyncio -from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient from sqlalchemy import insert from sqlalchemy.pool import NullPool @@ -18,7 +17,7 @@ from lsst.cmservice.common.types import AnyAsyncSession from lsst.cmservice.config import config from lsst.cmservice.db.campaigns_v2 import metadata -from lsst.cmservice.db.session import DatabaseSessionDependency, db_session_dependency +from lsst.cmservice.db.session import DatabaseManager, db_session_dependency if TYPE_CHECKING: from fastapi import FastAPI @@ -32,8 +31,17 @@ def monkeypatch_module() -> Generator[pytest.MonkeyPatch]: yield mp +@pytest.fixture(scope="module", autouse=True) +def patched_config(monkeypatch_module: pytest.MonkeyPatch, tmp_path_factory: pytest.TempPathFactory) -> None: + """Fixture which monkeypatches configuration settings""" + monkeypatch_module.setattr( + target=config.bps, name="artifact_path", value=tmp_path_factory.mktemp("output") + ) + monkeypatch_module.setattr(target=config.db, name="echo", value=False) + + @pytest_asyncio.fixture(scope="module", loop_scope="module") -async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[DatabaseSessionDependency]: +async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[DatabaseManager]: """Test fixture for a postgres container. A scoped ephemeral container will be created for the test if the env var @@ -74,7 +82,7 @@ async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[Databa @pytest_asyncio.fixture(scope="module", loop_scope="module") -async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSessionDependency]: +async def testdb(rawdb: DatabaseManager) -> AsyncGenerator[DatabaseManager]: """Test fixture for a migrated postgres container. This fixture creates all the database objects defined for the ORM metadata @@ -93,6 +101,9 @@ async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSes namespace=str(NAMESPACE_DNS), name="DEFAULT", owner="root", + status="accepted", + metadata={}, + configuration={}, ) ) await aconn.commit() @@ -103,50 +114,145 @@ async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSes await aconn.commit() -@pytest_asyncio.fixture(name="session", scope="module", loop_scope="module") -async def session_fixture(testdb: DatabaseSessionDependency) -> AsyncGenerator[AnyAsyncSession]: - """Test fixture for an async database session""" - assert testdb.engine is not None - assert testdb.sessionmaker is not None - async with testdb.sessionmaker() as session: - try: - yield session - finally: - await session.close() - await testdb.engine.dispose() - - -def client_fixture(session: AnyAsyncSession) -> Generator[TestClient]: - """Test fixture for a FastAPI test client with dependency injection - overriden. - """ +@pytest_asyncio.fixture(name="session_factory", scope="module", loop_scope="module") +async def session_factory_fixture( + testdb: DatabaseManager, +) -> Callable[..., AsyncGenerator[AnyAsyncSession]]: + """Test fixture for providing an AsyncSession Factory.""" - def get_session_override() -> AnyAsyncSession: - return session + async def session_factory() -> AsyncGenerator[AnyAsyncSession]: + assert testdb.engine is not None + assert testdb.sessionmaker is not None + async with testdb.sessionmaker() as session: + try: + yield session + finally: + await session.close() + await testdb.engine.dispose() - main_ = importlib.import_module("lsst.cmservice.main") - app: FastAPI = getattr(main_, "app") + return session_factory - app.dependency_overrides[db_session_dependency] = get_session_override - client = TestClient(app) - yield client - app.dependency_overrides.clear() + +@pytest_asyncio.fixture(name="session", scope="module", loop_scope="module") +async def session_fixture(session_factory: Callable) -> AsyncGenerator[AnyAsyncSession]: + """Test fixture for an async database session, useful for tests that need + a database session directly. + """ + async for session in session_factory(): + yield session @pytest_asyncio.fixture(name="aclient", scope="module", loop_scope="module") -async def async_client_fixture(session: AnyAsyncSession) -> AsyncGenerator[AsyncClient]: - """Test fixture for an HTTPX async test client with dependency injection - overriden. +async def async_client_fixture( + session_factory: Callable, testdb: DatabaseManager +) -> AsyncGenerator[AsyncClient]: + """Test fixture for an HTTPX async test client backed by a FastAPI app with + its dependency injections overriden by factory fixtures. """ main_ = importlib.import_module("lsst.cmservice.main") app: FastAPI = getattr(main_, "app") + app.dependency_overrides[db_session_dependency] = session_factory - def get_session_override() -> AnyAsyncSession: - return session - - app.dependency_overrides[db_session_dependency] = get_session_override async with AsyncClient( follow_redirects=True, transport=ASGITransport(app), base_url="http://test" ) as aclient: yield aclient app.dependency_overrides.clear() + + +@pytest_asyncio.fixture(scope="function", loop_scope="module") +async def test_campaign(aclient: AsyncClient) -> AsyncGenerator[str]: + """Fixture managing a test campaign with three (additional) nodes, which + yields the URL for the campaign's edges endpoint. + """ + campaign_name = uuid4().hex[-8:] + node_ids = [] + + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata": {"name": campaign_name}, + "spec": {}, + }, + ) + campaign_edge_url = x.headers["Edges"] + campaign = x.json() + + # create a trio of nodes for the campaign + for _ in range(3): + x = await aclient.post( + "/cm-service/v2/nodes", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "node", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": {}, + }, + ) + node = x.json() + node_ids.append(node["name"]) + + # Create edges between each campaign node with parallelization + _ = await aclient.post( + "/cm-service/v2/edges", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "edge", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": { + "source": "START", + "target": node_ids[0], + }, + }, + ) + _ = await aclient.post( + "/cm-service/v2/edges", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "edge", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": { + "source": node_ids[0], + "target": node_ids[1], + }, + }, + ) + _ = await aclient.post( + "/cm-service/v2/edges", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "edge", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": { + "source": node_ids[0], + "target": node_ids[2], + }, + }, + ) + _ = await aclient.post( + "/cm-service/v2/edges", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "edge", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": { + "source": node_ids[1], + "target": "END", + }, + }, + ) + _ = await aclient.post( + "/cm-service/v2/edges", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "edge", + "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, + "spec": { + "source": node_ids[2], + "target": "END", + }, + }, + ) + yield campaign_edge_url diff --git a/tests/v2/test_campaign_routes.py b/tests/v2/test_campaign_routes.py index 75a514055..db8b98c1f 100644 --- a/tests/v2/test_campaign_routes.py +++ b/tests/v2/test_campaign_routes.py @@ -122,7 +122,7 @@ async def test_create_campaign(aclient: AsyncClient) -> None: assert len(edges) == 0 -async def test_patch_campaign(aclient: AsyncClient) -> None: +async def test_patch_campaign(aclient: AsyncClient, caplog: pytest.LogCaptureFixture) -> None: # Create a new campaign with spec data campaign_name = uuid4().hex[-8:] x = await aclient.post( @@ -160,15 +160,29 @@ async def test_patch_campaign(aclient: AsyncClient) -> None: assert y.status_code == 501 # Update the campaign using RFC7396 and campaign id + caplog.clear() y = await aclient.patch( campaign_url, - json={"status": "ready", "owner": "bob_loblaw"}, + json={"status": "running", "owner": "bob_loblaw"}, headers={"Content-Type": "application/merge-patch+json"}, ) assert y.is_success + # Obtain the Status Update URL from the response headers + status_update_url = y.headers["StatusUpdate"] + + # Check the updates are as expected updated_campaign = y.json() assert updated_campaign["owner"] == "bob_loblaw" - assert updated_campaign["status"] == "ready" + # the status update will not be applied + assert updated_campaign["status"] == "waiting" + + # we should see a failed transition in the log + log_entry_found = False + for r in caplog.records: + if "Invalid campaign graph" in r.message: + log_entry_found = True + break + assert log_entry_found # Update the campaign again using RFC7396, ensuring only a single field # is patched, using campaign name @@ -180,4 +194,13 @@ async def test_patch_campaign(aclient: AsyncClient) -> None: assert y.is_success updated_campaign = y.json() assert updated_campaign["owner"] == "alice_bob" - assert updated_campaign["status"] == "ready" + # the previous status update will not be successful + assert updated_campaign["status"] == "waiting" + + # The (failed) attempt to resume a campaign with a broken graph should + # produce an error detail at the reported location, i.e., at the logs + # API with a request-id query param. + y = await aclient.get(status_update_url) + assert y.is_success + activity_log_entry = y.json()[0] + assert activity_log_entry["detail"] == {"error": "Invalid campaign graph", "trigger": "start"} diff --git a/tests/v2/test_daemon.py b/tests/v2/test_daemon.py new file mode 100644 index 000000000..61a157fc3 --- /dev/null +++ b/tests/v2/test_daemon.py @@ -0,0 +1,152 @@ +"""tests for the v2 daemon""" + +from urllib.parse import urlparse +from uuid import uuid5 + +import pytest +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from lsst.cmservice.common.daemon_v2 import consider_campaigns, consider_nodes +from lsst.cmservice.common.enums import StatusEnum +from lsst.cmservice.db.campaigns_v2 import Campaign, Node, Task + +pytestmark = pytest.mark.asyncio(loop_scope="module") +"""All tests in this module will run in the same event loop.""" + + +async def test_daemon_campaign( + caplog: pytest.LogCaptureFixture, test_campaign: str, session: AsyncSession +) -> None: + """Tests the handling of campaigns during daemon iteration, which is + primarily done by checking side effects. This test assesses a test campaign + with three nodes and asserts that node(s) are added to the Task table in + the correct order and that as they enter an accepted state the daemon can + continue to traverse the campaign graph and visit more nodes until the + END node is reached. + """ + + # At first, the test_campaign in a waiting state is not subject to daemon + # consideration + caplog.clear() + await consider_campaigns(session) + + # extract the test campaign id from the fixture url + campaign_id = urlparse(test_campaign).path.split("/")[-2:][0] + + campaign = await session.get_one(Campaign, campaign_id) + assert campaign.status is not None + x = campaign.status + assert x is StatusEnum.waiting + + # check the next state in the happy path and set the campaign status + # (i.e., without using the Campaign FSM) + assert campaign.status.next_status() is StatusEnum.ready + campaign.status = campaign.status.next_status() + await session.commit() + + # now the daemon should consider the prepared campaign + caplog.clear() + await consider_campaigns(session) + found_log_messages = 0 + for r in caplog.records: + if any(["considering campaign" in r.message, "considering node" in r.message]): + found_log_messages += 1 + assert found_log_messages == 2 + + # A task should now be in the task table + tasks = (await session.exec(select(Task))).all() + assert len(tasks) == 1 + + # Fetch the node, set it to a terminal status, and consider the campaign + # again + node = await session.get_one(Node, tasks[0].node) + assert node.name == "START" + node.status = StatusEnum.accepted + await session.commit() + + caplog.clear() + await consider_campaigns(session) + tasks = (await session.exec(select(Task))).all() + # One additional task should be in the table now + assert len(tasks) == 2 + node = await session.get_one(Node, tasks[-1].node) + node.status = StatusEnum.accepted + await session.commit() + + # The next assessment should produce two nodes to be handled in parallel + caplog.clear() + await consider_campaigns(session) + tasks = (await session.exec(select(Task))).all() + # Two additional tasks should be in the table now + assert len(tasks) == 4 + + node = await session.get_one(Node, tasks[-1].node) + node.status = StatusEnum.accepted + node = await session.get_one(Node, tasks[-2].node) + node.status = StatusEnum.accepted + await session.commit() + + # The next assessment should produce the END node + caplog.clear() + await consider_campaigns(session) + tasks = (await session.exec(select(Task))).all() + # One additional task should be in the table now + assert len(tasks) == 5 + node = await session.get_one(Node, tasks[-1].node) + assert node.name == "END" + node.status = StatusEnum.accepted + await session.commit() + + +async def test_daemon_node( + caplog: pytest.LogCaptureFixture, test_campaign: str, session: AsyncSession +) -> None: + # set the campaign to running (without involving a campaign machine) + campaign_id = urlparse(url=test_campaign).path.split("/")[-2:][0] + campaign = await session.get_one(Campaign, campaign_id) + campaign.status = StatusEnum.running + await session.commit() + + # after the equivalent of a single iteration, the test campaign's START + # Node will be on the task list with a transition from waiting->ready. + await consider_campaigns(session) + await consider_nodes(session) + + # As we continue to iterate the daemon over the campaign's 5 nodes + # (including its START and END), each node in the graph is evolved. + # To simulate this, we'll pull the END node from the database and wait + # until it is in its terminal "accepted" state. + end_node = await session.get_one(Node, uuid5(campaign.id, "END.1")) + + # set up a release valve to stave off infinite loops in case things don't + # go well. This should take 11 iterations (4 nodes * 3 transitions, one of + # which is already done.) + i = 12 + while end_node.status is not StatusEnum.accepted: + i -= 1 + await consider_campaigns(session) + await consider_nodes(session) + # the end node is expunged from the session as a side effect when the + # graph is built + session.add(end_node) + await session.refresh(end_node, attribute_names=["status"]) + if not i: + raise RuntimeError("Node evolution took too long") + ... + + +def test_dynamic_node_machine() -> None: + """Test the dynamic resolution of a ``NodeMachine`` class based on a Node's + ``kind`` attribute. + """ + from lsst.cmservice.common.enums import ManifestKind + from lsst.cmservice.machines.node import GroupMachine, NodeMachine, node_machine_factory + + k = ManifestKind.node + x = node_machine_factory(k) + assert x is NodeMachine + + k = ManifestKind.step_group + x = node_machine_factory(k) + assert x is GroupMachine diff --git a/tests/v2/test_db.py b/tests/v2/test_db.py index 7995f52f2..c55fe8a42 100644 --- a/tests/v2/test_db.py +++ b/tests/v2/test_db.py @@ -6,11 +6,11 @@ from sqlmodel import select from lsst.cmservice.db.campaigns_v2 import Campaign, Machine, _default_campaign_namespace -from lsst.cmservice.db.session import DatabaseSessionDependency +from lsst.cmservice.db.session import DatabaseManager @pytest.mark.asyncio -async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None: +async def test_create_campaigns_v2(testdb: DatabaseManager) -> None: """Tests the campaigns_v2 table by creating and updating a Campaign.""" assert testdb.sessionmaker is not None @@ -55,7 +55,7 @@ async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None: @pytest.mark.asyncio -async def test_create_machines_v2(testdb: DatabaseSessionDependency) -> None: +async def test_create_machines_v2(testdb: DatabaseManager) -> None: """Tests the machines_v2 table by storing + retrieving a pickled object.""" assert testdb.sessionmaker is not None diff --git a/tests/v2/test_graph.py b/tests/v2/test_graph.py index 3f57d3444..491571f82 100644 --- a/tests/v2/test_graph.py +++ b/tests/v2/test_graph.py @@ -1,11 +1,7 @@ """Tests graph operations using v2 objects""" -from collections.abc import AsyncGenerator -from uuid import uuid4 - import networkx as nx import pytest -import pytest_asyncio from httpx import AsyncClient from lsst.cmservice.common.enums import StatusEnum @@ -17,102 +13,6 @@ """All tests in this module will run in the same event loop.""" -@pytest_asyncio.fixture(scope="module", loop_scope="module") -async def test_campaign(aclient: AsyncClient) -> AsyncGenerator[str]: - """Fixture managing a test campaign with two (additional) nodes.""" - campaign_name = uuid4().hex[-8:] - node_ids = [] - - x = await aclient.post( - "/cm-service/v2/campaigns", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "campaign", - "metadata": {"name": campaign_name}, - "spec": {}, - }, - ) - campaign_edge_url = x.headers["Edges"] - campaign = x.json() - - # create a trio of nodes for the campaign - for _ in range(3): - x = await aclient.post( - "/cm-service/v2/nodes", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "node", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": {}, - }, - ) - node = x.json() - node_ids.append(node["name"]) - - # Create edges between each campaign node with parallelization - _ = await aclient.post( - "/cm-service/v2/edges", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "edge", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": { - "source": "START", - "target": node_ids[0], - }, - }, - ) - _ = await aclient.post( - "/cm-service/v2/edges", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "edge", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": { - "source": node_ids[0], - "target": node_ids[1], - }, - }, - ) - _ = await aclient.post( - "/cm-service/v2/edges", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "edge", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": { - "source": node_ids[0], - "target": node_ids[2], - }, - }, - ) - _ = await aclient.post( - "/cm-service/v2/edges", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "edge", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": { - "source": node_ids[1], - "target": "END", - }, - }, - ) - _ = await aclient.post( - "/cm-service/v2/edges", - json={ - "apiVersion": "io.lsst.cmservice/v1", - "kind": "edge", - "metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]}, - "spec": { - "source": node_ids[2], - "target": "END", - }, - }, - ) - yield campaign_edge_url - - async def test_build_and_walk_graph( aclient: AsyncClient, session: AnyAsyncSession, test_campaign: str ) -> None: diff --git a/tests/v2/test_machines.py b/tests/v2/test_machines.py new file mode 100644 index 000000000..cf7626695 --- /dev/null +++ b/tests/v2/test_machines.py @@ -0,0 +1,236 @@ +"""Tests for State Machines.""" + +import pickle +import random +from unittest.mock import patch +from urllib.parse import urlparse +from uuid import UUID, uuid4, uuid5 + +import pytest +from httpx import AsyncClient +from sqlmodel.ext.asyncio.session import AsyncSession + +from lsst.cmservice.common.enums import StatusEnum +from lsst.cmservice.db.campaigns_v2 import Campaign, Machine, Node +from lsst.cmservice.machines.node import NodeMachine, StartMachine +from lsst.cmservice.machines.tasks import change_campaign_state + +pytestmark = pytest.mark.asyncio(loop_scope="module") +"""All tests in this module will run in the same event loop.""" + + +async def test_node_machine(test_campaign: str, session: AsyncSession) -> None: + """Test the critical/happy path of a node state machine.""" + + # extract the test campaign id from the fixture url and determine the START + # node id + campaign_id = urlparse(url=test_campaign).path.split("/")[-2:][0] + node_id = uuid5(UUID(campaign_id), "START.1") + + node = await session.get_one(Node, node_id) + x = node.status + assert x is StatusEnum.waiting + await session.commit() + + node_machine = StartMachine(o=node) + assert node_machine.is_waiting() + assert await node_machine.may_prepare() + + did_prepare = await node_machine.prepare() + assert did_prepare + assert await node_machine.may_start() + + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.ready + await session.commit() + + did_start = await node_machine.start() + assert did_start + assert await node_machine.may_pause() + assert await node_machine.may_finish() + + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.running + await session.commit() + + did_finish = await node_machine.trigger("finish") + assert did_finish + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.accepted + await session.commit() + + +async def test_bad_transition(test_campaign: str, session: AsyncSession) -> None: + """Forces a state transition to raise an exception and tests the generation + of an activity log record. + """ + # extract the test campaign id from the fixture url + campaign_id = urlparse(url=test_campaign).path.split("/")[-2:][0] + node_id = uuid5(UUID(campaign_id), "START.1") + + node = await session.get_one(Node, node_id) + x = node.status + assert x is StatusEnum.waiting + await session.commit() + + node_machine = StartMachine(o=node) + with patch( + "lsst.cmservice.machines.node.StartMachine.do_prepare", + side_effect=RuntimeError("Error: unknown error"), + ): + assert await node_machine.may_trigger("prepare") + assert not await node_machine.trigger("prepare") + + assert await node_machine.prepare() + assert await node_machine.start() + + # test the automatic transition to failure when an exception is raised in + # the success check conditional callback + with patch( + "lsst.cmservice.machines.node.StartMachine.is_done_running", + side_effect=RuntimeError("Error: WMS execution failure"), + ): + await node_machine.may_finish() + + # Both the state machine and db object should reflect the failed state + assert node_machine.is_failed() + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.failed + await session.commit() + + # retry by rolling back to ready + assert await node_machine.may_retry() + await node_machine.trigger("retry") + + assert node_machine.is_ready() + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.ready + await session.commit() + + # start and finish without further error + assert await node_machine.start() + assert await node_machine.finish() + + await session.refresh(node, attribute_names=["status"]) + x = node.status + assert x is StatusEnum.accepted + await session.commit() + + +async def test_machine_pickle(test_campaign: str, session: AsyncSession) -> None: + """Tests the serialization of a state machine in and out of a pickle.""" + # extract the test campaign id from the fixture url + campaign_id = urlparse(url=test_campaign).path.split("/")[-2:][0] + node_id = uuid5(UUID(campaign_id), "START.1") + + node = await session.get_one(Node, node_id) + x = node.status + assert x is StatusEnum.waiting + await session.commit() + + node_machine = NodeMachine(o=node) + # evolve the machine to "prepared" + await node_machine.prepare() + assert node_machine.state is StatusEnum.ready + + # pickles can go to the database + machine_pickle = pickle.dumps(node_machine.machine) + machine_db = Machine.model_validate(dict(id=uuid4(), state=machine_pickle)) + + # campaigns can have a reference to their machine, but the machine ids are + # not necessarily deterministic (i.e., they do not have to be namespaced.) + session.add(machine_db) + node.machine = machine_db.id + await session.commit() + + await session.close() + del node_machine + del node + del machine_db + del machine_pickle + + # get new objects from the database + new_node = await session.get_one(Node, node_id) + machine_unpickle = await session.get_one(Machine, new_node.machine) + + # Before pickling, the stateful model had the machine as its `.machine` + # member, and the machine itself had the stateful model as its `.model` + # member. To get back where we started, we would want to use the unpickled + # machine's `.model` attribute as our first-class named object, then re- + # assign any members we had to disasssociate before pickling + assert machine_unpickle.state is not None + new_node_machine: NodeMachine = (pickle.loads(machine_unpickle.state)).model + new_node_machine.db_model = new_node + + # test the rehydrated machine and continue to evolve it + assert new_node_machine.state == new_node.status + await session.commit() + + await new_node_machine.start() + await session.refresh(new_node, attribute_names=["status"]) + assert new_node_machine.state == new_node.status + await session.commit() + + await new_node_machine.finish() + await session.refresh(new_node, attribute_names=["status"]) + assert new_node_machine.state == new_node.status + await session.commit() + + +async def test_change_campaign_state( + session: AsyncSession, aclient: AsyncClient, test_campaign: str, caplog: pytest.LogCaptureFixture +) -> None: + """Tests that the campaign state change background task sets a campaign + status when it is supposed to, and does not when it is not. + """ + # the test_campaign fixture produces a valid graph so it should be possible + # to set the campaign to running, then paused + # extract the test campaign id from the fixture url + edge_list = (await aclient.get(test_campaign)).json() + + campaign_id = urlparse(test_campaign).path.split("/")[-2:][0] + campaign = await session.get_one(Campaign, campaign_id) + await session.commit() + + x = (await aclient.get(f"/cm-service/v2/campaigns/{campaign_id}")).json() + assert x["status"] == "waiting" + + await change_campaign_state(campaign, StatusEnum.running, uuid4()) + await session.refresh(campaign, attribute_names=["status"]) + await session.commit() + + x = (await aclient.get(f"/cm-service/v2/campaigns/{campaign_id}")).json() + assert x["status"] == "running" + + await change_campaign_state(campaign, StatusEnum.paused, uuid4()) + await session.refresh(campaign, attribute_names=["status"]) + await session.commit() + + x = (await aclient.get(f"/cm-service/v2/campaigns/{campaign_id}")).json() + assert x["status"] == "paused" + + # Break the graph by removing an edge from the campaign, and try to enter + # the running state + edge_to_remove = random.choice(edge_list)["id"] + x = await aclient.delete(f"/cm-service/v2/edges/{edge_to_remove}") + + caplog.clear() + await change_campaign_state(campaign, StatusEnum.running, uuid4()) + await session.refresh(campaign, attribute_names=["status"]) + await session.commit() + + x = (await aclient.get(f"/cm-service/v2/campaigns/{campaign_id}")).json() + assert x["status"] == "paused" + + # Check log messages + log_entry_found = False + for r in caplog.records: + if "Invalid campaign graph" in r.message: + log_entry_found = True + break + assert log_entry_found diff --git a/tests/v2/test_misc_routes.py b/tests/v2/test_misc_routes.py new file mode 100644 index 000000000..e880487a3 --- /dev/null +++ b/tests/v2/test_misc_routes.py @@ -0,0 +1,27 @@ +"""Tests for miscellaneous routes.""" + +from uuid import uuid4 + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio(loop_scope="module") +"""All tests in this module will run in the same event loop.""" + + +async def test_activity_log_routes( + aclient: AsyncClient, caplog: pytest.LogCaptureFixture, test_campaign: str +) -> None: + """Tests the activity log multipurpose collection route, ensuring that + query parameters behave as expected to generate executable SQL to satisfy + the request. + """ + # Test multiple query parameters at once + y = await aclient.get( + f"/cm-service/v2/logs?pilot=bobloblaw&campaign={uuid4()}&node={uuid4()}&since=1989-06-03T12:34:56Z" + ) + assert y.is_success + + # Test a different datetime format + y = await aclient.get("/cm-service/v2/logs?pilot=bobloblaw&since=946684800") + assert y.is_success diff --git a/tests/v2/test_node_routes.py b/tests/v2/test_node_routes.py index 4ee5a2ddb..b087f38ad 100644 --- a/tests/v2/test_node_routes.py +++ b/tests/v2/test_node_routes.py @@ -70,6 +70,7 @@ async def test_node_negative(aclient: AsyncClient) -> None: async def test_node_lifecycle(aclient: AsyncClient) -> None: """Tests node lifecycle.""" campaign_name = uuid4().hex[:8] + node_name = uuid4().hex[:8] # Create a campaign for edges. Campaigns come with START and END nodes. x = await aclient.post( @@ -88,7 +89,7 @@ async def test_node_lifecycle(aclient: AsyncClient) -> None: "/cm-service/v2/nodes", json={ "kind": "node", - "metadata": {"name": uuid4().hex[8:], "namespace": campaign_id}, + "metadata": {"name": node_name, "namespace": campaign_id}, "spec": { "handler": "lsst.cmservice.handlers.element_handler.ElementHandler", "pipeline_yaml": "${DRP_PIPE_DIR}/pipelines/HSC/DRP-RC2.yaml#step1", @@ -108,6 +109,12 @@ async def test_node_lifecycle(aclient: AsyncClient) -> None: assert node["version"] == 1 node_url = x.headers["Self"] + # Get a node using its name (fail) and its name+namespace (succeed) + x = await aclient.get("/cm-service/v2/nodes/{node_name}") + assert x.is_client_error + x = await aclient.get(f"/cm-service/v2/nodes/{node_name}?campaign-id={campaign_id}") + assert x.is_success + # Edit a Node using RFC6902 json-patch x = await aclient.patch( node_url,