diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fa4318e45..2e4a8b8c0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,6 +12,7 @@ name: "CI" env: UV_FROZEN: "1" + TEST__LOCAL_DB: "1" # signals test fixtures to not use testcontainers jobs: rebase-checker: @@ -30,7 +31,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Run pre-commit uses: pre-commit/action@v3.0.1 @@ -39,7 +40,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.11"] + python-version: ["3.12"] needs: - rebase-checker steps: @@ -49,7 +50,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - version: "0.6.x" + version: "0.7.x" enable-cache: true python-version: ${{ matrix.python-version }} @@ -63,7 +64,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.11"] + python-version: ["3.12"] needs: - lint - mypy @@ -75,7 +76,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - version: "0.6.x" + version: "0.7.x" enable-cache: true python-version: ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aecfc6aba..44c1c0ef0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.9.2 + rev: v0.11.13 hooks: - id: ruff - id: ruff diff --git a/.python-version b/.python-version index 2c0733315..e4fba2183 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.12 diff --git a/Makefile b/Makefile index d6d5066b8..f01fa4fe8 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,7 @@ update: update-deps init .PHONY: build build: export BUILDKIT_PROGRESS=plain +build: export COMPOSE_BAKE=true build: docker compose build cmservice docker compose build cmworker @@ -107,6 +108,7 @@ test: PGPORT=$(shell docker compose port postgresql 5432 | cut -d: -f2) test: export DB__URL=postgresql://cm-service@localhost:${PGPORT}/cm-service test: export DB__PASSWORD=INSECURE-PASSWORD test: export DB__TABLE_SCHEMA=cm_service_test +test: export BPS__ARTIFACT_PATH=$(PWD)/output test: run-compose alembic upgrade head pytest -vvv --asyncio-mode=auto --cov=lsst.cmservice --cov-branch --cov-report=term --cov-report=html ${PYTEST_ARGS} diff --git a/README.md b/README.md index 89c3d54db..422aca9d4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # cm-service +![Python](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Flsst-dm%2Fcm-service%2Frefs%2Fheads%2Fmain%2Fpyproject.toml) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv) @@ -8,7 +9,7 @@ https://cm-service.lsst.io. ## Developer Quick Start -You can build and run `cm-service` on any system which has Python 3.11 or greater, `uv`, `make`, and Docker w/ the +You can build and run `cm-service` on any system which has Python 3.12 or greater, `uv`, `make`, and Docker w/ the Docker Compose V2 CLI plugin (this includes, in particular, recent MacOS with Docker Desktop). Proceed as follows: diff --git a/alembic/README.md b/alembic/README.md index 8e57ac78d..04bf03826 100644 --- a/alembic/README.md +++ b/alembic/README.md @@ -3,8 +3,6 @@ Database migrations and schema evolution are handled by `alembic`, a database tool that is part of the `sqlalchemy` toolkit ecosystem. -Alembic is included in the project's dependency graph via the Safir package. - ## Running Alembic The `alembic` tool establishes an execution environment via the `env.py` file which diff --git a/alembic/versions/1da92a1c740f_create_v2_tables.py b/alembic/versions/1da92a1c740f_create_v2_tables.py new file mode 100644 index 000000000..27fac9815 --- /dev/null +++ b/alembic/versions/1da92a1c740f_create_v2_tables.py @@ -0,0 +1,250 @@ +"""create v2 tables + +Revision ID: 1da92a1c740f +Revises: acf951c80750 +Create Date: 2025-06-13 14:56:31.238050+00:00 + +""" + +from collections.abc import Sequence +from enum import Enum +from uuid import NAMESPACE_DNS + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1da92a1c740f" +down_revision: str | None = "acf951c80750" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +DEFAULT_CAMPAIGN_NAMESPACE = "dda54a0c-6878-5c95-ac4f-007f6808049e" +"""UUID5 of name 'io.lsst.cmservice' in `uuid.NAMESPACE_DNS`.""" + +# DB model uses mapped columns with Python Enum types, but we do not care +# to use native enums in the database, so when we have such a column, this +# definition will produce a VARCHAR instead. +ENUM_COLUMN_AS_VARCHAR = sa.Enum(Enum, length=20, native_enum=False, check_constraint=False) + + +def upgrade() -> None: + # Create table for machines v2 + machines_v2 = op.create_table( + "machines_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("state", sa.PickleType, nullable=False), + sa.PrimaryKeyConstraint("id"), + if_not_exists=True, + ) + + # Create table for campaigns v2 + campaigns_v2 = op.create_table( + "campaigns_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("name", postgresql.VARCHAR(), nullable=False), + sa.Column("namespace", postgresql.UUID(), nullable=False, default=DEFAULT_CAMPAIGN_NAMESPACE), + sa.Column("owner", postgresql.VARCHAR(), nullable=True), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column( + "configuration", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column("status", ENUM_COLUMN_AS_VARCHAR, nullable=False, default="waiting"), + sa.Column( + "machine", postgresql.UUID(), sa.ForeignKey(machines_v2.c.id, ondelete="CASCADE"), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "namespace"), + if_not_exists=True, + ) + + # Create node and edges tables for campaign digraph + nodes_v2 = op.create_table( + "nodes_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column( + "namespace", + postgresql.UUID(), + sa.ForeignKey(campaigns_v2.c.id, ondelete="CASCADE"), + nullable=False, + ), + sa.Column("name", postgresql.VARCHAR(), nullable=False), + sa.Column("version", postgresql.INTEGER(), nullable=False, default=1), + sa.Column("kind", ENUM_COLUMN_AS_VARCHAR, nullable=False, default="node"), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column( + "configuration", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column("status", ENUM_COLUMN_AS_VARCHAR, nullable=False, default="waiting"), + sa.Column( + "machine", postgresql.UUID(), sa.ForeignKey(machines_v2.c.id, ondelete="CASCADE"), nullable=True + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "version", "namespace"), + if_not_exists=True, + ) + + _ = op.create_table( + "edges_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("name", postgresql.VARCHAR(), nullable=False), + sa.Column( + "namespace", + postgresql.UUID(), + sa.ForeignKey(campaigns_v2.c.id, ondelete="CASCADE"), + nullable=False, + ), + sa.Column("source", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=False), + sa.Column("target", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=False), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column( + "configuration", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("source", "target", "namespace"), + if_not_exists=True, + ) + + # Create table for spec blocks v2 ("manifests") + _ = op.create_table( + "manifests_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("name", postgresql.VARCHAR(), nullable=False), + sa.Column( + "namespace", + postgresql.UUID(), + sa.ForeignKey(campaigns_v2.c.id, ondelete="CASCADE"), + nullable=False, + ), + sa.Column("version", postgresql.INTEGER(), nullable=False, default=1), + sa.Column("kind", ENUM_COLUMN_AS_VARCHAR, nullable=False, default="other"), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column( + "spec", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "version", "namespace"), + if_not_exists=True, + ) + + # Create table for tasks v2 + _ = op.create_table( + "tasks_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("namespace", postgresql.UUID(), nullable=False), + sa.Column("node", postgresql.UUID(), nullable=False), + sa.Column("priority", postgresql.INTEGER(), nullable=True), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column("submitted_at", postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.Column("wms_id", postgresql.VARCHAR(), nullable=True), + sa.Column("site_affinity", postgresql.ARRAY(postgresql.VARCHAR()), nullable=True), + sa.Column("status", ENUM_COLUMN_AS_VARCHAR, nullable=False), + sa.Column("previous_status", ENUM_COLUMN_AS_VARCHAR, nullable=True), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["node"], ["nodes_v2.id"]), + sa.ForeignKeyConstraint(["namespace"], ["campaigns_v2.id"]), + if_not_exists=True, + ) + + _ = op.create_table( + "activity_log_v2", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("namespace", postgresql.UUID(), nullable=False), + sa.Column("node", postgresql.UUID(), sa.ForeignKey(nodes_v2.c.id), nullable=True), + sa.Column("operator", postgresql.VARCHAR(), nullable=False, default="root"), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column("finished_at", postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.Column("from_status", ENUM_COLUMN_AS_VARCHAR, nullable=False), + sa.Column("to_status", ENUM_COLUMN_AS_VARCHAR, nullable=False), + sa.Column( + "detail", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.Column( + "metadata", + postgresql.JSONB(), + nullable=False, + default=dict, + server_default=sa.text("'{}'::json"), + ), + sa.PrimaryKeyConstraint("id"), + if_not_exists=True, + ) + + # Insert default campaign (namespace) record + op.bulk_insert( + campaigns_v2, + [ + { + "id": DEFAULT_CAMPAIGN_NAMESPACE, + "namespace": str(NAMESPACE_DNS), + "name": "DEFAULT", + "owner": "root", + } + ], + ) + + +def downgrade() -> None: + """Drop tables in the reverse order in which they were created.""" + op.drop_table("activity_log_v2", if_exists=True) + op.drop_table("tasks_v2", if_exists=True) + op.drop_table("manifests_v2", if_exists=True) + op.drop_table("edges_v2", if_exists=True) + op.drop_table("nodes_v2", if_exists=True) + op.drop_table("campaigns_v2", if_exists=True) + op.drop_table("machines_v2", if_exists=True) diff --git a/docker/Dockerfile b/docker/Dockerfile index d92b48eb2..bba03da78 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 -ARG PYTHON_VERSION="3.11" -ARG UV_VERSION="0.6" +ARG PYTHON_VERSION="3.12" +ARG UV_VERSION="0.7" ARG ASGI_PORT="8080" #============================================================================== diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 3d78cfce5..8bb337ad5 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -19,9 +19,7 @@ Releases are performed at an unspecified cadence, to be no shorter than 1 week a - Releases are named according to their semantic version (major.minor.patch). - Releases are made by adding a named tag to the trunk branch. - Each release will increment the minor version and set the patch level to 0, e.g., `1.0.12` -> `1.1.0` -- If a bugfix commit needs to be added to a release, then a retroactive branch will be created from the - release tag; the commit is cherry-picked into the release branch and a new tag is written with an incremented - patch level, e.g., `1.23.0` -> `1.23.1`. +- If a bugfix commit in the trunk needs to be added to a release, then a retroactive branch will be created from the affected release tag; any fix commits are cherry-picked into the release branch and a new tag is written with an incremented patch level, e.g., `1.23.0` -> `1.23.1`. This release branch is never merged to `main` (trunk) but is kept for subsequent cherry-picked fixes. - The major version is incremented only in the presence of user-facing breaking changes. This project uses `python-semantic-release` to manage releases. A release may be triggered by any ticket branch diff --git a/docs/DEPLOYING.md b/docs/DEPLOYING.md index 753817620..f2bc7759a 100644 --- a/docs/DEPLOYING.md +++ b/docs/DEPLOYING.md @@ -50,3 +50,18 @@ The CM Service consumes several secrets from the k8s application environment. Th - The AWS "default" profile is configured to use environment variables for its credentials source, so the appropriate secrets should be assigned to the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. - Authentication values for other profiles are stored in a secret mounted as a file at `/etc/aws/credentials`. + +## Dev / Staging Deployments +Development or staging builds are deployed to the USDF Kubernetes vcluster `usdf-cm-dev`. + +Prior to deploying a new build using a prerelease tag, release tag, or ticket branch tag, it may be necessary to clear the database of all data so it can be migrated from scratch. + +### Clearing the Database +After appropriately setting the Kubernetes context and namespace: + +1. Scale down the daemon deployment using `kubectl scale deployment cm-service-daemon --replicas=0`. +1. Obtain a shell in an API server pod using `kubectl exec -it cm-service-server- -- bash`. +1. Within this shell, downgrade the database migration using `alembic downgrade base`. + +> [!CAUTION] +> This operation unconditionally destroys the database contents and all objects. diff --git a/pyproject.toml b/pyproject.toml index 0a7c1f213..35907ae34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ name = "lsst-cm-service" -description = "Rubin Observatory campaign management FastAPI service" +description = "Rubin Observatory Campaign Management Service" license = { file = "LICENSE" } readme = "README.md" keywords = ["rubin", "lsst"] @@ -11,13 +11,13 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Intended Audience :: Developers", "Natural Language :: English", "Operating System :: POSIX", "Typing :: Typed", ] -requires-python = ">=3.11,<3.13" +requires-python = ">=3.12,<3.13" dynamic = ["version"] dependencies = [ @@ -27,32 +27,34 @@ dependencies = [ "click==8.1.*", "fastapi==0.115.*", "greenlet==3.1.*", - "htcondor==24.0.6; sys_platform == 'linux'", + "htcondor==24.0.7; sys_platform == 'linux'", "jinja2==3.1.*", "numpy==2.1.*", "psycopg2-binary==2.9.*", - "pydantic==2.10.*", - "pydantic-settings==2.7.*", + "pydantic==2.11.*", + "pydantic-settings==2.9.*", "python-multipart==0.0.*", "rich==13.9.*", "structlog==24.4.*", "tabulate==0.9.*", "sqlalchemy[asyncio]==2.0.*", "safir[db]==7.0.*", - "uvicorn[standard]==0.32.*", + "uvicorn[standard]==0.34.*", "panda-client>=1.5.82", "httpx>=0.27.2", "networkx>=3.5", + "sqlmodel>=0.0.24", + "pytransitions>=0.9.2", ] [dependency-groups] lsst = [ - "lsst-ctrl-bps>=29.2025.1500", - "lsst-ctrl-bps-htcondor>=29.2025.1500; sys_platform == 'linux'", - "lsst-ctrl-bps-panda>=29.2025.1500", - "lsst-daf-butler>=29.2025.1500", - "lsst-pipe-base>=29.2025.1500", - "lsst-utils>=29.2025.1500", + "lsst-ctrl-bps>=29.2025.2300", + "lsst-ctrl-bps-htcondor>=29.2025.2300; sys_platform == 'linux'", + "lsst-ctrl-bps-panda>=29.2025.2300", + "lsst-daf-butler>=29.2025.2300", + "lsst-pipe-base>=29.2025.2300", + "lsst-utils>=29.2025.2300", ] dev = [ @@ -60,8 +62,8 @@ dev = [ "asgi-lifespan>=2.1.0", "coverage[toml]>=7.6.7", "greenlet>=3.1.1", - "mypy>=1.13.0", - "pre-commit>=4.0.1", + "mypy>=1.16.0", + "pre-commit>=4.2.0", "pytest>=8.3.3", "pytest-asyncio>=0.24.0", "pytest-cov>=6.0.0", @@ -73,6 +75,7 @@ dev = [ "types-pyyaml>=6.0.12.20240917", "types-tabulate>=0.9.0.20240106", "respx>=0.22.0", + "testcontainers[postgres]==4.10.*", ] [project.scripts] @@ -113,11 +116,7 @@ exclude_lines = [ "def __repr__", "if self.debug:", "if settings.DEBUG", - "raise AssertionError", "raise NotImplementedError", - "except Exception as msg", - "except KeyError as msg", - "except IntegrityError as msg", "if 0:", "if __name__ == .__main__.:", "if TYPE_CHECKING:", @@ -128,7 +127,7 @@ exclude = [ "__init__.py", ] line-length = 110 -target-version = "py311" +target-version = "py312" [tool.ruff.lint] ignore = [ @@ -171,7 +170,7 @@ max-doc-length = 79 convention = "numpy" [tool.pytest.ini_options] -asyncio_mode = "strict" +asyncio_mode = "auto" asyncio_default_fixture_loop_scope="function" # The python_files setting is not for test detection (pytest will pick up any # test files named *_test.py without this setting) but to enable special @@ -225,10 +224,16 @@ tag_format = "{version}" match = "main" prerelease = false +# A Ticketed release branch [tool.semantic_release.branches.release] match = "^tickets/DM-\\d+(.*)/release$" prerelease = false +# A non-ticketed release branch +[tool.semantic_release.branches.bugfix] +match = "^release/\\d+.\\d+.\\d+$" +prerelease = false + [tool.semantic_release.branches.ticket] match = "^tickets/DM-\\d+(.*)$" prerelease_token = "rc" diff --git a/src/lsst/cmservice/__init__.py b/src/lsst/cmservice/__init__.py index 7d78d4378..7f5039908 100644 --- a/src/lsst/cmservice/__init__.py +++ b/src/lsst/cmservice/__init__.py @@ -2,4 +2,4 @@ __all__ = ["__version__"] -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/src/lsst/cmservice/cli/wrappers.py b/src/lsst/cmservice/cli/wrappers.py index 7b23e112b..99e751c97 100644 --- a/src/lsst/cmservice/cli/wrappers.py +++ b/src/lsst/cmservice/cli/wrappers.py @@ -11,7 +11,7 @@ import json from collections.abc import Callable, Sequence from enum import Enum -from typing import Any, TypeAlias +from typing import Any import click import yaml @@ -20,7 +20,7 @@ from ..client.client import CMClient from ..common.enums import StatusEnum -from ..db import Job, Script, SpecBlock, Specification +from ..db import ElementMixin, Job, RowMixin, Script, SpecBlock, Specification from . import options @@ -127,7 +127,7 @@ def output_dict( def get_list_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that gets all the rows from a table and attaches that function to the cli. @@ -143,7 +143,7 @@ def get_list_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -170,7 +170,7 @@ def get_rows( def get_row_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that gets a row from a table and attaches that function to the cli. @@ -183,7 +183,7 @@ def get_row_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -212,7 +212,7 @@ def get_row( def get_row_by_name_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that gets a row from a table and attaches that function to the cli. @@ -225,7 +225,7 @@ def get_row_by_name_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -254,7 +254,7 @@ def get_row_by_name( def get_row_by_fullname_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that gets a row from a table and attaches that function to the cli. @@ -267,7 +267,7 @@ def get_row_by_fullname_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -296,7 +296,7 @@ def get_row_by_fullname( def get_create_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], create_options: list[Callable], ) -> Callable: """Return a function that creates a new row in the table @@ -310,7 +310,7 @@ def get_create_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class create_options: list[Callable] @@ -342,7 +342,7 @@ def create( def get_update_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], update_options: list[Callable], ) -> Callable: """Return a function that updates a row in the table @@ -356,7 +356,7 @@ def get_update_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class update_options: list[Callable] @@ -512,7 +512,7 @@ def get_resolved_collections_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -693,7 +693,7 @@ def get_spec_aliases( def get_update_status_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that updates the status of row in the table and attaches that function to the cli. @@ -706,7 +706,7 @@ def get_update_status_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -967,9 +967,6 @@ def get_action_run_check_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin - Underlying database class - Returns ------- the_function: Callable @@ -996,7 +993,7 @@ def run_check( def get_action_accept_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that marks a row in the table as accepted and attaches that function to the cli. @@ -1009,7 +1006,7 @@ def get_action_accept_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -1038,7 +1035,7 @@ def accept( def get_action_reject_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that marks a row in the table as rejected and attaches that function to the cli. @@ -1051,7 +1048,7 @@ def get_action_reject_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -1080,7 +1077,7 @@ def reject( def get_action_reset_command( group_command: Callable, sub_client_name: str, - db_class: TypeAlias, + db_class: type[RowMixin], ) -> Callable: """Return a function that resets the status of a row in the table and attaches that function to the cli. @@ -1093,7 +1090,7 @@ def get_action_reset_command( sub_client_name: str Name of python API sub-client to use - db_class: TypeAlias = db.RowMixin + db_class: type Underlying database class Returns @@ -1125,7 +1122,7 @@ def reset( def get_element_parent_command( group_command: Callable, sub_client_name: str, - db_parent_class: TypeAlias, + db_parent_class: type[ElementMixin], ) -> Callable: """Return a function that gets the parent of an element @@ -1137,7 +1134,7 @@ def get_element_parent_command( sub_client_name: str Name of python API sub-client to use - db_parent_class: TypeAlias = db.RowMixin + db_parent_class: type Underlying parent database class Returns diff --git a/src/lsst/cmservice/client/script_dependencies.py b/src/lsst/cmservice/client/script_dependencies.py index 95c2b75e1..1f66e3df0 100644 --- a/src/lsst/cmservice/client/script_dependencies.py +++ b/src/lsst/cmservice/client/script_dependencies.py @@ -50,9 +50,4 @@ def client(self) -> httpx.Client: f"{router_string}/create", ) - update = wrappers.update_row_function( - ResponseModelClass, - f"{router_string}/update", - ) - delete = wrappers.delete_row_function(f"{router_string}/delete") diff --git a/src/lsst/cmservice/client/step_dependencies.py b/src/lsst/cmservice/client/step_dependencies.py index 39e74fc72..eb73bfdf3 100644 --- a/src/lsst/cmservice/client/step_dependencies.py +++ b/src/lsst/cmservice/client/step_dependencies.py @@ -50,9 +50,4 @@ def client(self) -> httpx.Client: f"{router_string}/create", ) - update = wrappers.update_row_function( - ResponseModelClass, - f"{router_string}/update", - ) - delete = wrappers.delete_row_function(f"{router_string}/delete") diff --git a/src/lsst/cmservice/client/wrappers.py b/src/lsst/cmservice/client/wrappers.py index 1f41b2410..4107abc28 100644 --- a/src/lsst/cmservice/client/wrappers.py +++ b/src/lsst/cmservice/client/wrappers.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, TypeAlias from httpx import ConnectError, HTTPStatusError -from pydantic import BaseModel, TypeAdapter +from pydantic import TypeAdapter from .. import models from ..common.logging import LOGGER @@ -17,7 +17,7 @@ def get_rows_no_parent_function( - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that gets all the rows from a table @@ -28,7 +28,7 @@ def get_rows_no_parent_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -53,7 +53,7 @@ def get_rows(obj: CMClient) -> list[response_model_class]: def get_rows_function( - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, query: str = "", ) -> Callable: # pragma: no cover """Return a function that gets all the rows from a table @@ -66,7 +66,7 @@ def get_rows_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -99,7 +99,7 @@ def get_rows( def get_row_function( - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that gets a single row from a table (by ID) @@ -107,7 +107,7 @@ def get_row_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -131,8 +131,8 @@ def row_get( def create_row_function( - response_model_class: TypeAlias = BaseModel, - create_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, + create_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that creates a single row in a table @@ -140,10 +140,10 @@ def create_row_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value - create_model_class: TypeAlias = BaseModel, + create_model_class: TypeAlias, Pydantic class used to serialize the inputs value query: str @@ -168,8 +168,8 @@ def row_create(obj: CMClient, **kwargs: Any) -> response_model_class: def update_row_function( - response_model_class: TypeAlias = BaseModel, - update_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, + update_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that updates a single row in a table @@ -177,10 +177,10 @@ def update_row_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value - update_model_class: TypeAlias = BaseModel, + update_model_class: TypeAlias, Pydantic class used to serialize the input values query: str @@ -233,7 +233,7 @@ def row_delete( def get_row_by_fullname_function( - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that gets a single row from a table (by fullname) @@ -241,7 +241,7 @@ def get_row_by_fullname_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -268,7 +268,7 @@ def get_row_by_fullname( def get_row_by_name_function( - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, query: str = "", ) -> Callable: """Return a function that gets a single row from a table (by name) @@ -276,7 +276,7 @@ def get_row_by_name_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -316,7 +316,7 @@ def get_node_property_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -353,7 +353,7 @@ def get_node_post_query_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query_class: TypeAlias @@ -394,7 +394,7 @@ def get_node_post_no_query_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -421,7 +421,7 @@ def node_update( def get_general_post_function( - query_class: TypeAlias = BaseModel, + query_class: TypeAlias, response_model_class: TypeAlias = Any, query: str = "", results_key: str | None = None, @@ -431,7 +431,7 @@ def get_general_post_function( Parameters ---------- - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str @@ -460,7 +460,7 @@ def general_post_function( def get_general_query_function( - query_class: TypeAlias = BaseModel, + query_class: TypeAlias, response_model_class: TypeAlias = Any, query: str = "", query_suffix: str = "", @@ -474,7 +474,7 @@ def get_general_query_function( query_class: TypeAlias Pydantic class used to serialize the query parameters - response_model_class: TypeAlias = BaseModel, + response_model_class: TypeAlias, Pydantic class used to serialize the return value query: str diff --git a/src/lsst/cmservice/common/daemon.py b/src/lsst/cmservice/common/daemon.py index 0a8cb5855..c49f7d159 100644 --- a/src/lsst/cmservice/common/daemon.py +++ b/src/lsst/cmservice/common/daemon.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.future import select from ..common import notification, timestamp from ..common.enums import StatusEnum +from ..common.types import AnyAsyncSession from ..config import config from ..db.node import NodeMixin from ..db.queue import Queue @@ -15,19 +15,19 @@ logger = LOGGER.bind(module=__name__) -async def check_due_date(session: async_scoped_session, node: NodeMixin, time_next_check: datetime) -> None: +async def check_due_date(session: AnyAsyncSession, node: NodeMixin, time_next_check: datetime) -> None: """For a provided due date, check if the queue entry is overdue""" - due_date = node.metadata_.get("due_date", None) + due_date: int | None = node.metadata_.get("due_date", None) if due_date is None: return None - if time_next_check > due_date: + if time_next_check.timestamp() > due_date: campaign = await node.get_campaign(session) await notification.send_notification(for_status=StatusEnum.overdue, for_campaign=campaign) -async def daemon_iteration(session: async_scoped_session) -> None: +async def daemon_iteration(session: AnyAsyncSession) -> None: iteration_start = timestamp.now_utc() processed_nodes = 0 queue_entries = await session.execute( diff --git a/src/lsst/cmservice/common/daemon_v2.py b/src/lsst/cmservice/common/daemon_v2.py new file mode 100644 index 000000000..2612d98f2 --- /dev/null +++ b/src/lsst/cmservice/common/daemon_v2.py @@ -0,0 +1,198 @@ +import pickle +from asyncio import Task as AsyncTask +from asyncio import TaskGroup, create_task +from collections.abc import Awaitable, Mapping +from typing import TYPE_CHECKING +from uuid import UUID, uuid5 + +from sqlalchemy.dialects.postgresql import insert +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession +from transitions import Event + +from ..common import graph, timestamp +from ..common.enums import StatusEnum +from ..config import config +from ..db.campaigns_v2 import Campaign, Edge, Machine, Node, Task +from ..db.session import db_session_dependency +from ..machines.node import NodeMachine, node_machine_factory +from .logging import LOGGER + +logger = LOGGER.bind(module=__name__) + + +async def consider_campaigns(session: AsyncSession) -> None: + """In Phase One, the daemon considers campaigns. Campaigns subject to + consideration have a non-terminal prepared status (ready or running), and + optionally tagged with a priority value lower than the daemon's own + priority. + + For any campaigns thus discovered, the daemon then constructs a graph + from the campaign's Edges, and starting at the START node, walks the graph + until a Node is found that requires attention. Each Node found is added to + the Tasks table as a queue item. + """ + c_statement = ( + select(Campaign.id) + .where(col(Campaign.status).in_((StatusEnum.ready, StatusEnum.running))) + .with_for_update(key_share=True, skip_locked=True) + ) + campaigns = (await session.exec(c_statement)).all() + + for campaign_id in campaigns: + logger.info("Daemon considering campaign", id=campaign_id) + + # Fetch the Edges for the campaign + e_statement = select(Edge).filter_by(namespace=campaign_id) + edges = (await session.exec(e_statement)).all() + campaign_graph = await graph.graph_from_edge_list_v2(edges=edges, session=session) + + for node in graph.processable_graph_nodes(campaign_graph): + logger.info("Daemon considering node", id=str(node.id)) + desired_state = node.status.next_status() + node_task = Task( + id=uuid5(node.id, desired_state.name), + namespace=campaign_id, + node=node.id, + status=desired_state, + previous_status=node.status, + ) + statement = insert(node_task.__table__).values(**node_task.model_dump()).on_conflict_do_nothing() # type: ignore[attr-defined] + await session.exec(statement) # type: ignore[call-overload] + + await session.commit() + + +async def consider_nodes(session: AsyncSession) -> None: + """In Phase Two, the daemon considers Nodes. Nodes subject to consideration + are only those Nodes found on the Tasks table that have a priority lower + than the daemon's own priority, and share the daemon's site affinity. + + For each node considered by the daemon, the Node's FSM is loaded from the + Machines table, or creates one if needed. The daemon uses methods on the + Node's Stateful Model to evolve the state of the Node. + + After handling, the Node's FSM is serialized and the Node is updated with + new values as necessary. The Task is not returned to the Task table. + """ + # Select and lock unsubmitted tasks + statement = select(Task).where(col(Task.submitted_at).is_(None)) + # TODO add filter criteria for priority and site affinity + statement = statement.with_for_update(skip_locked=True) + + cm_tasks = (await session.exec(statement)).all() + + # Using a TaskGroup context manager means all "tasks" added to the group + # are awaited when the CM exits, giving us concurrency for all the nodes + # being considered in the current iteration. + async with TaskGroup() as tg: + for cm_task in cm_tasks: + node = await session.get_one(Node, cm_task.node) + + # the task's status field is the target status for the node, so the + # daemon intends to evolve the node machine to that state. + try: + assert node.status is cm_task.previous_status + except AssertionError: + logger.error("Node status out of sync with Machine", id=str(node.id)) + continue + + # Expunge the node from *this* session because it will be added to + # whatever session the node_machine acquires during its transition + session.expunge(node) + + node_machine: NodeMachine + node_machine_pickle: Machine | None + if node.machine is None: + # create a new machine for the node + node_machine = node_machine_factory(node.kind)(o=node) + node_machine_pickle = None + else: + # unpickle the node's machine and rehydrate the Stateful Model + node_machine_pickle = await session.get_one(Machine, node.machine) + node_machine = (pickle.loads(node_machine_pickle.state)).model + node_machine.db_model = node + # discard the pickled machine from this session and context + session.expunge(node_machine_pickle) + del node_machine_pickle + + # check possible triggers for state + # TODO how to pick the "best" trigger from multiple available? + # - Add a caller-backed conditional to the triggers, to identify + # . triggers the daemon is "allowed" to use + # - Determine the "desired" trigger from the task (source, dest) + if (trigger := trigger_for_transition(cm_task, node_machine.machine.events)) is None: + logger.warning( + "No trigger available for desired state transition", + source=cm_task.previous_status, + dest=cm_task.status, + ) + continue + + # Add the node transition trigger method to the task group + task = tg.create_task(node_machine.trigger(trigger), name=str(cm_task.id)) + task.add_done_callback(task_runner_callback) + + # wrap up - update the task and commit + cm_task.submitted_at = timestamp.now_utc() + await session.commit() + + +async def daemon_iteration(session: AsyncSession) -> None: + """A single iteraton of the CM daemon's work loop, which is carried out in + two phases: Campaigns and Nodes. + """ + iteration_start = timestamp.now_utc() + logger.debug("Daemon V2 Iteration: %s", iteration_start) + if config.daemon.process_campaigns: + await consider_campaigns(session) + if config.daemon.process_nodes: + await consider_nodes(session) + await session.close() + + +def trigger_for_transition(task: Task, events: Mapping[str, Event]) -> str | None: + """Determine the trigger name for transition that matches the desired state + tuple as indicated on a Task. + """ + + for trigger, event in events.items(): + for transition_list in event.transitions.values(): + for transition in transition_list: + if all( + [ + transition.source == task.previous_status.name, + transition.dest == task.status.name, + ] + ): + return trigger + return None + + +async def finalize_runner_callback(context: AsyncTask) -> None: + """Callback function for finalizing the CM Task runner.""" + + # Using the task name as the ID of a task, get the object and update its + # finished_at column. Alternately, we could delete the task from the table + # now. + if TYPE_CHECKING: + assert db_session_dependency.sessionmaker is not None + + logger.info("Finalizing CM Task", id=context.get_name()) + async with db_session_dependency.sessionmaker.begin() as session: + cm_task = await session.get_one(Task, UUID(context.get_name())) + cm_task.finished_at = timestamp.now_utc() + + +def task_runner_callback(context: AsyncTask) -> None: + """Callback function for `asyncio.TaskGroup` tasks.""" + if (exc := context.exception()) is not None: + logger.error(exc) + return + + logger.info("Transition complete", id=context.get_name()) + callbacks: set[Awaitable] = set() + # TODO: notification callback + finalizer = create_task(finalize_runner_callback(context)) + finalizer.add_done_callback(callbacks.discard) + callbacks.add(finalizer) diff --git a/src/lsst/cmservice/common/enums.py b/src/lsst/cmservice/common/enums.py index 6b362e779..8371b3f02 100644 --- a/src/lsst/cmservice/common/enums.py +++ b/src/lsst/cmservice/common/enums.py @@ -145,7 +145,7 @@ def is_successful_element(self) -> bool: def is_successful_script(self) -> bool: """Is this successful state for Script""" - return self.value >= StatusEnum.reviewable.value + return self.value >= StatusEnum.accepted.value def is_bad(self) -> bool: """Is this a failed state""" @@ -166,6 +166,7 @@ def is_terminal_script(self) -> bool: [ self.is_successful_script(), self.is_bad(), + self is StatusEnum.reviewable, ] ) @@ -177,6 +178,17 @@ def is_processable_script(self) -> bool: """Is this a processable state for an elememnt""" return self.value >= StatusEnum.waiting.value and self.value <= StatusEnum.running.value + def next_status(self) -> StatusEnum: + """If the status is on the "happy" path, return the next status along + that path, otherwise return the failed status. + """ + happy_path = [StatusEnum.waiting, StatusEnum.ready, StatusEnum.running, StatusEnum.accepted] + if self in happy_path: + i = happy_path.index(self) + return happy_path[i + 1] + else: + return StatusEnum.failed + class TaskStatusEnum(enum.Enum): """Defines possible outcomes for Pipetask tasks""" @@ -286,3 +298,24 @@ class WmsComputeSite(enum.Enum): lanc = 2 ral = 3 in2p3 = 4 + + +class ManifestKind(enum.Enum): + """Define a manifest kind""" + + campaign = enum.auto() + node = enum.auto() + edge = enum.auto() + # Node kinds + grouped_step = enum.auto() + step_group = enum.auto() + collect_groups = enum.auto() + # Legacy kinds + specification = enum.auto() + spec_block = enum.auto() + step = enum.auto() + group = enum.auto() + job = enum.auto() + script = enum.auto() + # Fallback kind + other = enum.auto() diff --git a/src/lsst/cmservice/common/errors.py b/src/lsst/cmservice/common/errors.py index a52489f88..d4b6c6662 100644 --- a/src/lsst/cmservice/common/errors.py +++ b/src/lsst/cmservice/common/errors.py @@ -1,11 +1,9 @@ """cm-service specific error types""" -from typing import Any, TypeVar +from typing import Any from sqlalchemy.exc import IntegrityError -T = TypeVar("T") - class CMCheckError(KeyError): """Raised when script checking fails""" @@ -123,7 +121,7 @@ class CMYamlParseError(KeyError): """Raised when parsing a yaml file fails""" -def test_type_and_raise(object: Any, expected_type: type[T], var_name: str) -> T: +def test_type_and_raise[T](object: Any, expected_type: type[T], var_name: str) -> T: if not isinstance(object, expected_type): raise CMBadParameterTypeError(f"{var_name} expected type {expected_type} got {type(object)}") return object diff --git a/src/lsst/cmservice/common/graph.py b/src/lsst/cmservice/common/graph.py index c994d60ba..b206034c6 100644 --- a/src/lsst/cmservice/common/graph.py +++ b/src/lsst/cmservice/common/graph.py @@ -1,18 +1,22 @@ -from collections.abc import Mapping, Sequence -from typing import TypeVar +from collections.abc import Iterable, Mapping, MutableSet, Sequence +from typing import Literal +from uuid import UUID import networkx as nx -from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from ..db import Script, ScriptDependency, Step, StepDependency +from ..db.campaigns_v2 import Edge, Node from ..parsing.string import parse_element_fullname +from .types import AnyAsyncSession -A = TypeVar("A", async_scoped_session, AsyncSession) -N = TypeVar("N", type[Step], type[Script]) +type AnyGraphEdge = StepDependency | ScriptDependency +type AnyGraphNode = Step | Script async def graph_from_edge_list( - edges: Sequence[StepDependency | ScriptDependency], node_type: N, session: A + edges: Sequence[AnyGraphEdge], + node_type: type[AnyGraphNode], + session: AnyAsyncSession, ) -> nx.DiGraph: """Given a sequence of edge-tuples, create a directed graph for these edges with nodes derived from database lookups of the related objects. @@ -34,8 +38,157 @@ async def graph_from_edge_list( return g +async def graph_from_edge_list_v2( + edges: Sequence[Edge], + session: AnyAsyncSession, + node_type: type[Node] = Node, + node_view: Literal["simple", "model"] = "model", +) -> nx.DiGraph: + """Given a sequence of Edges, create a directed graph for these + edges with nodes derived from database lookups of the related objects. + + Parameters + ---------- + edges: Sequence[Edge] + The list of edges forming the graph + + node_type: type + The pydantic or sqlmodel class representing the graph node model + + node_view: "simple" or "model" + Whether the node metadata in the graph should be simplified (dict) or + using the full expunged model form. + + session + An async database session + """ + g = nx.DiGraph() + g.add_edges_from([(e.source, e.target) for e in edges]) + relabel_mapping = {} + + # The graph understands the nodes in terms of the IDs used in the edges, + # but we want to hydrate the entire Node model for subsequent users of this + # graph to reference without dipping back to the Database. + for node in g.nodes: + db_node = await session.get_one(Node, node) + # This Node is going on an adventure where it does not need to drag its + # SQLAlchemy baggage along, so we expunge it from the session before + # adding it to the graph. + session.expunge(db_node) + if node_view == "simple": + # for the simple node view, the goal is to minimize the amount of + # data attached to the node and ensure that this data is json- + # serializable and otherwise appropriate for an API response + g.nodes[node]["uuid"] = str(db_node.id) + g.nodes[node]["status"] = db_node.status.name + g.nodes[node]["kind"] = db_node.kind.name + g.nodes[node]["version"] = db_node.version + relabel_mapping[node] = db_node.name + else: + g.nodes[node]["model"] = db_node + + if relabel_mapping: + g = nx.relabel_nodes(g, mapping=relabel_mapping, copy=False) + + return g + + def graph_to_dict(g: nx.DiGraph) -> Mapping: """Renders a networkx directed graph to a mapping format suitable for JSON serialization. + + Notes + ----- + The "edges" attribute name in the node link data is "edges" instead of the + default "links". """ return nx.node_link_data(g, edges="edges") + + +def validate_graph(g: nx.DiGraph, source: UUID | str = "START", sink: UUID | str = "END") -> bool: + """Validates a graph by asserting by traversal that a complete and correct + path exists between `source` and `sink` nodes. + + "Correct" means that there are no cycles or isolate nodes (nodes with + degree 0) and no nodes with degree 1. + """ + try: + # Test that G is a directed graph with no cycles + is_valid = nx.is_directed_acyclic_graph(g) + assert is_valid + + # And that any path from source to sink exists + is_valid = nx.has_path(g, source, sink) + assert is_valid + + # Guard against bad graphs where START and/or END have been connected + # such that they are no longer the only source and sink + ... + + # Test that there are no isolated Nodes in the graph. A node becomes + # isolated if it was involved with an edge that has been removed from + # G with no replacement edge added, in which case the node should also + # be removed. + is_valid = nx.number_of_isolates(g) == 0 + assert is_valid + + # TODO Given the set of nodes in the graph, consider all paths in G + # from source to sink, making sure every node appears in a path? + + # Every node in G that is not the START/END node must have a degree + # of at least 2 (one inbound and one outbound edge). If G has any + # node with a degree of 1, it cannot be considered valid. + g_degree_view: Iterable = nx.degree(g, (n for n in g.nodes if n not in [source, sink])) + is_valid = min([d[1] for d in g_degree_view]) > 1 + assert is_valid + except (nx.exception.NodeNotFound, AssertionError): + return False + return True + + +def processable_graph_nodes(g: nx.DiGraph) -> Iterable[Node]: + """Traverse the graph G and produce an iterator of any nodes that are + candidates for processing, i.e., their status is waiting/prepared/running + and their ancestors are complete/successful. Graph nodes in a failed state + will block the graph and prevent candidacy for subsequent nodes. + + Yields + ------ + `lsst.cmservice.db.campaigns_v2.Node` + A Node ORM object that has been ``expunge``d from its ``Session``. + + Notes + ----- + This function operates only on valid graphs (see `validate_graph()`) that + have been built by the `graph_from_edge_list_v2()` function, where each + graph-node is decorated with a "model" attribute referring to an expunged + instance of ``Node``. This ``Node`` can be ``add``ed back to a ``Session`` + and manipulated in the usual way. + """ + processable_nodes: MutableSet[Node] = set() + + # A valid campaign graph will have only one source (START) with in_degree 0 + # and only one sink (END) with out_degree 0 + source = next(v for v, d in g.in_degree() if d == 0) + sink = next(v for v, d in g.out_degree() if d == 0) + + # For each path through the graph, evaluate the state of nodes to determine + # which nodes are up for processing. When there are multiple paths, we have + # parallelization and common ancestors may be evaluated more than once, + # which is an exercise in optimization left as a TODO + for path in nx.all_simple_paths(g, source, sink): + for n in path: + node: Node = g.nodes[n]["model"] + if node.status.is_processable_element(): + processable_nodes.add(node) + # We found a processable node in this path, stop traversal + break + elif node.status.is_bad(): + # We reached a failed node in this path, it is blocked + break + else: + # This node must be in a "successful" terminal state + continue + + # the inspection should stop when there are no more nodes to check + yield from processable_nodes diff --git a/src/lsst/cmservice/common/jsonpatch.py b/src/lsst/cmservice/common/jsonpatch.py new file mode 100644 index 000000000..5c0768060 --- /dev/null +++ b/src/lsst/cmservice/common/jsonpatch.py @@ -0,0 +1,252 @@ +"""Module implementing functions to support json-patch operations on Python +objects based on RFC6902. +""" + +import operator +from collections.abc import Mapping, MutableMapping, MutableSequence +from functools import reduce +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import AliasChoices, BaseModel, Field + +type AnyMutable = MutableMapping | MutableSequence + + +class JSONPatchError(Exception): + """Exception raised when a JSON patch operation cannot be completed.""" + + pass + + +class JSONPatch(BaseModel): + """Model representing a PATCH operation using RFC6902. + + This model will generally be accepted as a ``Sequence[JSONPatch]``. + """ + + op: Literal["add", "remove", "replace", "move", "copy", "test"] + path: str = Field( + description="An RFC6901 JSON Pointer", pattern=r"^\/(metadata|spec|configuration|metadata_|data)\/.*$" + ) + value: Any | None = None + from_: str | None = Field( + default=None, + pattern=r"^\/(metadata|spec|configuration|metadata_|data)\/.*$", + validation_alias=AliasChoices("from", "from_"), + ) + + +def apply_json_patch[T: MutableMapping](op: JSONPatch, o: T) -> T: + """Applies a jsonpatch to an object, returning the modified object. + + Modifications are made in-place (i.e., the input object is not copied). + + Notes + ----- + While this JSON Patch operation nominally implements RFC6902, there are + some edge cases inappropriate to the application that are supported by the + RFC but disallowed through lack of support: + + - Unsupported: JSON pointer values that refer to object/dict keys that are + numeric, e.g., {"1": "first", "2": "second"} + - Unsupported: JSON pointer values that refer to an entire object, e.g., + "" -- the JSON Patch must have a root element ("/") per the model. + - Unsupported: JSON pointer values that refer to a nameless object, e.g., + "/" -- JSON allows object keys to be the empty string ("") but this is + disallowed by the application. + """ + # The JSON Pointer root value is discarded as the rest of the pointer is + # split into parts + op_path = op.path.split("/")[1:] + + # The terminal path part is either the name of a key or an index in a list + # FIXME this assumes that an "integer-string" in the path is always refers + # to a list index, although it could just as well be a key in a dict + # like ``{"1": "first, "2": "second"}`` which is complicated by the + # fact that Python dict keys can be either ints or strs but this is + # not allowed in JSON (i.e., object keys MUST be strings) + # FIXME this doesn't support, e.g., nested lists with multiple index values + # in the path, e.g., ``[["a", "A"], ["b", "B"]]`` + target_key_or_index: str | None = op_path.pop() + if target_key_or_index is None: + raise JSONPatchError("JSON Patch operations on empty keys not allowed.") + + reference_token: int | str + # the reference token is referring to a an array index if the token is + # numeric or is the single character "-" + if target_key_or_index == "-": + reference_token = target_key_or_index + elif target_key_or_index.isnumeric(): + reference_token = int(target_key_or_index) + else: + reference_token = str(target_key_or_index) + + # The remaining parts of the path are a pointer to the object needing + # modification, which should reduce to either a dict or a list + try: + op_target: AnyMutable = reduce(operator.getitem, op_path, o) + except KeyError: + raise JSONPatchError(f"Path {op.path} not found in object") + + match op: + case JSONPatch(op="add", value=new_value): + if reference_token == "-" and isinstance(op_target, MutableSequence): + # The "-" reference token is unique to the add operation and + # means the next element beyond the end of the current list + op_target.append(new_value) + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + op_target.insert(reference_token, new_value) + elif isinstance(reference_token, str) and isinstance(op_target, MutableMapping): + op_target[reference_token] = new_value + + case JSONPatch(op="replace", value=new_value): + # The main difference between replace and add is that replace will + # not create new properties or elements in the target + if reference_token == "-": + raise JSONPatchError("Cannot use reference token `-` with replace operation.") + elif isinstance(op_target, MutableMapping): + try: + assert reference_token in op_target.keys() + except AssertionError: + raise JSONPatchError(f"Cannot replace missing key {reference_token} in object") + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + try: + assert reference_token < len(op_target) + except AssertionError: + raise JSONPatchError(f"Cannot replace missing index {reference_token} in object") + + if TYPE_CHECKING: + assert isinstance(op_target, MutableMapping) + op_target[reference_token] = new_value + + case JSONPatch(op="remove"): + if isinstance(reference_token, str) and isinstance(op_target, MutableMapping): + if reference_token == "-": + raise JSONPatchError("Removal operations not allowed on `-` reference token") + _ = op_target.pop(reference_token, None) + elif isinstance(reference_token, int): + try: + _ = op_target.pop(reference_token) + except IndexError: + # The index we are meant to remove does not exist, but that + # is not an error (idempotence) + pass + else: + # This should be unreachable + raise ValueError("Reference token in JSON Patch must be int | str") + + case JSONPatch(op="move", from_=from_location): + # the move operation is equivalent to a remove(from) + add(target) + if TYPE_CHECKING: + assert from_location is not None + + # Handle the from_location with the same logic as the op.path + from_path = from_location.split("/")[1:] + + # Is the last element of the from_path an index or a key? + from_target: str | int = from_path.pop() + try: + from_target = int(from_target) + except ValueError: + pass + + try: + from_object = reduce(operator.getitem, from_path, o) + value = from_object[from_target] + except (KeyError, IndexError): + raise JSONPatchError(f"Path {from_location} not found in object") + + # add the value to the new location + op_target[reference_token] = value # type: ignore[index] + # and remove it from the old + _ = from_object.pop(from_target) + + case JSONPatch(op="copy", from_=from_location): + # The copy op is the same as the move op except the original is not + # removed + if TYPE_CHECKING: + assert from_location is not None + + # Handle the from_location with the same logic as the op.path + from_path = from_location.split("/")[1:] + + # Is the last element of the from_path an index or a key? + from_target = from_path.pop() + try: + from_target = int(from_target) + except ValueError: + pass + + try: + from_object = reduce(operator.getitem, from_path, o) + value = from_object[from_target] + except (KeyError, IndexError): + raise JSONPatchError(f"Path {from_location} not found in object") + + # add the value to the new location + op_target[reference_token] = value # type: ignore[index] + + case JSONPatch(op="test", value=assert_value): + # assert that the patch value is present at the patch path + # The main difference between test and replace is that test does + # not make any modifications after its assertions + if reference_token == "-": + raise JSONPatchError("Cannot use reference token `-` with test operation.") + elif isinstance(op_target, MutableMapping): + try: + assert reference_token in op_target.keys() + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: Key {reference_token} does not exist at {op.path}" + ) + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + try: + assert reference_token < len(op_target) + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: " + f"Index {reference_token} does not exist at {op.path}" + ) + + if TYPE_CHECKING: + assert isinstance(op_target, MutableMapping) + try: + assert op_target[reference_token] == assert_value + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: {op.path} does not match value {assert_value}" + ) + + case _: + # Model validation should prevent this from ever happening + raise JSONPatchError(f"Unknown JSON Patch operation: {op.op}") + + return o + + +def apply_json_merge[T: MutableMapping](patch: Any, o: T) -> T: + """Applies a patch to a mapping object as per the RFC7396 JSON Merge Patch. + + Notably, this operation may only target a ``MutableMapping`` as an analogue + of a JSON object. This means that any keyed value in a Mapping may be + replaced, added, or removed by a JSON Merge. This is not appropriate for + patches that need to perform more tactical updates, such as modifying + elements of a ``Sequence``. + + This function does not allow setting a field value in the target to `None`; + instead, any `None` value in a patch is an instruction to remove that + field from the target completely. + + This function differs from the RFC in the following ways: it will not + replace the entire target object with a new mapping (i.e., the target must + be a Mapping). + """ + if isinstance(patch, Mapping): + for k, v in patch.items(): + if v is None: + _ = o.pop(k, None) + else: + o[k] = apply_json_merge(v, o.get(k, {})) + return o + else: + return patch diff --git a/src/lsst/cmservice/common/types.py b/src/lsst/cmservice/common/types.py new file mode 100644 index 000000000..840d3bfa2 --- /dev/null +++ b/src/lsst/cmservice/common/types.py @@ -0,0 +1,28 @@ +from typing import Annotated + +from sqlalchemy.ext.asyncio import AsyncSession as AsyncSessionSA +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlmodel.ext.asyncio.session import AsyncSession + +from .. import models +from ..models.serde import EnumSerializer, ManifestKindEnumValidator, StatusEnumValidator +from .enums import ManifestKind, StatusEnum + +type AnyAsyncSession = AsyncSession | AsyncSessionSA | async_scoped_session +"""A type union of async database sessions the application may use""" + + +type AnyCampaignElement = models.Group | models.Campaign | models.Step | models.Job +"""A type union of Campaign elements""" + + +type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] +"""A type for fields representing a Status with a custom validator tuned for +enums operations. +""" + + +type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] +"""A type for fields representing a Kind with a custom validator tuned for +enums operations. +""" diff --git a/src/lsst/cmservice/config.py b/src/lsst/cmservice/config.py index bfb045c15..73ed94c9c 100644 --- a/src/lsst/cmservice/config.py +++ b/src/lsst/cmservice/config.py @@ -8,6 +8,7 @@ AliasChoices, BaseModel, Field, + SecretStr, computed_field, field_serializer, field_validator, @@ -60,6 +61,11 @@ class BpsConfiguration(BaseModel): default=16, ) + artifact_path: str = Field( + description="Filesystem path location for writing artifacts (`prod_area`)", + default="/prod_area", + ) + class ButlerConfiguration(BaseModel): """Configuration settings for butler client operations. @@ -456,6 +462,11 @@ class AsgiConfiguration(BaseModel): default="/cm-service", ) + enable_frontend: bool = Field( + description="Whether to run the frontend web app", + default=True, + ) + frontend_prefix: str = Field( description="The URL prefix for the frontend web app", default="/web_app", @@ -511,6 +522,26 @@ class DaemonConfiguration(BaseModel): ), ) + v1_enabled: bool = Field( + default=True, + description="Whether the v1 daemon is enabled and included in the event loop.", + ) + + v2_enabled: bool = Field( + default=False, + description="Whether the v2 daemon is enabled and included in the event loop.", + ) + + process_campaigns: bool = Field( + default=True, + description="Whether the v2 daemon processes Campaigns in the event loop.", + ) + + process_nodes: bool = Field( + default=True, + description="Whether the v2 daemon processes Nodes in the event loop.", + ) + class NotificationConfiguration(BaseModel): """Configurations for notifications. @@ -535,13 +566,13 @@ class DatabaseConfiguration(BaseModel): description="The URL for the cm-service database", ) - password: str | None = Field( + password: SecretStr | None = Field( default=None, description="The password for the cm-service database", ) - table_schema: str | None = Field( - default=None, + table_schema: str = Field( + default="public", description="Schema to use for cm-service database", ) @@ -550,6 +581,31 @@ class DatabaseConfiguration(BaseModel): description="SQLAlchemy engine echo setting for the cm-service database", ) + max_overflow: int = Field( + default=10, + description="Maximum connection overflow allowed for QueuePool.", + ) + + pool_size: int = Field( + default=5, + description="Number of open connections kept in the QueuePool", + ) + + pool_recycle: int = Field( + default=-1, + description="Timeout in seconds before connections are recycled", + ) + + pool_timeout: int = Field( + default=30, + description="Wait timeout for acquiring a connection from the pool", + ) + + pool_fields: set[str] = Field( + default={"max_overflow", "pool_size", "pool_recycle", "pool_timeout"}, + description="Set of fields used for connection pool configuration", + ) + class Configuration(BaseSettings): """Configuration for cm-service. diff --git a/src/lsst/cmservice/daemon.py b/src/lsst/cmservice/daemon.py index 3c242cd3f..d7e407896 100644 --- a/src/lsst/cmservice/daemon.py +++ b/src/lsst/cmservice/daemon.py @@ -6,15 +6,16 @@ import uvicorn from anyio import current_time, sleep_until from fastapi import FastAPI -from safir.database import create_async_session, create_database_engine from safir.logging import configure_uvicorn_logging from . import __version__ from .common.butler import BUTLER_FACTORY # noqa: F401 from .common.daemon import daemon_iteration +from .common.daemon_v2 import daemon_iteration as daemon_iteration_v2 from .common.logging import LOGGER from .common.panda import get_panda_token from .config import config +from .db.session import db_session_dependency from .routers.healthz import health_router configure_uvicorn_logging(config.logging.level) @@ -31,10 +32,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: os.environ |= config.panda.model_dump(by_alias=True, exclude_none=True) os.environ |= config.htcondor.model_dump(by_alias=True, exclude_none=True) app.state.tasks = set() + # Dependency inits before app starts running + await db_session_dependency.initialize() + assert db_session_dependency.engine is not None daemon = create_task(main_loop(app=app), name="daemon") app.state.tasks.add(daemon) yield # stop + await db_session_dependency.aclose() async def main_loop(app: FastAPI) -> None: @@ -43,24 +48,24 @@ async def main_loop(app: FastAPI) -> None: With a database session, perform a single daemon interation and then sleep until the next daemon appointment. """ - engine = create_database_engine(config.db.url, config.db.password) - sleep_time = config.daemon.processing_interval - async with engine.begin(): - session = await create_async_session(engine, logger) - logger.info("Daemon starting.") - _iteration_count = 0 + session = await anext(db_session_dependency()) + logger.info("Daemon starting.") + _iteration_count = 0 - while True: - _iteration_count += 1 - logger.info("Daemon starting iteration.") + while True: + _iteration_count += 1 + logger.info("Daemon starting iteration.") + if config.daemon.v1_enabled: await daemon_iteration(session) - _iteration_time = current_time() - logger.info(f"Daemon completed {_iteration_count} iterations at {_iteration_time}.") - _next_wakeup = _iteration_time + sleep_time - logger.info(f"Daemon next iteration at {_next_wakeup}.") - await sleep_until(_next_wakeup) + if config.daemon.v2_enabled: + await daemon_iteration_v2(session) + _iteration_time = current_time() + logger.info(f"Daemon completed {_iteration_count} iterations at {_iteration_time}.") + _next_wakeup = _iteration_time + sleep_time + logger.info(f"Daemon next iteration at {_next_wakeup}.") + await sleep_until(_next_wakeup) def main() -> None: diff --git a/src/lsst/cmservice/db/__init__.py b/src/lsst/cmservice/db/__init__.py index 11f61d4fe..eeec4db4e 100644 --- a/src/lsst/cmservice/db/__init__.py +++ b/src/lsst/cmservice/db/__init__.py @@ -1,5 +1,6 @@ """Database table definitions and utility functions""" +from . import campaigns_v2 from .base import Base from .campaign import Campaign from .element import ElementMixin diff --git a/src/lsst/cmservice/db/campaign.py b/src/lsst/cmservice/db/campaign.py index 559f3a471..c31afcbfc 100644 --- a/src/lsst/cmservice/db/campaign.py +++ b/src/lsst/cmservice/db/campaign.py @@ -5,7 +5,6 @@ from sqlalchemy import JSON from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -22,6 +21,7 @@ from .specification import Specification if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .job import Job from .script import Script from .step import Step @@ -62,7 +62,7 @@ class Campaign(Base, ElementMixin): metadata_: Mapped[dict] = mapped_column("metadata_", type_=MutableDict.as_mutable(JSONB), default=dict) child_config: Mapped[dict | list | None] = mapped_column(type_=JSON) collections: Mapped[dict | list | None] = mapped_column(type_=JSON) - spec_aliases: Mapped[dict | list | None] = mapped_column(type_=JSON) + spec_aliases: Mapped[dict | None] = mapped_column(type_=JSON) spec_: Mapped[Specification] = relationship( "Specification", @@ -94,7 +94,7 @@ def level(self) -> LevelEnum: async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Campaign: """Maps self to self.get_campaign() for consistency""" assert session # For mypy @@ -105,7 +105,7 @@ def __repr__(self) -> str: async def children( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Iterable: """Maps self.s_ to self.children() for consistency""" await session.refresh(self, attribute_names=["s_"]) @@ -113,7 +113,7 @@ async def children( async def get_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedWmsTaskReportDict: the_dict = MergedWmsTaskReportDict(reports={}) @@ -125,7 +125,7 @@ async def get_wms_reports( async def get_tasks( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedTaskSetDict: the_dict = MergedTaskSetDict(reports={}) @@ -136,7 +136,7 @@ async def get_tasks( async def get_products( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedProductSetDict: the_dict = MergedProductSetDict(reports={}) @@ -148,7 +148,7 @@ async def get_products( @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: name = kwargs["name"] diff --git a/src/lsst/cmservice/db/campaigns_v2.py b/src/lsst/cmservice/db/campaigns_v2.py new file mode 100644 index 000000000..b2968f06e --- /dev/null +++ b/src/lsst/cmservice/db/campaigns_v2.py @@ -0,0 +1,299 @@ +"""ORM Models for v2 tables and objects.""" + +from collections.abc import MutableSequence +from typing import Any +from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 + +from pydantic import AliasChoices, AwareDatetime, ValidationInfo, model_validator +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.mutable import MutableDict, MutableList +from sqlalchemy.types import PickleType +from sqlmodel import Column, DateTime, Enum, Field, MetaData, SQLModel, String + +from ..common.enums import ManifestKind, StatusEnum +from ..common.timestamp import now_utc +from ..common.types import KindField, StatusField +from ..config import config + +_default_campaign_namespace = uuid5(namespace=NAMESPACE_DNS, name="io.lsst.cmservice") +"""Default UUID5 namespace for campaigns""" + +metadata: MetaData = MetaData(schema=config.db.table_schema) +"""SQLModel metadata for table models""" + + +def jsonb_column(name: str, aliases: list[str] | None = None) -> Any: + """Constructor for a Field based on a JSONB database column. + + If provided, a list of aliases will be used to construct a pydantic + ``AliasChoices`` object for the field's validation alias, which improves + usability by making model validation more flexible (e.g., having "metadata" + and "metadata_" refer to the same field). + + Additionally, the first alias in the list will be used for the model's + serialization alias. + """ + schema_extra = {} + if aliases: + schema_extra = { + "validation_alias": AliasChoices(*aliases), + "serialization_alias": aliases[0], + } + return Field( + sa_column=Column(name, MutableDict.as_mutable(postgresql.JSONB)), + default_factory=dict, + schema_extra=schema_extra, + ) + + +class BaseSQLModel(SQLModel): + """Shared base SQL model for all tables.""" + + __table_args__ = {"schema": config.db.table_schema} + metadata = metadata + + +class CampaignBase(BaseSQLModel): + """Campaigns_v2 base model, used to create new Campaign objects.""" + + id: UUID = Field(primary_key=True) + name: str + namespace: UUID + owner: str | None = Field(default=None) + status: StatusField = Field( + default=StatusEnum.waiting, + sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), + ) + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) + configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"]) + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") + + @model_validator(mode="before") + @classmethod + def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any: + """Validates the model based on different types of raw inputs, + where some default non-optional fields can be auto-populated. + """ + if isinstance(data, dict): + if "name" not in data: + raise ValueError(" name missing.") + if "namespace" not in data: + data["namespace"] = _default_campaign_namespace + if "id" not in data: + data["id"] = uuid5(namespace=data["namespace"], name=data["name"]) + return data + + +class Campaign(CampaignBase, table=True): + """Model used for database operations involving campaigns_v2 table rows""" + + __tablename__: str = "campaigns_v2" # type: ignore[misc] + + +class CampaignUpdate(BaseSQLModel): + """Model representing updatable fields for a PATCH operation on a Campaign + using RFC7396. + """ + + owner: str | None = None + status: StatusField | None = None + + +class CampaignSummary(CampaignBase): + """Model for the response of a Campaign Summary route.""" + + node_summary: MutableSequence["NodeStatusSummary"] + + +class NodeStatusSummary(BaseSQLModel): + """Model for a Node Status Summary.""" + + status: StatusField = Field(description="A state name") + count: int = Field(description="Count of nodes in this state") + mtime: AwareDatetime | None = Field(description="The most recent update time for nodes in this state") + + +class NodeBase(BaseSQLModel): + """nodes_v2 db table""" + + def __hash__(self) -> int: + """A Node is hashable according to its unique ID, so it can be used in + sets and other places hashable types are required. + """ + return self.id.int + + id: UUID = Field(primary_key=True) + name: str + namespace: UUID + version: int + kind: KindField = Field( + default=ManifestKind.other, + sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)), + ) + status: StatusField = Field( + default=StatusEnum.waiting, + sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), + ) + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) + configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"]) + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") + + @model_validator(mode="before") + @classmethod + def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any: + """Validates the model based on different types of raw inputs, + where some default non-optional fields can be auto-populated. + """ + if isinstance(data, dict): + if (node_name := data.get("name")) is None: + raise ValueError(" name missing.") + if (node_namespace := data.get("namespace")) is None: + raise ValueError(" namespace missing.") + if (node_version := data.get("version")) is None: + data["version"] = node_version = 1 + if "id" not in data: + data["id"] = uuid5(namespace=node_namespace, name=f"{node_name}.{node_version}") + return data + + +class Node(NodeBase, table=True): + __tablename__: str = "nodes_v2" # type: ignore[misc] + + +class EdgeBase(BaseSQLModel): + """edges_v2 db table""" + + id: UUID = Field(primary_key=True) + name: str + namespace: UUID = Field(foreign_key="campaigns_v2.id") + source: UUID = Field(foreign_key="nodes_v2.id") + target: UUID = Field(foreign_key="nodes_v2.id") + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) + configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"]) + + +class EdgeResponseModel(EdgeBase): + source: Any + target: Any + + +class Edge(EdgeBase, table=True): + __tablename__: str = "edges_v2" # type: ignore[misc] + + +class MachineBase(BaseSQLModel): + """machines_v2 db table.""" + + id: UUID = Field(primary_key=True, default_factory=uuid4) + state: Any = Field(sa_column=Column("state", PickleType)) + + +class Machine(MachineBase, table=True): + """machines_v2 db table.""" + + __tablename__: str = "machines_v2" # type: ignore[misc] + + +class ManifestBase(BaseSQLModel): + """manifests_v2 db table""" + + id: UUID = Field(primary_key=True) + name: str + version: int + namespace: UUID = Field(foreign_key="campaigns_v2.id") + kind: KindField = Field( + default=ManifestKind.other, + sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)), + ) + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) + spec: dict = jsonb_column("spec", aliases=["spec", "configuration", "data"]) + + +class Manifest(ManifestBase, table=True): + __tablename__: str = "manifests_v2" # type: ignore[misc] + + +class Task(BaseSQLModel, table=True): + """tasks_v2 db table""" + + __tablename__: str = "tasks_v2" # type: ignore[misc] + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + description="A hash of the related Node ID and target status, as a UUID5.", + ) + namespace: UUID = Field(foreign_key="campaigns_v2.id", description="The ID of a Campaign") + node: UUID = Field(foreign_key="nodes_v2.id", description="The ID of the target node") + priority: int | None = Field(default=None) + created_at: AwareDatetime = Field( + description="The `datetime` (UTC) at which this Task was first added to the queue", + default_factory=now_utc, + sa_column=Column(DateTime(timezone=True)), + ) + submitted_at: AwareDatetime | None = Field( + description="The `datetime` (UTC) at which this Task was first submitted as work to the event loop", + default=None, + sa_column=Column(DateTime(timezone=True)), + ) + finished_at: AwareDatetime | None = Field( + description=( + "The `datetime` (UTC) at which this Task successfully finalized. " + "A Task whose `finished_at` is not `None` is tombstoned and is subject to deletion." + ), + default=None, + sa_column=Column(DateTime(timezone=True)), + ) + wms_id: str | None = Field(default=None) + site_affinity: list[str] | None = Field( + default=None, sa_column=Column("site_affinity", MutableList.as_mutable(postgresql.ARRAY(String()))) + ) + status: StatusField = Field( + description="The 'target' status to which this Task will attempt to transition the Node", + sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), + ) + previous_status: StatusField = Field( + description="The 'original' status from which this Task will attempt to transition the Node", + sa_column=Column( + "previous_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) + ), + ) + + +class ActivityLogBase(BaseSQLModel): + id: UUID = Field(primary_key=True, default_factory=uuid4) + namespace: UUID = Field(foreign_key="campaigns_v2.id", description="The ID of a Campaign") + node: UUID | None = Field(default=None, foreign_key="nodes_v2.id", description="The ID of a Node") + operator: str = Field(description="The name of the operator or pilot who triggered the activity") + created_at: AwareDatetime = Field( + description="The `datetime` in UTC at which this log entry was created.", + default_factory=now_utc, + sa_column=Column(DateTime(timezone=True)), + ) + finished_at: AwareDatetime | None = Field( + description="The `datetime` in UTC at which this log entry was finalized.", + default=None, + sa_column=Column(DateTime(timezone=True), nullable=True), + ) + to_status: StatusField = Field( + description=( + "The `target` state to which this activity tried to transition. " + "This may be the same as `from_status` in cases where no transition was attempted " + "(such as for a conditional check)." + ), + sa_column=Column( + "to_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) + ), + ) + from_status: StatusField = Field( + description="The `original` state from which this activity tried to transition", + sa_column=Column( + "from_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) + ), + ) + detail: dict = jsonb_column("detail") + metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"]) + + +class ActivityLog(ActivityLogBase, table=True): + __tablename__: str = "activity_log_v2" # type: ignore[misc] diff --git a/src/lsst/cmservice/db/element.py b/src/lsst/cmservice/db/element.py index d96c6b7b7..6e9d82d5c 100644 --- a/src/lsst/cmservice/db/element.py +++ b/src/lsst/cmservice/db/element.py @@ -3,8 +3,6 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any -from sqlalchemy.ext.asyncio import async_scoped_session - from ..common.enums import LevelEnum, NodeTypeEnum, StatusEnum from ..common.errors import CMBadStateTransitionError, CMTooManyActiveScriptsError from ..config import config @@ -14,6 +12,7 @@ from .node import NodeMixin if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .job import Job from .script import Script @@ -37,7 +36,7 @@ def node_type(self) -> NodeTypeEnum: async def get_scripts( self, - session: async_scoped_session, + session: AnyAsyncSession, script_name: str | None = None, *, remaining_only: bool = False, @@ -47,7 +46,7 @@ async def get_scripts( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script_name: str | None @@ -78,7 +77,7 @@ async def get_scripts( async def get_jobs( self, - session: async_scoped_session, + session: AnyAsyncSession, *, remaining_only: bool = False, skip_superseded: bool = True, @@ -87,7 +86,7 @@ async def get_jobs( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager remaining_only: bool @@ -113,7 +112,7 @@ async def get_jobs( async def get_all_scripts( self, - session: async_scoped_session, + session: AnyAsyncSession, *, remaining_only: bool = False, skip_superseded: bool = True, @@ -122,7 +121,7 @@ async def get_all_scripts( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager remaining_only: bool @@ -153,7 +152,7 @@ async def get_all_scripts( async def children( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Iterable: """Maps to [] for consistency""" assert session # for mypy @@ -161,7 +160,7 @@ async def children( async def retry_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script_name: str, *, fake_reset: bool = True, @@ -170,7 +169,7 @@ async def retry_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script_name: str @@ -198,7 +197,7 @@ async def retry_script( async def estimate_sleep_time( self, - session: async_scoped_session, + session: AnyAsyncSession, minimum_sleep_time: int = 10, ) -> int: """Estimate how long to sleep before calling process again. @@ -208,7 +207,7 @@ async def estimate_sleep_time( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -234,14 +233,14 @@ async def estimate_sleep_time( async def get_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedWmsTaskReportDict: """Get the WmwTaskReports associated to this element Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -253,7 +252,7 @@ async def get_wms_reports( async def get_tasks( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedTaskSetDict: """Get the TaskSet associated to this element @@ -270,14 +269,14 @@ async def get_tasks( async def get_products( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedProductSetDict: """Get the ProductSet associated to this element Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -289,7 +288,7 @@ async def get_products( async def review( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> StatusEnum: """Run review() function on this Element @@ -298,7 +297,7 @@ async def review( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns diff --git a/src/lsst/cmservice/db/group.py b/src/lsst/cmservice/db/group.py index eca383bd7..e902fcf75 100644 --- a/src/lsst/cmservice/db/group.py +++ b/src/lsst/cmservice/db/group.py @@ -4,7 +4,6 @@ from sqlalchemy import JSON from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey, UniqueConstraint @@ -18,6 +17,7 @@ CMTooFewAcceptedJobsError, CMTooManyActiveScriptsError, ) +from ..common.types import AnyAsyncSession from ..models.merged_product_set import MergedProductSetDict from ..models.merged_task_set import MergedTaskSetDict from ..models.merged_wms_task_report import MergedWmsTaskReportDict @@ -61,7 +61,7 @@ class Group(Base, ElementMixin): metadata_: Mapped[dict] = mapped_column("metadata_", type_=MutableDict.as_mutable(JSONB), default=dict) child_config: Mapped[dict | list | None] = mapped_column(type_=JSON) collections: Mapped[dict | list | None] = mapped_column(type_=JSON) - spec_aliases: Mapped[dict | list | None] = mapped_column(type_=JSON) + spec_aliases: Mapped[dict | None] = mapped_column(type_=JSON) spec_block_: Mapped[SpecBlock] = relationship("SpecBlock", viewonly=True) c_: Mapped["Campaign"] = relationship( @@ -85,7 +85,7 @@ def level(self) -> LevelEnum: async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> "Campaign": """Maps self.c_ to self.get_campaign() for consistency""" await session.refresh(self, attribute_names=["c_"]) @@ -96,7 +96,7 @@ def __repr__(self) -> str: async def children( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Iterable: """Maps self.g_ to self.children() for consistency""" await session.refresh(self, attribute_names=["jobs_"]) @@ -104,7 +104,7 @@ async def children( async def get_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedWmsTaskReportDict: the_dict = MergedWmsTaskReportDict(reports={}) @@ -116,7 +116,7 @@ async def get_wms_reports( async def get_tasks( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedTaskSetDict: the_dict = MergedTaskSetDict(reports={}) @@ -127,7 +127,7 @@ async def get_tasks( async def get_products( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedProductSetDict: the_dict = MergedProductSetDict(reports={}) @@ -139,7 +139,7 @@ async def get_products( @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: try: @@ -175,7 +175,7 @@ async def get_create_kwargs( async def rescue_job( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> "Job": """Create a rescue `Job` @@ -183,7 +183,7 @@ async def rescue_job( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -208,13 +208,13 @@ async def rescue_job( async def mark_job_rescued( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> list["Job"]: """Mark jobs as `rescued` once one of their siblings is `accepted` Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns diff --git a/src/lsst/cmservice/db/handler.py b/src/lsst/cmservice/db/handler.py index 1a213ded0..9753df891 100644 --- a/src/lsst/cmservice/db/handler.py +++ b/src/lsst/cmservice/db/handler.py @@ -3,8 +3,6 @@ import types from typing import TYPE_CHECKING, Any, ClassVar -from sqlalchemy.ext.asyncio import async_scoped_session - from lsst.utils import doImport from lsst.utils.introspection import get_full_type_name @@ -13,6 +11,7 @@ from ..common.logging import LOGGER if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .element import ElementMixin from .node import NodeMixin from .script import Script @@ -93,7 +92,7 @@ def get_handler_class_name(self) -> str: async def process( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -101,7 +100,7 @@ async def process( Parameters ---------- - session : async_scoped_session + session : A DB session manager node: NodeMixin @@ -121,7 +120,7 @@ async def process( async def run_check( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -129,7 +128,7 @@ async def run_check( Parameters ---------- - session : async_scoped_session + session : A DB session manager node: NodeMixin @@ -149,7 +148,7 @@ async def run_check( async def reset( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, to_status: StatusEnum, *, @@ -159,7 +158,7 @@ async def reset( Parameters ---------- - session : async_scoped_session + session : A DB session manager node: NodeMixin @@ -180,7 +179,7 @@ async def reset( async def reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, to_status: StatusEnum, *, @@ -190,7 +189,7 @@ async def reset_script( Parameters ---------- - session : async_scoped_session + session : A DB session manager node: NodeMixin @@ -211,7 +210,7 @@ async def reset_script( async def review( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> StatusEnum: @@ -219,7 +218,7 @@ async def review( Parameters ---------- - session : async_scoped_session + session : A DB session manager element: ElementMixin @@ -234,7 +233,7 @@ async def review( async def review_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -243,7 +242,7 @@ async def review_script( Parameters ---------- - session : async_scoped_session + session : A DB session manager script: Script diff --git a/src/lsst/cmservice/db/job.py b/src/lsst/cmservice/db/job.py index 45a6ab523..cdb52df20 100644 --- a/src/lsst/cmservice/db/job.py +++ b/src/lsst/cmservice/db/job.py @@ -6,7 +6,6 @@ from sqlalchemy import JSON, and_, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -27,6 +26,7 @@ from .step import Step if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .campaign import Campaign from .pipetask_error import PipetaskError from .product_set import ProductSet @@ -68,7 +68,7 @@ class Job(Base, ElementMixin): ) child_config: Mapped[dict | list | None] = mapped_column(type_=JSON) collections: Mapped[dict | list | None] = mapped_column(type_=JSON) - spec_aliases: Mapped[dict | list | None] = mapped_column(type_=JSON) + spec_aliases: Mapped[dict | None] = mapped_column(type_=JSON) wms_job_id: Mapped[str | None] = mapped_column() stamp_url: Mapped[str | None] = mapped_column() @@ -118,7 +118,7 @@ def level(self) -> LevelEnum: async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Campaign: """Maps self.c_ to self.get_campaign() for consistency""" await session.refresh(self, attribute_names=["c_"]) @@ -126,13 +126,13 @@ async def get_campaign( async def get_siblings( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Sequence[Job]: """Get the sibling Jobs Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -152,7 +152,7 @@ async def get_siblings( async def get_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedWmsTaskReportDict: await session.refresh(self, attribute_names=["wms_reports_"]) @@ -163,7 +163,7 @@ async def get_wms_reports( async def get_tasks( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedTaskSetDict: await session.refresh(self, attribute_names=["tasks_"]) @@ -172,7 +172,7 @@ async def get_tasks( async def get_products( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedProductSetDict: await session.refresh(self, attribute_names=["products_"]) @@ -181,7 +181,7 @@ async def get_products( async def get_errors( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Sequence[PipetaskError]: await session.refresh(self, attribute_names=["errors_"]) return self.errors_ @@ -192,7 +192,7 @@ def __repr__(self) -> str: @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: try: @@ -230,14 +230,14 @@ async def get_create_kwargs( async def copy_job( self, - session: async_scoped_session, + session: AnyAsyncSession, parent: ElementMixin, ) -> Job: """Copy a Job Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager parent : ElementMixin @@ -298,7 +298,7 @@ async def copy_job( async def get_jobs( self, - session: async_scoped_session, + session: AnyAsyncSession, *, remaining_only: bool = False, skip_superseded: bool = True, diff --git a/src/lsst/cmservice/db/manifests_v2.py b/src/lsst/cmservice/db/manifests_v2.py new file mode 100644 index 000000000..22947dd2e --- /dev/null +++ b/src/lsst/cmservice/db/manifests_v2.py @@ -0,0 +1,136 @@ +"""Module for models representing generic CM Service manifests. + +These manifests are used in APIs, especially when creating resources. They do +not necessarily represent the object's database or ORM model. +""" + +from typing import Self +from uuid import uuid4 + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationInfo, model_validator + +from ..common.enums import DEFAULT_NAMESPACE, ManifestKind +from ..common.timestamp import element_time +from ..common.types import KindField + + +class Manifest[MetadataT, SpecT](BaseModel): + """A parameterized model for an object's Manifest, used by APIs where the + `spec` should be the kind's table model, more or less. + """ + + apiversion: str = Field(default="io.lsst.cmservice/v1") + kind: KindField = Field(default=ManifestKind.other) + metadata_: MetadataT = Field( + validation_alias=AliasChoices("metadata", "metadata_"), + serialization_alias="metadata", + ) + spec: SpecT = Field( + validation_alias=AliasChoices("spec", "configuration", "data"), + serialization_alias="spec", + ) + + +class ManifestSpec(BaseModel): + """Generic spec model for Manifests. + + Notes + ----- + Any spec body is allowed via config, but any fields that aren't first-class + fields won't be subject to validation or available as model attributes + except in the ``__pydantic_extra__`` dictionary. The full spec will be + expressed via ``model_dump()``. + """ + + model_config = ConfigDict(extra="allow") + + +class ManifestMetadata(BaseModel): + """Generic metadata model for Manifests. + + Conventionally denormalized fields are excluded from the model_dump when + serialized for ORM use. + """ + + name: str = Field(exclude=True) + namespace: str = Field(exclude=True) + crtime: int = Field(default_factory=element_time) + + +class VersionedMetadata(ManifestMetadata): + """Metadata model for versioned Manifests.""" + + version: int = Field(exclude=True, default=0) + + +class ManifestModelMetadata(VersionedMetadata): + """Manifest model for general Manifests. These manifests are versioned but + a namespace is optional (defaultable). + """ + + namespace: str = Field(default=str(DEFAULT_NAMESPACE), exclude=True) + + +class ManifestModel(Manifest[ManifestModelMetadata, ManifestSpec]): + """Manifest model for generic Manifest handling.""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate an Campaign Manifest after a model has been created.""" + if self.kind in [ManifestKind.campaign, ManifestKind.node, ManifestKind.edge]: + raise ValueError(f"Manifests may not be a {self.kind.name} kind.") + + return self + + +class CampaignManifest(Manifest[ManifestModelMetadata, ManifestSpec]): + """validating model for campaigns""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate an Campaign Manifest after a model has been created.""" + if self.kind is not ManifestKind.campaign: + raise ValueError("Campaigns may only be created from a manifest") + + return self + + +class EdgeMetadata(ManifestMetadata): + """Metadata model for an Edge Manifest. + + A default random alphanumeric 8-byte name is generated if no name provided. + """ + + name: str = Field(default_factory=lambda: uuid4().hex[:8], exclude=True) + crtime: int = Field(default_factory=element_time) + + +class EdgeSpec(ManifestSpec): + """Spec model for an Edge Manifest.""" + + source: str = Field(exclude=True) + target: str = Field(exclude=True) + + +class EdgeManifest(Manifest[EdgeMetadata, EdgeSpec]): + """validating model for Edges""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate an Edge Manifest after a model has been created.""" + if self.kind is not ManifestKind.edge: + raise ValueError("Edges may only be created from an manifest") + + return self + + +class NodeManifest(Manifest[VersionedMetadata, ManifestSpec]): + """validating model for Nodes""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate a Node Manifest after a model has been created.""" + if self.kind is not ManifestKind.node: + raise ValueError("Nodes may only be created from an manifest") + + return self diff --git a/src/lsst/cmservice/db/node.py b/src/lsst/cmservice/db/node.py index 212f9abf3..244e2f21d 100644 --- a/src/lsst/cmservice/db/node.py +++ b/src/lsst/cmservice/db/node.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.orm.collections import InstrumentedList from ..common import timestamp @@ -27,6 +26,7 @@ from .specification import Specification if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .campaign import Campaign from .element import ElementMixin @@ -60,13 +60,13 @@ class NodeMixin(RowMixin): async def get_spec_block( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> SpecBlock: """Get the `SpecBlock` object associated to a particular row Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -81,13 +81,13 @@ async def get_spec_block( async def get_specification( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Specification: """Get the `Specification` object associated to a particular row Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -101,13 +101,13 @@ async def get_specification( async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Campaign: """Get the parent `Campaign` Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -119,13 +119,13 @@ async def get_campaign( async def get_parent( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> ElementMixin: """Get the parent `Element` Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -140,7 +140,7 @@ async def get_parent( async def get_handler( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Handler: """Get the Handler object associated with a particular row @@ -149,7 +149,7 @@ async def get_handler( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -167,7 +167,7 @@ async def get_handler( async def resolve_collections( self, - session: async_scoped_session, + session: AnyAsyncSession, *, throw_overrides: bool = True, ) -> dict: @@ -181,7 +181,7 @@ async def resolve_collections( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager throw_overrides : bool @@ -221,7 +221,7 @@ async def resolve_collections( async def get_collections( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> dict: """Get the collection configuration associated with a particular row. @@ -230,7 +230,7 @@ async def get_collections( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -262,7 +262,7 @@ async def get_collections( async def get_child_config( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> dict: """Get the child configuration associated with a particular row. @@ -271,7 +271,7 @@ async def get_child_config( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -291,7 +291,7 @@ async def get_child_config( async def data_dict( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> dict: """Get the data configuration associated to a particular row @@ -300,7 +300,7 @@ async def data_dict( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -326,7 +326,7 @@ async def data_dict( async def get_spec_aliases( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> dict: """Get the spec_aliases associated with a particular node @@ -335,7 +335,7 @@ async def get_spec_aliases( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -359,14 +359,14 @@ async def get_spec_aliases( async def update_child_config( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> NodeMixin: """Update the child configuration associated with this Node Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs: Any @@ -402,7 +402,7 @@ async def update_child_config( async def update_collections( self, - session: async_scoped_session, + session: AnyAsyncSession, *, force: bool = False, **kwargs: Any, @@ -411,7 +411,7 @@ async def update_collections( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs: Any @@ -446,14 +446,14 @@ async def update_collections( async def update_spec_aliases( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> NodeMixin: """Update the spec_alisases configuration associated with this Node Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs: Any @@ -489,7 +489,7 @@ async def update_spec_aliases( async def update_metadata_dict( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> NodeMixin: """Update the metadata configuration associated with this Node. @@ -499,7 +499,7 @@ async def update_metadata_dict( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs: Any @@ -534,7 +534,7 @@ async def update_metadata_dict( async def update_data_dict( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> NodeMixin: """Update the data configuration associated with this Node. @@ -544,7 +544,7 @@ async def update_data_dict( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs: Any @@ -601,14 +601,14 @@ async def update_data_dict( async def check_prerequisites( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> bool: """Check if the prerequisties for processing a particular node are completed. Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -628,13 +628,13 @@ async def check_prerequisites( async def reject( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> NodeMixin: """Set a node as rejected Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -656,13 +656,13 @@ async def reject( async def accept( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> NodeMixin: """Set a node as accepted Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -690,7 +690,7 @@ async def accept( async def reset( self, - session: async_scoped_session, + session: AnyAsyncSession, *, fake_reset: bool = False, ) -> NodeMixin: @@ -698,7 +698,7 @@ async def reset( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fake_reset: bool @@ -718,7 +718,7 @@ async def reset( async def _clean_up_node( self, - session: async_scoped_session, + session: AnyAsyncSession, *, fake_reset: bool = False, ) -> NodeMixin: @@ -726,7 +726,7 @@ async def _clean_up_node( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fake_reset: bool @@ -741,7 +741,7 @@ async def _clean_up_node( async def process( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> tuple[bool, StatusEnum]: """Process this `Node` as much as possible @@ -750,7 +750,7 @@ async def process( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -766,7 +766,7 @@ async def process( async def run_check( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> tuple[bool, StatusEnum]: """Check on this Nodes's status @@ -775,7 +775,7 @@ async def run_check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -790,7 +790,7 @@ async def run_check( async def estimate_sleep_time( self, - session: async_scoped_session, + session: AnyAsyncSession, minimum_sleep_time: int = 10, ) -> int: """Estimate how long to sleep before calling process again. @@ -800,7 +800,7 @@ async def estimate_sleep_time( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -815,7 +815,7 @@ async def estimate_sleep_time( async def update_mtime( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> None: """Update the mtime attribute in an element's hierarchy.""" mtime = timestamp.element_time() diff --git a/src/lsst/cmservice/db/queue.py b/src/lsst/cmservice/db/queue.py index d93feed14..9e776b470 100644 --- a/src/lsst/cmservice/db/queue.py +++ b/src/lsst/cmservice/db/queue.py @@ -1,11 +1,10 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta -from typing import Any +from typing import TYPE_CHECKING, Any from sqlalchemy import JSON, and_, select from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -22,6 +21,10 @@ from .script import Script from .step import Step +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + + logger = LOGGER.bind(module=__name__) @@ -71,13 +74,13 @@ class Queue(Base, NodeMixin): async def get_node( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> NodeMixin: """Get the parent `Node` Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -108,14 +111,14 @@ async def get_node( @classmethod async def get_queue_item( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> Queue: """Get the queue row corresponding to a partiuclar node Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Keywords -------- @@ -157,7 +160,7 @@ async def get_queue_item( @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: fullname = kwargs["fullname"] @@ -200,7 +203,7 @@ async def get_create_kwargs( async def node_sleep_time( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> int: """Check how long to sleep based on what is running""" node = await self.get_node(session) @@ -224,7 +227,7 @@ def waiting( async def process_node( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> bool: # pragma: no cover """Process associated node and update queue row""" node = await self.get_node(session) diff --git a/src/lsst/cmservice/db/row.py b/src/lsst/cmservice/db/row.py index 3210da197..16dce5b25 100644 --- a/src/lsst/cmservice/db/row.py +++ b/src/lsst/cmservice/db/row.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from ..common.enums import StatusEnum from ..common.errors import ( @@ -21,8 +20,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - T = TypeVar("T", bound="RowMixin") - A = TypeVar("A", AsyncSession, async_scoped_session) + from ..common.types import AnyAsyncSession DELETABLE_STATES = [ StatusEnum.failed, @@ -38,23 +36,23 @@ class RowMixin: Defines an interface to manipulate any sort of table. """ - id: Any # Primary Key, typically an int - name: Any # Human-readable name for row fullname: Any # Human-readable unique name for row - + id: Any # Primary Key, typically an int class_string: str # Name to use for help functions and descriptions + col_names_for_table: list + name: Any # Human-readable name for row @classmethod - async def get_rows( + async def get_rows[T: RowMixin]( cls: type[T], - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> Sequence[T]: """Get rows associated to a particular table Parameters ---------- - session : async_scoped_session + session : A DB session manager Keywords @@ -109,16 +107,16 @@ async def get_rows( return results.all() @classmethod - async def get_row( + async def get_row[T: RowMixin]( cls: type[T], - session: A, + session: AnyAsyncSession, row_id: int, ) -> T: """Get a single row, matching row.id == row_id Parameters ---------- - session : async_scoped_session + session : A DB session manager row_id: int @@ -135,16 +133,16 @@ async def get_row( return result @classmethod - async def get_row_by_name( + async def get_row_by_name[T: RowMixin]( cls: type[T], - session: async_scoped_session, + session: AnyAsyncSession, name: str, ) -> T: """Get a single row, with row.name == name Parameters ---------- - session : async_scoped_session + session : A DB session manager name : str @@ -163,16 +161,16 @@ async def get_row_by_name( return row @classmethod - async def get_row_by_fullname( + async def get_row_by_fullname[T: RowMixin]( cls: type[T], - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, ) -> T: """Get a single row, with row.fullname == fullname Parameters ---------- - session : async_scoped_session + session : A DB session manager fullname : str @@ -193,14 +191,14 @@ async def get_row_by_fullname( @classmethod async def delete_row( cls, - session: async_scoped_session, + session: AnyAsyncSession, row_id: int, ) -> None: """Delete a single row, matching row.id == row_id Parameters ---------- - session : async_scoped_session + session : A DB session manager row_id: int @@ -231,14 +229,14 @@ async def delete_row( @classmethod async def _delete_hook( cls, - session: async_scoped_session, + session: AnyAsyncSession, row_id: int, ) -> None: """Hook called during delete_row Parameters ---------- - session : async_scoped_session + session : A DB session manager row_id: int @@ -250,9 +248,9 @@ async def _delete_hook( return @classmethod - async def update_row( + async def update_row[T: RowMixin]( cls: type[T], - session: async_scoped_session, + session: AnyAsyncSession, row_id: int, **kwargs: Any, ) -> T: @@ -260,7 +258,7 @@ async def update_row( Parameters ---------- - session : async_scoped_session + session : A DB session manager row_id: int @@ -311,16 +309,16 @@ async def update_row( return row @classmethod - async def create_row( + async def create_row[T: RowMixin]( cls: type[T], - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> T: """Create a single row Parameters ---------- - session : async_scoped_session + session : A DB session manager kwargs: Any @@ -346,8 +344,8 @@ async def create_row( @classmethod async def get_create_kwargs( - cls: type[T], - session: async_scoped_session, + cls, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: """Get additional keywords needed to create a row @@ -358,7 +356,7 @@ async def get_create_kwargs( Parameters ---------- - session : async_scoped_session + session : A DB session manager kwargs: Any @@ -371,16 +369,16 @@ async def get_create_kwargs( """ return kwargs - async def update_values( + async def update_values[T: RowMixin]( self: T, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> T: """Update values in a row Parameters ---------- - session : async_scoped_session + session : A DB session manager kwargs: Any diff --git a/src/lsst/cmservice/db/script.py b/src/lsst/cmservice/db/script.py index 26187b555..82646a1ef 100644 --- a/src/lsst/cmservice/db/script.py +++ b/src/lsst/cmservice/db/script.py @@ -4,7 +4,6 @@ from sqlalchemy import JSON from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -25,6 +24,7 @@ from .step import Step if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .script_error import ScriptError @@ -111,14 +111,14 @@ def level(self) -> LevelEnum: async def get_script_errors( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> list[ScriptError]: await session.refresh(self, attribute_names=["errors_"]) return self.errors_ async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Campaign: """Maps self.get_parent().c_ to self.get_campaign() for consistency""" parent = await self.get_parent(session) @@ -131,13 +131,13 @@ def node_type(self) -> NodeTypeEnum: async def get_parent( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> ElementMixin: """Get the parent `Element` Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns @@ -165,7 +165,7 @@ async def get_parent( @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: try: @@ -229,7 +229,7 @@ async def get_create_kwargs( async def reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, to_status: StatusEnum, *, fake_reset: bool = False, @@ -243,7 +243,7 @@ async def reset_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager to_status : StatusEnum @@ -262,7 +262,7 @@ async def reset_script( async def review( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> StatusEnum: """Run review() function on this Script @@ -272,7 +272,7 @@ async def review( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns diff --git a/src/lsst/cmservice/db/script_dependency.py b/src/lsst/cmservice/db/script_dependency.py index 2ade4c254..e9ac58f75 100644 --- a/src/lsst/cmservice/db/script_dependency.py +++ b/src/lsst/cmservice/db/script_dependency.py @@ -4,7 +4,6 @@ from uuid import UUID import sqlalchemy.dialects.postgresql as sapg -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -13,6 +12,7 @@ from .row import RowMixin if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .script import Script @@ -41,13 +41,13 @@ def __repr__(self) -> str: async def is_done( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> bool: """Check if this dependency is completed Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns diff --git a/src/lsst/cmservice/db/session.py b/src/lsst/cmservice/db/session.py index 1141f03e1..e912e24ad 100644 --- a/src/lsst/cmservice/db/session.py +++ b/src/lsst/cmservice/db/session.py @@ -1,9 +1,97 @@ """Module to create and handle async database sessions""" -from safir.dependencies.db_session import db_session_dependency -from sqlalchemy.ext.asyncio import async_scoped_session +from collections.abc import AsyncGenerator +from sqlalchemy import URL, make_url +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlalchemy.pool import AsyncAdaptedQueuePool, Pool +from sqlmodel.ext.asyncio.session import AsyncSession -async def get_async_scoped_session() -> async_scoped_session: - """Provides an async session from the safir session maker.""" - return await anext(db_session_dependency()) +from ..common.logging import LOGGER +from ..config import config + +logger = LOGGER.bind(module=__name__) + + +class DatabaseManager: + """A database session manager class designed to manage an async sqlalchemy + engine and produce sessions. + + A module-level instance of this class is created, and when called a new + async session is yielded. + """ + + engine: AsyncEngine | None + sessionmaker: async_sessionmaker[AsyncSession] | None + url: URL + pool_class: type[Pool] = AsyncAdaptedQueuePool + + def __init__(self) -> None: + self.engine = None + self.sessionmaker = None + + async def initialize( + self, + *, + use_async: bool = True, + ) -> None: + """Initialize the database manager. + + Parameters + ---------- + use_async + If true (default), the database drivername will be forced to an + async form. + """ + await self.aclose() + if isinstance(config.db.url, str): + self.url = make_url(config.db.url) + if use_async and self.url.drivername == "postgresql": + self.url = self.url.set(drivername="postgresql+asyncpg") + if config.db.password is not None: + self.url = self.url.set(password=config.db.password.get_secret_value()) + pool_kwargs = ( + config.db.model_dump(include=config.db.pool_fields) + if self.pool_class is AsyncAdaptedQueuePool + else {} + ) + self.engine = create_async_engine( + url=self.url, + echo=config.db.echo, + poolclass=self.pool_class, + **pool_kwargs, + ) + self.sessionmaker = async_sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + + async def __call__(self) -> AsyncGenerator[AsyncSession]: + """Yields a database session, rolls it back on error and closes it on + completion. + + Yields + ------- + sqlmodel.ext.asyncio.AsyncSession + The newly-created session. + """ + if not self.sessionmaker: + raise RuntimeError("Async sessionmaker is not initialized") + + async with self.sessionmaker() as session: + try: + yield session + except Exception: + logger.exception() + await session.rollback() + raise + finally: + await session.close() + + async def aclose(self) -> None: + """Shut down the database engine.""" + if self.engine: + self.sessionmaker = None + await self.engine.dispose() + self.engine = None + + +db_session_dependency = DatabaseManager() +"""A module-level instance of the database manager""" diff --git a/src/lsst/cmservice/db/spec_block.py b/src/lsst/cmservice/db/spec_block.py index 57a1856e7..5a135c645 100644 --- a/src/lsst/cmservice/db/spec_block.py +++ b/src/lsst/cmservice/db/spec_block.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from sqlalchemy import JSON -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, mapped_column @@ -11,6 +10,9 @@ from .handler import Handler from .row import RowMixin +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + class SpecBlock(Base, RowMixin): """Database table to manage blocks that are used to build campaigns @@ -45,7 +47,7 @@ def __repr__(self) -> str: @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: handler = kwargs["handler"] @@ -64,7 +66,7 @@ async def get_create_kwargs( @classmethod async def _delete_hook( cls, - session: async_scoped_session, + session: AnyAsyncSession, row_id: int, ) -> None: Handler.remove_from_cache(row_id) diff --git a/src/lsst/cmservice/db/specification.py b/src/lsst/cmservice/db/specification.py index 6a7eee6d2..33a4b26b8 100644 --- a/src/lsst/cmservice/db/specification.py +++ b/src/lsst/cmservice/db/specification.py @@ -1,7 +1,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from sqlalchemy import JSON -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, mapped_column @@ -10,6 +11,9 @@ from .row import RowMixin from .spec_block import SpecBlock +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + class Specification(Base, RowMixin): """Database table to manage mapping and grouping SpecBlocks @@ -24,7 +28,7 @@ class Specification(Base, RowMixin): data: Mapped[dict] = mapped_column(type_=JSON, default=dict) child_config: Mapped[dict | list | None] = mapped_column(type_=JSON) collections: Mapped[dict | list | None] = mapped_column(type_=JSON) - spec_aliases: Mapped[dict | list | None] = mapped_column(type_=JSON) + spec_aliases: Mapped[dict | None] = mapped_column(type_=JSON) col_names_for_table = ["id", "name"] @@ -38,14 +42,14 @@ def __repr__(self) -> str: async def get_block( self, - session: async_scoped_session, + session: AnyAsyncSession, spec_block_name: str, ) -> SpecBlock: """Get a SpecBlock associated to this Specification Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager spec_block_name: str diff --git a/src/lsst/cmservice/db/step.py b/src/lsst/cmservice/db/step.py index 5581ca682..54524c5cb 100644 --- a/src/lsst/cmservice/db/step.py +++ b/src/lsst/cmservice/db/step.py @@ -5,7 +5,6 @@ from sqlalchemy import JSON from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey, UniqueConstraint @@ -22,6 +21,7 @@ from .spec_block import SpecBlock if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .group import Group from .job import Job from .script import Script @@ -58,7 +58,7 @@ class Step(Base, ElementMixin): metadata_: Mapped[dict] = mapped_column("metadata_", type_=MutableDict.as_mutable(JSONB), default=dict) child_config: Mapped[dict | list | None] = mapped_column(type_=JSON) collections: Mapped[dict | list | None] = mapped_column(type_=JSON) - spec_aliases: Mapped[dict | list | None] = mapped_column(type_=JSON) + spec_aliases: Mapped[dict | None] = mapped_column(type_=JSON) spec_block_: Mapped[SpecBlock] = relationship("SpecBlock", viewonly=True) parent_: Mapped[Campaign] = relationship("Campaign", back_populates="s_") @@ -94,7 +94,7 @@ def __repr__(self) -> str: async def get_campaign( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Campaign: """Maps self.c_ to self.get_campaign() for consistency""" await session.refresh(self, attribute_names=["parent_"]) @@ -102,7 +102,7 @@ async def get_campaign( async def children( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> Iterable: """Maps self.g_ to self.children() for consistency""" await session.refresh(self, attribute_names=["g_"]) @@ -110,7 +110,7 @@ async def children( async def get_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedWmsTaskReportDict: the_dict = MergedWmsTaskReportDict(reports={}) @@ -122,7 +122,7 @@ async def get_wms_reports( async def get_tasks( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedTaskSetDict: the_dict = MergedTaskSetDict(reports={}) @@ -133,7 +133,7 @@ async def get_tasks( async def get_products( self, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> MergedProductSetDict: the_dict = MergedProductSetDict(reports={}) @@ -144,7 +144,7 @@ async def get_products( async def get_all_prereqs( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> list[Step]: all_prereqs: list[Step] = [] await session.refresh(self, attribute_names=["prereqs_"]) @@ -158,7 +158,7 @@ async def get_all_prereqs( @classmethod async def get_create_kwargs( cls, - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> dict: try: diff --git a/src/lsst/cmservice/db/step_dependency.py b/src/lsst/cmservice/db/step_dependency.py index 835214b90..063606b2a 100644 --- a/src/lsst/cmservice/db/step_dependency.py +++ b/src/lsst/cmservice/db/step_dependency.py @@ -4,7 +4,6 @@ from uuid import UUID import sqlalchemy.dialects.postgresql as sapg -from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -13,6 +12,7 @@ from .row import RowMixin if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from .step import Step @@ -41,13 +41,13 @@ def __repr__(self) -> str: async def is_done( self, - session: async_scoped_session, + session: AnyAsyncSession, ) -> bool: """Check if this dependency is completed Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager Returns diff --git a/src/lsst/cmservice/handlers/element_handler.py b/src/lsst/cmservice/handlers/element_handler.py index aae1cfdd4..73527f015 100644 --- a/src/lsst/cmservice/handlers/element_handler.py +++ b/src/lsst/cmservice/handlers/element_handler.py @@ -4,8 +4,6 @@ from typing import TYPE_CHECKING, Any from uuid import UUID, uuid5 -from sqlalchemy.ext.asyncio import async_scoped_session - from ..common.enums import LevelEnum, StatusEnum from ..common.errors import CMYamlParseError, test_type_and_raise from ..common.notification import send_notification @@ -18,6 +16,9 @@ from ..db.script_dependency import ScriptDependency from .functions import render_campaign_steps +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + class ElementHandler(Handler): """SubClass of Handler to deal with generic 'Element' operations, @@ -26,7 +27,7 @@ class ElementHandler(Handler): @staticmethod async def _add_prerequisite( - session: async_scoped_session, + session: AnyAsyncSession, script_id: int, prereq_id: int, namespace: UUID | None = None, @@ -35,7 +36,7 @@ async def _add_prerequisite( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script_id: int @@ -59,7 +60,7 @@ async def _add_prerequisite( async def process( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -67,7 +68,7 @@ async def process( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager node: NodeMixin @@ -116,7 +117,7 @@ async def process( async def run_check( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -126,7 +127,7 @@ async def run_check( async def prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, ) -> tuple[bool, StatusEnum]: """Prepare `Element` for processing @@ -135,7 +136,7 @@ async def prepare( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -218,7 +219,7 @@ async def prepare( async def continue_processing( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -228,7 +229,7 @@ async def continue_processing( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -253,7 +254,7 @@ async def continue_processing( async def review( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> StatusEnum: @@ -265,7 +266,7 @@ async def review( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -282,7 +283,7 @@ async def review( async def _run_script_checks( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> bool: @@ -290,7 +291,7 @@ async def _run_script_checks( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -320,7 +321,7 @@ async def _run_script_checks( async def _run_job_checks( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> bool: @@ -328,7 +329,7 @@ async def _run_job_checks( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -357,7 +358,7 @@ async def _run_job_checks( async def check( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -366,7 +367,7 @@ async def check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -424,7 +425,7 @@ async def check( async def _post_check( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> StatusEnum: @@ -432,7 +433,7 @@ async def _post_check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin @@ -451,7 +452,7 @@ class CampaignHandler(ElementHandler): async def prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, ) -> tuple[bool, StatusEnum]: if TYPE_CHECKING: @@ -466,7 +467,7 @@ async def prepare( async def _post_check( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> StatusEnum: @@ -474,7 +475,7 @@ async def _post_check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element: ElementMixin diff --git a/src/lsst/cmservice/handlers/elements.py b/src/lsst/cmservice/handlers/elements.py index 9237b5430..567cc871b 100644 --- a/src/lsst/cmservice/handlers/elements.py +++ b/src/lsst/cmservice/handlers/elements.py @@ -2,11 +2,10 @@ from collections.abc import AsyncGenerator from functools import partial -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from anyio import to_thread -from sqlalchemy.ext.asyncio import async_scoped_session from ..common.butler import BUTLER_FACTORY from ..common.enums import StatusEnum @@ -20,6 +19,10 @@ from ..db.script import Script from .script_handler import FunctionHandler +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + + logger = LOGGER.bind(module=__name__) @@ -32,7 +35,7 @@ class RunElementScriptHandler(FunctionHandler): async def _do_run( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -49,7 +52,7 @@ async def _do_run( async def _do_check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -78,7 +81,7 @@ class RunJobsScriptHandler(RunElementScriptHandler): async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -102,7 +105,7 @@ async def _do_prepare( async def review_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -126,7 +129,7 @@ class Splitter: @classmethod async def split( cls, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -135,7 +138,7 @@ async def split( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager script: Script @@ -157,7 +160,7 @@ class NoSplit(Splitter): @classmethod async def split( cls, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -177,7 +180,7 @@ class SplitByVals(Splitter): @classmethod async def split( cls, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -203,7 +206,7 @@ class SplitByQuery(Splitter): @classmethod async def split( cls, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -281,7 +284,7 @@ class RunGroupsScriptHandler(RunElementScriptHandler): async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -320,7 +323,7 @@ class RunStepsScriptHandler(RunElementScriptHandler): async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, diff --git a/src/lsst/cmservice/handlers/functions.py b/src/lsst/cmservice/handlers/functions.py index ad25ccb32..d3451bd58 100644 --- a/src/lsst/cmservice/handlers/functions.py +++ b/src/lsst/cmservice/handlers/functions.py @@ -9,7 +9,6 @@ from anyio import Path from pydantic.v1.utils import deep_update from sqlalchemy import select -from sqlalchemy.ext.asyncio import async_scoped_session from lsst.ctrl.bps.bps_reports import compile_job_summary from lsst.ctrl.bps.wms_service import WmsRunReport, WmsStates @@ -17,6 +16,7 @@ from ..common.enums import DEFAULT_NAMESPACE, StatusEnum from ..common.errors import CMMissingFullnameError, CMYamlParseError from ..common.logging import LOGGER +from ..common.types import AnyAsyncSession from ..config import config from ..db.campaign import Campaign from ..db.element import ElementMixin @@ -24,7 +24,7 @@ from ..db.pipetask_error import PipetaskError from ..db.pipetask_error_type import PipetaskErrorType from ..db.product_set import ProductSet -from ..db.session import get_async_scoped_session +from ..db.session import db_session_dependency from ..db.spec_block import SpecBlock from ..db.specification import Specification from ..db.step import Step @@ -36,7 +36,7 @@ async def upsert_spec_block( - session: async_scoped_session, + session: AnyAsyncSession, config_values: dict, loaded_specs: dict, *, @@ -48,7 +48,7 @@ async def upsert_spec_block( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager config_values: dict @@ -127,7 +127,7 @@ async def upsert_spec_block( async def upsert_specification( - session: async_scoped_session, + session: AnyAsyncSession, config_values: dict, *, allow_update: bool = False, @@ -138,7 +138,7 @@ async def upsert_specification( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager config_values: dict @@ -166,7 +166,7 @@ async def upsert_specification( async def load_specification( - session: async_scoped_session, + session: AnyAsyncSession, yaml_file: str | Path | deque, loaded_specs: dict | None = None, *, @@ -178,7 +178,7 @@ async def load_specification( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager yaml_file: str | anyio.Path @@ -259,13 +259,13 @@ async def load_specification( async def add_step_prerequisite( - session: async_scoped_session, depend_id: int, prereq_id: int, namespace: UUID | None = None + session: AnyAsyncSession, depend_id: int, prereq_id: int, namespace: UUID | None = None ) -> StepDependency: """Create and return a StepDependency Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager depend_id: int @@ -288,7 +288,7 @@ async def add_step_prerequisite( async def add_steps( - session: async_scoped_session, + session: AnyAsyncSession, campaign: Campaign, step_config_list: list[dict[str, dict]] | None, ) -> Campaign: @@ -296,7 +296,7 @@ async def add_steps( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager campaign: Campaign @@ -386,14 +386,15 @@ async def force_accept_node( node: int, db_class: type[ElementMixin], output_collection: str | None = None, - session: async_scoped_session | None = None, + session: AnyAsyncSession | None = None, ) -> None: """Force accept a node by bypassing state transition checks and setting node and node's scripts to accepted. """ local_session = False if session is None: - session = await get_async_scoped_session() + assert db_session_dependency.sessionmaker is not None + session = db_session_dependency.sessionmaker() local_session = True the_node = await db_class.get_row(session, node) @@ -431,7 +432,7 @@ async def force_accept_node( async def render_campaign_steps( campaign: Campaign | int, - session: async_scoped_session | None = None, + session: AnyAsyncSession | None = None, ) -> None: """Render the steps for a campaign. @@ -444,7 +445,8 @@ async def render_campaign_steps( """ local_session = False if session is None: - session = await get_async_scoped_session() + assert db_session_dependency.sessionmaker is not None + session = db_session_dependency.sessionmaker() local_session = True if isinstance(campaign, int): campaign = await Campaign.get_row(session, campaign) @@ -466,7 +468,7 @@ async def render_campaign_steps( async def match_pipetask_error( - session: async_scoped_session, + session: AnyAsyncSession, task_name: str, diagnostic_message: str, ) -> PipetaskErrorType | None: @@ -474,7 +476,7 @@ async def match_pipetask_error( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager task_name: str @@ -495,7 +497,7 @@ async def match_pipetask_error( async def load_manifest_report( - session: async_scoped_session, + session: AnyAsyncSession, job_name: str, yaml_file: str | Path, fake_status: StatusEnum | None = None, @@ -506,7 +508,7 @@ async def load_manifest_report( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager job_name: str @@ -715,7 +717,7 @@ def status_from_bps_report( async def load_wms_reports( - session: async_scoped_session, + session: AnyAsyncSession, job: Job, wms_run_report: WmsRunReport | None, ) -> Job: # pragma: no cover @@ -725,7 +727,7 @@ async def load_wms_reports( Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager job_name: str @@ -764,14 +766,14 @@ async def load_wms_reports( async def load_error_types( - session: async_scoped_session, + session: AnyAsyncSession, yaml_file: str | Path, ) -> list[PipetaskErrorType]: """Parse and load error types Parameters ---------- - session: async_scoped_session + session: AnyAsyncSession DB session manager yaml_file: str | anyio.Path @@ -810,7 +812,7 @@ async def load_error_types( async def compute_job_status( - session: async_scoped_session, + session: AnyAsyncSession, job: Job, ) -> StatusEnum: await session.refresh( diff --git a/src/lsst/cmservice/handlers/interface.py b/src/lsst/cmservice/handlers/interface.py index 7c0cb8e1c..8a69b98c1 100644 --- a/src/lsst/cmservice/handlers/interface.py +++ b/src/lsst/cmservice/handlers/interface.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import select -from sqlalchemy.ext.asyncio import async_scoped_session from .. import db from ..common.enums import LevelEnum, NodeTypeEnum, StatusEnum, TableEnum @@ -13,6 +12,7 @@ test_type_and_raise, ) from ..common.logging import LOGGER +from ..common.types import AnyAsyncSession from . import functions TABLE_DICT: dict[TableEnum, type[db.RowMixin]] = { @@ -64,7 +64,7 @@ def get_table( async def get_row_by_table_and_id( - session: async_scoped_session, + session: AnyAsyncSession, row_id: int, table_enum: TableEnum, ) -> db.RowMixin: @@ -72,7 +72,7 @@ async def get_row_by_table_and_id( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager row_id: int @@ -105,7 +105,7 @@ async def get_row_by_table_and_id( async def get_node_by_level_and_id( - session: async_scoped_session, + session: AnyAsyncSession, element_id: int, level: LevelEnum, ) -> db.NodeMixin: @@ -113,7 +113,7 @@ async def get_node_by_level_and_id( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager element_id: int @@ -164,14 +164,14 @@ def get_node_type_by_fullname( async def get_element_by_fullname( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, ) -> db.ElementMixin: """Get a `Element` from the DB Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -208,14 +208,14 @@ async def get_element_by_fullname( async def get_node_by_fullname( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, ) -> db.NodeMixin: """Get a `Node` from the DB Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -241,7 +241,7 @@ async def get_node_by_fullname( async def process_script( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, fake_status: StatusEnum | None = None, ) -> tuple[bool, StatusEnum]: @@ -249,7 +249,7 @@ async def process_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -276,7 +276,7 @@ async def process_script( async def process_element( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, fake_status: StatusEnum | None = None, ) -> tuple[bool, StatusEnum]: @@ -284,7 +284,7 @@ async def process_element( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -313,7 +313,7 @@ async def process_element( async def process( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, fake_status: StatusEnum | None = None, ) -> tuple[bool, StatusEnum]: @@ -321,7 +321,7 @@ async def process( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -356,7 +356,7 @@ async def process( async def reset_script( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, status: StatusEnum, *, @@ -374,7 +374,7 @@ async def reset_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -403,7 +403,7 @@ async def reset_script( async def rescue_job( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, ) -> db.Job: """Run a rescue on a `Job` @@ -414,7 +414,7 @@ async def rescue_job( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -441,7 +441,7 @@ async def rescue_job( async def mark_job_rescued( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, ) -> list[db.Job]: """Mark a `Job` as rescued @@ -452,7 +452,7 @@ async def mark_job_rescued( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -481,14 +481,14 @@ async def mark_job_rescued( async def create_campaign( - session: async_scoped_session, + session: AnyAsyncSession, **kwargs: Any, ) -> db.Campaign: """Create a new Campaign Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager kwargs : Any @@ -504,7 +504,7 @@ async def create_campaign( async def load_and_create_campaign( - session: async_scoped_session, + session: AnyAsyncSession, yaml_file: str, name: str, spec_block_assoc_name: str | None = None, @@ -514,7 +514,7 @@ async def load_and_create_campaign( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager yaml_file: str @@ -552,7 +552,7 @@ async def load_and_create_campaign( async def add_steps( - session: async_scoped_session, + session: AnyAsyncSession, fullname: str, child_configs: list[dict[str, Any]], ) -> db.Campaign: @@ -560,7 +560,7 @@ async def add_steps( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager fullname: str @@ -586,14 +586,14 @@ async def add_steps( async def load_error_types( - session: async_scoped_session, + session: AnyAsyncSession, yaml_file: str, ) -> list[db.PipetaskErrorType]: """Load a set of `PipetaskErrorType`s from a yaml file Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager yaml_file: str, @@ -609,7 +609,7 @@ async def load_error_types( async def load_manifest_report( - session: async_scoped_session, + session: AnyAsyncSession, yaml_file: str, fullname: str, *, @@ -619,7 +619,7 @@ async def load_manifest_report( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager yaml_file: str, @@ -641,7 +641,7 @@ async def load_manifest_report( async def match_pipetask_errors( - session: async_scoped_session, + session: AnyAsyncSession, *, rematch: bool = False, ) -> list[db.PipetaskError]: @@ -651,7 +651,7 @@ async def match_pipetask_errors( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager rematch: bool diff --git a/src/lsst/cmservice/handlers/job_handler.py b/src/lsst/cmservice/handlers/job_handler.py index c033e38cc..4224ba9d9 100644 --- a/src/lsst/cmservice/handlers/job_handler.py +++ b/src/lsst/cmservice/handlers/job_handler.py @@ -2,14 +2,13 @@ from typing import TYPE_CHECKING, Any -from sqlalchemy.ext.asyncio import async_scoped_session - from ..common.enums import ErrorActionEnum, StatusEnum from ..common.errors import CMBadEnumError from ..db.element import ElementMixin from .element_handler import ElementHandler if TYPE_CHECKING: + from ..common.types import AnyAsyncSession from ..db import Job @@ -18,7 +17,7 @@ class JobHandler(ElementHandler): async def _post_check( self, - session: async_scoped_session, + session: AnyAsyncSession, element: ElementMixin, **kwargs: Any, ) -> StatusEnum: diff --git a/src/lsst/cmservice/handlers/jobs.py b/src/lsst/cmservice/handlers/jobs.py index 0f735b704..1325d809e 100644 --- a/src/lsst/cmservice/handlers/jobs.py +++ b/src/lsst/cmservice/handlers/jobs.py @@ -10,7 +10,6 @@ from anyio import Path, to_thread from fastapi.concurrency import run_in_threadpool from jinja2 import Environment, PackageLoader -from sqlalchemy.ext.asyncio import async_scoped_session from lsst.ctrl.bps import BaseWmsService, WmsRunReport, WmsStates from lsst.utils import doImport @@ -36,6 +35,10 @@ from .functions import compute_job_status, load_manifest_report, load_wms_reports, status_from_bps_report from .script_handler import FunctionHandler, ScriptHandler +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + + WMS_TO_TASK_STATUS_MAP = { WmsStates.UNKNOWN: TaskStatusEnum.missing, WmsStates.MISFIT: TaskStatusEnum.missing, @@ -64,7 +67,7 @@ class BpsScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -75,11 +78,11 @@ async def _write_script( await session.refresh(parent, attribute_names=["c_"]) data_dict = await script.data_dict(session) resolved_cols = await script.resolve_collections(session) + prod_area = os.path.expandvars(config.bps.artifact_path) # Resolve mandatory data element inputs. All of these values must be # provided somewhere along the SpecBlock chain. try: - prod_area = os.path.expandvars(data_dict["prod_area"]) butler_repo = os.path.expandvars(data_dict["butler_repo"]) lsst_version = os.path.expandvars(data_dict.get("lsst_version", "w_latest")) lsst_distrib_dir = os.path.expandvars(data_dict["lsst_distrib_dir"]) @@ -149,7 +152,7 @@ async def _write_script( # unless they are unique to the submission and separated for # readability. The use of any kind of "shared" or "global" config # items breaks provenance for all campaigns that reference them. - bps_wms_extra_files = data_dict.get("bps_wms_extra_files", []) + bps_wms_extra_files: list = data_dict.get("bps_wms_extra_files", []) bps_wms_clustering_file = data_dict.get("bps_wms_clustering_file", None) bps_wms_resources_file = data_dict.get("bps_wms_resources_file", None) bps_wms_yaml_file = data_dict.get("bps_wms_yaml_file", None) @@ -233,7 +236,7 @@ async def _write_script( async def _check_slurm_job( self, - session: async_scoped_session, + session: AnyAsyncSession, slurm_id: str | None, script: Script, parent: ElementMixin, @@ -263,7 +266,7 @@ async def _check_slurm_job( async def _check_htcondor_job( self, - session: async_scoped_session, + session: AnyAsyncSession, htcondor_id: str | None, script: Script, parent: ElementMixin, @@ -284,7 +287,7 @@ async def _check_htcondor_job( # Irrespective of status, if the bps stdout log file exists, try to # parse it for valuable information # FIXME is this appropriate? maybe it should only be for terminal state - bps_submit_dir: str | None + bps_submit_dir: str | None = None if fake_status is not None: wms_job_id = "fake_job" bps_submit_dir = "fake_path" @@ -303,7 +306,7 @@ async def _check_htcondor_job( async def launch( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -330,7 +333,7 @@ def get_bps_submit_dir(cls, bps_dict: dict) -> str | None: async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -358,7 +361,7 @@ async def _reset_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -407,7 +410,7 @@ def _get_wms_svc(self, **kwargs: Any) -> BaseWmsService | None: async def _load_wms_reports( self, - session: async_scoped_session, + session: AnyAsyncSession, job: Job, wms_workflow_id: str | None, **kwargs: Any, @@ -462,7 +465,7 @@ async def _load_wms_reports( async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin | Job, **kwargs: Any, @@ -478,7 +481,7 @@ async def _do_prepare( async def _do_check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin | Job, **kwargs: Any, @@ -497,7 +500,7 @@ async def _do_check( async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -564,7 +567,7 @@ class ManifestReportScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -572,7 +575,7 @@ async def _write_script( if TYPE_CHECKING: assert isinstance(parent, Job) data_dict = await script.data_dict(session) - prod_area = await Path(os.path.expandvars(data_dict["prod_area"])).resolve() + prod_area = await Path(os.path.expandvars(config.bps.artifact_path)).resolve() resolved_cols = await script.resolve_collections(session) script_url = await self._set_script_files(session, script, prod_area) butler_repo = data_dict["butler_repo"] @@ -616,15 +619,14 @@ class ManifestReportLoadHandler(FunctionHandler): async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, ) -> StatusEnum: if TYPE_CHECKING: assert isinstance(parent, Job) - data_dict = await script.data_dict(session) - prod_area = await Path(os.path.expandvars(data_dict["prod_area"])).resolve() + prod_area = await Path(os.path.expandvars(config.bps.artifact_path)).resolve() report_url = parent.metadata_.get("report_url") or ( prod_area / parent.fullname / "manifest_report.yaml" @@ -644,7 +646,7 @@ async def _do_prepare( async def _do_check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin | Job, **kwargs: Any, @@ -661,7 +663,7 @@ async def _do_check( async def _load_pipetask_report( self, - session: async_scoped_session, + session: AnyAsyncSession, job: Job, pipetask_report_yaml: str, fake_status: StatusEnum | None = None, @@ -688,7 +690,7 @@ async def _load_pipetask_report( async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, diff --git a/src/lsst/cmservice/handlers/script_handler.py b/src/lsst/cmservice/handlers/script_handler.py index 57ac71484..6429fc958 100644 --- a/src/lsst/cmservice/handlers/script_handler.py +++ b/src/lsst/cmservice/handlers/script_handler.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any from anyio import Path -from sqlalchemy.ext.asyncio import async_scoped_session from ..common.bash import check_stamp_file, get_diagnostic_message, run_bash_job from ..common.enums import ErrorSourceEnum, ScriptMethodEnum, StatusEnum @@ -26,6 +25,10 @@ from ..db.script import Script from ..db.script_error import ScriptError +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + + logger = LOGGER.bind(module=__name__) DOUBLE_QUOTE = '"' @@ -36,7 +39,7 @@ class BaseScriptHandler(Handler): async def process( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -129,7 +132,7 @@ async def process( async def run_check( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, **kwargs: Any, ) -> tuple[bool, StatusEnum]: @@ -145,7 +148,7 @@ async def run_check( async def prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, ) -> StatusEnum: @@ -157,7 +160,7 @@ async def prepare( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -178,7 +181,7 @@ async def prepare( async def launch( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -191,7 +194,7 @@ async def launch( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -209,7 +212,7 @@ async def launch( async def check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -222,7 +225,7 @@ async def check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -240,7 +243,7 @@ async def check( async def review_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -253,7 +256,7 @@ async def review_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -272,7 +275,7 @@ async def review_script( async def reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, node: NodeMixin, to_status: StatusEnum, *, @@ -305,7 +308,7 @@ async def reset_script( async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -315,7 +318,7 @@ async def _reset_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -325,7 +328,7 @@ async def _purge_products( async def update_status( self, - session: async_scoped_session, + session: AnyAsyncSession, status: StatusEnum, node: NodeMixin, **kwargs: Any, @@ -352,7 +355,7 @@ class ScriptHandler(BaseScriptHandler): @staticmethod async def _check_stamp_file( - session: async_scoped_session, + session: AnyAsyncSession, stamp_file: str | None, script: Script, parent: ElementMixin, @@ -362,7 +365,7 @@ async def _check_stamp_file( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager stamp_file: str | None @@ -390,7 +393,7 @@ async def _check_stamp_file( async def _check_slurm_job( self, - session: async_scoped_session, + session: AnyAsyncSession, slurm_id: str | None, script: Script, parent: ElementMixin, @@ -400,7 +403,7 @@ async def _check_slurm_job( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager slurm_id : str @@ -426,7 +429,7 @@ async def _check_slurm_job( async def _check_htcondor_job( self, - session: async_scoped_session, + session: AnyAsyncSession, htcondor_id: str | None, script: Script, parent: ElementMixin, @@ -436,7 +439,7 @@ async def _check_htcondor_job( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager htcondor_id : str | None @@ -462,7 +465,7 @@ async def _check_htcondor_job( async def prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -486,7 +489,7 @@ async def prepare( async def launch( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -535,7 +538,7 @@ async def launch( async def check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -581,7 +584,7 @@ async def check( async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -590,7 +593,7 @@ async def _write_script( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -608,7 +611,7 @@ async def _write_script( async def _set_script_files( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, prod_area: str | Path, ) -> str: @@ -620,7 +623,7 @@ async def _set_script_files( async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -649,7 +652,7 @@ class FunctionHandler(BaseScriptHandler): async def prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -667,7 +670,7 @@ async def prepare( async def launch( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -686,7 +689,7 @@ async def launch( async def check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -703,7 +706,7 @@ async def check( async def _do_prepare( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -712,7 +715,7 @@ async def _do_prepare( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -730,7 +733,7 @@ async def _do_prepare( async def _do_run( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -739,7 +742,7 @@ async def _do_run( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -757,7 +760,7 @@ async def _do_run( async def _do_check( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -766,7 +769,7 @@ async def _do_check( Parameters ---------- - session : async_scoped_session + session : AnyAsyncSession DB session manager script: Script @@ -784,7 +787,7 @@ async def _do_check( async def _reset_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, diff --git a/src/lsst/cmservice/handlers/scripts.py b/src/lsst/cmservice/handlers/scripts.py index 3b5284270..f678880ea 100644 --- a/src/lsst/cmservice/handlers/scripts.py +++ b/src/lsst/cmservice/handlers/scripts.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any from anyio import Path -from sqlalchemy.ext.asyncio import async_scoped_session from ..common.bash import write_bash_script from ..common.butler import ( @@ -23,6 +22,10 @@ from ..db.step import Step from .script_handler import ScriptHandler +if TYPE_CHECKING: + from ..common.types import AnyAsyncSession + + logger = LOGGER.bind(module=__name__) @@ -31,7 +34,7 @@ class NullScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -40,7 +43,7 @@ async def _write_script( data_dict = await script.data_dict(session) try: output_coll = resolved_cols["output"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as e: raise CMMissingScriptInputError(f"{script.fullname} missing an input: {e}") from e @@ -58,7 +61,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -82,7 +85,7 @@ class ChainCreateScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -92,7 +95,7 @@ async def _write_script( try: output_coll = resolved_cols["output"] input_colls = resolved_cols["inputs"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as msg: logger.exception() @@ -117,7 +120,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -148,7 +151,7 @@ class ChainPrependScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -158,7 +161,7 @@ async def _write_script( try: output_coll = resolved_cols["output"] input_coll = resolved_cols["input"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as msg: raise CMMissingScriptInputError(f"{script.fullname} missing an input: {msg}") from msg @@ -178,7 +181,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -210,7 +213,7 @@ class ChainCollectScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -238,7 +241,7 @@ async def _write_script( raise CMMissingScriptInputError( "Must specify what to collect in ChainCollectScriptHandler, jobs or steps", ) - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] command = f"{config.butler.butler_bin} collection-chain {butler_repo} {output_coll}" for collect_coll_ in collect_colls: @@ -257,7 +260,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -288,7 +291,7 @@ class TagInputsScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -298,7 +301,7 @@ async def _write_script( try: output_coll = resolved_cols["output"] input_coll = resolved_cols["input"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] data_query = data_dict.get("data_query") except KeyError as msg: @@ -318,7 +321,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -345,7 +348,7 @@ class TagCreateScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -354,7 +357,7 @@ async def _write_script( data_dict = await script.data_dict(session) try: output_coll = resolved_cols["output"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as msg: raise CMMissingScriptInputError(f"{script.fullname} missing an input: {msg}") from msg @@ -371,7 +374,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -400,7 +403,7 @@ class TagAssociateScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -410,7 +413,7 @@ async def _write_script( try: input_coll = resolved_cols["input"] output_coll = resolved_cols["output"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as msg: raise CMMissingScriptInputError(f"{script.fullname} missing an input: {msg}") from msg @@ -428,7 +431,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -462,7 +465,7 @@ class PrepareStepScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -474,7 +477,7 @@ async def _write_script( resolved_cols = await script.resolve_collections(session) data_dict = await script.data_dict(session) try: - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] output_coll = resolved_cols["output"] except KeyError as msg: @@ -505,7 +508,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -527,14 +530,14 @@ class ResourceUsageScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, ) -> StatusEnum: resolved_cols = await script.resolve_collections(session) data_dict = await script.data_dict(session) - prod_area = os.path.expandvars(data_dict["prod_area"]) + prod_area = os.path.expandvars(config.bps.artifact_path) script_url = await self._set_script_files(session, script, prod_area) butler_repo = data_dict["butler_repo"] usage_graph_url = os.path.expandvars(f"{prod_area}/{parent.fullname}/resource_usage.qgraph") @@ -557,7 +560,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -584,14 +587,14 @@ class HipsMapsScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, ) -> StatusEnum: resolved_cols = await script.resolve_collections(session) data_dict = await script.data_dict(session) - prod_area = os.path.expandvars(data_dict["prod_area"]) + prod_area = os.path.expandvars(config.bps.artifact_path) script_url = await self._set_script_files(session, script, prod_area) butler_repo = data_dict["butler_repo"] hips_maps_graph_url = os.path.expandvars(f"{prod_area}/{parent.fullname}/hips_maps.qgraph") @@ -649,7 +652,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, @@ -683,7 +686,7 @@ class ValidateScriptHandler(ScriptHandler): async def _write_script( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, parent: ElementMixin, **kwargs: Any, @@ -693,7 +696,7 @@ async def _write_script( try: input_coll = resolved_cols["input"] output_coll = resolved_cols["output"] - script_url = await self._set_script_files(session, script, data_dict["prod_area"]) + script_url = await self._set_script_files(session, script, config.bps.artifact_path) butler_repo = data_dict["butler_repo"] except KeyError as msg: raise CMMissingScriptInputError(f"{script.fullname} missing an input: {msg}") from msg @@ -710,7 +713,7 @@ async def _write_script( async def _purge_products( self, - session: async_scoped_session, + session: AnyAsyncSession, script: Script, to_status: StatusEnum, *, diff --git a/src/lsst/cmservice/machines/__init__.py b/src/lsst/cmservice/machines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lsst/cmservice/machines/abc.py b/src/lsst/cmservice/machines/abc.py new file mode 100644 index 000000000..8d64454e7 --- /dev/null +++ b/src/lsst/cmservice/machines/abc.py @@ -0,0 +1,255 @@ +"""Abstract Base Classes used by Stateful Model and/or Machine classes. + +These primarily exist and are used to satisfy static type checkers that are +otherwise unaware of any dynamic methods added to Stateful Model classes by +a Machine instance. + +Notes +----- +These ABCs were generated automatically by `transitions.experimental.utils. +generate_base_model and simplified and/or modified for use by the application. + +These ABCs do not use abstractclasses because the implmentations will not be +available to static type checkers (i.e., they only exist at runtime). + +These ABCs may implement methods that are not used by application, i.e., that +involve states that are not referenced by any transition. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from sqlmodel.ext.asyncio.session import AsyncSession +from transitions import EventData, Machine +from transitions.extensions.asyncio import AsyncMachine + +from ..common.enums import ManifestKind, StatusEnum +from ..db.campaigns_v2 import ActivityLog, Campaign, Node + +type AnyStatefulObject = Campaign | Node +type AnyMachine = Machine | AsyncMachine + + +class StatefulModel(ABC): + """Base ABC for a Stateful Model, where the Machine will override abstract + methods and properties when it is created. + """ + + __kind__ = [ManifestKind.other] + activity_log_entry: ActivityLog | None = None + db_model: AnyStatefulObject | None + machine: AnyMachine + state: StatusEnum + session: AsyncSession | None = None + + @abstractmethod + def __init__( + self, *args: Any, o: AnyStatefulObject, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: ... + + @abstractmethod + async def error_handler(self, event: EventData) -> None: ... + + @abstractmethod + async def prepare_activity_log(self, event: EventData) -> None: ... + + @abstractmethod + async def update_persistent_status(self, event: EventData) -> None: ... + + @abstractmethod + async def finalize(self, event: EventData) -> None: ... + + async def may_trigger(self, trigger_name: str) -> bool: + raise NotImplementedError("Must be overridden by a Machine") + + async def trigger(self, trigger_name: str, **kwargs: Any) -> bool: + raise NotImplementedError("Must be overridden by a Machine") + + async def resume(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_resume(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def force(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_force(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def pause(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_pause(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def start(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_start(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def unblock(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_unblock(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def unprepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_unprepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def stop(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_stop(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def retry(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_retry(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def finish(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_finish(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def block(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_block(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def prepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_prepare(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def fail(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_fail(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_overdue(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_failed(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rejected(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_blocked(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_paused(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rescuable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_waiting(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_ready(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_prepared(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_running(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_reviewable(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_accepted(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def is_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def to_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") + + async def may_to_rescued(self) -> bool: + raise NotImplementedError("This should be overridden") diff --git a/src/lsst/cmservice/machines/campaign.py b/src/lsst/cmservice/machines/campaign.py new file mode 100644 index 000000000..55cff74b6 --- /dev/null +++ b/src/lsst/cmservice/machines/campaign.py @@ -0,0 +1,190 @@ +"""Module for state machine implementations related to Campaigns. + +A Campaign state machine should be a simple one, since a Campaign itself does +not need to implement much in the way of Actions or Triggers. A campaign's +status should generally reflect the "worst-case" status of any of Nodes active +in its namespace. + +Since a campaign is mostly a container, the critical path of its state machine +should focus on validity and completeness of its graph, while providing useful +information about the overall campaign progress to pilots and other users. +""" + +from typing import TYPE_CHECKING, Any +from uuid import uuid5 + +from sqlmodel import select +from transitions import EventData +from transitions.extensions.asyncio import AsyncMachine + +from ..common import timestamp +from ..common.enums import ManifestKind, StatusEnum +from ..common.graph import graph_from_edge_list_v2, validate_graph +from ..common.logging import LOGGER +from ..db.campaigns_v2 import ActivityLog, Campaign, Edge, Node +from .node import NodeMachine + +logger = LOGGER.bind(module=__name__) + + +TRANSITIONS = [ + # The critical/happy path of state evolution from waiting to accepted + { + "trigger": "start", + "source": StatusEnum.waiting, + "dest": StatusEnum.running, + "conditions": "has_valid_graph", + }, + { + "trigger": "finish", + "source": StatusEnum.running, + "dest": StatusEnum.accepted, + "conditions": "is_successful", + }, + # User-initiated transitions + {"trigger": "pause", "source": StatusEnum.running, "dest": StatusEnum.paused}, + { + "trigger": "resume", + "source": StatusEnum.paused, + "dest": StatusEnum.running, + "conditions": "has_valid_graph", + }, +] +"""Transitions available to a Campaign, expressed as source-destination pairs +with a named trigger-verb. +""" + + +class InvalidCampaignGraphError(Exception): ... + + +class CampaignMachine(NodeMachine): + """Class representing the stateful structure of a Campaign State Machine, + including callbacks and actions to be executed during transitions. + """ + + __kind__ = [ManifestKind.campaign] + + def __init__( + self, *args: Any, o: Campaign, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + self.db_model = o + self.machine = AsyncMachine( + model=self, + states=StatusEnum, + transitions=TRANSITIONS, + initial=initial_state, + auto_transitions=False, + prepare_event=["prepare_session", "prepare_activity_log"], + after_state_change="update_persistent_status", + finalize_event="finalize", + on_exception="error_handler", + send_event=True, + model_override=True, + ) + + async def error_handler(self, event: EventData) -> None: + """Error handler function for the Stateful Model, called by the Machine + if any exception is raised in a callback function. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + if event.error is None: + return + + logger.exception(event.error, id=self.db_model.id) + if self.activity_log_entry is not None: + self.activity_log_entry.detail["trigger"] = event.event.name + self.activity_log_entry.detail["error"] = str(event.error) + self.activity_log_entry.finished_at = timestamp.now_utc() + + async def prepare_activity_log(self, event: EventData) -> None: + """Callback method invoked by the Machine before every state-change.""" + + if TYPE_CHECKING: + assert self.db_model is not None + + if self.activity_log_entry is not None: + return None + + from_state = StatusEnum[event.transition.source] if event.transition else self.state + to_state = ( + StatusEnum[event.transition.dest] if event.transition and event.transition.dest else self.state + ) + + self.activity_log_entry = ActivityLog( + namespace=self.db_model.id, + operator=event.kwargs.get("operator", "daemon"), + from_status=from_state, + to_status=to_state, + detail={}, + metadata_={"request_id": event.kwargs.get("request_id")}, + ) + + async def finalize(self, event: EventData) -> None: + """Callback method invoked by the Machine unconditionally at the end + of every callback chain. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + # The activity log entry is added to the db. For failed transitions it + # may include error detail. For other transitions it is not necessary + # to log every attempt, so if no callback has registered any detail + # for the log entry it is not persisted. + if self.activity_log_entry is None: + return + elif self.activity_log_entry.finished_at is None: + return + + try: + self.session.add(self.activity_log_entry) + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(self.activity_log_entry) + self.activity_log_entry = None + + await self.session.close() + self.session = None + self.activity_log_entry = None + + async def is_successful(self, event: EventData) -> bool: + """A conditional method associated with a transition. + + This callback should assert that the campaign is in a complete and + accepted state by the virtue of all its Nodes also being in a complete + and accepted state. The campaign's "END" node is used as a proxy + for this assertion, because by the rules of the campaign's graph, the + "END" node may only be reached if all other nodes have been success- + fully evolved by an executor. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + end_node = await self.session.get_one(Node, uuid5(self.db_model.id, "END.1")) + logger.info(f"Checking whether campaign {self.db_model.name} is finished.", end_node=end_node.status) + return end_node.status is StatusEnum.accepted + + async def has_valid_graph(self, event: EventData) -> bool: + """A conditional method associated with a transition. + + This callback asserts that the campaign graph is valid as a condition + that must be met before the campaign may transition to a "ready" state. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + edges = await self.session.exec(select(Edge).where(Edge.namespace == self.db_model.id)) + graph = await graph_from_edge_list_v2(edges.all(), self.session) + source = uuid5(self.db_model.id, "START.1") + sink = uuid5(self.db_model.id, "END.1") + graph_is_valid = validate_graph(graph, source, sink) + if not graph_is_valid: + raise InvalidCampaignGraphError("Invalid campaign graph") + return graph_is_valid diff --git a/src/lsst/cmservice/machines/node.py b/src/lsst/cmservice/machines/node.py new file mode 100644 index 000000000..cecf0a47a --- /dev/null +++ b/src/lsst/cmservice/machines/node.py @@ -0,0 +1,452 @@ +"""Module for state machine implementations related to Nodes.""" + +import inspect +import pickle +import shutil +import sys +from functools import cache +from os.path import expandvars +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from anyio import Path +from fastapi.concurrency import run_in_threadpool +from transitions import EventData +from transitions.extensions.asyncio import AsyncEvent, AsyncMachine + +from ..common import timestamp +from ..common.enums import ManifestKind, StatusEnum +from ..common.logging import LOGGER +from ..common.timestamp import element_time +from ..config import config +from ..db.campaigns_v2 import ActivityLog, Machine, Node +from ..db.session import db_session_dependency +from .abc import StatefulModel + +logger = LOGGER.bind(module=__name__) + + +TRANSITIONS = [ + # The critical/happy path of state evolution from waiting to accepted + { + "trigger": "prepare", + "source": StatusEnum.waiting, + "dest": StatusEnum.ready, + }, + { + "trigger": "start", + "source": StatusEnum.ready, + "dest": StatusEnum.running, + "conditions": "is_startable", + }, + { + "trigger": "finish", + "source": StatusEnum.running, + "dest": StatusEnum.accepted, + "conditions": "is_done_running", + }, + # The bad transitions + {"trigger": "block", "source": StatusEnum.running, "dest": StatusEnum.blocked}, + {"trigger": "fail", "source": StatusEnum.running, "dest": StatusEnum.failed}, + # User-initiated transitions + {"trigger": "pause", "source": StatusEnum.running, "dest": StatusEnum.paused}, + {"trigger": "unblock", "source": StatusEnum.blocked, "dest": StatusEnum.running}, + {"trigger": "resume", "source": StatusEnum.paused, "dest": StatusEnum.running}, + {"trigger": "force", "source": StatusEnum.failed, "dest": StatusEnum.accepted}, + # Inverse transitions, i.e., rollbacks + {"trigger": "unprepare", "source": StatusEnum.ready, "dest": StatusEnum.waiting}, + {"trigger": "stop", "source": StatusEnum.paused, "dest": StatusEnum.ready}, + {"trigger": "retry", "source": StatusEnum.failed, "dest": StatusEnum.ready}, +] +"""Transitions available to a Node, expressed as source-destination pairs +with a named trigger-verb. +""" + + +class NodeMachine(StatefulModel): + """General state model for a Node in a Campaign Graph.""" + + __kind__ = [ManifestKind.node] + + def __init__( + self, *args: Any, o: Node, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + self.db_model = o + self.machine = AsyncMachine( + model=self, + states=StatusEnum, + transitions=TRANSITIONS, + initial=initial_state, + auto_transitions=False, + prepare_event=["prepare_session", "prepare_activity_log"], + after_state_change="update_persistent_status", + finalize_event="finalize", + on_exception="error_handler", + send_event=True, + model_override=True, + ) + self.post_init() + + def post_init(self) -> None: + """Additional initialization method called at the end of ``__init__``, + as a convenience to child classes. + """ + pass + + def __getstate__(self) -> dict: + """Prepares the stateful model for serialization, as with pickle.""" + # Remove members that are not picklable or should not be included + # in the pickle + state = self.__dict__.copy() + del state["session"] + del state["db_model"] + del state["activity_log_entry"] + return state + + async def error_handler(self, event: EventData) -> None: + """Error handler function for the Stateful Model, called by the Machine + if any exception is raised in a callback function. + """ + if event.error is None: + return + + logger.exception(event.error) + if self.activity_log_entry is not None: + self.activity_log_entry.detail["trigger"] = event.event.name + self.activity_log_entry.detail["error"] = str(event.error) + self.activity_log_entry.finished_at = timestamp.now_utc() + + # Auto-transition on error + match event.event: + case AsyncEvent(name="finish"): + # TODO if we need to distinguish between types of failures, + # e.g., fail vs block, we'd have to inspect the error here + await self.trigger("fail") + case _: + ... + + async def prepare_session(self, event: EventData) -> None: + """Prepares the machine by acquiring a database session.""" + # This positive assertion concerning the ORM member will prevent + # any callback from proceeding if no such member is defined, but type + # checkers don't know this, which is why it repeated in a TYPE_CHECKING + # guard in each method that accesses the ORM member. + assert self.db_model is not None, "Stateful Model must have a Node member." + + logger.debug("Preparing session for transition", id=str(self.db_model.id)) + if self.session is not None: + await self.session.close() + else: + assert db_session_dependency.sessionmaker is not None + self.session = db_session_dependency.sessionmaker() + + async def prepare_activity_log(self, event: EventData) -> None: + """Callback method invoked by the Machine before every state-change.""" + if TYPE_CHECKING: + assert self.db_model is not None + + if self.activity_log_entry is not None: + return None + + logger.debug("Preparing activity log for transition", id=str(self.db_model.id)) + + from_state = StatusEnum[event.transition.source] if event.transition else self.state + to_state = ( + StatusEnum[event.transition.dest] if event.transition and event.transition.dest else self.state + ) + + self.activity_log_entry = ActivityLog( + namespace=self.db_model.namespace, + node=self.db_model.id, + operator="daemon", + from_status=from_state, + to_status=to_state, + detail={}, + metadata_={}, + ) + + async def update_persistent_status(self, event: EventData) -> None: + """Callback method invoked by the Machine after every state-change.""" + # Update activity log entry with new state and timestamp + if TYPE_CHECKING: + assert self.db_model is not None, "Stateful Model must have a Node member." + assert self.session is not None + logger.debug("Updating the ORM instance after transition.", id=str(self.db_model.id)) + + if self.activity_log_entry is not None: + self.activity_log_entry.to_status = self.state + self.activity_log_entry.finished_at = timestamp.now_utc() + + # Ensure database record for transitioned object is updated + self.db_model = await self.session.merge(self.db_model, load=False) + self.db_model.status = self.state + self.db_model.metadata_["mtime"] = element_time() + await self.session.commit() + + async def finalize(self, event: EventData) -> None: + """Callback method invoked by the Machine unconditionally at the end + of every callback chain. During this callback, if the activity log + indicates that change has occurred, it is written to the db and the + machine is serialized to the Machines table for later use. + """ + if TYPE_CHECKING: + assert self.db_model is not None + assert self.session is not None + + # The activity log entry is added to the db. For failed transitions it + # may include error detail. For other transitions it is not necessary + # to log every attempt. + if self.activity_log_entry is None: + return + elif self.activity_log_entry.finished_at is None: + return + + # ensure the orm instance is in the session + if self.db_model not in self.session: + self.db_model = await self.session.merge(self.db_model, load=False) + + # flush the activity log entry to the db + try: + logger.debug("Finalizing the activity log after transition.", id=str(self.db_model.id)) + self.session.add(self.activity_log_entry) + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(self.activity_log_entry) + self.activity_log_entry = None + + # create or update a machine entry in the db + new_machine = Machine.model_validate( + dict(id=self.db_model.machine or uuid4(), state=pickle.dumps(self.machine)) + ) + try: + logger.debug("Serializing the state machine after transition.", id=str(self.db_model.id)) + await self.session.merge(new_machine) + self.db_model.machine = new_machine.id + await self.session.commit() + except Exception: + logger.exception() + await self.session.rollback() + finally: + self.session.expunge(new_machine) + + await self.session.close() + self.session = None + + async def is_startable(self, event: EventData) -> bool: + """Conditional method called to check whether a ``start`` trigger may + be called. + """ + return True + + async def is_done_running(self, event: EventData) -> bool: + """Conditional method called to check whether a ``finish`` trigger may + be called. + """ + return True + + +class StartMachine(NodeMachine): + """Conceptually, a campaign's START node may participate in activities like + any other kind of node, even though its purpose is to provide a solid well- + known root to the campaign graph. Some activities assigned to the Campaign + Machine could also be modeled as belonging to the START node instead. The + END node could serve a similar purpose. + """ + + __kind__ = [ManifestKind.node] + + def post_init(self) -> None: + """Post init, set class-specific callback triggers.""" + self.machine.before_prepare("do_prepare") + self.machine.before_unprepare("do_unprepare") + self.machine.before_start("do_start") + + async def do_prepare(self, event: EventData) -> None: + """Action method invoked when executing the "prepare" transition. + + For a Campaign to enter the ready state, the machine must consider: + + Conditions + ---------- + - the campaign's graph is valid. + + Callbacks + --------- + - artifact directory is created and writable. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + logger.info("Preparing START node", id=str(self.db_model.id)) + + artifact_location = Path(expandvars(config.bps.artifact_path)) / str(self.db_model.namespace) + await artifact_location.mkdir(parents=False, exist_ok=True) + + async def do_unprepare(self, event: EventData) -> None: + if TYPE_CHECKING: + assert self.db_model is not None + + logger.info("Unpreparing START node", id=str(self.db_model.id)) + artifact_location = Path(expandvars(config.bps.artifact_path)) / str(self.db_model.namespace) + await run_in_threadpool(shutil.rmtree, artifact_location) + + async def do_start(self, event: EventData) -> None: + """Callback invoked when entering the "running" state. + + There is no particular work performed when a campaign enters a running + state other than to update the record's entry in the database which + acts as a flag to an executor to signal that a campaign's graph Nodes + may now be evolved. + """ + if TYPE_CHECKING: + assert self.db_model is not None + + logger.debug("Starting START Node for Campaign", id=str(self.db_model.id)) + return None + + +class StepMachine(NodeMachine): + """Specific state model for a Node of kind GroupedStep. + + The Step-Nodes may be the most involved state models, as the logic that + must execute during each transition is complex. The behaviors are generally + the same as the "scripts" associated with a Step/Group/Job in the legacy + CM implementation. + + A summary of the logic at each transition: + + - prepare + - determine number of groups and group membership + - create new Manifest for each Group + - start + - create new StepGroup Nodes (reading prepared Manifests) + - create new StepCollect Node + - create edges + - finish + - (condition) campaign graph is valid + - unprepare (rollback) + - no action taken, but know that on the next use of "prepare" + new versions of the group manifests may be created. + + Failure modes may include + - Butler errors (can't query for group membership) + - Bad inputs (group membership rules don't make sense) + """ + + __kind__ = [ManifestKind.grouped_step] + + def __init__( + self, *args: Any, o: Node, initial_state: StatusEnum = StatusEnum.waiting, **kwargs: Any + ) -> None: + super().__init__(*args, o, initial_state, **kwargs) + self.machine.before_prepare("do_prepare") + self.machine.before_start("do_start") + self.machine.before_unprepare("do_unprepare") + self.machine.before_finish("do_finish") + + async def do_prepare(self, event: EventData) -> None: ... + + async def do_unprepare(self, event: EventData) -> None: ... + + async def do_start(self, event: EventData) -> None: ... + + async def do_finish(self, event: EventData) -> None: ... + + async def is_successful(self, event: EventData) -> bool: + """Checks whether the WMS job is finished or not based on the result of + a bps-report or similar. Returns a True value if the batch is done and + good, a False value if it is still running. Raises an exception in any + other terminal WMS state (HELD or FAILED). + + ``` + bps_report: WmsStatusReport = get_wms_status_from_bps(...) + + match bps_report: + case WmsStatusReport(wms_status="FINISHED"): + return True + case WmsStatusReport(wms_status="HELD"): + raise WmsBlockedError() + case WmsStatusReport(wms_status="FAILED"): + raise WmsFailedError() + case WmsStatusReport(wms_status="RUNNING"): + return False + ``` + """ + return True + + +class GroupMachine(NodeMachine): + """Specific state model for a Node of kind StepGroup. + + A summary of the logic at each transition: + + - prepare + - create artifact output directory + - collect all relevant configuration Manifests + - render bps workflow artifacts + - create butler in collection(s) + + - start + - bps submit + - (after_start) determine bps submit directory + + - finish + - (condition) bps report == done + - create butler out collection(s) + + - fail + - read/parse bps output logs + + - stop (rollback) + - bps cancel + + - unprepare (rollback) + - remove artifact output directory + - Butler collections are not modified (paint-over pattern) + + Failure modes may include: + - Unwritable artifact output directory + - Manifests insufficient to render bps workflow artifacts + - Butler errors + - BPS or other middleware errors + """ + + __kind__ = [ManifestKind.step_group] + + ... + + +class StepCollectMachine(NodeMachine): + """Specific state model for a Node of kind StepCollect. + + - prepare + - create step output chained butler collection + + - start + - (condition) ancestor output collections exist in butler? + - add each ancestor output collection to step output chain + + - finish + - (condition) all ancestor output collections in chain + """ + + __kind__ = [ManifestKind.collect_groups] + + ... + + +@cache +def node_machine_factory(kind: ManifestKind) -> type[NodeMachine]: + """Returns the Stateful Model for a node based on its kind, by matching + the ``__kind__`` attribute of available classes in this module. + + TODO: May "construct" new classes from multiple matches, but this is not + yet necessary. + """ + for _, o in inspect.getmembers(sys.modules[__name__], inspect.isclass): + if issubclass(o, NodeMachine) and kind in o.__kind__: + return o + return NodeMachine diff --git a/src/lsst/cmservice/machines/tasks.py b/src/lsst/cmservice/machines/tasks.py new file mode 100644 index 000000000..93c74604d --- /dev/null +++ b/src/lsst/cmservice/machines/tasks.py @@ -0,0 +1,47 @@ +"""Background task implementations for FSM-related operations performed via +API routes. +""" + +from uuid import UUID + +from ..common.enums import StatusEnum +from ..common.logging import LOGGER +from ..db.campaigns_v2 import Campaign +from .campaign import CampaignMachine + +logger = LOGGER.bind(module=__name__) + + +async def change_campaign_state(campaign: Campaign, desired_state: StatusEnum, request_id: UUID) -> None: + """A Background Task to affect a state change in a Campaign, using an + FSM by triggering based on one of a handful of possible user-initiated + state changes, as by PATCHing a campaign using the REST API. + """ + + logger.info( + "Updating campaign state", + campaign=str(campaign.id), + request_id=str(request_id), + dest=desired_state.name, + ) + # Establish an FSM for the Campaign initialized to the current status + campaign_machine = CampaignMachine(o=campaign, initial_state=campaign.status) + + trigger: str + match (campaign.status, desired_state): + case (StatusEnum.waiting, StatusEnum.running): + trigger = "start" + case (StatusEnum.running, StatusEnum.paused): + trigger = "pause" + case (StatusEnum.paused, StatusEnum.running): + trigger = "resume" + case _: + logger.warning( + "Invalid campaign transition requested", + id=str(campaign.id), + source=campaign.status, + dest=desired_state, + ) + return None + + await campaign_machine.trigger(trigger, request_id=str(request_id)) diff --git a/src/lsst/cmservice/main.py b/src/lsst/cmservice/main.py index 5b53eec81..db3133e09 100644 --- a/src/lsst/cmservice/main.py +++ b/src/lsst/cmservice/main.py @@ -3,102 +3,33 @@ import uvicorn from fastapi import FastAPI -from safir.dependencies.db_session import db_session_dependency from safir.dependencies.http_client import http_client_dependency from safir.logging import configure_logging, configure_uvicorn_logging from safir.middleware.x_forwarded import XForwardedMiddleware from . import __version__ from .config import config +from .db.session import db_session_dependency from .routers import ( healthz, index, + tags_metadata, v1, + v2, ) from .web_app import web_app configure_uvicorn_logging(config.logging.level) configure_logging(profile=config.logging.profile, log_level=config.logging.level, name=config.asgi.title) -tags_metadata = [ - { - "name": "Loaders", - "description": "Operations that load Objects in to the DB.", - }, - { - "name": "Actions", - "description": "Operations perform actions on existing Objects in to the DB." - "In many cases this will result in the creating of new objects in the DB.", - }, - { - "name": "Campaigns", - "description": "Operations with `campaign`s. A `campaign` consists of several processing `step`s " - "which are run sequentially. A `campaign` also holds configuration such as a URL for a butler repo " - "and a production area. `campaign`s must be uniquely named withing a given `production`.", - }, - { - "name": "Steps", - "description": "Operations with `step`s. A `step` consists of several processing `group`s which " - "may be run in parallel. `step`s must be uniquely named within a give `campaign`.", - }, - { - "name": "Groups", - "description": "Operations with `groups`. A `group` can be processed in a single `workflow`, " - "but we also need to account for possible failures. `group`s must be uniquely named within a " - "given `step`.", - }, - { - "name": "Scripts", - "description": "Operations with `scripts`. A `script` does a single operation, either something" - "that is done asynchronously, such as making new collections in the Butler, or creating" - "new objects in the DB, such as new `steps` and `groups`.", - }, - { - "name": "Jobs", - "description": "Operations with `jobs`. A `job` runs a single `workflow`: keeps a count" - "of the results data products and keeps track of associated errors.", - }, - { - "name": "Pipetask Error Types", - "description": "Operations with `pipetask_error_type` table.", - }, - { - "name": "Pipetask Errors", - "description": "Operations with `pipetask_error` table.", - }, - { - "name": "Product Sets", - "description": "Operations with `product_set` table.", - }, - { - "name": "Task Sets", - "description": "Operations with `task_set` table.", - }, - { - "name": "Script Dependencies", - "description": "Operations with `script_dependency` table.", - }, - { - "name": "Step Dependencies", - "description": "Operations with `step_dependency` table.", - }, - { - "name": "Wms Task Reports", - "description": "Operations with `wms_task_report` table.", - }, - {"name": "Specifications", "description": "Operations with `specification` table."}, - {"name": "SpecBlocks", "description": "Operations with `spec_block` table."}, -] - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: """Hook FastAPI init/cleanups.""" app.state.tasks = set() # Dependency inits before app starts running - await db_session_dependency.initialize(config.db.url, config.db.password) - assert db_session_dependency._engine is not None - db_session_dependency._engine.echo = config.db.echo + await db_session_dependency.initialize() + assert db_session_dependency.engine is not None # App runs here... yield @@ -123,9 +54,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: app.include_router(healthz.health_router, prefix="") app.include_router(index.router, prefix="") app.include_router(v1.router, prefix=config.asgi.prefix) +app.include_router(v2.router, prefix=config.asgi.prefix) # Start the frontend web application. -app.mount(config.asgi.frontend_prefix, web_app) +if config.asgi.enable_frontend: + app.mount(config.asgi.frontend_prefix, web_app) if __name__ == "__main__": diff --git a/src/lsst/cmservice/models/element.py b/src/lsst/cmservice/models/element.py index a20ec71b2..eb4ebbc18 100644 --- a/src/lsst/cmservice/models/element.py +++ b/src/lsst/cmservice/models/element.py @@ -21,13 +21,13 @@ class ElementBase(BaseModel): metadata_: dict = Field(default_factory=dict) # Overrides for configuring child nodes - child_config: dict | str | None = None + child_config: dict | None = None # Overrides for making collection names - collections: dict | str | None = None + collections: dict | None = None # Overrides for which SpecBlocks to use in constructing child Nodes - spec_aliases: dict | str | None = None + spec_aliases: dict | None = None # Override for Callback handler class handler: str | None = None @@ -75,13 +75,13 @@ class ElementUpdate(BaseModel): metadata_: dict | None = None # Overrides for configuring child nodes - child_config: dict | str | None = None + child_config: dict | None = None # Overrides for making collection names - collections: dict | str | None = None + collections: dict | None = None # Overrides for which SpecBlocks to use in constructing child Nodes - spec_aliases: dict | str | None = None + spec_aliases: dict | None = None # Override for Callback handler class handler: str | None = None diff --git a/src/lsst/cmservice/models/interface.py b/src/lsst/cmservice/models/interface.py index 34336d622..7112724a7 100644 --- a/src/lsst/cmservice/models/interface.py +++ b/src/lsst/cmservice/models/interface.py @@ -138,13 +138,13 @@ class LoadAndCreateCampaign(YamlFileQuery): # If empty use {spec_name}#campaign spec_block_assoc_name: str | None = None # Parameter Overrides - data: dict | str | None = None + data: dict | None = None # Overrides for configuring child nodes - child_config: dict | str | None = None + child_config: dict | None = None # Overrides for making collection names - collections: dict | str | None = None + collections: dict | None = None # Overrides for which SpecBlocks to use in constructing child Nodes - spec_aliases: dict | str | None = None + spec_aliases: dict | None = None # Override for Callback handler class handler: str | None = None # Allow updating existing specifications diff --git a/src/lsst/cmservice/models/serde.py b/src/lsst/cmservice/models/serde.py new file mode 100644 index 000000000..02ac5c467 --- /dev/null +++ b/src/lsst/cmservice/models/serde.py @@ -0,0 +1,44 @@ +"""Module for serialization and deserialization support for pydantic and +other derivative models. +""" + +from enum import EnumType +from functools import partial +from typing import Any + +from pydantic import PlainSerializer, PlainValidator + +from ..common.enums import ManifestKind, StatusEnum + + +def EnumValidator[T: EnumType](value: Any, enum_: T) -> T: + """Create an enum from the input value. The input can be either the + enum name or its value. + + Used as a Validator for a pydantic field. + """ + try: + new_enum: T = enum_[value] if value in enum_.__members__ else enum_(value) + except (KeyError, ValueError): + raise ValueError(f"Value must be a member of {enum_.__qualname__}") + return new_enum + + +EnumSerializer = PlainSerializer( + lambda x: x.name, + return_type="str", + when_used="always", +) +"""A serializer for enums that produces its name, not the value.""" + + +StatusEnumValidator = PlainValidator(partial(EnumValidator, enum_=StatusEnum)) +"""A validator for the StatusEnum that can parse the enum from either a name +or a value. +""" + + +ManifestKindEnumValidator = PlainValidator(partial(EnumValidator, enum_=ManifestKind)) +"""A validator for the ManifestKindEnum that can parse the enum from a name +or a value. +""" diff --git a/src/lsst/cmservice/models/specification.py b/src/lsst/cmservice/models/specification.py index b7528dd79..462fd1a8f 100644 --- a/src/lsst/cmservice/models/specification.py +++ b/src/lsst/cmservice/models/specification.py @@ -14,16 +14,16 @@ class SpecificationBase(BaseModel): name: str # Parameter Overrides - data: dict | str | None = None + data: dict | None = None # Overrides for configuring child nodes - child_config: dict | str | None = None + child_config: dict | None = None # Overrides for making collection names - collections: dict | str | None = None + collections: dict | None = None # Overrides for which SpecBlocks to use in constructing child Nodes - spec_aliases: dict | str | None = None + spec_aliases: dict | None = None class SpecificationCreate(SpecificationBase): diff --git a/src/lsst/cmservice/parsing/string.py b/src/lsst/cmservice/parsing/string.py index 24a052ede..dc485a187 100644 --- a/src/lsst/cmservice/parsing/string.py +++ b/src/lsst/cmservice/parsing/string.py @@ -48,11 +48,11 @@ def parse_element_fullname(fullname: str) -> Fullname: fullname_r = re.compile( ( r"^" - r"(?P[\w]+){1}(?:\/)*" - r"(?P[\w]+){0,1}(?:\/)*" - r"(?P[\w]+){0,1}(?:\/)*" - r"(?P[\w]+){0,1}(?:\/)*" - r"(?P