Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 155 additions & 1 deletion src/lsst/cmservice/common/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, MutableSet, Sequence
from typing import Literal

import networkx as nx
from sqlalchemy import select

from ..db import Script, ScriptDependency, Step, StepDependency
from ..db.campaigns_v2 import Edge, Node
from ..parsing.string import parse_element_fullname
from .types import AnyAsyncSession

Expand Down Expand Up @@ -35,8 +38,159 @@ async def graph_from_edge_list(
return g


async def graph_from_edge_list_v2(
edges: Sequence[Edge],
session: AnyAsyncSession,
node_type: type[Node] = Node,
node_view: Literal["simple", "model"] = "model",
) -> nx.DiGraph:
"""Given a sequence of Edges, create a directed graph for these
edges with nodes derived from database lookups of the related objects.

Parameters
----------
edges: Sequence[Edge]
The list of edges forming the graph

node_type: type
The pydantic or sqlmodel class representing the graph node model

node_view: "simple" or "model"
Whether the node metadata in the graph should be simplified (dict) or
using the full expunged model form.

session
An async database session
"""
g = nx.DiGraph()
g.add_edges_from([(e.source, e.target) for e in edges])
relabel_mapping = {}

# The graph understands the nodes in terms of the IDs used in the edges,
# but we want to hydrate the entire Node model for subsequent users of this
# graph to reference without dipping back to the Database.
for node in g.nodes:
s = select(Node).where(Node.id == node)
db_node: Node = (await session.execute(s)).scalars().one()

# This Node is going on an adventure where it does not need to drag its
# SQLAlchemy baggage along, so we expunge it from the session before
# adding it to the graph.
session.expunge(db_node)
if node_view == "simple":
# for the simple node view, the goal is to minimize the amount of
# data attached to the node and ensure that this data is json-
# serializable and otherwise appropriate for an API response
g.nodes[node]["id"] = str(db_node.id)
g.nodes[node]["status"] = db_node.status.name
g.nodes[node]["kind"] = db_node.kind.name
relabel_mapping[node] = db_node.name
else:
g.nodes[node]["model"] = db_node

if relabel_mapping:
g = nx.relabel_nodes(g, mapping=relabel_mapping, copy=False)

# TODO validate graph now raise exception, or leave it to the caller?
return g


def graph_to_dict(g: nx.DiGraph) -> Mapping:
"""Renders a networkx directed graph to a mapping format suitable for JSON
serialization.

Notes
-----
The "edges" attribute name in the node link data is "edges" instead of the
default "links".
"""
return nx.node_link_data(g, edges="edges")


def validate_graph(g: nx.DiGraph, source: str = "START", sink: str = "END") -> bool:
"""Validates a graph by asserting by traversal that a complete and correct
path exists between `source` and `sink` nodes.

"Correct" means that there are no cycles or isolate nodes (nodes with
degree 0) and no nodes with degree 1.
"""
try:
# Test that G is a directed graph with no cycles
is_valid = nx.is_directed_acyclic_graph(g)
assert is_valid

# And that any path from source to sink exists
is_valid = nx.has_path(g, source, sink)
assert is_valid

# Guard against bad graphs where START and/or END have been connected
# such that they are no longer the only source and sink
...

# Test that there are no isolated Nodes in the graph. A node becomes
# isolated if it was involved with an edge that has been removed from
# G with no replacement edge added, in which case the node should also
# be removed.
is_valid = nx.number_of_isolates(g) == 0
assert is_valid

# TODO Given the set of nodes in the graph, consider all paths in G
# from source to sink, making sure every node appears in a path?

# Every node in G that is not the START/END node must have a degree
# of at least 2 (one inbound and one outbound edge). If G has any
# node with a degree of 1, it cannot be considered valid.
g_degree_view: Iterable = nx.degree(g, (n for n in g.nodes if n not in [source, sink]))
is_valid = min([d[1] for d in g_degree_view]) > 1
assert is_valid
except (nx.exception.NodeNotFound, AssertionError):
return False
return True


def processable_graph_nodes(g: nx.DiGraph) -> Iterable[Node]:
"""Traverse the graph G and produce an iterator of any nodes that are
candidates for processing, i.e., their status is waiting/prepared/running
and their ancestors are complete/successful. Graph nodes in a failed state
will block the graph and prevent candidacy for subsequent nodes.

Yields
------
`lsst.cmservice.db.campaigns_v2.Node`
A Node ORM object that has been ``expunge``d from its ``Session``.

Notes
-----
This function operates only on valid graphs (see `validate_graph()`) that
have been built by the `graph_from_edge_list_v2()` function, where each
graph-node is decorated with a "model" attribute referring to an expunged
instance of ``Node``. This ``Node`` can be ``add``ed back to a ``Session``
and manipulated in the usual way.
"""
processable_nodes: MutableSet[Node] = set()

# A valid campaign graph will have only one source (START) with in_degree 0
# and only one sink (END) with out_degree 0
source = next(v for v, d in g.in_degree() if d == 0)
Comment thread
tcjennings marked this conversation as resolved.
sink = next(v for v, d in g.out_degree() if d == 0)

# For each path through the graph, evaluate the state of nodes to determine
# which nodes are up for processing. When there are multiple paths, we have
# parallelization and common ancestors may be evaluated more than once,
# which is an exercise in optimization left as a TODO
for path in nx.all_simple_paths(g, source, sink):
for n in path:
node: Node = g.nodes[n]["model"]
if node.status.is_processable_element():
processable_nodes.add(node)
# We found a processable node in this path, stop traversal
break
elif node.status.is_bad():
# We reached a failed node in this path, it is blocked
break
else:
# This node must be in a "successful" terminal state
continue

# the inspection should stop when there are no more nodes to check
yield from processable_nodes
32 changes: 30 additions & 2 deletions src/lsst/cmservice/common/jsonpatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import operator
from collections.abc import MutableMapping, MutableSequence
from collections.abc import Mapping, MutableMapping, MutableSequence
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -51,7 +51,7 @@ def apply_json_patch[T: MutableMapping](op: JSONPatch, o: T) -> T:
numeric, e.g., {"1": "first", "2": "second"}
- Unsupported: JSON pointer values that refer to an entire object, e.g.,
"" -- the JSON Patch must have a root element ("/") per the model.
- Unsupported: JSON pointer values taht refer to a nameless object, e.g.,
- Unsupported: JSON pointer values that refer to a nameless object, e.g.,
"/" -- JSON allows object keys to be the empty string ("") but this is
disallowed by the application.
"""
Expand Down Expand Up @@ -222,3 +222,31 @@ def apply_json_patch[T: MutableMapping](op: JSONPatch, o: T) -> T:
raise JSONPatchError(f"Unknown JSON Patch operation: {op.op}")

return o


def apply_json_merge[T: MutableMapping](patch: Any, o: T) -> T:
"""Applies a patch to a mapping object as per the RFC7396 JSON Merge Patch.

Notably, this operation may only target a ``MutableMapping`` as an analogue
of a JSON object. This means that any keyed value in a Mapping may be
replaced, added, or removed by a JSON Merge. This is not appropriate for
patches that need to perform more tactical updates, such as modifying
elements of a ``Sequence``.

This function does not allow setting a field value in the target to `None`;
instead, any `None` value in a patch is an instruction to remove that
field from the target completely.

This function differs from the RFC in the following ways: it will not
replace the entire target object with a new mapping (i.e., the target must
be a Mapping).
"""
if isinstance(patch, Mapping):
for k, v in patch.items():
if v is None:
_ = o.pop(k, None)
else:
o[k] = apply_json_merge(v, o.get(k, {}))
return o
else:
return patch
88 changes: 24 additions & 64 deletions src/lsst/cmservice/db/campaigns_v2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""ORM Models for v2 tables and objects."""

from datetime import datetime
from typing import Any
from uuid import NAMESPACE_DNS, UUID, uuid5
Expand Down Expand Up @@ -43,15 +45,6 @@ def jsonb_column(name: str, aliases: list[str] | None = None) -> Any:
)


# NOTES
# - model validation is not triggered when table=True
# - Every object model needs to have three flavors:
# 1. the declarative model of the object's database table
# 2. the model of the manifest when creating a new object
# 3. the model of the manifest when updating an object
# 4. a response model for APIs related to the object


class BaseSQLModel(SQLModel):
__table_args__ = {"schema": config.db.table_schema}
metadata = metadata
Expand All @@ -71,10 +64,6 @@ class CampaignBase(BaseSQLModel):
metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"])
configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"])


class CampaignModel(CampaignBase):
"""model used for resource creation."""

@model_validator(mode="before")
@classmethod
def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
Expand All @@ -83,15 +72,15 @@ def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
"""
if isinstance(data, dict):
if "name" not in data:
raise ValueError("'name' must be specified.")
raise ValueError("<campaign> name missing.")
if "namespace" not in data:
data["namespace"] = _default_campaign_namespace
if "id" not in data:
data["id"] = uuid5(namespace=data["namespace"], name=data["name"])
return data


class Campaign(CampaignModel, table=True):
class Campaign(CampaignBase, table=True):
"""Model used for database operations involving campaigns_v2 table rows"""

__tablename__: str = "campaigns_v2" # type: ignore[misc]
Expand All @@ -111,6 +100,12 @@ class CampaignUpdate(BaseSQLModel):
class NodeBase(BaseSQLModel):
"""nodes_v2 db table"""

def __hash__(self) -> int:
"""A Node is hashable according to its unique ID, so it can be used in
sets and other places hashable types are required.
"""
return self.id.int

id: UUID = Field(primary_key=True)
name: str
namespace: UUID
Expand All @@ -119,33 +114,32 @@ class NodeBase(BaseSQLModel):
default=ManifestKind.other,
sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)),
)
status: StatusField | None = Field(
status: StatusField = Field(
default=StatusEnum.waiting,
sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)),
)
metadata_: dict = jsonb_column("metadata", aliases=["metadata", "metadata_"])
configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"])


class NodeModel(NodeBase):
"""model validating class for Nodes"""

@model_validator(mode="before")
@classmethod
def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
"""Validates the model based on different types of raw inputs,
where some default non-optional fields can be auto-populated.
"""
if isinstance(data, dict):
if "version" not in data:
data["version"] = 1
if "name" not in data:
raise ValueError("'name' must be specified.")
if "namespace" not in data:
data["namespace"] = _default_campaign_namespace
if (node_name := data.get("name")) is None:
raise ValueError("<node> name missing.")
if (node_namespace := data.get("namespace")) is None:
raise ValueError("<node> namespace missing.")
if (node_version := data.get("version")) is None:
data["version"] = node_version = 1
if "id" not in data:
data["id"] = uuid5(namespace=data["namespace"], name=f"""{data["name"]}.{data["version"]}""")
data["id"] = uuid5(namespace=node_namespace, name=f"{node_name}.{node_version}")
return data


class Node(NodeModel, table=True):
class Node(NodeBase, table=True):
__tablename__: str = "nodes_v2" # type: ignore[misc]

machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE")
Expand All @@ -163,28 +157,12 @@ class EdgeBase(BaseSQLModel):
configuration: dict = jsonb_column("configuration", aliases=["configuration", "data", "spec"])


class EdgeModel(EdgeBase):
"""model validating class for Edges"""

@model_validator(mode="before")
@classmethod
def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
if isinstance(data, dict):
if "name" not in data:
raise ValueError("'name' must be specified.")
if "namespace" not in data:
raise ValueError("Edges may only exist in a 'namespace'.")
if "id" not in data:
data["id"] = uuid5(namespace=data["namespace"], name=data["name"])
return data


class EdgeResponseModel(EdgeModel):
class EdgeResponseModel(EdgeBase):
source: Any
target: Any


class Edge(EdgeModel, table=True):
class Edge(EdgeBase, table=True):
__tablename__: str = "edges_v2" # type: ignore[misc]


Expand Down Expand Up @@ -216,24 +194,6 @@ class ManifestBase(BaseSQLModel):
spec: dict = jsonb_column("spec", aliases=["spec", "configuration", "data"])


class ManifestModel(ManifestBase):
"""model validating class for Manifests"""

@model_validator(mode="before")
@classmethod
def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any:
if isinstance(data, dict):
if "version" not in data:
data["version"] = 1
if "name" not in data:
raise ValueError("'name' must be specified.")
if "namespace" not in data:
data["namespace"] = _default_campaign_namespace
if "id" not in data:
data["id"] = uuid5(namespace=data["namespace"], name=f"""{data["name"]}.{data["version"]}""")
return data


class Manifest(ManifestBase, table=True):
__tablename__: str = "manifests_v2" # type: ignore[misc]

Expand Down
Loading