Skip to content

Commit 91d2530

Browse files
authored
Merge pull request #2612 from alicevision/dev/graphIO
Refactor Graph de/serialization
2 parents ebf2270 + 0594f59 commit 91d2530

19 files changed

+1537
-767
lines changed

bin/meshroom_batch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ with meshroom.core.graph.GraphModification(graph):
154154
# initialize template pipeline
155155
loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items())
156156
if args.pipeline.lower() in loweredPipelineTemplates:
157-
graph.load(loweredPipelineTemplates[args.pipeline.lower()], setupProjectFile=False, publishOutputs=True if args.output else False)
157+
graph.initFromTemplate(loweredPipelineTemplates[args.pipeline.lower()], publishOutputs=True if args.output else False)
158158
else:
159159
# custom pipeline
160-
graph.load(args.pipeline, setupProjectFile=False, publishOutputs=True if args.output else False)
160+
graph.initFromTemplate(args.pipeline, publishOutputs=True if args.output else False)
161161

162162
def parseInputs(inputs, uniqueInitNode):
163163
"""Utility method for parsing the input and inputRecursive arguments."""

meshroom/core/attribute.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,12 @@ def _applyExpr(self):
339339
elif self.isInput and Attribute.isLinkExpression(v):
340340
# value is a link to another attribute
341341
link = v[1:-1]
342-
linkNode, linkAttr = link.split('.')
342+
linkNodeName, linkAttrName = link.split('.')
343343
try:
344-
g.addEdge(g.node(linkNode).attribute(linkAttr), self)
344+
node = g.node(linkNodeName)
345+
if not node:
346+
raise KeyError(f"Node '{linkNodeName}' not found")
347+
g.addEdge(node.attribute(linkAttrName), self)
345348
except KeyError as err:
346349
logging.warning('Connect Attribute from Expression failed.')
347350
logging.warning('Expression: "{exp}"\nError: "{err}".'.format(exp=v, err=err))

meshroom/core/graph.py

Lines changed: 282 additions & 429 deletions
Large diffs are not rendered by default.

meshroom/core/graphIO.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from enum import Enum
2+
from typing import Any, TYPE_CHECKING, Union
3+
4+
import meshroom
5+
from meshroom.core import Version
6+
from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute
7+
from meshroom.core.node import Node
8+
9+
if TYPE_CHECKING:
10+
from meshroom.core.graph import Graph
11+
12+
13+
class GraphIO:
14+
"""Centralize Graph file keys and IO version."""
15+
16+
__version__ = "2.0"
17+
18+
class Keys(object):
19+
"""File Keys."""
20+
21+
# Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value)
22+
Header = "header"
23+
NodesVersions = "nodesVersions"
24+
ReleaseVersion = "releaseVersion"
25+
FileVersion = "fileVersion"
26+
Graph = "graph"
27+
Template = "template"
28+
29+
class Features(Enum):
30+
"""File Features."""
31+
32+
Graph = "graph"
33+
Header = "header"
34+
NodesVersions = "nodesVersions"
35+
PrecomputedOutputs = "precomputedOutputs"
36+
NodesPositions = "nodesPositions"
37+
38+
@staticmethod
39+
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
40+
"""Return the list of supported features based on a file version.
41+
42+
Args:
43+
fileVersion (str, Version): the file version
44+
45+
Returns:
46+
tuple of GraphIO.Features: the list of supported features
47+
"""
48+
if isinstance(fileVersion, str):
49+
fileVersion = Version(fileVersion)
50+
51+
features = [GraphIO.Features.Graph]
52+
if fileVersion >= Version("1.0"):
53+
features += [
54+
GraphIO.Features.Header,
55+
GraphIO.Features.NodesVersions,
56+
GraphIO.Features.PrecomputedOutputs,
57+
]
58+
59+
if fileVersion >= Version("1.1"):
60+
features += [GraphIO.Features.NodesPositions]
61+
62+
return tuple(features)
63+
64+
65+
class GraphSerializer:
66+
"""Standard Graph serializer."""
67+
68+
def __init__(self, graph: "Graph") -> None:
69+
self._graph = graph
70+
71+
def serialize(self) -> dict:
72+
"""
73+
Serialize the Graph.
74+
"""
75+
return {
76+
GraphIO.Keys.Header: self.serializeHeader(),
77+
GraphIO.Keys.Graph: self.serializeContent(),
78+
}
79+
80+
@property
81+
def nodes(self) -> list[Node]:
82+
return self._graph.nodes
83+
84+
def serializeHeader(self) -> dict:
85+
"""Build and return the graph serialization header.
86+
87+
The header contains metadata about the graph, such as the:
88+
- version of the software used to create it.
89+
- version of the file format.
90+
- version of the nodes types used in the graph.
91+
- template flag.
92+
"""
93+
header: dict[str, Any] = {}
94+
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
95+
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
96+
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
97+
return header
98+
99+
def _getNodeTypesVersions(self) -> dict[str, str]:
100+
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
101+
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
102+
nodeTypesVersions = {
103+
nodeType.__name__: version
104+
for nodeType in nodeTypes
105+
if (version := meshroom.core.nodeVersion(nodeType)) is not None
106+
}
107+
# Sort them by name (to avoid random order changing from one save to another).
108+
return dict(sorted(nodeTypesVersions.items()))
109+
110+
def serializeContent(self) -> dict:
111+
"""Graph content serialization logic."""
112+
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
113+
114+
def serializeNode(self, node: Node) -> dict:
115+
"""Node serialization logic."""
116+
return node.toDict()
117+
118+
119+
class TemplateGraphSerializer(GraphSerializer):
120+
"""Serializer for serializing a graph as a template."""
121+
122+
def serializeHeader(self) -> dict:
123+
header = super().serializeHeader()
124+
header[GraphIO.Keys.Template] = True
125+
return header
126+
127+
def serializeNode(self, node: Node) -> dict:
128+
"""Adapt node serialization to template graphs.
129+
130+
Instead of getting all the inputs and internal attribute keys, only get the keys of
131+
the attributes whose value is not the default one.
132+
The output attributes, UIDs, parallelization parameters and internal folder are
133+
not relevant for templates, so they are explicitly removed from the returned dictionary.
134+
"""
135+
# For now, implemented as a post-process to update the default serialization.
136+
nodeData = super().serializeNode(node)
137+
138+
inputKeys = list(nodeData["inputs"].keys())
139+
140+
internalInputKeys = []
141+
internalInputs = nodeData.get("internalInputs", None)
142+
if internalInputs:
143+
internalInputKeys = list(internalInputs.keys())
144+
145+
for attrName in inputKeys:
146+
attribute = node.attribute(attrName)
147+
# check that attribute is not a link for choice attributes
148+
if attribute.isDefault and not attribute.isLink:
149+
del nodeData["inputs"][attrName]
150+
151+
for attrName in internalInputKeys:
152+
attribute = node.internalAttribute(attrName)
153+
# check that internal attribute is not a link for choice attributes
154+
if attribute.isDefault and not attribute.isLink:
155+
del nodeData["internalInputs"][attrName]
156+
157+
# If all the internal attributes are set to their default values, remove the entry
158+
if len(nodeData["internalInputs"]) == 0:
159+
del nodeData["internalInputs"]
160+
161+
del nodeData["outputs"]
162+
del nodeData["uid"]
163+
del nodeData["internalFolder"]
164+
del nodeData["parallelization"]
165+
166+
return nodeData
167+
168+
169+
class PartialGraphSerializer(GraphSerializer):
170+
"""Serializer to serialize a partial graph (a subset of nodes)."""
171+
172+
def __init__(self, graph: "Graph", nodes: list[Node]):
173+
super().__init__(graph)
174+
self._nodes = nodes
175+
176+
@property
177+
def nodes(self) -> list[Node]:
178+
"""Override to consider only the subset of nodes."""
179+
return self._nodes
180+
181+
def serializeNode(self, node: Node) -> dict:
182+
"""Adapt node serialization to partial graph serialization."""
183+
# NOTE: For now, implemented as a post-process to the default serialization.
184+
nodeData = super().serializeNode(node)
185+
186+
# Override input attributes with custom serialization logic, to handle attributes
187+
# connected to nodes that are not in the list of nodes to serialize.
188+
for attributeName in nodeData["inputs"]:
189+
nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName))
190+
191+
# Clear UID for non-compatibility nodes, as the custom attribute serialization
192+
# can be impacting the UID by removing connections to missing nodes.
193+
if not node.isCompatibilityNode:
194+
del nodeData["uid"]
195+
196+
return nodeData
197+
198+
def _serializeAttribute(self, attribute: Attribute) -> Any:
199+
"""
200+
Serialize `attribute` (recursively for list/groups) and deal with attributes being connected
201+
to nodes that are not part of the partial list of nodes to serialize.
202+
"""
203+
linkParam = attribute.getLinkParam()
204+
205+
if linkParam is not None:
206+
# Use standard link serialization if upstream node is part of the serialization.
207+
if linkParam.node in self.nodes:
208+
return attribute.getExportValue()
209+
# Skip link serialization otherwise.
210+
# If part of a list, this entry can be discarded.
211+
if isinstance(attribute.root, ListAttribute):
212+
return None
213+
# Otherwise, return the default value for this attribute.
214+
return attribute.defaultValue()
215+
216+
if isinstance(attribute, ListAttribute):
217+
# Recusively serialize each child of the ListAttribute, skipping those for which the attribute
218+
# serialization logic above returns None.
219+
return [
220+
exportValue
221+
for child in attribute
222+
if (exportValue := self._serializeAttribute(child)) is not None
223+
]
224+
225+
if isinstance(attribute, GroupAttribute):
226+
# Recursively serialize each child of the group attribute.
227+
return {name: self._serializeAttribute(child) for name, child in attribute.value.items()}
228+
229+
return attribute.getExportValue()
230+
231+

0 commit comments

Comments
 (0)