Skip to content

Commit 7daf298

Browse files
committed
feat(api): Implement v2 campaign graph api
1 parent 29780ee commit 7daf298

3 files changed

Lines changed: 111 additions & 11 deletions

File tree

src/lsst/cmservice/common/graph.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Iterable, Mapping, MutableSet, Sequence
2+
from typing import Literal
23

34
import networkx as nx
45
from sqlalchemy import select
@@ -39,14 +40,31 @@ async def graph_from_edge_list(
3940

4041
async def graph_from_edge_list_v2(
4142
edges: Sequence[Edge],
42-
node_type: type[Node],
4343
session: AnyAsyncSession,
44+
node_type: type[Node] = Node,
45+
node_view: Literal["simple", "model"] = "model",
4446
) -> nx.DiGraph:
4547
"""Given a sequence of Edges, create a directed graph for these
4648
edges with nodes derived from database lookups of the related objects.
49+
50+
Parameters
51+
----------
52+
edges: Sequence[Edge]
53+
The list of edges forming the graph
54+
55+
node_type: type
56+
The pydantic or sqlmodel class representing the graph node model
57+
58+
node_view: "simple" or "model"
59+
Whether the node metadata in the graph should be simplified (dict) or
60+
using the full expunged model form.
61+
62+
session
63+
An async database session
4764
"""
4865
g = nx.DiGraph()
4966
g.add_edges_from([(e.source, e.target) for e in edges])
67+
relabel_mapping = {}
5068

5169
# The graph understands the nodes in terms of the IDs used in the edges,
5270
# but we want to hydrate the entire Node model for subsequent users of this
@@ -59,7 +77,19 @@ async def graph_from_edge_list_v2(
5977
# SQLAlchemy baggage along, so we expunge it from the session before
6078
# adding it to the graph.
6179
session.expunge(db_node)
62-
g.nodes[node]["model"] = db_node
80+
if node_view == "simple":
81+
# for the simple node view, the goal is to minimize the amount of
82+
# data attached to the node and ensure that this data is json-
83+
# serializable and otherwise appropriate for an API response
84+
g.nodes[node]["id"] = str(db_node.id)
85+
g.nodes[node]["status"] = db_node.status.name
86+
g.nodes[node]["kind"] = db_node.kind.name
87+
relabel_mapping[node] = db_node.name
88+
else:
89+
g.nodes[node]["model"] = db_node
90+
91+
if relabel_mapping:
92+
g = nx.relabel_nodes(g, mapping=relabel_mapping, copy=False)
6393

6494
# TODO validate graph now raise exception, or leave it to the caller?
6595
return g
@@ -68,6 +98,11 @@ async def graph_from_edge_list_v2(
6898
def graph_to_dict(g: nx.DiGraph) -> Mapping:
6999
"""Renders a networkx directed graph to a mapping format suitable for JSON
70100
serialization.
101+
102+
Notes
103+
-----
104+
The "edges" attribute name in the node link data is "edges" instead of the
105+
default "links".
71106
"""
72107
return nx.node_link_data(g, edges="edges")
73108

src/lsst/cmservice/routers/v2/campaigns.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
representing campaign objects within CM-Service.
55
"""
66

7-
from collections.abc import Sequence
7+
from collections.abc import Mapping, Sequence
88
from typing import TYPE_CHECKING, Annotated
99
from uuid import UUID, uuid5
1010

@@ -13,6 +13,7 @@
1313
from sqlmodel import col, select
1414
from sqlmodel.ext.asyncio.session import AsyncSession
1515

16+
from ...common.graph import graph_from_edge_list_v2, graph_to_dict
1617
from ...common.logging import LOGGER
1718
from ...db.campaigns_v2 import Campaign, CampaignUpdate, Edge, Node
1819
from ...db.manifests_v2 import CampaignManifest
@@ -40,7 +41,9 @@ async def read_campaign_collection(
4041
limit: Annotated[int, Query(le=100)] = 10,
4142
offset: Annotated[int, Query()] = 0,
4243
) -> Sequence[Campaign]:
43-
"""..."""
44+
"""A paginated API returning a list of all Campaigns known to the
45+
application.
46+
"""
4447
try:
4548
campaigns = await session.exec(select(Campaign).offset(offset).limit(limit))
4649

@@ -187,7 +190,9 @@ async def read_campaign_node_collection(
187190
limit: Annotated[int, Query(le=100)] = 10,
188191
offset: Annotated[int, Query()] = 0,
189192
) -> Sequence[Node]:
190-
# This is a convenience api that could also be `/nodes?campaign=...
193+
"""A paginated API returning a list of all Nodes in the namespace of a
194+
single Campaign.
195+
"""
191196

192197
# The input could be a campaign UUID or it could be a literal name.
193198
# TODO this could just as well be a campaign query with a join to nodes
@@ -222,7 +227,10 @@ async def read_campaign_edge_collection(
222227
*,
223228
resolve_names: bool = False,
224229
) -> Sequence[Edge]:
225-
# This is a convenience api that could also be `/edges?campaign=...
230+
"""A paginated API returning a list of all Edges in the namespace of a
231+
single Campaign. This list of Edges can be used to construct the Campaign
232+
graph.
233+
"""
226234

227235
# The input could be a campaign UUID or it could be a literal name.
228236
# This is why raw SQL is better than ORMs
@@ -301,6 +309,7 @@ async def create_campaign_resource(
301309
session: Annotated[AsyncSession, Depends(db_session_dependency)],
302310
manifest: CampaignManifest,
303311
) -> Campaign:
312+
"""An API to create a Campaign from an appropriate Manifest."""
304313
# Create a campaign spec from the manifest, delegating the creation of new
305314
# dynamic fields to the model validation method, -OR- create new dynamic
306315
# fields here.
@@ -333,3 +342,46 @@ async def create_campaign_resource(
333342
)
334343

335344
return campaign
345+
346+
347+
@router.get(
348+
"/{campaign_name_or_id}/graph",
349+
status_code=200,
350+
summary="Construct and return a Campaign's graph of nodes",
351+
)
352+
async def read_campaign_graph(
353+
request: Request,
354+
response: Response,
355+
campaign_name_or_id: str,
356+
session: Annotated[AsyncSession, Depends(db_session_dependency)],
357+
) -> Mapping:
358+
"""Reads the graph resource for a campaign and returns its JSON represent-
359+
ation as serialized by the ``networkx.node_link_data()` function, i.e, the
360+
"node-link format".
361+
"""
362+
363+
# The input could be a campaign UUID or it could be a literal name.
364+
campaign_id: UUID | None
365+
try:
366+
campaign_id = UUID(campaign_name_or_id)
367+
except ValueError:
368+
s = select(Campaign.id).where(Campaign.name == campaign_name_or_id)
369+
campaign_id = (await session.exec(s)).one_or_none()
370+
371+
if campaign_id is None:
372+
raise HTTPException(status_code=404, detail="No such campaign found.")
373+
374+
# Fetch the Edges for the campaign
375+
statement = select(Edge).filter_by(namespace=campaign_id)
376+
edges = (await session.exec(statement)).all()
377+
378+
# Organize the edges into a graph. The graph nodes are annotated with their
379+
# current database attributes.
380+
# TODO it makes sense for the graph to include expunged Nodes in the meta-
381+
# data for campaign processing, but for the purposes of this api route,
382+
# only the most relevant information should be associated with each node,
383+
# e.g., its name, status, id, and its URL
384+
graph = await graph_from_edge_list_v2(edges=edges, node_type=Node, session=session, node_view="simple")
385+
386+
response.headers["Self"] = ""
387+
return graph_to_dict(graph)

tests/v2/test_graph.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lsst.cmservice.common.enums import StatusEnum
1212
from lsst.cmservice.common.graph import graph_from_edge_list_v2, processable_graph_nodes, validate_graph
1313
from lsst.cmservice.common.types import AnyAsyncSession
14-
from lsst.cmservice.db.campaigns_v2 import Edge, Node
14+
from lsst.cmservice.db.campaigns_v2 import Edge
1515

1616
pytestmark = pytest.mark.asyncio(loop_scope="module")
1717
"""All tests in this module will run in the same event loop."""
@@ -132,7 +132,7 @@ async def test_build_and_walk_graph(
132132
```
133133
"""
134134
edge_list = [Edge.model_validate(edge) for edge in (await aclient.get(test_campaign)).json()]
135-
graph = await graph_from_edge_list_v2(edge_list, Node, session)
135+
graph = await graph_from_edge_list_v2(edge_list, session)
136136

137137
# the START node should be the only processable Node
138138
for node in processable_graph_nodes(graph):
@@ -146,7 +146,7 @@ async def test_build_and_walk_graph(
146146

147147
# Repeat the graph building and traversal, this time expecting a single
148148
# node that is not "START"
149-
graph = await graph_from_edge_list_v2(edge_list, Node, session)
149+
graph = await graph_from_edge_list_v2(edge_list, session)
150150
for node in processable_graph_nodes(graph):
151151
assert node.name != "START"
152152
assert node.status is StatusEnum.waiting
@@ -158,7 +158,7 @@ async def test_build_and_walk_graph(
158158

159159
# Repeat the graph building and traversal, this time expecting a pair of
160160
# nodes processable in parallel
161-
graph = await graph_from_edge_list_v2(edge_list, Node, session)
161+
graph = await graph_from_edge_list_v2(edge_list, session)
162162
count = 0
163163
for node in processable_graph_nodes(graph):
164164
count += 1
@@ -172,7 +172,7 @@ async def test_build_and_walk_graph(
172172
assert count == 2
173173

174174
# Finally, expect the END node
175-
graph = await graph_from_edge_list_v2(edge_list, Node, session)
175+
graph = await graph_from_edge_list_v2(edge_list, session)
176176
for node in processable_graph_nodes(graph):
177177
assert node.name == "END"
178178
assert node.status is StatusEnum.waiting
@@ -215,3 +215,16 @@ def test_validate_graph() -> None:
215215
# remove the unneeded node
216216
g.remove_node("CC")
217217
assert validate_graph(g, "A", "F")
218+
219+
220+
async def test_campaign_graph_route(aclient: AsyncClient, test_campaign: str) -> None:
221+
"""Tests the acquisition of a serialized graph from a REST endpoint and
222+
the subsequent reconstruction of a valid graph from the node-link data.
223+
"""
224+
graph_url = test_campaign.replace("/edges", "/graph")
225+
x = await aclient.get(graph_url)
226+
assert x.is_success
227+
228+
# Test reconstruction of the serialized graph
229+
graph = nx.node_link_graph(x.json(), edges="edges")
230+
assert validate_graph(graph)

0 commit comments

Comments
 (0)