|
| 1 | +from enum import Enum |
| 2 | +from typing import Any, TYPE_CHECKING, Union |
| 3 | + |
| 4 | +import meshroom |
| 5 | +from meshroom.core import Version |
| 6 | +from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute |
| 7 | +from meshroom.core.node import Node |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from meshroom.core.graph import Graph |
| 11 | + |
| 12 | + |
| 13 | +class GraphIO: |
| 14 | + """Centralize Graph file keys and IO version.""" |
| 15 | + |
| 16 | + __version__ = "2.0" |
| 17 | + |
| 18 | + class Keys(object): |
| 19 | + """File Keys.""" |
| 20 | + |
| 21 | + # Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value) |
| 22 | + Header = "header" |
| 23 | + NodesVersions = "nodesVersions" |
| 24 | + ReleaseVersion = "releaseVersion" |
| 25 | + FileVersion = "fileVersion" |
| 26 | + Graph = "graph" |
| 27 | + Template = "template" |
| 28 | + |
| 29 | + class Features(Enum): |
| 30 | + """File Features.""" |
| 31 | + |
| 32 | + Graph = "graph" |
| 33 | + Header = "header" |
| 34 | + NodesVersions = "nodesVersions" |
| 35 | + PrecomputedOutputs = "precomputedOutputs" |
| 36 | + NodesPositions = "nodesPositions" |
| 37 | + |
| 38 | + @staticmethod |
| 39 | + def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]: |
| 40 | + """Return the list of supported features based on a file version. |
| 41 | +
|
| 42 | + Args: |
| 43 | + fileVersion (str, Version): the file version |
| 44 | +
|
| 45 | + Returns: |
| 46 | + tuple of GraphIO.Features: the list of supported features |
| 47 | + """ |
| 48 | + if isinstance(fileVersion, str): |
| 49 | + fileVersion = Version(fileVersion) |
| 50 | + |
| 51 | + features = [GraphIO.Features.Graph] |
| 52 | + if fileVersion >= Version("1.0"): |
| 53 | + features += [ |
| 54 | + GraphIO.Features.Header, |
| 55 | + GraphIO.Features.NodesVersions, |
| 56 | + GraphIO.Features.PrecomputedOutputs, |
| 57 | + ] |
| 58 | + |
| 59 | + if fileVersion >= Version("1.1"): |
| 60 | + features += [GraphIO.Features.NodesPositions] |
| 61 | + |
| 62 | + return tuple(features) |
| 63 | + |
| 64 | + |
| 65 | +class GraphSerializer: |
| 66 | + """Standard Graph serializer.""" |
| 67 | + |
| 68 | + def __init__(self, graph: "Graph") -> None: |
| 69 | + self._graph = graph |
| 70 | + |
| 71 | + def serialize(self) -> dict: |
| 72 | + """ |
| 73 | + Serialize the Graph. |
| 74 | + """ |
| 75 | + return { |
| 76 | + GraphIO.Keys.Header: self.serializeHeader(), |
| 77 | + GraphIO.Keys.Graph: self.serializeContent(), |
| 78 | + } |
| 79 | + |
| 80 | + @property |
| 81 | + def nodes(self) -> list[Node]: |
| 82 | + return self._graph.nodes |
| 83 | + |
| 84 | + def serializeHeader(self) -> dict: |
| 85 | + """Build and return the graph serialization header. |
| 86 | +
|
| 87 | + The header contains metadata about the graph, such as the: |
| 88 | + - version of the software used to create it. |
| 89 | + - version of the file format. |
| 90 | + - version of the nodes types used in the graph. |
| 91 | + - template flag. |
| 92 | + """ |
| 93 | + header: dict[str, Any] = {} |
| 94 | + header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ |
| 95 | + header[GraphIO.Keys.FileVersion] = GraphIO.__version__ |
| 96 | + header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions() |
| 97 | + return header |
| 98 | + |
| 99 | + def _getNodeTypesVersions(self) -> dict[str, str]: |
| 100 | + """Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances.""" |
| 101 | + nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)]) |
| 102 | + nodeTypesVersions = { |
| 103 | + nodeType.__name__: version |
| 104 | + for nodeType in nodeTypes |
| 105 | + if (version := meshroom.core.nodeVersion(nodeType)) is not None |
| 106 | + } |
| 107 | + # Sort them by name (to avoid random order changing from one save to another). |
| 108 | + return dict(sorted(nodeTypesVersions.items())) |
| 109 | + |
| 110 | + def serializeContent(self) -> dict: |
| 111 | + """Graph content serialization logic.""" |
| 112 | + return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)} |
| 113 | + |
| 114 | + def serializeNode(self, node: Node) -> dict: |
| 115 | + """Node serialization logic.""" |
| 116 | + return node.toDict() |
| 117 | + |
| 118 | + |
| 119 | +class TemplateGraphSerializer(GraphSerializer): |
| 120 | + """Serializer for serializing a graph as a template.""" |
| 121 | + |
| 122 | + def serializeHeader(self) -> dict: |
| 123 | + header = super().serializeHeader() |
| 124 | + header[GraphIO.Keys.Template] = True |
| 125 | + return header |
| 126 | + |
| 127 | + def serializeNode(self, node: Node) -> dict: |
| 128 | + """Adapt node serialization to template graphs. |
| 129 | + |
| 130 | + Instead of getting all the inputs and internal attribute keys, only get the keys of |
| 131 | + the attributes whose value is not the default one. |
| 132 | + The output attributes, UIDs, parallelization parameters and internal folder are |
| 133 | + not relevant for templates, so they are explicitly removed from the returned dictionary. |
| 134 | + """ |
| 135 | + # For now, implemented as a post-process to update the default serialization. |
| 136 | + nodeData = super().serializeNode(node) |
| 137 | + |
| 138 | + inputKeys = list(nodeData["inputs"].keys()) |
| 139 | + |
| 140 | + internalInputKeys = [] |
| 141 | + internalInputs = nodeData.get("internalInputs", None) |
| 142 | + if internalInputs: |
| 143 | + internalInputKeys = list(internalInputs.keys()) |
| 144 | + |
| 145 | + for attrName in inputKeys: |
| 146 | + attribute = node.attribute(attrName) |
| 147 | + # check that attribute is not a link for choice attributes |
| 148 | + if attribute.isDefault and not attribute.isLink: |
| 149 | + del nodeData["inputs"][attrName] |
| 150 | + |
| 151 | + for attrName in internalInputKeys: |
| 152 | + attribute = node.internalAttribute(attrName) |
| 153 | + # check that internal attribute is not a link for choice attributes |
| 154 | + if attribute.isDefault and not attribute.isLink: |
| 155 | + del nodeData["internalInputs"][attrName] |
| 156 | + |
| 157 | + # If all the internal attributes are set to their default values, remove the entry |
| 158 | + if len(nodeData["internalInputs"]) == 0: |
| 159 | + del nodeData["internalInputs"] |
| 160 | + |
| 161 | + del nodeData["outputs"] |
| 162 | + del nodeData["uid"] |
| 163 | + del nodeData["internalFolder"] |
| 164 | + del nodeData["parallelization"] |
| 165 | + |
| 166 | + return nodeData |
| 167 | + |
| 168 | + |
| 169 | +class PartialGraphSerializer(GraphSerializer): |
| 170 | + """Serializer to serialize a partial graph (a subset of nodes).""" |
| 171 | + |
| 172 | + def __init__(self, graph: "Graph", nodes: list[Node]): |
| 173 | + super().__init__(graph) |
| 174 | + self._nodes = nodes |
| 175 | + |
| 176 | + @property |
| 177 | + def nodes(self) -> list[Node]: |
| 178 | + """Override to consider only the subset of nodes.""" |
| 179 | + return self._nodes |
| 180 | + |
| 181 | + def serializeNode(self, node: Node) -> dict: |
| 182 | + """Adapt node serialization to partial graph serialization.""" |
| 183 | + # NOTE: For now, implemented as a post-process to the default serialization. |
| 184 | + nodeData = super().serializeNode(node) |
| 185 | + |
| 186 | + # Override input attributes with custom serialization logic, to handle attributes |
| 187 | + # connected to nodes that are not in the list of nodes to serialize. |
| 188 | + for attributeName in nodeData["inputs"]: |
| 189 | + nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName)) |
| 190 | + |
| 191 | + # Clear UID for non-compatibility nodes, as the custom attribute serialization |
| 192 | + # can be impacting the UID by removing connections to missing nodes. |
| 193 | + if not node.isCompatibilityNode: |
| 194 | + del nodeData["uid"] |
| 195 | + |
| 196 | + return nodeData |
| 197 | + |
| 198 | + def _serializeAttribute(self, attribute: Attribute) -> Any: |
| 199 | + """ |
| 200 | + Serialize `attribute` (recursively for list/groups) and deal with attributes being connected |
| 201 | + to nodes that are not part of the partial list of nodes to serialize. |
| 202 | + """ |
| 203 | + linkParam = attribute.getLinkParam() |
| 204 | + |
| 205 | + if linkParam is not None: |
| 206 | + # Use standard link serialization if upstream node is part of the serialization. |
| 207 | + if linkParam.node in self.nodes: |
| 208 | + return attribute.getExportValue() |
| 209 | + # Skip link serialization otherwise. |
| 210 | + # If part of a list, this entry can be discarded. |
| 211 | + if isinstance(attribute.root, ListAttribute): |
| 212 | + return None |
| 213 | + # Otherwise, return the default value for this attribute. |
| 214 | + return attribute.defaultValue() |
| 215 | + |
| 216 | + if isinstance(attribute, ListAttribute): |
| 217 | + # Recusively serialize each child of the ListAttribute, skipping those for which the attribute |
| 218 | + # serialization logic above returns None. |
| 219 | + return [ |
| 220 | + exportValue |
| 221 | + for child in attribute |
| 222 | + if (exportValue := self._serializeAttribute(child)) is not None |
| 223 | + ] |
| 224 | + |
| 225 | + if isinstance(attribute, GroupAttribute): |
| 226 | + # Recursively serialize each child of the group attribute. |
| 227 | + return {name: self._serializeAttribute(child) for name, child in attribute.value.items()} |
| 228 | + |
| 229 | + return attribute.getExportValue() |
| 230 | + |
| 231 | + |
0 commit comments