Skip to content

Commit 29780ee

Browse files
committed
feat(graph): Implement graph functions
1 parent 1f40e90 commit 29780ee

4 files changed

Lines changed: 346 additions & 3 deletions

File tree

src/lsst/cmservice/common/graph.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Iterable, Mapping, MutableSet, Sequence
22

33
import networkx as nx
4+
from sqlalchemy import select
45

56
from ..db import Script, ScriptDependency, Step, StepDependency
7+
from ..db.campaigns_v2 import Edge, Node
68
from ..parsing.string import parse_element_fullname
79
from .types import AnyAsyncSession
810

@@ -35,8 +37,125 @@ async def graph_from_edge_list(
3537
return g
3638

3739

40+
async def graph_from_edge_list_v2(
41+
edges: Sequence[Edge],
42+
node_type: type[Node],
43+
session: AnyAsyncSession,
44+
) -> nx.DiGraph:
45+
"""Given a sequence of Edges, create a directed graph for these
46+
edges with nodes derived from database lookups of the related objects.
47+
"""
48+
g = nx.DiGraph()
49+
g.add_edges_from([(e.source, e.target) for e in edges])
50+
51+
# The graph understands the nodes in terms of the IDs used in the edges,
52+
# but we want to hydrate the entire Node model for subsequent users of this
53+
# graph to reference without dipping back to the Database.
54+
for node in g.nodes:
55+
s = select(Node).where(Node.id == node)
56+
db_node: Node = (await session.execute(s)).scalars().one()
57+
58+
# This Node is going on an adventure where it does not need to drag its
59+
# SQLAlchemy baggage along, so we expunge it from the session before
60+
# adding it to the graph.
61+
session.expunge(db_node)
62+
g.nodes[node]["model"] = db_node
63+
64+
# TODO validate graph now raise exception, or leave it to the caller?
65+
return g
66+
67+
3868
def graph_to_dict(g: nx.DiGraph) -> Mapping:
3969
"""Renders a networkx directed graph to a mapping format suitable for JSON
4070
serialization.
4171
"""
4272
return nx.node_link_data(g, edges="edges")
73+
74+
75+
def validate_graph(g: nx.DiGraph, source: str = "START", sink: str = "END") -> bool:
76+
"""Validates a graph by asserting by traversal that a complete and correct
77+
path exists between `source` and `sink` nodes.
78+
79+
"Correct" means that there are no cycles or isolate nodes (nodes with
80+
degree 0) and no nodes with degree 1.
81+
"""
82+
try:
83+
# Test that G is a directed graph with no cycles
84+
is_valid = nx.is_directed_acyclic_graph(g)
85+
assert is_valid
86+
87+
# And that any path from source to sink exists
88+
is_valid = nx.has_path(g, source, sink)
89+
assert is_valid
90+
91+
# Guard against bad graphs where START and/or END have been connected
92+
# such that they are no longer the only source and sink
93+
...
94+
95+
# Test that there are no isolated Nodes in the graph. A node becomes
96+
# isolated if it was involved with an edge that has been removed from
97+
# G with no replacement edge added, in which case the node should also
98+
# be removed.
99+
is_valid = nx.number_of_isolates(g) == 0
100+
assert is_valid
101+
102+
# TODO Given the set of nodes in the graph, consider all paths in G
103+
# from source to sink, making sure every node appears in a path?
104+
105+
# Every node in G that is not the START/END node must have a degree
106+
# of at least 2 (one inbound and one outbound edge). If G has any
107+
# node with a degree of 1, it cannot be considered valid.
108+
g_degree_view: Iterable = nx.degree(g, (n for n in g.nodes if n not in [source, sink]))
109+
is_valid = min([d[1] for d in g_degree_view]) > 1
110+
assert is_valid
111+
except (nx.exception.NodeNotFound, AssertionError):
112+
return False
113+
return True
114+
115+
116+
def processable_graph_nodes(g: nx.DiGraph) -> Iterable[Node]:
117+
"""Traverse the graph G and produce an iterator of any nodes that are
118+
candidates for processing, i.e., their status is waiting/prepared/running
119+
and their ancestors are complete/successful. Graph nodes in a failed state
120+
will block the graph and prevent candidacy for subsequent nodes.
121+
122+
Yields
123+
------
124+
`lsst.cmservice.db.campaigns_v2.Node`
125+
A Node ORM object that has been ``expunge``d from its ``Session``.
126+
127+
Notes
128+
-----
129+
This function operates only on valid graphs (see `validate_graph()`) that
130+
have been built by the `graph_from_edge_list_v2()` function, where each
131+
graph-node is decorated with a "model" attribute referring to an expunged
132+
instance of ``Node``. This ``Node`` can be ``add``ed back to a ``Session``
133+
and manipulated in the usual way.
134+
"""
135+
processable_nodes: MutableSet[Node] = set()
136+
137+
# A valid campaign graph will have only one source (START) with in_degree 0
138+
# and only one sink (END) with out_degree 0
139+
source = next(v for v, d in g.in_degree() if d == 0)
140+
sink = next(v for v, d in g.out_degree() if d == 0)
141+
142+
# For each path through the graph, evaluate the state of nodes to determine
143+
# which nodes are up for processing. When there are multiple paths, we have
144+
# parallelization and common ancestors may be evaluated more than once,
145+
# which is an exercise in optimization left as a TODO
146+
for path in nx.all_simple_paths(g, source, sink):
147+
for n in path:
148+
node: Node = g.nodes[n]["model"]
149+
if node.status.is_processable_element():
150+
processable_nodes.add(node)
151+
# We found a processable node in this path, stop traversal
152+
break
153+
elif node.status.is_bad():
154+
# We reached a failed node in this path, it is blocked
155+
break
156+
else:
157+
# This node must be in a "successful" terminal state
158+
continue
159+
160+
# the inspection should stop when there are no more nodes to check
161+
yield from processable_nodes

src/lsst/cmservice/db/campaigns_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ class CampaignUpdate(BaseSQLModel):
100100
class NodeBase(BaseSQLModel):
101101
"""nodes_v2 db table"""
102102

103+
def __hash__(self) -> int:
104+
"""A Node is hashable according to its unique ID, so it can be used in
105+
sets and other places hashable types are required.
106+
"""
107+
return self.id.int
108+
103109
id: UUID = Field(primary_key=True)
104110
name: str
105111
namespace: UUID
@@ -108,7 +114,7 @@ class NodeBase(BaseSQLModel):
108114
default=ManifestKind.other,
109115
sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)),
110116
)
111-
status: StatusField | None = Field(
117+
status: StatusField = Field(
112118
default=StatusEnum.waiting,
113119
sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)),
114120
)

src/lsst/cmservice/routers/v2/edges.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def create_edge_resource(
126126
namespace: {campaign uuid}
127127
spec:
128128
source: {node name or id}
129-
target: {ndoe name or id}
129+
target: {node name or id}
130130
```
131131
"""
132132
edge_name = manifest.metadata_.name
@@ -153,6 +153,7 @@ async def create_edge_resource(
153153
# TODO the edge spec should support mappings for source/target nodes but
154154
# for now assume the provided name has `.vN` appended to it already or
155155
# default to v1
156+
# TODO support node id in spec
156157
source_node = f"{source_node}.1" if "." not in source_node else str(source_node)
157158
target_node = f"{target_node}.1" if "." not in target_node else str(target_node)
158159

tests/v2/test_graph.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Tests graph operations using v2 objects"""
2+
3+
from collections.abc import AsyncGenerator
4+
from uuid import uuid4
5+
6+
import networkx as nx
7+
import pytest
8+
import pytest_asyncio
9+
from httpx import AsyncClient
10+
11+
from lsst.cmservice.common.enums import StatusEnum
12+
from lsst.cmservice.common.graph import graph_from_edge_list_v2, processable_graph_nodes, validate_graph
13+
from lsst.cmservice.common.types import AnyAsyncSession
14+
from lsst.cmservice.db.campaigns_v2 import Edge, Node
15+
16+
pytestmark = pytest.mark.asyncio(loop_scope="module")
17+
"""All tests in this module will run in the same event loop."""
18+
19+
20+
@pytest_asyncio.fixture(scope="module", loop_scope="module")
21+
async def test_campaign(aclient: AsyncClient) -> AsyncGenerator[str]:
22+
"""Fixture managing a test campaign with two (additional) nodes."""
23+
campaign_name = uuid4().hex[-8:]
24+
node_ids = []
25+
26+
x = await aclient.post(
27+
"/cm-service/v2/campaigns",
28+
json={
29+
"apiVersion": "io.lsst.cmservice/v1",
30+
"kind": "campaign",
31+
"metadata": {"name": campaign_name},
32+
"spec": {},
33+
},
34+
)
35+
campaign_edge_url = x.headers["Edges"]
36+
campaign = x.json()
37+
38+
# create a trio of nodes for the campaign
39+
for _ in range(3):
40+
x = await aclient.post(
41+
"/cm-service/v2/nodes",
42+
json={
43+
"apiVersion": "io.lsst.cmservice/v1",
44+
"kind": "node",
45+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
46+
"spec": {},
47+
},
48+
)
49+
node = x.json()
50+
node_ids.append(node["name"])
51+
52+
# Create edges between each campaign node with parallelization
53+
_ = await aclient.post(
54+
"/cm-service/v2/edges",
55+
json={
56+
"apiVersion": "io.lsst.cmservice/v1",
57+
"kind": "edge",
58+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
59+
"spec": {
60+
"source": "START",
61+
"target": node_ids[0],
62+
},
63+
},
64+
)
65+
_ = await aclient.post(
66+
"/cm-service/v2/edges",
67+
json={
68+
"apiVersion": "io.lsst.cmservice/v1",
69+
"kind": "edge",
70+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
71+
"spec": {
72+
"source": node_ids[0],
73+
"target": node_ids[1],
74+
},
75+
},
76+
)
77+
_ = await aclient.post(
78+
"/cm-service/v2/edges",
79+
json={
80+
"apiVersion": "io.lsst.cmservice/v1",
81+
"kind": "edge",
82+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
83+
"spec": {
84+
"source": node_ids[0],
85+
"target": node_ids[2],
86+
},
87+
},
88+
)
89+
_ = await aclient.post(
90+
"/cm-service/v2/edges",
91+
json={
92+
"apiVersion": "io.lsst.cmservice/v1",
93+
"kind": "edge",
94+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
95+
"spec": {
96+
"source": node_ids[1],
97+
"target": "END",
98+
},
99+
},
100+
)
101+
_ = await aclient.post(
102+
"/cm-service/v2/edges",
103+
json={
104+
"apiVersion": "io.lsst.cmservice/v1",
105+
"kind": "edge",
106+
"metadata": {"name": uuid4().hex[-8:], "namespace": campaign["id"]},
107+
"spec": {
108+
"source": node_ids[2],
109+
"target": "END",
110+
},
111+
},
112+
)
113+
yield campaign_edge_url
114+
115+
116+
async def test_build_and_walk_graph(
117+
aclient: AsyncClient, session: AnyAsyncSession, test_campaign: str
118+
) -> None:
119+
"""Test the generation and traversal of a campaign graph as created in the
120+
``test_campaign`` fixture.
121+
122+
Test that the graph is traversed from START to END in order, and that as
123+
graph nodes are "processable" they can be handled. In this test, the status
124+
of each node is set to "accepted" and updated in the databse. The campaign
125+
graph is recreated between each stage of the mock graph processing.
126+
127+
The test campaign is a set of 3 nodes arranged in a graph:
128+
129+
```
130+
START --> A --> B --> END
131+
--> C -->
132+
```
133+
"""
134+
edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign)).json()]
135+
graph = await graph_from_edge_list_v2(edge_list, Node, session)
136+
137+
# the START node should be the only processable Node
138+
for node in processable_graph_nodes(graph):
139+
assert node.name == "START"
140+
assert node.status is StatusEnum.waiting
141+
# Add the Node back to the session and update its status
142+
session.add(node)
143+
await session.refresh(node)
144+
node.status = StatusEnum.accepted
145+
await session.commit()
146+
147+
# Repeat the graph building and traversal, this time expecting a single
148+
# node that is not "START"
149+
graph = await graph_from_edge_list_v2(edge_list, Node, session)
150+
for node in processable_graph_nodes(graph):
151+
assert node.name != "START"
152+
assert node.status is StatusEnum.waiting
153+
# Add the Node back to the session and update its status
154+
session.add(node)
155+
await session.refresh(node)
156+
node.status = StatusEnum.accepted
157+
await session.commit()
158+
159+
# Repeat the graph building and traversal, this time expecting a pair of
160+
# nodes processable in parallel
161+
graph = await graph_from_edge_list_v2(edge_list, Node, session)
162+
count = 0
163+
for node in processable_graph_nodes(graph):
164+
count += 1
165+
assert node.name != "START"
166+
assert node.status is StatusEnum.waiting
167+
# Add the Node back to the session and update its status
168+
session.add(node)
169+
await session.refresh(node)
170+
node.status = StatusEnum.accepted
171+
await session.commit()
172+
assert count == 2
173+
174+
# Finally, expect the END node
175+
graph = await graph_from_edge_list_v2(edge_list, Node, session)
176+
for node in processable_graph_nodes(graph):
177+
assert node.name == "END"
178+
assert node.status is StatusEnum.waiting
179+
# Add the Node back to the session and update its status
180+
session.add(node)
181+
await session.refresh(node)
182+
node.status = StatusEnum.accepted
183+
await session.commit()
184+
185+
186+
def test_validate_graph() -> None:
187+
"""Test basic graph validation operations using a simple DAG."""
188+
edge_list = [("A", "B"), ("B", "C"), ("C", "D"), ("C", "E"), ("D", "F"), ("E", "F")]
189+
190+
g = nx.DiGraph()
191+
g.add_edges_from(edge_list)
192+
193+
# this is a valid graph
194+
assert validate_graph(g, "A", "F")
195+
196+
# add a new parallel node with no path to sink
197+
g.add_edge("C", "CC")
198+
assert not validate_graph(g, "A", "F")
199+
200+
# create a cycle with the new node
201+
g.add_edge("CC", "A")
202+
assert not validate_graph(g, "A", "F")
203+
204+
# correct the path
205+
g.remove_edge("CC", "A")
206+
g.add_edge("CC", "F")
207+
assert validate_graph(g, "A", "F")
208+
209+
# remove the edges from a node
210+
g.remove_edge("CC", "F")
211+
g.remove_edge("C", "CC")
212+
# the graph is invalid because "CC" is now an isolate
213+
assert not validate_graph(g, "A", "F")
214+
215+
# remove the unneeded node
216+
g.remove_node("CC")
217+
assert validate_graph(g, "A", "F")

0 commit comments

Comments
 (0)