diff --git a/meshroom/core/attribute.py b/meshroom/core/attribute.py index 4141cce519..610d86a07c 100644 --- a/meshroom/core/attribute.py +++ b/meshroom/core/attribute.py @@ -2,6 +2,7 @@ import copy import os import re +from typing import Optional import weakref import types import logging @@ -10,6 +11,7 @@ from string import Template from meshroom.common import BaseObject, Property, Variant, Signal, ListModel, DictModel, Slot from meshroom.core import desc, hashValue +from meshroom.core.exception import InvalidEdgeError from typing import TYPE_CHECKING @@ -80,6 +82,7 @@ def __init__(self, node, attributeDesc: desc.Attribute, isOutput: bool, root=Non # invalidation value for output attributes self._invalidationValue = "" + self._linkExpression: Optional[str] = None self._value = None self.initValue() @@ -201,9 +204,9 @@ def _set_value(self, value): if self._value == value: return - if isinstance(value, Attribute) or Attribute.isLinkExpression(value): - # if we set a link to another attribute - self._value = value + if self._handleLinkValue(value): + return + elif isinstance(value, types.FunctionType): # evaluate the function self._value = value(self) @@ -228,6 +231,27 @@ def _set_value(self, value): self.valueChanged.emit() self.validValueChanged.emit() + def _handleLinkValue(self, value) -> bool: + """ + Handle assignment of a link if `value` is a serialized link expression or in-memory Attribute reference. + + Returns: Whether the value has been handled as a link, False otherwise. + """ + isAttribute = isinstance(value, Attribute) + isLinkExpression = Attribute.isLinkExpression(value) + + if not isAttribute and not isLinkExpression: + return False + + if isAttribute: + self._linkExpression = value.asLinkExpr() + # If the value is a direct reference to an attribute, it can be directly converted to an edge as + # the source attribute already exists in memory. + self._applyExpr() + elif isLinkExpression: + self._linkExpression = value + return True + @Slot() def _onValueChanged(self): self.node._onAttributeChanged(self) @@ -369,26 +393,30 @@ def _applyExpr(self): this function convert the expression into a real edge in the graph and clear the string value. """ - v = self._value - g = self.node.graph - if not g: + if not self.isInput or not self._linkExpression: return - if isinstance(v, Attribute): - g.addEdge(v, self) - self.resetToDefaultValue() - elif self.isInput and Attribute.isLinkExpression(v): - # value is a link to another attribute - link = v[1:-1] - linkNodeName, linkAttrName = link.split('.') - try: - node = g.node(linkNodeName) - if not node: - raise KeyError(f"Node '{linkNodeName}' not found") - g.addEdge(node.attribute(linkAttrName), self) - except KeyError as err: - logging.warning('Connect Attribute from Expression failed.') - logging.warning(f'Expression: "{v}"\nError: "{err}".') - self.resetToDefaultValue() + + if not (graph := self.node.graph): + return + + link = self._linkExpression[1:-1] + linkNodeName, linkAttrName = link.split(".") + try: + node = graph.node(linkNodeName) + if node is None: + raise InvalidEdgeError(self.fullNameToNode, link, "Source node does not exist") + attr = node.attribute(linkAttrName) + if attr is None: + raise InvalidEdgeError(self.fullNameToNode, link, "Source attribute does not exist") + graph.addEdge(attr, self) + except InvalidEdgeError as err: + logging.warning(err) + except Exception as err: + logging.warning("Unexpected error happened during edge creation") + logging.warning(f"Expression '{self._linkExpression}': {err}") + + self._linkExpression = None + self.resetToDefaultValue() def getExportValue(self): if self.isLink: @@ -480,6 +508,20 @@ def _is2D(self) -> bool: return False return next((imageSemantic for imageSemantic in Attribute.VALID_IMAGE_SEMANTICS if self.desc.semantic == imageSemantic), None) is not None + + @Slot(BaseObject, result=bool) + def validateConnectionFrom(self, otherAttribute: "Attribute") -> bool: + """ Check if the given attribute can be conected to the current Attribute + """ + return self._validateConnectionFrom(otherAttribute) + + def _validateConnectionFrom(self, otherAttribute: "Attribute") -> bool: + """ Implementation of the connection validation + + .. note: + Override this method to use custom connection validation logic + """ + return self.baseType == otherAttribute.baseType name = Property(str, getName, constant=True) fullName = Property(str, getFullName, constant=True) @@ -651,9 +693,8 @@ def resetToDefaultValue(self): def _set_value(self, value): if self.node.graph: self.remove(0, len(self)) - # Link to another attribute - if isinstance(value, ListAttribute) or Attribute.isLinkExpression(value): - self._value = value + if self._handleLinkValue(value): + return # New value else: # During initialization self._value may not be set @@ -664,6 +705,9 @@ def _set_value(self, value): self.requestGraphUpdate() def upgradeValue(self, exportedValues): + if self._handleLinkValue(exportedValues): + return + if not isinstance(exportedValues, list): if isinstance(exportedValues, ListAttribute) or \ Attribute.isLinkExpression(exportedValues): @@ -731,9 +775,7 @@ def uid(self): return super().uid() def _applyExpr(self): - if not self.node.graph: - return - if isinstance(self._value, ListAttribute) or Attribute.isLinkExpression(self._value): + if self._linkExpression: super()._applyExpr() else: for value in self._value: @@ -817,7 +859,6 @@ def getOutputConnections(self) -> list["Edge"]: hasOutputConnections = Property(bool, hasOutputConnections.fget, notify=Attribute.hasOutputConnectionsChanged) - class GroupAttribute(Attribute): def __init__(self, node, attributeDesc: desc.GroupAttribute, isOutput: bool, @@ -833,7 +874,34 @@ def __getattr__(self, key): except KeyError: raise AttributeError(key) + def _get_value(self): + linkedParam = self.getLinkParam() + + if not linkedParam: + return self._value + + def linkAttributesValues(srcAttr, dstAttr): + + for i, attrDesc in enumerate(dstAttr.desc._groupDesc): + linkedAttrDesc = srcAttr.desc._groupDesc[i] + + subSrcAttr = srcAttr._value.get(linkedAttrDesc.name) + subDstAttr = dstAttr._value.get(attrDesc.name) + + if isinstance(linkedAttrDesc, desc.GroupAttribute) and isinstance(attrDesc, desc.GroupAttribute): + linkAttributesValues(subSrcAttr, subDstAttr) + else: + subDstAttr.value = subSrcAttr.value + + # If linked, the driver attributes values are copied to the current attribute + linkAttributesValues(linkedParam, self) + + return self._value + def _set_value(self, exportedValue): + if self._handleLinkValue(exportedValue): + return + value = self.validateValue(exportedValue) if isinstance(value, dict): # set individual child attribute values @@ -848,6 +916,8 @@ def _set_value(self, exportedValue): raise AttributeError(f"Failed to set on GroupAttribute: {str(value)}") def upgradeValue(self, exportedValue): + if self._handleLinkValue(exportedValue): + return value = self.validateValue(exportedValue) if isinstance(value, dict): # set individual child attribute values @@ -892,6 +962,9 @@ def childAttribute(self, key: str) -> Attribute: return None def uid(self): + if self.isLink: + return super().uid() + uids = [] for k, v in self._value.items(): if v.enabled and v.invalidate: @@ -899,13 +972,20 @@ def uid(self): return hashValue(uids) def _applyExpr(self): - for value in self._value: - value._applyExpr() + if self._linkExpression: + super()._applyExpr() + else: + for value in self._value: + value._applyExpr() def getExportValue(self): - return {key: attr.getExportValue() for key, attr in self._value.objects.items()} + if linkParam := self.getLinkParam(): + return linkParam.asLinkExpr() + return {key: attr.getExportValue() for key, attr in self._value.items()} def _isDefault(self): + if linkParam := self.getLinkParam(): + return linkParam._isDefault() return all(v.isDefault for v in self._value) def defaultValue(self): @@ -949,6 +1029,33 @@ def updateInternals(self): def matchText(self, text: str) -> bool: return super().matchText(text) or any(c.matchText(text) for c in self._value) + def _validateConnectionFrom(self, otherAttribute:"Attribute") -> bool: + + isValid = super()._validateConnectionFrom(otherAttribute=otherAttribute) + + if not isValid: + return False + + return self._haveSameStructure(otherAttribute) + + def _haveSameStructure(self, otherAttribute: "Attribute") -> bool: + """ Does the given attribute have the same number of attributes, and all ordered attributes have the same baseType + """ + + if isinstance(otherAttribute._value, Iterable) and len(otherAttribute._value) != len(self._value): + return False + + for i, attr in enumerate(self._value): + otherAttr = list(otherAttribute._value)[i] + if isinstance(attr, GroupAttribute): + return attr._haveSameStructure(otherAttr) + elif not otherAttr: + return False + elif attr.baseType != otherAttr.baseType: + return False + + return True + # Override value property - value = Property(Variant, Attribute._get_value, _set_value, notify=Attribute.valueChanged) + value = Property(Variant, _get_value, _set_value, notify=Attribute.valueChanged) isDefault = Property(bool, _isDefault, notify=Attribute.valueChanged) diff --git a/meshroom/core/exception.py b/meshroom/core/exception.py index 4443a8962f..f858ad0f93 100644 --- a/meshroom/core/exception.py +++ b/meshroom/core/exception.py @@ -11,6 +11,13 @@ class GraphException(MeshroomException): pass +class InvalidEdgeError(GraphException): + """Raised when an edge between two attributes cannot be created.""" + + def __init__(self, srcAttrName: str, dstAttrName: str, msg: str) -> None: + super().__init__(f"Failed to connect {srcAttrName}->{dstAttrName}: {msg}") + + class GraphCompatibilityError(GraphException): """ Raised when node compatibility issues occur when loading a graph. @@ -57,3 +64,8 @@ class StopGraphVisit(GraphVisitMessage): class StopBranchVisit(GraphVisitMessage): """ Immediately stop branch visit. """ pass + + +class CyclicDependencyError(Exception): + """ Raised if a cyclic dependency is find in a DAG graph """ + pass diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index a5bff2a250..1aa5ccca8b 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -15,7 +15,7 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute -from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit +from meshroom.core.exception import GraphCompatibilityError, InvalidEdgeError, StopGraphVisit, StopBranchVisit, CyclicDependencyError from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory @@ -134,6 +134,14 @@ def finishVertex(self, u, g): pass +class DAGVisitor(Visitor): + + def backEdge(self, e, g): + """ Is invoked on the back edges in the graph. Means that there is a cyclic dependency in the visited graph """ + + raise CyclicDependencyError("A cyclic dependency exists on the current DAG") + + def changeTopology(func): """ Graph methods modifying the graph topology (add/remove edges or nodes) @@ -498,41 +506,38 @@ def addNode(self, node, uniqueName=None): node._applyExpr() return node - def copyNode(self, srcNode, withEdges=False): + def copyNode(self, srcNode: Node, withEdges: bool=False): """ Get a copy instance of a node outside the graph. Args: - srcNode (Node): the node to copy - withEdges (bool): whether to copy edges + srcNode: the node to copy + withEdges: whether to copy edges Returns: - Node, dict: the created node instance, - a dictionary of linked attributes with their original value (empty if withEdges is True) + The created node instance and the mapping of skipped edge per attribute (always empty if `withEdges` is True). """ + def _removeLinkExpressions(attribute: Attribute, removed: dict[Attribute, str]): + """Recursively remove link expressions from the given root `attribute`.""" + # Link expressions are only stored on input attributes. + if attribute.isOutput: + return + + if attribute._linkExpression: + removed[attribute] = attribute._linkExpression + attribute._linkExpression = None + elif isinstance(attribute, (ListAttribute, GroupAttribute)): + for child in attribute.value: + _removeLinkExpressions(child, removed) + with GraphModification(self): - # create a new node of the same type and with the same attributes values - # keep links as-is so that CompatibilityNodes attributes can be created with correct automatic description - # (File params for link expressions) - node = nodeFactory(srcNode.toDict(), srcNode.nodeType) # use nodeType as name - # skip edges: filter out attributes which are links by resetting default values + node = nodeFactory(srcNode.toDict(), name=srcNode.nodeType) + skippedEdges = {} if not withEdges: - for n, attr in node.attributes.items(): - if attr.isOutput: - # edges are declared in input with an expression linking - # to another param (which could be an output) - continue - # find top-level links - if Attribute.isLinkExpression(attr.value): - skippedEdges[attr] = attr.value - attr.resetToDefaultValue() - # find links in ListAttribute children - elif isinstance(attr, (ListAttribute, GroupAttribute)): - for child in attr.value: - if Attribute.isLinkExpression(child.value): - skippedEdges[child] = child.value - child.resetToDefaultValue() + for _, attr in node.attributes.items(): + _removeLinkExpressions(attr, skippedEdges) + return node, skippedEdges def duplicateNodes(self, srcNodes): @@ -892,11 +897,11 @@ def getRootNodes(self, dependenciesOnly): return set(self._nodes) - nodesWithInputLink @changeTopology - def addEdge(self, srcAttr, dstAttr): - assert isinstance(srcAttr, Attribute) - assert isinstance(dstAttr, Attribute) - if srcAttr.node.graph != self or dstAttr.node.graph != self: - raise RuntimeError('The attributes of the edge should be part of a common graph.') + def addEdge(self, srcAttr: Attribute, dstAttr: Attribute) -> "Edge": + if not (srcAttr.node.graph == dstAttr.node.graph == self): + raise InvalidEdgeError( + srcAttr.fullNameToGraph, dstAttr.fullNameToGraph, "Attributes do not belong to this Graph" + ) if dstAttr in self.edges.keys(): raise RuntimeError(f'Destination attribute "{dstAttr.getFullNameToNode()}" is already connected.') edge = Edge(srcAttr, dstAttr) @@ -1033,7 +1038,7 @@ def dfsOnFinish(self, startNodes=None, longestPathFirst=False, reverse=False, de """ nodes = [] edges = [] - visitor = Visitor(reverse=reverse, dependenciesOnly=dependenciesOnly) + visitor = DAGVisitor(reverse=reverse, dependenciesOnly=dependenciesOnly) visitor.finishVertex = lambda vertex, graph: nodes.append(vertex) visitor.finishEdge = lambda edge, graph: edges.append(edge) self.dfs(visitor=visitor, startNodes=startNodes, longestPathFirst=longestPathFirst) @@ -1058,7 +1063,7 @@ def dfsOnDiscover(self, startNodes=None, filterTypes=None, longestPathFirst=Fals """ nodes = [] edges = [] - visitor = Visitor(reverse=reverse, dependenciesOnly=dependenciesOnly) + visitor = DAGVisitor(reverse=reverse, dependenciesOnly=dependenciesOnly) def discoverVertex(vertex, graph): if not filterTypes or vertex.nodeType in filterTypes: @@ -1082,7 +1087,7 @@ def dfsToProcess(self, startNodes=None): """ nodes = [] edges = [] - visitor = Visitor(reverse=False, dependenciesOnly=True) + visitor = DAGVisitor(reverse=False, dependenciesOnly=True) def discoverVertex(vertex, graph): if vertex.hasStatus(Status.SUCCESS): @@ -1138,7 +1143,7 @@ def updateNodesTopologicalData(self): self._computationBlocked.clear() compatNodes = [] - visitor = Visitor(reverse=False, dependenciesOnly=False) + visitor = DAGVisitor(reverse=False, dependenciesOnly=False) def discoverVertex(vertex, graph): # initialize depths @@ -1196,7 +1201,7 @@ def dfsMaxEdgeLength(self, startNodes=None, dependenciesOnly=True): """ nodesStack = [] edgesScore = defaultdict(int) - visitor = Visitor(reverse=False, dependenciesOnly=dependenciesOnly) + visitor = DAGVisitor(reverse=False, dependenciesOnly=dependenciesOnly) def finishEdge(edge, graph): u, v = edge @@ -1279,7 +1284,7 @@ def canSubmitOrCompute(self, startNode): if startNode.isAlreadySubmittedOrFinished(): return 0 - class SCVisitor(Visitor): + class SCVisitor(DAGVisitor): def __init__(self, reverse, dependenciesOnly): super().__init__(reverse, dependenciesOnly) diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index dcbcac0f49..653c8af6ab 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -6,6 +6,7 @@ from PySide6.QtCore import Property, Signal from meshroom.core.attribute import ListAttribute, Attribute +from meshroom.core.exception import CyclicDependencyError from meshroom.core.graph import Graph, GraphModification from meshroom.core.node import Position, CompatibilityIssue from meshroom.core.nodeFactory import nodeFactory @@ -314,11 +315,17 @@ def __init__(self, graph, src, dst, parent=None): self.dstAttr = dst.getFullNameToNode() self.setText(f"Connect '{self.srcAttr}'->'{self.dstAttr}'") - if src.baseType != dst.baseType: - raise ValueError(f"Attribute types are not compatible and cannot be connected: '{self.srcAttr}'({src.baseType})->'{self.dstAttr}'({dst.baseType})") + if not dst.validateConnectionFrom(src): + raise ValueError(f"Attribute are not compatible and cannot be connected: '{self.srcAttr}'({src.baseType})->'{self.dstAttr}'({dst.baseType})") def redoImpl(self): - self.graph.addEdge(self.graph.attribute(self.srcAttr), self.graph.attribute(self.dstAttr)) + + try: + self.graph.addEdge(self.graph.attribute(self.srcAttr), self.graph.attribute(self.dstAttr)) + except CyclicDependencyError: + self.graph.removeEdge(self.graph.attribute(self.dstAttr)) + return False + return True def undoImpl(self): diff --git a/meshroom/ui/qml/GraphEditor/AttributeItemDelegate.qml b/meshroom/ui/qml/GraphEditor/AttributeItemDelegate.qml index 7a830d6e3f..8e9d28ce50 100644 --- a/meshroom/ui/qml/GraphEditor/AttributeItemDelegate.qml +++ b/meshroom/ui/qml/GraphEditor/AttributeItemDelegate.qml @@ -769,7 +769,7 @@ RowLayout { var obj = cpt.createObject(groupItem, { 'model': Qt.binding(function() { return attribute.value }), - 'readOnly': Qt.binding(function() { return root.readOnly }), + 'readOnly': Qt.binding(function() { return !root.editable }), 'labelWidth': 100, // Reduce label width for children (space gain) 'objectsHideable': Qt.binding(function() { return root.objectsHideable }), 'filterText': Qt.binding(function() { return root.filterText }), diff --git a/meshroom/ui/qml/GraphEditor/AttributePin.qml b/meshroom/ui/qml/GraphEditor/AttributePin.qml index d435601bd5..2486dddd44 100755 --- a/meshroom/ui/qml/GraphEditor/AttributePin.qml +++ b/meshroom/ui/qml/GraphEditor/AttributePin.qml @@ -106,7 +106,7 @@ RowLayout { // Check if attributes are compatible to create a valid connection if (root.readOnly // Cannot connect on a read-only attribute || drag.source.objectName != inputDragTarget.objectName // Not an edge connector - || drag.source.baseType !== inputDragTarget.baseType // Not the same base type + || !inputDragTarget.attribute.validateConnectionFrom(drag.source.attribute) // || drag.source.nodeItem === inputDragTarget.nodeItem // Connection between attributes of the same node || (drag.source.isList && childrenRepeater.count) // Source/target are lists but target already has children || drag.source.connectorType === "input" // Refuse to connect an "input pin" on another one (input attr can be connected to input attr, but not the graphical pin) @@ -256,7 +256,7 @@ RowLayout { onEntered: function(drag) { // Check if attributes are compatible to create a valid connection if (drag.source.objectName != outputDragTarget.objectName // Not an edge connector - || drag.source.baseType !== outputDragTarget.baseType // Not the same base type + || !outputDragTarget.attribute.validateConnectionFrom(drag.source.attribute) // || drag.source.nodeItem === outputDragTarget.nodeItem // Connection between attributes of the same node || (!drag.source.isList && outputDragTarget.isList) // Connection between a list and a simple attribute || (drag.source.isList && childrenRepeater.count) // Source/target are lists but target already has children diff --git a/tests/nodes/test/color.py b/tests/nodes/test/color.py new file mode 100644 index 0000000000..83df6bee1c --- /dev/null +++ b/tests/nodes/test/color.py @@ -0,0 +1,40 @@ +from meshroom.core import desc + +class Color(desc.Node): + + inputs = [ + desc.GroupAttribute( + name="rgb", + label="rgb", + description="rgb", + exposed=True, + groupDesc=[ + desc.FloatParam(name="r", label="r", description="r", value=0.0), + desc.FloatParam(name="g", label="g", description="g", value=0.0), + desc.FloatParam(name="b", label="b", description="b", value=0.0) + ] + ) + ] + +class NestedColor(desc.Node): + + inputs = [ + desc.GroupAttribute( + name="rgb", + label="rgb", + description="rgb", + exposed=True, + groupDesc=[ + desc.FloatParam(name="r", label="r", description="r", value=0.0), + desc.FloatParam(name="g", label="g", description="g", value=0.0), + desc.FloatParam(name="b", label="b", description="b", value=0.0), + desc.GroupAttribute(label="test", name="test", description="", + groupDesc=[ + desc.FloatParam(name="r", label="r", description="r", value=0.0), + desc.FloatParam(name="g", label="g", description="g", value=0.0), + desc.FloatParam(name="b", label="b", description="b", value=0.0), + + ]) + ] + ) + ] \ No newline at end of file diff --git a/tests/nodes/test/position.py b/tests/nodes/test/position.py new file mode 100644 index 0000000000..45630a5d55 --- /dev/null +++ b/tests/nodes/test/position.py @@ -0,0 +1,40 @@ +from meshroom.core import desc + + +class Position(desc.Node): + + inputs = [ + desc.GroupAttribute( + name="xyz", + label="xyz", + description="xyz", + exposed=True, + groupDesc=[ + desc.FloatParam(name="x", label="x", description="x", value=0.0), + desc.FloatParam(name="y", label="z", description="z", value=0.0), + desc.FloatParam(name="z", label="z", description="z", value=0.0) + ] + ) + ] + +class NestedPosition(desc.Node): + + inputs = [ + desc.GroupAttribute( + name="xyz", + label="xyz", + description="xyz", + exposed=True, + groupDesc=[ + desc.FloatParam(name="x", label="x", description="x", value=0.0), + desc.FloatParam(name="y", label="z", description="z", value=0.0), + desc.FloatParam(name="z", label="z", description="z", value=0.0), + desc.GroupAttribute(label="test", name="test", description="", + groupDesc=[ + desc.FloatParam(name="x", label="x", description="x", value=0.0), + desc.FloatParam(name="y", label="z", description="z", value=0.0), + desc.FloatParam(name="z", label="z", description="z", value=0.0), + ]) + ] + ) + ] \ No newline at end of file diff --git a/tests/nodes/test/test.py b/tests/nodes/test/test.py new file mode 100644 index 0000000000..44bb3aa254 --- /dev/null +++ b/tests/nodes/test/test.py @@ -0,0 +1,23 @@ +from meshroom.core import desc + +class NestedTest(desc.Node): + + inputs = [ + desc.GroupAttribute( + name="xyz", + label="xyz", + description="xyz", + exposed=True, + groupDesc=[ + desc.FloatParam(name="x", label="x", description="x", value=0.0), + desc.FloatParam(name="y", label="z", description="z", value=0.0), + desc.FloatParam(name="z", label="z", description="z", value=0.0), + desc.GroupAttribute(label="test", name="test", description="", + groupDesc=[ + desc.StringParam(name="x", label="x", description="x", value="test"), + desc.FloatParam(name="y", label="z", description="z", value=0.0), + desc.FloatParam(name="z", label="z", description="z", value=0.0), + ]) + ] + ) + ] \ No newline at end of file diff --git a/tests/test_graph.py b/tests/test_graph.py index 280cd46f7d..816e0c2890 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,7 @@ from meshroom.core.graph import Graph +from meshroom.core.exception import CyclicDependencyError + +import pytest def test_depth(): @@ -278,3 +281,20 @@ def test_duplicate_nodes(): assert nMap[n2][0].input.getLinkParam() == nMap[n1][0].output assert nMap[n3][0].input.getLinkParam() == nMap[n1][0].output assert nMap[n3][0].input2.getLinkParam() == nMap[n2][0].output + + +def test_acyclic_connection_should_raise_error(): + + # Given + graph = Graph("Test acyclic connection") + tB = graph.addNewNode("AppendText", inputText="echo B") + tC = graph.addNewNode("AppendText", inputText="echo C") + + graph.addEdge(tB.output, tC.input) + + # When + # Then + with pytest.raises(CyclicDependencyError): + graph.addEdge(tC.output, tB.input) + + diff --git a/tests/test_groupAttributes.py b/tests/test_groupAttributes.py new file mode 100644 index 0000000000..5eb859fce8 --- /dev/null +++ b/tests/test_groupAttributes.py @@ -0,0 +1,85 @@ +from meshroom.core.graph import Graph +import math +import logging + +logger = logging.getLogger('test') + + +def test_groupAttributes_with_same_structure_can_be_linked_and_only_calue_is_copied(): + + # Given + graph = Graph() + position = graph.addNewNode("Position") + color = graph.addNewNode("Color") + + # When + graph.addEdge(position.xyz, color.rgb) + position.xyz.x.value = 1.0 + position.xyz.y.value = 2.0 + position.xyz.z.value = 3.0 + + # Then + assert color.rgb.value != position.xyz.value + assert math.isclose(color.rgb.r.value, position.xyz.x.value) + assert math.isclose(color.rgb.g.value, position.xyz.y.value) + assert math.isclose(color.rgb.b.value, position.xyz.z.value) + assert math.isclose(color.rgb.r.value, 1.0) + assert math.isclose(color.rgb.g.value, 2.0) + assert math.isclose(color.rgb.b.value, 3.0) + +def test_groupAttributes_with_same_nested_structure_can_be_linked_and_only_calue_is_copied(): + + # Given + graph = Graph() + nestedColor = graph.addNewNode("NestedColor") + nestedPosition = graph.addNewNode("NestedPosition") + + # When + graph.addEdge(nestedPosition.xyz, nestedColor.rgb) + nestedPosition.xyz.x.value = 1.0 + nestedPosition.xyz.y.value = 2.0 + nestedPosition.xyz.z.value = 3.0 + nestedPosition.xyz.test.x.value = 4.0 + nestedPosition.xyz.test.y.value = 5.0 + nestedPosition.xyz.test.z.value = 6.0 + + # Then + assert nestedColor.rgb.value != nestedPosition.xyz.test.value + assert math.isclose(nestedColor.rgb.r.value, nestedPosition.xyz.x.value) + assert math.isclose(nestedColor.rgb.g.value, nestedPosition.xyz.y.value) + assert math.isclose(nestedColor.rgb.b.value, nestedPosition.xyz.z.value) + assert math.isclose(nestedColor.rgb.test.r.value, nestedPosition.xyz.test.x.value) + assert math.isclose(nestedColor.rgb.test.g.value, nestedPosition.xyz.test.y.value) + assert math.isclose(nestedColor.rgb.test.b.value, nestedPosition.xyz.test.z.value) + assert math.isclose(nestedColor.rgb.r.value, 1.0) + assert math.isclose(nestedColor.rgb.g.value, 2.0) + assert math.isclose(nestedColor.rgb.b.value, 3.0) + assert math.isclose(nestedColor.rgb.test.r.value, 4.0) + assert math.isclose(nestedColor.rgb.test.g.value, 5.0) + assert math.isclose(nestedColor.rgb.test.b.value, 6.0) + +def test_groupAttributes_with_smae_structure_should_allow_connection(): + + # Given + graph = Graph() + nestedPosition = graph.addNewNode("NestedPosition") + nestedColor = graph.addNewNode("NestedColor") + + # When + acceptedConnection = nestedPosition.xyz.validateConnectionFrom(nestedColor.rgb) + + # Then + assert acceptedConnection == True + +def test_groupAttributes_with_different_structure_should_not_allow_connection(): + + # Given + graph = Graph() + nestedPosition = graph.addNewNode("NestedPosition") + nestedTest = graph.addNewNode("NestedTest") + + # When + acceptedConnection = nestedPosition.xyz.validateConnectionFrom(nestedTest.xyz) + + # Then + assert acceptedConnection == False \ No newline at end of file