diff --git a/meshroom/common/qt.py b/meshroom/common/qt.py index f479ddfd2a..0fc4a1ef62 100644 --- a/meshroom/common/qt.py +++ b/meshroom/common/qt.py @@ -305,7 +305,8 @@ def _dereferenceItem(self, item): key = getattr(item, self._keyAttrName, None) if key is None: return - assert key in self._objectByKey + if key not in self._objectByKey: + raise RuntimeError(f"{key} is not in the Model: {self._objectByKey.keys()}") del self._objectByKey[key] def onRequestDeletion(self, item): diff --git a/meshroom/core/__init__.py b/meshroom/core/__init__.py index 0300647799..e3eab31630 100644 --- a/meshroom/core/__init__.py +++ b/meshroom/core/__init__.py @@ -19,6 +19,7 @@ except Exception: pass +from meshroom.core.plugins import NodePlugin, NodePluginManager, Plugin, ProcessEnv from meshroom.core.submitter import BaseSubmitter from meshroom.env import EnvVar, meshroomFolder from . import desc @@ -31,7 +32,7 @@ sessionUid = str(uuid.uuid1()) cacheFolderName = 'MeshroomCache' -nodesDesc: dict[str, desc.BaseNode] = {} +pluginManager: NodePluginManager = NodePluginManager() submitters: dict[str, BaseSubmitter] = {} pipelineTemplates: dict[str, str] = {} @@ -41,7 +42,6 @@ def hashValue(value) -> str: hashObject = hashlib.sha1(str(value).encode('utf-8')) return hashObject.hexdigest() - @contextmanager def add_to_path(p): import sys @@ -53,9 +53,15 @@ def add_to_path(p): finally: sys.path = old_path - -def loadClasses(folder, packageName, classType): +def loadClasses(folder: str, packageName: str, classType: type) -> list[type]: """ + Go over the Python module named "packageName" located in "folder" to find files + that contain classes of type "classType" and return these classes in a list. + + Args: + folder: the folder to load the module from. + packageName: the name of the module to look for nodes in. + classType: the class to look for in the files that are inspected. """ classes = [] errors = [] @@ -67,7 +73,8 @@ def loadClasses(folder, packageName, classType): try: package = importlib.import_module(packageName) - packageName = package.packageName if hasattr(package, 'packageName') else package.__name__ + packageName = package.packageName if hasattr(package, "packageName") \ + else package.__name__ packageVersion = getattr(package, "__version__", None) packagePath = os.path.dirname(package.__file__) except Exception as e: @@ -83,31 +90,34 @@ def loadClasses(folder, packageName, classType): ) return [] - for importer, pluginName, ispkg in pkgutil.iter_modules(package.__path__): - pluginModuleName = '.' + pluginName + for _, pluginName, _ in pkgutil.iter_modules(package.__path__): + pluginModuleName = "." + pluginName try: pluginMod = importlib.import_module(pluginModuleName, package=package.__name__) - plugins = [plugin for name, plugin in inspect.getmembers(pluginMod, inspect.isclass) - if plugin.__module__ == f'{package.__name__}.{pluginName}' + plugins = [plugin for _, plugin in inspect.getmembers(pluginMod, inspect.isclass) + if plugin.__module__ == f"{package.__name__}.{pluginName}" and issubclass(plugin, classType)] + if not plugins: - logging.warning(f"No class defined in plugin: {pluginModuleName}") + # Only packages/folders have __path__, single module/file do not have it. + isPackage = hasattr(pluginMod, "__path__") + # Sub-folders/Packages should not raise a warning + if not isPackage: + logging.warning(f"No class defined in plugin: {package.__name__}.{pluginName} ('{pluginMod.__file__}')") - importPlugin = True for p in plugins: - if classType == desc.Node: - nodeErrors = validateNodeDesc(p) - if nodeErrors: - errors.append(" * {}: The following parameters do not have valid default values/ranges: {}" - .format(pluginName, ", ".join(nodeErrors))) - importPlugin = False - break p.packageName = packageName p.packageVersion = packageVersion p.packagePath = packagePath - if importPlugin: - classes.extend(plugins) + if classType == desc.BaseNode: + nodePlugin = NodePlugin(p) + if nodePlugin.errors: + errors.append(" * {}: The following parameters do not have valid " \ + "default values/ranges: {}".format(pluginName, ", ".join(nodePlugin.errors))) + classes.append(nodePlugin) + else: + classes.append(p) except Exception as e: tb = traceback.extract_tb(e.__traceback__) last_call = tb[-1] @@ -124,41 +134,42 @@ def loadClasses(folder, packageName, classType): logging.warning(' The following "{package}" plugins could not be loaded:\n' '{errorMsg}\n' .format(package=packageName, errorMsg='\n'.join(errors))) - return classes + return classes -def validateNodeDesc(nodeDesc): +def loadClassesNodes(folder: str, packageName: str) -> list[NodePlugin]: """ - Check that the node has a valid description before being loaded. For the description - to be valid, the default value of every parameter needs to correspond to the type - of the parameter. - An empty returned list means that every parameter is valid, and so is the node's description. - If it is not valid, the returned list contains the names of the invalid parameters. In case - of nested parameters (parameters in groups or lists, for example), the name of the parameter - follows the name of the parent attributes. For example, if the attribute "x", contained in group - "group", is invalid, then it will be added to the list as "group:x". + Return the list of all the NodePlugins that were created following the search of the + Python module named "packageName" located in the folder "folder". + A NodePlugin is created when a file within "packageName" that contains a class inheriting + desc.BaseNode is found. Args: - nodeDesc (desc.Node): description of the node + folder: the folder to load the module from. + packageName: the name of the module to look for nodes in. Returns: - errors (list): the list of invalid parameters if there are any, empty list otherwise + list[NodePlugin]: a list of all the NodePlugins that were created based on the + module's search. If none has been created, an empty list is returned. """ - errors = [] + return loadClasses(folder, packageName, desc.BaseNode) - for param in nodeDesc.inputs: - err = param.checkValueTypes() - if err: - errors.append(err) +def loadClassesSubmitters(folder: str, packageName: str) -> list[BaseSubmitter]: + """ + Return the list of all the submitters that were found during the search of the + Python module named "packageName" that located in the folder "folder". + A submitter is found if a file within "packageName" contains a class inheriting + from BaseSubmitter. - for param in nodeDesc.outputs: - if param.value is None: - continue - err = param.checkValueTypes() - if err: - errors.append(err) + Args: + folder: the folder to load the module from. + packageName: the name of the module to look for nodes in. - return errors + Returns: + list[BaseSubmitter]: a list of all the submitters that were found during the + module's search + """ + return loadClasses(folder, packageName, BaseSubmitter) class Version: @@ -250,7 +261,8 @@ def toComponents(versionName): status = '' # If there is a status, it is placed after a "-" splitComponents = versionName.split("-", maxsplit=1) - if (len(splitComponents) > 1): # If there is no status, splitComponents is equal to [versionName] + # If there is no status, splitComponents is equal to [versionName] + if len(splitComponents) > 1: status = splitComponents[-1] return tuple([int(v) for v in splitComponents[0].split(".")]), status @@ -279,7 +291,7 @@ def micro(self): return self.components[2] -def moduleVersion(moduleName, default=None): +def moduleVersion(moduleName: str, default=None): """ Return the version of a module indicated with '__version__' keyword. Args: @@ -292,7 +304,7 @@ def moduleVersion(moduleName, default=None): return getattr(sys.modules[moduleName], "__version__", default) -def nodeVersion(nodeDesc, default=None): +def nodeVersion(nodeDesc: desc.Node, default=None): """ Return node type version for the given node description class. Args: @@ -305,38 +317,28 @@ def nodeVersion(nodeDesc, default=None): return moduleVersion(nodeDesc.__module__, default) -def registerNodeType(nodeType): - """ Register a Node Type based on a Node Description class. - - After registration, nodes of this type can be instantiated in a Graph. - """ - if nodeType.__name__ in nodesDesc: - logging.error(f"Node Desc {nodeType.__name__} is already registered.") - nodesDesc[nodeType.__name__] = nodeType - - -def unregisterNodeType(nodeType): - """ Remove 'nodeType' from the list of register node types. """ - assert nodeType.__name__ in nodesDesc - del nodesDesc[nodeType.__name__] - - -def loadNodes(folder, packageName): +def loadNodes(folder, packageName) -> list[NodePlugin]: if not os.path.isdir(folder): logging.error(f"Node folder '{folder}' does not exist.") - return + return [] - return loadClasses(folder, packageName, desc.BaseNode) + nodes = loadClassesNodes(folder, packageName) + return nodes -def loadAllNodes(folder): - for importer, package, ispkg in pkgutil.walk_packages([folder]): +def loadAllNodes(folder) -> list[Plugin]: + plugins = [] + for _, package, ispkg in pkgutil.iter_modules([folder]): if ispkg: - nodeTypes = loadNodes(folder, package) - for nodeType in nodeTypes: - registerNodeType(nodeType) - nodesStr = ', '.join([nodeType.__name__ for nodeType in nodeTypes]) - logging.debug(f'Nodes loaded [{package}]: {nodesStr}') + plugin = Plugin(package, folder) + nodePlugins = loadNodes(folder, package) + if nodePlugins: + for node in nodePlugins: + plugin.addNodePlugin(node) + nodesStr = ', '.join([node.nodeDescriptor.__name__ for node in nodePlugins]) + logging.debug(f'Nodes loaded [{package}]: {nodesStr}') + plugins.append(plugin) + return plugins def loadPluginFolder(folder): @@ -349,26 +351,29 @@ def loadPluginFolder(folder): logging.info(f"Plugin folder '{folder}' does not contain a 'meshroom' folder.") return - binFolders = [Path(folder, 'bin')] - libFolders = [Path(folder, 'lib'), Path(folder, 'lib64')] - pythonPathFolders = [Path(folder)] + binFolders + processEnv = ProcessEnv(folder) + + plugins = loadAllNodes(folder=mrFolder) + if plugins: + for plugin in plugins: + pluginManager.addPlugin(plugin) + pipelineTemplates.update(plugin.templates) - loadAllNodes(folder=mrFolder) - loadPipelineTemplates(folder=mrFolder) + return plugins def loadPluginsFolder(folder): if not os.path.isdir(folder): logging.debug(f"PluginSet folder '{folder}' does not exist.") return - + for file in os.listdir(folder): if os.path.isdir(file): subFolder = os.path.join(folder, file) loadPluginFolder(subFolder) -def registerSubmitter(s): +def registerSubmitter(s: BaseSubmitter): if s.name in submitters: logging.error(f"Submitter {s.name} is already registered.") submitters[s.name] = s @@ -379,10 +384,9 @@ def loadSubmitters(folder, packageName): logging.error(f"Submitters folder '{folder}' does not exist.") return - return loadClasses(folder, packageName, BaseSubmitter) - + return loadClassesSubmitters(folder, packageName) -def loadPipelineTemplates(folder): +def loadPipelineTemplates(folder: str): if not os.path.isdir(folder): logging.error(f"Pipeline templates folder '{folder}' does not exist.") return @@ -390,19 +394,21 @@ def loadPipelineTemplates(folder): if file.endswith(".mg") and file not in pipelineTemplates: pipelineTemplates[os.path.splitext(file)[0]] = os.path.join(folder, file) - def initNodes(): additionalNodesPath = EnvVar.getList(EnvVar.MESHROOM_NODES_PATH) - nodesFolders = [os.path.join(meshroomFolder, 'nodes')] + additionalNodesPath + nodesFolders = [os.path.join(meshroomFolder, "nodes")] + additionalNodesPath for f in nodesFolders: - loadAllNodes(folder=f) + plugins = loadAllNodes(folder=f) + if plugins: + for plugin in plugins: + pluginManager.addPlugin(plugin) def initSubmitters(): additionalPaths = EnvVar.getList(EnvVar.MESHROOM_SUBMITTERS_PATH) allSubmittersFolders = [meshroomFolder] + additionalPaths for folder in allSubmittersFolders: - subs = loadSubmitters(folder, 'submitters') + subs = loadSubmitters(folder, "submitters") for sub in subs: registerSubmitter(sub()) @@ -411,13 +417,15 @@ def initPipelines(): # Load pipeline templates: check in the default folder and any folder the user might have # added to the environment variable additionalPipelinesPath = EnvVar.getList(EnvVar.MESHROOM_PIPELINE_TEMPLATES_PATH) - pipelineTemplatesFolders = [os.path.join(meshroomFolder, 'pipelines')] + additionalPipelinesPath + pipelineTemplatesFolders = [os.path.join(meshroomFolder, "pipelines")] + additionalPipelinesPath for f in pipelineTemplatesFolders: loadPipelineTemplates(f) + for plugin in pluginManager.getPlugins().values(): + pipelineTemplates.update(plugin.templates) def initPlugins(): additionalpluginsPath = EnvVar.getList(EnvVar.MESHROOM_PLUGINS_PATH) - nodesFolders = [os.path.join(meshroomFolder, 'plugins')] + additionalpluginsPath + nodesFolders = [os.path.join(meshroomFolder, "plugins")] + additionalpluginsPath for f in nodesFolders: loadPluginFolder(folder=f) diff --git a/meshroom/core/attribute.py b/meshroom/core/attribute.py index a06e912607..4141cce519 100644 --- a/meshroom/core/attribute.py +++ b/meshroom/core/attribute.py @@ -17,15 +17,16 @@ from meshroom.core.graph import Edge -def attributeFactory(description, value, isOutput, node, root=None, parent=None): +def attributeFactory(description: str, value, isOutput: bool, node, root=None, parent=None): """ Create an Attribute based on description type. Args: description: the Attribute description - value: value of the Attribute. Will be set if not None. - isOutput: whether is Attribute is an output attribute. - node (Node): node owning the Attribute. Note that the created Attribute is not added to Node's attributes + value: value of the Attribute. Will be set if not None. + isOutput: whether the Attribute is an output attribute. + node (Node): node owning the Attribute. Note that the created Attribute is not added to \ + Node's attributes root: (optional) parent Attribute (must be ListAttribute or GroupAttribute) parent (BaseObject): (optional) the parent BaseObject if any """ @@ -53,27 +54,27 @@ class Attribute(BaseObject): VALID_IMAGE_SEMANTICS = ["image", "imageList", "sequence"] VALID_3D_EXTENSIONS = [".obj", ".stl", ".fbx", ".gltf", ".abc", ".ply"] - def __init__(self, node, attributeDesc, isOutput, root=None, parent=None): + def __init__(self, node, attributeDesc: desc.Attribute, isOutput: bool, root=None, parent=None): """ Attribute constructor Args: node (Node): the Node hosting this Attribute - attributeDesc (desc.Attribute): the description of this Attribute - isOutput (bool): whether this Attribute is an output of the Node + attributeDesc: the description of this Attribute + isOutput: whether this Attribute is an output of the Node root (Attribute): (optional) the root Attribute (List or Group) containing this one parent (BaseObject): (optional) the parent BaseObject """ super().__init__(parent) - self._name = attributeDesc.name + self._name: str = attributeDesc.name self._root = None if root is None else weakref.ref(root) self._node = weakref.ref(node) - self.attributeDesc = attributeDesc - self._isOutput = isOutput - self._label = attributeDesc.label - self._enabled = True - self._validValue = True - self._description = attributeDesc.description + self.attributeDesc: desc.Attribute = attributeDesc + self._isOutput: bool = isOutput + self._label: str = attributeDesc.label + self._enabled: bool = True + self._validValue: bool = True + self._description: str = attributeDesc.description self._invalidate = False if self._isOutput else attributeDesc.invalidate # invalidation value for output attributes @@ -91,11 +92,11 @@ def node(self): def root(self): return self._root() if self._root else None - def getName(self): + def getName(self) -> str: """ Attribute name """ return self._name - def getFullName(self): + def getFullName(self) -> str: """ Name inside the Graph: groupName.name """ if isinstance(self.root, ListAttribute): return f'{self.root.getFullName()}[{self.root.index(self)}]' @@ -103,53 +104,55 @@ def getFullName(self): return f'{self.root.getFullName()}.{self.getName()}' return self.getName() - def getFullNameToNode(self): + def getFullNameToNode(self) -> str: """ Name inside the Graph: nodeName.groupName.name """ return f'{self.node.name}.{self.getFullName()}' - def getFullNameToGraph(self): + def getFullNameToGraph(self) -> str: """ Name inside the Graph: graphName.nodeName.groupName.name """ graphName = self.node.graph.name if self.node.graph else "UNDEFINED" return f'{graphName}.{self.getFullNameToNode()}' - def asLinkExpr(self): + def asLinkExpr(self) -> str: """ Return link expression for this Attribute """ return "{" + self.getFullNameToNode() + "}" - def getType(self): + def getType(self) -> str: return self.attributeDesc.type - def _isReadOnly(self): + def _isReadOnly(self) -> bool: return not self._isOutput and self.node.isCompatibilityNode - def getBaseType(self): + def getBaseType(self) -> str: return self.getType() - def getLabel(self): + def getLabel(self) -> str: return self._label @Slot(str, result=bool) - def matchText(self, text): + def matchText(self, text: str) -> bool: return self.fullLabel.lower().find(text.lower()) > -1 - def getFullLabel(self): - """ Full Label includes the name of all parent groups, e.g. 'groupLabel subGroupLabel Label' """ + def getFullLabel(self) -> str: + """ + Full Label includes the name of all parent groups, e.g. 'groupLabel subGroupLabel Label'. + """ if isinstance(self.root, ListAttribute): return self.root.getFullLabel() elif isinstance(self.root, GroupAttribute): return f'{self.root.getFullLabel()} {self.getLabel()}' return self.getLabel() - def getFullLabelToNode(self): + def getFullLabelToNode(self) -> str: """ Label inside the Graph: nodeLabel groupLabel Label """ return f'{self.node.label} {self.getFullLabel()}' - def getFullLabelToGraph(self): + def getFullLabelToGraph(self) -> str: """ Label inside the Graph: graphName nodeLabel groupLabel Label """ graphName = self.node.graph.name if self.node.graph else "UNDEFINED" return f'{graphName} {self.getFullLabelToNode()}' - def getEnabled(self): + def getEnabled(self) -> bool: if isinstance(self.desc.enabled, types.FunctionType): try: return self.desc.enabled(self.node) @@ -205,8 +208,8 @@ def _set_value(self, value): # evaluate the function self._value = value(self) else: - # if we set a new value, we use the attribute descriptor validator to check the validity of the value - # and apply some conversion if needed + # if we set a new value, we use the attribute descriptor validator to check the + # validity of the value and apply some conversion if needed convertedValue = self.validateValue(value) self._value = convertedValue @@ -266,26 +269,27 @@ def requestNodeUpdate(self): self.node.updateInternalAttributes() @property - def isOutput(self): + def isOutput(self) -> bool: return self._isOutput @property - def isInput(self): + def isInput(self) -> bool: return not self._isOutput - def uid(self): + def uid(self) -> str: """ Compute the UID for the attribute. """ if self.isOutput: if self.desc.isDynamicValue: # If the attribute is a dynamic output, the UID is derived from the node UID. - # To guarantee that each output attribute receives a unique ID, we add the attribute name to it. + # To guarantee that each output attribute receives a unique ID, we add the attribute + # name to it. return hashValue((self.name, self.node._uid)) else: # Only dependent on the hash of its value without the cache folder. - # "/" at the end of the link is stripped to prevent having different UIDs depending on - # whether the invalidation value finishes with it or not + # "/" at the end of the link is stripped to prevent having different UIDs depending + # on whether the invalidation value finishes with it or not strippedInvalidationValue = self._invalidationValue.rstrip("/") return hashValue(strippedInvalidationValue) if self.isLink: @@ -298,13 +302,15 @@ def uid(self): return hashValue(self._value) @property - def isLink(self): + def isLink(self) -> bool: """ Whether the input attribute is a link to another attribute. """ - # note: directly use self.node.graph._edges to avoid using the property that may become invalid at some point - return self.node.graph and self.isInput and self.node.graph._edges and self in self.node.graph._edges.keys() + # note: directly use self.node.graph._edges to avoid using the property that may become + # invalid at some point + return self.node.graph and self.isInput and self.node.graph._edges and \ + self in self.node.graph._edges.keys() @staticmethod - def isLinkExpression(value): + def isLinkExpression(value) -> bool: """ Return whether the given argument is a link expression. A link expression is a string matching the {nodeName.attrName} pattern. @@ -322,12 +328,13 @@ def getLinkParam(self, recursive=False): return linkParam @property - def hasOutputConnections(self): - """ Whether the attribute has output connections, i.e is the source of at least one edge. """ + def hasOutputConnections(self) -> bool: + """ + Whether the attribute has output connections, i.e is the source of at least one edge. + """ # safety check to avoid evaluation errors if not self.node.graph or not self.node.graph.edges: return False - return next((edge for edge in self.node.graph.edges.values() if edge.src == self), None) is not None def getInputConnections(self) -> list["Edge"]: @@ -391,28 +398,29 @@ def getExportValue(self): return self.value def getEvalValue(self): - ''' + """ Return the value. If it is a string, expressions will be evaluated. - ''' + """ if isinstance(self.value, str): substituted = Template(self.value).safe_substitute(os.environ) try: varResolved = substituted.format(**self.node._cmdVars) return varResolved except (KeyError, IndexError): - # Catch KeyErrors and IndexErros to be able to open files created prior to the support of - # relative variables (when self.node._cmdVars was not used to evaluate expressions in the attribute) + # Catch KeyErrors and IndexErros to be able to open files created prior to the + # support of relative variables (when self.node._cmdVars was not used to evaluate + # expressions in the attribute) return substituted return self.value - def getValueStr(self, withQuotes=True): - ''' + def getValueStr(self, withQuotes=True) -> str: + """ Return the value formatted as a string with quotes to deal with spaces. If it is a string, expressions will be evaluated. If it is an empty string, it will returns 2 quotes. If it is an empty list, it will returns a really empty string. If it is a list with one empty string element, it will returns 2 quotes. - ''' + """ # ChoiceParam with multiple values should be combined if isinstance(self.attributeDesc, desc.ChoiceParam) and not self.attributeDesc.exclusive: # Ensure value is a list as expected @@ -421,8 +429,10 @@ def getValueStr(self, withQuotes=True): if withQuotes and v: return f'"{v}"' return v - # String, File, single value Choice are based on strings and should includes quotes to deal with spaces - if withQuotes and isinstance(self.attributeDesc, (desc.StringParam, desc.File, desc.ChoiceParam)): + # String, File, single value Choice are based on strings and should includes quotes + # to deal with spaces + if withQuotes and \ + isinstance(self.attributeDesc, (desc.StringParam, desc.File, desc.ChoiceParam)): return f'"{self.getEvalValue()}"' return str(self.getEvalValue()) @@ -436,10 +446,11 @@ def defaultValue(self): logging.warning("Failed to evaluate default value (node lambda) for attribute '{}': {}". format(self.name, e)) return None - # Need to force a copy, for the case where the value is a list (avoid reference to the desc value) + # Need to force a copy, for the case where the value is a list + # (avoid reference to the desc value) return copy.copy(self.desc.value) - def _isDefault(self): + def _isDefault(self) -> bool: return self.value == self.defaultValue() def getPrimitiveValue(self, exportDefault=True): @@ -510,7 +521,8 @@ def _is2D(self) -> bool: isDefault = Property(bool, _isDefault, notify=valueChanged) linkParam = Property(BaseObject, getLinkParam, notify=isLinkChanged) - rootLinkParam = Property(BaseObject, lambda self: self.getLinkParam(recursive=True), notify=isLinkChanged) + rootLinkParam = Property(BaseObject, lambda self: self.getLinkParam(recursive=True), + notify=isLinkChanged) node = Property(BaseObject, node.fget, constant=True) enabledChanged = Signal() enabled = Property(bool, getEnabled, setEnabled, notify=enabledChanged) @@ -522,7 +534,7 @@ def _is2D(self) -> bool: def raiseIfLink(func): - """ If Attribute instance is a link, raise a RuntimeError.""" + """ If Attribute instance is a link, raise a RuntimeError. """ def wrapper(attr, *args, **kwargs): if attr.isLink: raise RuntimeError("Can't modify connected Attribute") @@ -531,7 +543,8 @@ def wrapper(attr, *args, **kwargs): class PushButtonParam(Attribute): - def __init__(self, node, attributeDesc, isOutput, root=None, parent=None): + def __init__(self, node, attributeDesc: desc.PushButtonParam, isOutput: bool, + root=None, parent=None): super().__init__(node, attributeDesc, isOutput, root, parent) @Slot() @@ -541,7 +554,8 @@ def clicked(self): class ChoiceParam(Attribute): - def __init__(self, node, attributeDesc: desc.ChoiceParam, isOutput, root=None, parent=None): + def __init__(self, node, attributeDesc: desc.ChoiceParam, isOutput: bool, + root=None, parent=None): super().__init__(node, attributeDesc, isOutput, root, parent) self._values = None @@ -568,7 +582,7 @@ def validateValue(self, value): raise ValueError("Non exclusive ChoiceParam value should be iterable (param:{}, value:{}, type:{})". format(self.name, value, type(value))) return [self.conformValue(v) for v in value] - + def _set_value(self, value): # Handle alternative serialization for ChoiceParam with overriden values. serializedValueWithValuesOverrides = isinstance(value, dict) @@ -585,7 +599,8 @@ def setValues(self, values): self.valuesChanged.emit() def getExportValue(self): - useStandardSerialization = self.isLink or not self.desc._saveValuesOverride or self._values is None + useStandardSerialization = self.isLink or not self.desc._saveValuesOverride or \ + self._values is None if useStandardSerialization: return super().getExportValue() @@ -602,7 +617,8 @@ def getExportValue(self): class ListAttribute(Attribute): - def __init__(self, node, attributeDesc, isOutput, root=None, parent=None): + def __init__(self, node, attributeDesc: desc.ListAttribute, isOutput: bool, + root=None, parent=None): super().__init__(node, attributeDesc, isOutput, root, parent) def __len__(self): @@ -617,7 +633,7 @@ def getBaseType(self): return self.attributeDesc.elementDesc.__class__.__name__ def at(self, idx): - """ Returns child attribute at index 'idx' """ + """ Returns child attribute at index 'idx'. """ # Implement 'at' rather than '__getitem__' # since the later is called spuriously when object is used in QML return self._value.at(idx) @@ -649,7 +665,8 @@ def _set_value(self, value): def upgradeValue(self, exportedValues): if not isinstance(exportedValues, list): - if isinstance(exportedValues, ListAttribute) or Attribute.isLinkExpression(exportedValues): + if isinstance(exportedValues, ListAttribute) or \ + Attribute.isLinkExpression(exportedValues): self._set_value(exportedValues) return raise RuntimeError("ListAttribute.upgradeValue: the given value is of type " + @@ -657,7 +674,8 @@ def upgradeValue(self, exportedValues): attrs = [] for v in exportedValues: - a = attributeFactory(self.attributeDesc.elementDesc, None, self.isOutput, self.node, self) + a = attributeFactory(self.attributeDesc.elementDesc, None, self.isOutput, + self.node, self) a.upgradeValue(v) attrs.append(a) index = len(self._value) @@ -675,7 +693,8 @@ def insert(self, index, value): if self._value is None: self._value = ListModel(parent=self) values = value if isinstance(value, list) else [value] - attrs = [attributeFactory(self.attributeDesc.elementDesc, v, self.isOutput, self.node, self) for v in values] + attrs = [attributeFactory(self.attributeDesc.elementDesc, v, self.isOutput, self.node, self) + for v in values] self._value.insert(index, attrs) self.valueChanged.emit() self._applyExpr() @@ -725,27 +744,28 @@ def getExportValue(self): return self.getLinkParam().asLinkExpr() return [attr.getExportValue() for attr in self._value] - def defaultValue(self): + def defaultValue(self) -> list: return [] - def _isDefault(self): + def _isDefault(self) -> bool: return len(self._value) == 0 def getPrimitiveValue(self, exportDefault=True): if exportDefault: return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value] - else: - return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value if not attr.isDefault] + return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value + if not attr.isDefault] - def getValueStr(self, withQuotes=True): + def getValueStr(self, withQuotes=True) -> str: assert isinstance(self.value, ListModel) if self.attributeDesc.joinChar == ' ': - return self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=withQuotes) for v in self.value]) - else: - v = self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=False) for v in self.value]) - if withQuotes and v: - return f'"{v}"' - return v + return self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=withQuotes) + for v in self.value]) + v = self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=False) + for v in self.value]) + if withQuotes and v: + return f'"{v}"' + return v def updateInternals(self): super().updateInternals() @@ -753,9 +773,10 @@ def updateInternals(self): attr.updateInternals() @property - def isLinkNested(self): + def isLinkNested(self) -> bool: """ Whether the attribute or any of its elements is a link to another attribute. """ - # note: directly use self.node.graph._edges to avoid using the property that may become invalid at some point + # note: directly use self.node.graph._edges to avoid using the property that may become + # invalid at some point return self.isLink \ or self.node.graph and self.isInput and self.node.graph._edges \ and any(v in self.node.graph._edges.keys() for v in self._value) @@ -799,7 +820,8 @@ def getOutputConnections(self) -> list["Edge"]: class GroupAttribute(Attribute): - def __init__(self, node, attributeDesc, isOutput, root=None, parent=None): + def __init__(self, node, attributeDesc: desc.GroupAttribute, isOutput: bool, + root=None, parent=None): super().__init__(node, attributeDesc, isOutput, root, parent) def __getattr__(self, key): @@ -854,12 +876,12 @@ def resetToDefaultValue(self): self._value.get(attrDesc.name).resetToDefaultValue() @Slot(str, result=Attribute) - def childAttribute(self, key): + def childAttribute(self, key: str) -> Attribute: """ Get child attribute by name or None if none was found. Args: - key (str): the name of the child attribute + key: the name of the child attribute Returns: Attribute: the child attribute or None @@ -892,9 +914,8 @@ def defaultValue(self): def getPrimitiveValue(self, exportDefault=True): if exportDefault: return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items()} - else: - return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items() - if not attr.isDefault} + return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items() + if not attr.isDefault} def getValueStr(self, withQuotes=True): # add brackets if requested @@ -925,7 +946,7 @@ def updateInternals(self): attr.updateInternals() @Slot(str, result=bool) - def matchText(self, text): + def matchText(self, text: str) -> bool: return super().matchText(text) or any(c.matchText(text) for c in self._value) # Override value property diff --git a/meshroom/core/desc/node.py b/meshroom/core/desc/node.py index 5c663e019c..4d940c7592 100644 --- a/meshroom/core/desc/node.py +++ b/meshroom/core/desc/node.py @@ -39,8 +39,10 @@ class BaseNode(object): name="invalidation", label="Invalidation Message", description="A message that will invalidate the node's output folder.\n" - "This is useful for development, we can invalidate the output of the node when we modify the code.\n" - "It is displayed in bold font in the invalidation/comment messages tooltip.", + "This is useful for development, we can invalidate the output of the node " + "when we modify the code.\n" + "It is displayed in bold font in the invalidation/comment messages " + "tooltip.", value="", semantic="multiline", advanced=True, @@ -50,7 +52,8 @@ class BaseNode(object): name="comment", label="Comments", description="User comments describing this specific node instance.\n" - "It is displayed in regular font in the invalidation/comment messages tooltip.", + "It is displayed in regular font in the invalidation/comment messages " + "tooltip.", value="", semantic="multiline", invalidate=False, @@ -58,7 +61,8 @@ class BaseNode(object): StringParam( name="label", label="Node's Label", - description="Customize the default label (to replace the technical name of the node instance).", + description="Customize the default label (to replace the technical name of the node " + "instance).", value="", invalidate=False, ), diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 4a0f9aae8a..a5bff2a250 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -748,7 +748,12 @@ def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: A if dstName in outListAttributes: _recreateTargetListAttributeChildren(*outListAttributes[dstName]) try: - self.addEdge(self.attribute(srcName), self.attribute(dstName)) + srcAttr = self.attribute(srcName) + dstAttr = self.attribute(dstName) + if srcAttr is None or dstAttr is None: + logging.warning(f"Failed to restore edge {srcName}{' (missing)' if srcAttr is None else ''} -> {dstName}{' (missing)' if dstAttr is None else ''}") + continue + self.addEdge(srcAttr, dstAttr) except (KeyError, ValueError) as e: logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {e}") @@ -759,6 +764,30 @@ def upgradeAllNodes(self): for nodeName in nodeNames: self.upgradeNode(nodeName) + def reloadNodePlugins(self, nodeTypes: list[str]): + """ + Replace all the node instances of "nodeTypes" in the current graph with new node instances of the + same type. If the description of the nodes has changed, the reloaded nodes will reflect theses + changes. If "nodeTypes" is empty, then the function returns immediately. + + Args: + nodeTypes: the list of node types that will be reloaded. + """ + if not nodeTypes: + # No updated node to replace in the graph, nothing to do + return + + newNodes: dict[str, BaseNode] = {} + for node in self._nodes.values(): + if node.nodeType in nodeTypes: + newNode = nodeFactory(node.toDict(), node.nodeType, expectedUid=node._uid) + newNodes[node.name] = newNode + + # Replace in a different loop to ensure all the nodes have been looped over: when looping + # over self._nodes and replacing nodes at the same time, some nodes might not be reached + for name, node in newNodes.items(): + self.replaceNode(name, node) + @Slot(str, result=Attribute) def attribute(self, fullName): # type: (str) -> Attribute diff --git a/meshroom/core/node.py b/meshroom/core/node.py index bafe728854..50ee7ba617 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -18,7 +18,7 @@ import meshroom from meshroom.common import Signal, Variant, Property, BaseObject, Slot, ListModel, DictModel -from meshroom.core import desc, stats, hashValue, nodeVersion, Version, MrNodeType +from meshroom.core import desc, plugins, stats, hashValue, nodeVersion, Version, MrNodeType from meshroom.core.attribute import attributeFactory, ListAttribute, GroupAttribute, Attribute from meshroom.core.exception import NodeUpgradeError, UnknownNodeTypeError @@ -201,7 +201,7 @@ def fromDict(self, d): self.mrNodeType = d.get("mrNodeType", MrNodeType.NONE) if not isinstance(self.mrNodeType, MrNodeType): self.mrNodeType = MrNodeType[self.mrNodeType] - + self.nodeName = d.get("nodeName", "") self.nodeType = d.get("nodeType", "") self.packageName = d.get("packageName", "") @@ -639,10 +639,12 @@ def __init__(self, nodeType: str, position: Position = None, parent: BaseObject super().__init__(parent) self._nodeType: str = nodeType self.nodeDesc: desc.BaseNode = None + self.nodePlugin: plugins.Plugin = None # instantiate node description if nodeType is valid - if nodeType in meshroom.core.nodesDesc: - self.nodeDesc = meshroom.core.nodesDesc[nodeType]() + if meshroom.core.pluginManager.getRegisteredNodePlugin(nodeType): + self.nodeDesc = meshroom.core.pluginManager.getRegisteredNodePlugin(nodeType).nodeDescriptor() + self.nodePlugin = meshroom.core.pluginManager.getRegisteredNodePlugin(nodeType) self.packageName: str = "" self.packageVersion: str = "" @@ -1814,6 +1816,7 @@ class CompatibilityIssue(Enum): VersionConflict = 2 # mismatch between node's description version and serialized node data DescriptionConflict = 3 # mismatch between node's description attributes and serialized node data UidConflict = 4 # mismatch between computed UIDs and UIDs stored in serialized node data + PluginIssue = 5 # issue when loading the associated plugin class CompatibilityNode(BaseNode): diff --git a/meshroom/core/nodeFactory.py b/meshroom/core/nodeFactory.py index 7ca79fe53b..23997f7c83 100644 --- a/meshroom/core/nodeFactory.py +++ b/meshroom/core/nodeFactory.py @@ -54,7 +54,9 @@ def __init__( self.internalFolder = self.nodeData.get("internalFolder") self.position = Position(*self.nodeData.get("position", [])) self.uid = self.nodeData.get("uid", None) - self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None) + self.nodeDesc = None + if meshroom.core.pluginManager.isRegistered(self.nodeType): + self.nodeDesc = meshroom.core.pluginManager.getRegisteredNodePlugin(self.nodeType).nodeDescriptor def create(self) -> Union[Node, CompatibilityNode]: compatibilityIssue = self._checkCompatibilityIssues() @@ -74,6 +76,8 @@ def _normalizeNodeData(self): def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]: if self.nodeDesc is None: + if meshroom.core.pluginManager.belongsToPlugin(self.nodeType) is not None: + return CompatibilityIssue.PluginIssue return CompatibilityIssue.UnknownNodeType if not self._checkUidCompatibility(): @@ -121,13 +125,14 @@ def _checkAttributesNamesMatchDescription(self) -> bool: def _checkAttributesAreCompatibleWithDescription(self) -> bool: return ( self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs) - and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs) + and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, + self.internalInputs) and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs) ) def _checkInputAttributesNames(self) -> bool: def serializedInput(attr: desc.Attribute) -> bool: - """Filter that excludes not-serialized desc input attributes.""" + """ Filter that excludes not-serialized desc input attributes. """ if isinstance(attr, desc.PushButtonParam): # PushButtonParam are not serialized has they do not hold a value. return False @@ -138,7 +143,7 @@ def serializedInput(attr: desc.Attribute) -> bool: def _checkOutputAttributesNames(self) -> bool: def serializedOutput(attr: desc.Attribute) -> bool: - """Filter that excludes not-serialized desc output attributes.""" + """ Filter that excludes not-serialized desc output attributes. """ if attr.isDynamicValue: # Dynamic outputs values are not serialized with the node, # as their value is written in the computed output data. diff --git a/meshroom/core/plugins.py b/meshroom/core/plugins.py new file mode 100644 index 0000000000..a25ff5563e --- /dev/null +++ b/meshroom/core/plugins.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +import importlib +import logging +import os +import sys + +from enum import Enum +from inspect import getfile +from pathlib import Path + +from meshroom.common import BaseObject +from meshroom.core import desc + +def validateNodeDesc(nodeDesc: desc.Node) -> list: + """ + Check that the node has a valid description before being loaded. For the description + to be valid, the default value of every parameter needs to correspond to the type + of the parameter. + An empty returned list means that every parameter is valid, and so is the node's description. + If it is not valid, the returned list contains the names of the invalid parameters. In case + of nested parameters (parameters in groups or lists, for example), the name of the parameter + follows the name of the parent attributes. For example, if the attribute "x", contained in group + "group", is invalid, then it will be added to the list as "group:x". + + Args: + nodeDesc: description of the node. + + Returns: + errors: the list of invalid parameters if there are any, empty list otherwise + """ + errors = [] + + for param in nodeDesc.inputs: + err = param.checkValueTypes() + if err: + errors.append(err) + + for param in nodeDesc.outputs: + if param.value is None: + continue + err = param.checkValueTypes() + if err: + errors.append(err) + + return errors + + +class ProcessEnv(BaseObject): + """ + Describes the environment required by a node's process. + """ + + def __init__(self, folder: str): + super().__init__() + self.binPaths: list = [Path(folder, "bin")] + self.libPaths: list = [Path(folder, "lib"), Path(folder, "lib64")] + self.pythonPathFolders: list = [Path(folder)] + self.binPaths + + +class NodePluginStatus(Enum): + """ + Loading status for NodePlugin objects. + """ + NOT_LOADED = 0 # The node plugin exists but is not loaded and cannot be used (not registered) + LOADED = 1 # The node plugin is currently loaded and functional (it has been registered) + DESC_ERROR = 2 # The node plugin exists but has an invalid description + LOADING_ERROR = 3 # The node plugin exists and is valid but could not be successfully registered + ERROR = 4 # Error when importing the node plugin from its module + + +class Plugin(BaseObject): + """ + A collection of node plugins. + + Members: + name: the name of the plugin (e.g. name of the Python module containing the node plugins) + path: the absolute path of the plugin + _nodePlugins: dictionary mapping the name of a node plugin contained in the plugin + to its corresponding NodePlugin object + _templates: dictionary mapping the name of templates (.mg files) associated to the plugin + with their absolute paths + processEnv: the environment required for the nodes' processes to be correctly executed + """ + + def __init__(self, name: str, path: str): + super().__init__() + + self._name: str = name + self._path: str = path + + self._nodePlugins: dict[str: NodePlugin] = {} + self._templates: dict[str: str] = {} + self._processEnv: ProcessEnv = ProcessEnv(path) + + self.loadTemplates() + + @property + def name(self): + """ Return the name of the plugin. """ + return self._name + + @property + def path(self): + """ Return the absolute path of the plugin. """ + return self._path + + @property + def nodes(self): + """ + Return the dictionary containing the NodePlugin objects associated to + the plugin. + """ + return self._nodePlugins + + @property + def templates(self): + """ Return the list of templates associated to the plugin. """ + return self._templates + + @property + def processEnv(self): + """ Return the environment required to successfully execute processes. """ + return self._processEnv + + def addNodePlugin(self, nodePlugin: NodePlugin): + """ + Add a node plugin to the current plugin object and assign it as its containing plugin. + The node plugin is added to the dictionary of node plugins with the name of the node + descriptor as its key. + + Args: + nodePlugin: the NodePlugin object to add to the Plugin. + """ + self._nodePlugins[nodePlugin.nodeDescriptor.__name__] = nodePlugin + nodePlugin.plugin = self + + def removeNodePlugin(self, name: str): + """ + Remove a node plugin from the current plugin object and delete any container relationship. + + Args: + name: the name of the NodePlugin to remove. + """ + if name in self._nodePlugins: + self._nodePlugins[name].plugin = None + del self._nodePlugins[name] + else: + logging.warning(f"Node plugin {name} is not part of the plugin {self.name}.") + + def loadTemplates(self): + """ + Load all the pipeline templates that are available within the plugin folder. + Whenever this method is called, the list of templates for the plugin is cleared, + before being filled again. + """ + self._templates.clear() + for file in os.listdir(self.path): + if file.endswith(".mg"): + self._templates[os.path.splitext(file)[0]] = os.path.join(self.path, file) + + def containsNodePlugin(self, name: str) -> bool: + """ + Return whether the node plugin "name" is part of the plugin, independently from its + status. + + Args: + name: the name of the node plugin to be checked. + """ + return name in self._nodePlugins + + +class NodePlugin(BaseObject): + """ + Based on a node description, a NodePlugin represents a loadable node. + + Members: + plugin: the Plugin object that contains this node plugin + path: absolute path to the file containing the node's description + nodeDescriptor: the description of the node + status: the loading status on the node plugin + errors: the list of errors (if there are any) when validating the description + of the node or attempting to load it + processEnv: the environment required for the node plugin's process. It can either + be specific to this node plugin, or be common for all the node plugins within + the plugin + timestamp: the timestamp corresponding to the last time the node description's file has been + modified + """ + + def __init__(self, nodeDesc: desc.Node, plugin: Plugin = None): + super().__init__() + self.plugin: Plugin = plugin + self.path: str = Path(getfile(nodeDesc)).resolve().as_posix() + self.nodeDescriptor: desc.Node = nodeDesc + + self.status: NodePluginStatus = NodePluginStatus.NOT_LOADED + self.errors: list[str] = validateNodeDesc(nodeDesc) + + if self.errors: + self.status = NodePluginStatus.DESC_ERROR + + self._processEnv = None + self._timestamp = os.path.getmtime(self.path) + + def reload(self) -> bool: + """ + Reload the node plugin and update its status accordingly. If the timestamp of the node plugin's + path has not changed since the last time the plugin has been loaded, then nothing will happen. + + Returns: + bool: True if the node plugin has successfully been reloaded (i.e. there was no error, and + some changes were made since its last loading), False otherwise. + """ + timestamp = 0.0 + try: + timestamp = os.path.getmtime(self.path) + except FileNotFoundError: + self.status = NodePluginStatus.ERROR + logging.error(f"[Reload] {self.nodeDescriptor.__name__}: The path at {self.path} was not " + "not found.") + return False + + if self._timestamp == timestamp: + logging.info(f"[Reload] {self.nodeDescriptor.__name__}: Not reloading. The node description " + f"at {self.path} has not been modified since the last load.") + return False + + updated = importlib.reload(sys.modules.get(self.nodeDescriptor.__module__)) + descriptor = getattr(updated, self.nodeDescriptor.__name__) + + if not descriptor: + self.status = NodePluginStatus.ERROR + logging.error(f"[Reload] {self.nodeDescriptor.__name__}: The node description at {self.path} " + "was not found.") + return False + + self.errors = validateNodeDesc(descriptor) + if self.errors: + self.status = NodePluginStatus.DESC_ERROR + logging.error(f"[Reload] {self.nodeDescriptor.__name__}: The node description at {self.path} " + "has description errors.") + return False + + self.nodeDescriptor = descriptor + self._timestamp = timestamp + self.status = NodePluginStatus.NOT_LOADED + logging.info(f"[Reload] {self.nodeDescriptor.__name__}: Successful reloading.") + return True + + @property + def plugin(self): + """ + Return the Plugin object that contains this node plugin. + If the node plugin has not been assigned to a plugin yet, this value will + be set to None. + """ + return self._plugin + + @plugin.setter + def plugin(self, plugin: Plugin): + self._plugin = plugin + + @property + def processEnv(self): + """" + Return the process environment that is specific to the node plugin if it has any. + Otherwise, the Plugin's is returned. + """ + if self._processEnv: + return self._processEnv + if self.plugin: + return self.plugin.processEnv + return None + + +class NodePluginManager(BaseObject): + """ + Manager for all the loaded Plugin objects as well as the registered NodePlugin objects. + + Members: + _plugins: dictionary containing all the loaded Plugins, with their name as the key + _nodePlugins: dictionary containing all the NodePlugins that have been registered + (a NodePlugin may exist without having been registered) with their name as + the key + """ + + def __init__(self): + super().__init__() + + self._plugins: dict[str: Plugin] = {} # loaded plugins + self._nodePlugins: dict[str: NodePlugin] = {} # registered node plugins + + def isRegistered(self, name: str) -> bool: + """ + Return whether the node plugin has been registered already. + + Args: + name: the name of the node plugin whose registration needs to be checked. + """ + return name in self._nodePlugins + + def belongsToPlugin(self, name: str) -> Plugin: + """ + Check whether the node plugin belongs to a loaded plugin, independently from + whether it has been registered or not. + + Args: + name: the name of the node plugin that needs to be searched for across plugins. + + Returns: + Plugin | None: the Plugin the node belongs to if it exists, None otherwise. + """ + for plugin in self._plugins.values(): + if plugin.containsNodePlugin(name): + return plugin + return None + + def getPlugins(self) -> dict[str: Plugin]: + """ + Return a dictionary containing all the loaded Plugins, with {key, value} = + {name, Plugin}. + """ + return self._plugins + + def getPlugin(self, name: str) -> Plugin: + """ + Return the loaded Plugin object named "name". + + Args: + name: the name of the Plugin, used upon its loading. + + Returns: + Plugin | None: the loaded Plugin object if it exists, None otherwise. + """ + if name in self._plugins: + return self._plugins[name] + return None + + def addPlugin(self, plugin: Plugin, registerNodePlugins: bool = True): + """ + Load a Plugin object. + + Args: + plugin: the Plugin to load and add to the list of loaded plugins. + registerNodePlugins: True if all the NodePlugins from the plugin should be registered + at the same time the plugin is being loaded. Otherwise, the + NodePlugins will have to be registered at a later occasion. + """ + if not self.getPlugin(plugin.name): + self._plugins[plugin.name] = plugin + if registerNodePlugins: + for node in plugin.nodes: + self.registerNode(plugin.nodes[node]) + + def removePlugin(self, plugin: Plugin, unregisterNodePlugins: bool = True): + """ + Remove a loaded Plugin object. + + Args: + plugin: the Plugin to remove from the list of loaded plugins. + unregisterNodePlugins: True if all the nodes from the plugin should be unregistered (if they + are registered) at the same time as the plugin is unloaded. Otherwise, + the registered NodePlugins will remain while the Plugin itself will + be unloaded. + """ + if self.getPlugin(plugin.name): + if unregisterNodePlugins: + for node in plugin.nodes.values(): + self.unregisterNode(node) + del self._plugins[plugin.name] + + def getRegisteredNodePlugins(self) -> dict[str: NodePlugin]: + """ + Return a dictionary containing all the registered NodePlugins, with + {key, value} = {name, NodePlugin}. + """ + return self._nodePlugins + + def getRegisteredNodePlugin(self, name: str) -> NodePlugin: + """ + Return the NodePlugin object that has been registered under the name "name" if it exists. + + Args: + name: the name of the NodePlugin used for its registration. + + Returns: + NodePlugin | None: the loaded NodePlugin object if it exists, None otherwise. + """ + if self.isRegistered(name): + return self._nodePlugins[name] + return None + + def registerNode(self, nodePlugin: NodePlugin): + """ + Register a node plugin. A registered node plugin will become instantiable. + If it is already registered, or if there is an issue with the node description, + the node plugin will not be registered and its status will be updated. + + Args: + nodePlugin: the node plugin to register. + """ + name = nodePlugin.nodeDescriptor.__name__ + if not self.isRegistered(name) and nodePlugin.status not in (NodePluginStatus.DESC_ERROR, + NodePluginStatus.ERROR): + try: + self._nodePlugins[name] = nodePlugin + nodePlugin.status = NodePluginStatus.LOADED + except Exception as e: + logging.error(f"NodePlugin {name} could not be loaded: {e}") + nodePlugin.status = NodePluginStatus.LOADING_ERROR + + def unregisterNode(self, nodePlugin: NodePlugin): + """ + Unregister a node plugin. When unregistered, a node plugin cannot be instantiated anymore. + If it is not registered already, nothing happens. + + Args: + nodePlugin: the node plugin to unregister. + """ + name = nodePlugin.nodeDescriptor.__name__ + if self.isRegistered(name): + if nodePlugin.status != NodePluginStatus.LOADED: + logging.warning(f"NodePlugin {name} is registered but is not correctly loaded.") + else: + nodePlugin.status = NodePluginStatus.NOT_LOADED + del self._nodePlugins[name] diff --git a/meshroom/core/test.py b/meshroom/core/test.py index 728122fb9f..2ea2a1c07d 100644 --- a/meshroom/core/test.py +++ b/meshroom/core/test.py @@ -28,10 +28,10 @@ def checkTemplateVersions(path: str, nodesAlreadyLoaded: bool = False) -> bool: for _, nodeData in graphData.items(): nodeType = nodeData["nodeType"] - if not nodeType in meshroom.core.nodesDesc: + if not meshroom.core.pluginManager.isRegistered(nodeType): return False - nodeDesc = meshroom.core.nodesDesc[nodeType] + nodeDesc = meshroom.core.pluginManager.getRegisteredNodePlugin(nodeType) currentNodeVersion = meshroom.core.nodeVersion(nodeDesc) inputs = nodeData.get("inputs", {}) @@ -60,9 +60,9 @@ def checkTemplateVersions(path: str, nodesAlreadyLoaded: bool = False) -> bool: finally: if not nodesAlreadyLoaded: - nodeTypes = [nodeType for _, nodeType in meshroom.core.nodesDesc.items()] - for nodeType in nodeTypes: - unregisterNodeType(nodeType) + nodePlugins = meshroom.core.pluginManager.getRegisteredNodePlugins() + for node in nodePlugins: + meshroom.core.pluginManager.unregisterNode(node) def checkAllTemplatesVersions() -> bool: diff --git a/meshroom/ui/app.py b/meshroom/ui/app.py index e83c90e52b..37a44c974e 100644 --- a/meshroom/ui/app.py +++ b/meshroom/ui/app.py @@ -13,7 +13,7 @@ from PySide6.QtWidgets import QApplication import meshroom -from meshroom.core import nodesDesc +from meshroom.core import pluginManager from meshroom.core.taskManager import TaskManager from meshroom.common import Property, Variant, Signal, Slot @@ -261,7 +261,7 @@ def __init__(self, inputArgs): self.engine.addImportPath(qmlDir) # expose available node types that can be instantiated - self.engine.rootContext().setContextProperty("_nodeTypes", {n: {"category": nodesDesc[n].category} for n in sorted(nodesDesc.keys())}) + self.engine.rootContext().setContextProperty("_nodeTypes", {n: {"category": pluginManager.getRegisteredNodePlugins()[n].nodeDescriptor.category} for n in sorted(pluginManager.getRegisteredNodePlugins().keys())}) # instantiate Reconstruction object self._undoStack = commands.UndoStack(self) diff --git a/meshroom/ui/components/thumbnail.py b/meshroom/ui/components/thumbnail.py index acfd1591f7..06203dbc54 100644 --- a/meshroom/ui/components/thumbnail.py +++ b/meshroom/ui/components/thumbnail.py @@ -129,7 +129,7 @@ def clean(): # Compute storage duration since last usage of thumbnail lastUsage = f_stat.st_mtime storageTime = now - lastUsage - logging.debug(f'[ThumbnailCache] Thumbnail {f_name} has been stored for {storageTime}s') + # logging.debug(f'[ThumbnailCache] Thumbnail {f_name} has been stored for {storageTime}s') if storageTime > ThumbnailCache.storageTimeLimit * 3600 * 24: # Mark as removable if storage time exceeds limit diff --git a/meshroom/ui/qml/Application.qml b/meshroom/ui/qml/Application.qml index de7201449c..648992f183 100644 --- a/meshroom/ui/qml/Application.qml +++ b/meshroom/ui/qml/Application.qml @@ -552,6 +552,16 @@ Page { } } + Action { + id: reloadPluginsAction + property string tooltip: "Reload the source code for all nodes from all registered plugins" + text: "Reload Plugins Source Code" + shortcut: "Ctrl+Shift+R" + onTriggered: { + _reconstruction.reloadPlugins() + } + } + Action { id: undoAction @@ -830,6 +840,12 @@ Page { ToolTip.visible: hovered ToolTip.text: removeImagesFromAllGroupsAction.tooltip } + + MenuItem { + action: reloadPluginsAction + ToolTip.visible: hovered + ToolTip.text: reloadPluginsAction.tooltip + } } MenuSeparator { } Action { diff --git a/meshroom/ui/reconstruction.py b/meshroom/ui/reconstruction.py index b6fadba78d..62ff4ccee3 100755 --- a/meshroom/ui/reconstruction.py +++ b/meshroom/ui/reconstruction.py @@ -537,7 +537,7 @@ def initActiveNodes(self): # For all nodes declared to be accessed by the UI usedNodeTypes = {j for i in self.activeNodeCategories.values() for j in i} allUiNodes = set(self.uiNodes) | usedNodeTypes - allLoadedNodeTypes = set(meshroom.core.nodesDesc.keys()) + allLoadedNodeTypes = set(meshroom.core.pluginManager.getRegisteredNodePlugins().keys()) for nodeType in allUiNodes: self._activeNodes.add(ActiveNode(nodeType, parent=self)) @@ -552,6 +552,21 @@ def onCameraInitChanged(self): nodes = self._graph.dfsOnDiscover(startNodes=[self._cameraInit], reverse=True)[0] self.setActiveNodes(nodes) + @Slot() + def reloadPlugins(self): + """ + Reload all the NodePlugins from all the registered plugins. + The nodes in the graph will be updated to match the changes in the description, if + there was any. + """ + nodeTypes: list[str] = [] + for plugin in meshroom.core.pluginManager.getPlugins().values(): + for node in plugin.nodes.values(): + if node.reload(): + nodeTypes.append(node.nodeDescriptor.__name__) + + self._graph.reloadNodePlugins(nodeTypes) + @Slot() @Slot(str) def new(self, pipeline=None): @@ -684,7 +699,7 @@ def setupTempCameraInit(self, node, attrName): if not sfmFile or not os.path.isfile(sfmFile): self.tempCameraInit = None return - nodeDesc = meshroom.core.nodesDesc["CameraInit"]() + nodeDesc = meshroom.core.pluginManager.getRegisteredNodePlugin("CameraInit") views, intrinsics = nodeDesc.readSfMData(sfmFile) tmpCameraInit = Node("CameraInit", viewpoints=views, intrinsics=intrinsics) tmpCameraInit.locked = True @@ -736,14 +751,16 @@ def lastNodeOfType(self, nodeTypes, startNode, preferredStatus=None): """ if not startNode: return None - nodes = self._graph.dfsOnDiscover(startNodes=[startNode], filterTypes=nodeTypes, reverse=True)[0] + nodes = self._graph.dfsOnDiscover(startNodes=[startNode], + filterTypes=nodeTypes, reverse=True)[0] if not nodes: return None # order the nodes according to their depth in the graph, then according to their name nodes.sort(key=lambda n: (n.depth, n.name)) node = nodes[-1] if preferredStatus: - node = next((n for n in reversed(nodes) if n.getGlobalStatus() == preferredStatus), node) + node = next((n for n in reversed(nodes) + if n.getGlobalStatus() == preferredStatus), node) return node def addSfmAugmentation(self, withMVS=False): @@ -800,7 +817,8 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): This method allows to reduce process time by doing it on Python side. Args: - {images, videos, panoramaInfo, meshroomScenes, otherFiles}: Map containing the lists of paths for recognized images, videos, Meshroom scenes and other files. + {images, videos, panoramaInfo, meshroomScenes, otherFiles}: Map containing the + lists of paths for recognized images, videos, Meshroom scenes and other files. Node: cameraInit node used to add new images to it QPoint: position to locate the node (usually the mouse position) """ @@ -821,7 +839,8 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): else: p = position cameraInit = self.addNewNode("CameraInit", position=p) - self._workerThreads.apply_async(func=self.importImagesSync, args=(filesByType["images"], cameraInit,)) + self._workerThreads.apply_async(func=self.importImagesSync, + args=(filesByType["images"], cameraInit,)) if filesByType["videos"]: if self.nodes: boundingBox = self.layout.boundingBox() @@ -840,7 +859,8 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): newVideoNodeMessage, "Warning: You need to manually compute the KeyframeSelection node \n" "and then reimport the created images into Meshroom for the reconstruction.\n\n" - "If you know the Camera Make/Model, it is highly recommended to declare them in the Node." + "If you know the Camera Make/Model, it is highly recommended to declare " + "them in the Node." )) if filesByType["panoramaInfo"]: @@ -848,15 +868,15 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): self.error.emit( Message( "Multiple XML files in input", - "Ignore the xml Panorama files:\n\n'{}'.".format(',\n'.join(filesByType["panoramaInfo"])), + "Ignore the XML Panorama files:\n\n'{}'.".format(',\n'.join(filesByType["panoramaInfo"])), "", )) else: - panoramaInitNodes = self.graph.nodesOfType('PanoramaInit') + panoramaInitNodes = self.graph.nodesOfType("PanoramaInit") for panoramaInfoFile in filesByType["panoramaInfo"]: for panoramaInitNode in panoramaInitNodes: - panoramaInitNode.attribute('initializeCameras').value = 'File' - panoramaInitNode.attribute('config').value = panoramaInfoFile + panoramaInitNode.attribute("initializeCameras").value = "File" + panoramaInitNode.attribute("config").value = panoramaInfoFile if panoramaInitNodes: self.info.emit( Message( @@ -918,7 +938,11 @@ def getFilesByTypeFromDrop(self, urls): filesByType.extend(multiview.findFilesByTypeInFolder(localFile)) else: filesByType.addFile(localFile) - return {"images": filesByType.images, "videos": filesByType.videos, "panoramaInfo": filesByType.panoramaInfo, "meshroomScenes": filesByType.meshroomScenes, "other": filesByType.other} + return {"images": filesByType.images, + "videos": filesByType.videos, + "panoramaInfo": filesByType.panoramaInfo, + "meshroomScenes": filesByType.meshroomScenes, + "other": filesByType.other} def importImagesFromFolder(self, path, recursive=False): """ diff --git a/tests/__init__.py b/tests/__init__.py index 8a6bd273ca..6ebeba20f3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,8 +1,12 @@ import os -from meshroom.core import loadAllNodes, initPipelines +from meshroom.core import loadAllNodes +from meshroom.core import pluginManager + +plugins = loadAllNodes(os.path.join(os.path.dirname(__file__), "nodes")) +for plugin in plugins: + pluginManager.addPlugin(plugin) -loadAllNodes(os.path.join(os.path.dirname(__file__), "nodes")) if os.getenv("MESHROOM_PIPELINE_TEMPLATES_PATH", False): os.environ["MESHROOM_PIPELINE_TEMPLATES_PATH"] += os.pathsep + os.path.dirname(os.path.realpath(__file__)) else: diff --git a/tests/nodes/test/appendFiles.py b/tests/nodes/test/appendFiles.py index 9f9c0b236e..abaa8575b1 100644 --- a/tests/nodes/test/appendFiles.py +++ b/tests/nodes/test/appendFiles.py @@ -39,4 +39,3 @@ class AppendFiles(desc.CommandLineNode): value='{nodeCacheFolder}/appendText.txt', ) ] - diff --git a/tests/nodes/test/ls.py b/tests/nodes/test/ls.py index d359936868..9d683e174c 100644 --- a/tests/nodes/test/ls.py +++ b/tests/nodes/test/ls.py @@ -2,21 +2,21 @@ class Ls(desc.CommandLineNode): - commandLine = 'ls {inputValue} > {outputValue}' + commandLine = "ls {inputValue} > {outputValue}" inputs = [ desc.File( - name='input', - label='Input', - description='''''', - value='', + name="input", + label="Input", + description="", + value="", ) ] outputs = [ desc.File( - name='output', - label='Output', - description='''''', - value='{nodeCacheFolder}/ls.txt', + name="output", + label="Output", + description="", + value="{nodeCacheFolder}/ls.txt", ) ] diff --git a/tests/plugins/meshroom/pluginA/PluginANodeA.py b/tests/plugins/meshroom/pluginA/PluginANodeA.py new file mode 100644 index 0000000000..95182ec993 --- /dev/null +++ b/tests/plugins/meshroom/pluginA/PluginANodeA.py @@ -0,0 +1,22 @@ +__version__ = "1.0" + +from meshroom.core import desc + +class PluginANodeA(desc.Node): + inputs = [ + desc.File( + name="input", + label="Input", + description="", + value="", + ), + ] + + outputs = [ + desc.File( + name="output", + label="Output", + description="", + value="", + ), + ] \ No newline at end of file diff --git a/tests/plugins/meshroom/pluginA/PluginANodeB.py b/tests/plugins/meshroom/pluginA/PluginANodeB.py new file mode 100644 index 0000000000..ff048e943c --- /dev/null +++ b/tests/plugins/meshroom/pluginA/PluginANodeB.py @@ -0,0 +1,28 @@ +__version__ = "1.0" + +from meshroom.core import desc + +class PluginANodeB(desc.Node): + inputs = [ + desc.File( + name="input", + label="Input", + description="", + value="", + ), + desc.IntParam( + name="int", + label="Integer", + description="", + value=1, + ), + ] + + outputs = [ + desc.File( + name="output", + label="Output", + description="", + value="", + ), + ] \ No newline at end of file diff --git a/tests/plugins/meshroom/pluginA/__init__.py b/tests/plugins/meshroom/pluginA/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/meshroom/pluginB/PluginBNodeA.py b/tests/plugins/meshroom/pluginB/PluginBNodeA.py new file mode 100644 index 0000000000..f25ce6b0b8 --- /dev/null +++ b/tests/plugins/meshroom/pluginB/PluginBNodeA.py @@ -0,0 +1,22 @@ +__version__ = "1.0" + +from meshroom.core import desc + +class PluginBNodeA(desc.Node): + inputs = [ + desc.File( + name="input", + label="Input", + description="", + value="", + ), + ] + + outputs = [ + desc.File( + name="output", + label="Output", + description="", + value="", + ), + ] \ No newline at end of file diff --git a/tests/plugins/meshroom/pluginB/PluginBNodeB.py b/tests/plugins/meshroom/pluginB/PluginBNodeB.py new file mode 100644 index 0000000000..613d39362e --- /dev/null +++ b/tests/plugins/meshroom/pluginB/PluginBNodeB.py @@ -0,0 +1,28 @@ +__version__ = "1.0" + +from meshroom.core import desc + +class PluginBNodeB(desc.Node): + inputs = [ + desc.File( + name="input", + label="Input", + description="", + value="", + ), + desc.IntParam( + name="int", + label="Integer", + description="", + value="not an integer", + ), + ] + + outputs = [ + desc.File( + name="output", + label="Output", + description="", + value="", + ), + ] \ No newline at end of file diff --git a/tests/plugins/meshroom/pluginB/__init__.py b/tests/plugins/meshroom/pluginB/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/meshroom/sharedTemplate.mg b/tests/plugins/meshroom/sharedTemplate.mg new file mode 100644 index 0000000000..4e265f8c9f --- /dev/null +++ b/tests/plugins/meshroom/sharedTemplate.mg @@ -0,0 +1,10 @@ +{ + "header": { + "releaseVersion": "2025.1.0-develop", + "fileVersion": "2.0", + "nodesVersions": {}, + "template": true + }, + "graph": { + } +} \ No newline at end of file diff --git a/tests/test_attributeChoiceParam.py b/tests/test_attributeChoiceParam.py index 6527f99d1d..094bb3a12b 100644 --- a/tests/test_attributeChoiceParam.py +++ b/tests/test_attributeChoiceParam.py @@ -1,6 +1,7 @@ -from meshroom.core import desc, registerNodeType, unregisterNodeType +from meshroom.core import desc, pluginManager from meshroom.core.graph import Graph, loadGraph +from .utils import registerNodeDesc, unregisterNodeDesc class NodeWithChoiceParams(desc.Node): inputs = [ @@ -52,13 +53,14 @@ class NodeWithChoiceParamsSavingValuesOverride(desc.Node): class TestChoiceParam: + @classmethod def setup_class(cls): - registerNodeType(NodeWithChoiceParams) + registerNodeDesc(NodeWithChoiceParams) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithChoiceParams) + unregisterNodeDesc(NodeWithChoiceParams) def test_customValueIsSerialized(self, graphSavedOnDisk): graph: Graph = graphSavedOnDisk @@ -84,7 +86,7 @@ def test_overridenValuesAreNotSerialized(self, graphSavedOnDisk): graph: Graph = graphSavedOnDisk node = graph.addNewNode(NodeWithChoiceParams.__name__) node.choice.values = ["D", "E", "F"] - + graph.save() loadedGraph = loadGraph(graph.filepath) @@ -117,13 +119,14 @@ def test_connectionsAreSerialized(self, graphSavedOnDisk): class TestChoiceParamSavingCustomValues: + @classmethod def setup_class(cls): - registerNodeType(NodeWithChoiceParamsSavingValuesOverride) + registerNodeDesc(NodeWithChoiceParamsSavingValuesOverride) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithChoiceParamsSavingValuesOverride) + unregisterNodeDesc(NodeWithChoiceParamsSavingValuesOverride) def test_customValueIsSerialized(self, graphSavedOnDisk): graph: Graph = graphSavedOnDisk @@ -143,7 +146,7 @@ def test_overridenValuesAreSerialized(self, graphSavedOnDisk): node = graph.addNewNode(NodeWithChoiceParamsSavingValuesOverride.__name__) node.choice.values = ["D", "E", "F"] node.choiceMulti.values = ["D", "E", "F"] - + graph.save() loadedGraph = loadGraph(graph.filepath) @@ -166,4 +169,4 @@ def test_connectionsAreSerialized(self, graphSavedOnDisk): loadedNodeA = loadedGraph.node(nodeA.name) loadedNodeB = loadedGraph.node(nodeB.name) assert loadedNodeB.choice.linkParam == loadedNodeA.choice - assert loadedNodeB.choiceMulti.linkParam == loadedNodeA.choiceMulti \ No newline at end of file + assert loadedNodeB.choiceMulti.linkParam == loadedNodeA.choiceMulti diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 174380bc6a..3d3ea4d811 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -7,20 +7,21 @@ from typing import Type import pytest -import meshroom.core -from meshroom.core import desc, registerNodeType, unregisterNodeType +from meshroom.core import desc, pluginManager +from meshroom.core.plugins import NodePlugin from meshroom.core.exception import GraphCompatibilityError, NodeUpgradeError from meshroom.core.graph import Graph, loadGraph from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node -from .utils import registeredNodeTypes, overrideNodeTypeVersion +from .utils import registeredNodeTypes, overrideNodeTypeVersion, registerNodeDesc, unregisterNodeDesc SampleGroupV1 = [ desc.IntParam(name="a", label="a", description="", value=0, range=None), desc.ListAttribute( name="b", - elementDesc=desc.FloatParam(name="p", label="", description="", value=0.0, range=None), + elementDesc=desc.FloatParam(name="p", label="", + description="", value=0.0, range=None), label="b", description="", ) @@ -30,7 +31,8 @@ desc.IntParam(name="a", label="a", description="", value=0, range=None), desc.ListAttribute( name="b", - elementDesc=desc.GroupAttribute(name="p", label="", description="", groupDesc=SampleGroupV1), + elementDesc=desc.GroupAttribute(name="p", label="", + description="", groupDesc=SampleGroupV1), label="b", description="", ) @@ -39,10 +41,12 @@ # SampleGroupV3 is SampleGroupV2 with one more int parameter SampleGroupV3 = [ desc.IntParam(name="a", label="a", description="", value=0, range=None), - desc.IntParam(name="notInSampleGroupV2", label="notInSampleGroupV2", description="", value=0, range=None), + desc.IntParam(name="notInSampleGroupV2", label="notInSampleGroupV2", + description="", value=0, range=None), desc.ListAttribute( name="b", - elementDesc=desc.GroupAttribute(name="p", label="", description="", groupDesc=SampleGroupV1), + elementDesc=desc.GroupAttribute(name="p", label="", + description="", groupDesc=SampleGroupV1), label="b", description="", ) @@ -52,11 +56,12 @@ class SampleNodeV1(desc.Node): """ Version 1 Sample Node """ inputs = [ - desc.File(name='input', label='Input', description='', value='',), - desc.StringParam(name='paramA', label='ParamA', description='', value='', invalidate=False) # No impact on UID + desc.File(name="input", label="Input", description="", value=""), + desc.StringParam(name="paramA", label="ParamA", description="", + value="", invalidate=False) # No impact on UID ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -65,11 +70,12 @@ class SampleNodeV2(desc.Node): * 'input' has been renamed to 'in' """ inputs = [ - desc.File(name='in', label='Input', description='', value='',), - desc.StringParam(name='paramA', label='ParamA', description='', value='', invalidate=False), # No impact on UID + desc.File(name="in", label="Input", description="", value=""), + desc.StringParam(name="paramA", label="ParamA", description="", + value="", invalidate=False), # No impact on UID ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -79,10 +85,10 @@ class SampleNodeV3(desc.Node): * 'paramA' has been removed' """ inputs = [ - desc.File(name='in', label='Input', description='', value='',), + desc.File(name="in", label="Input", description="", value=""), ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -92,14 +98,14 @@ class SampleNodeV4(desc.Node): * 'paramA' has been added """ inputs = [ - desc.File(name='in', label='Input', description='', value='',), - desc.ListAttribute(name='paramA', label='ParamA', + desc.File(name="in", label="Input", description="", value=""), + desc.ListAttribute(name="paramA", label="ParamA", elementDesc=desc.GroupAttribute( - groupDesc=SampleGroupV1, name='gA', label='gA', description=''), - description='') + groupDesc=SampleGroupV1, name="gA", label="gA", description=""), + description="") ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -109,14 +115,14 @@ class SampleNodeV5(desc.Node): * 'paramA' elementDesc has changed from SampleGroupV1 to SampleGroupV2 """ inputs = [ - desc.File(name='in', label='Input', description='', value=''), - desc.ListAttribute(name='paramA', label='ParamA', + desc.File(name="in", label="Input", description="", value=""), + desc.ListAttribute(name="paramA", label="ParamA", elementDesc=desc.GroupAttribute( - groupDesc=SampleGroupV2, name='gA', label='gA', description=''), - description='') + groupDesc=SampleGroupV2, name="gA", label="gA", description=""), + description="") ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -126,24 +132,25 @@ class SampleNodeV6(desc.Node): * 'paramA' elementDesc has changed from SampleGroupV2 to SampleGroupV3 """ inputs = [ - desc.File(name='in', label='Input', description='', value=''), - desc.ListAttribute(name='paramA', label='ParamA', + desc.File(name="in", label="Input", description="", value=""), + desc.ListAttribute(name="paramA", label="ParamA", elementDesc=desc.GroupAttribute( - groupDesc=SampleGroupV3, name='gA', label='gA', description=''), - description='') + groupDesc=SampleGroupV3, name="gA", label="gA", description=""), + description="") ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] class SampleInputNodeV1(desc.InputNode): """ Version 1 Sample Input Node """ inputs = [ - desc.StringParam(name='path', label='path', description='', value='', invalidate=False) # No impact on UID + desc.StringParam(name="path", label="Path", description="", + value="", invalidate=False) # No impact on UID ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] @@ -152,50 +159,52 @@ class SampleInputNodeV2(desc.InputNode): * 'path' has been renamed to 'in' """ inputs = [ - desc.StringParam(name='in', label='path', description='', value='', invalidate=False) # No impact on UID + desc.StringParam(name="in", label="path", description="", + value="", invalidate=False) # No impact on UID ] outputs = [ - desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") + desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}") ] def replaceNodeTypeDesc(nodeType: str, nodeDesc: Type[desc.Node]): """Change the `nodeDesc` associated to `nodeType`.""" - meshroom.core.nodesDesc[nodeType] = nodeDesc + pluginManager.getRegisteredNodePlugins()[nodeType] = NodePlugin(nodeDesc) def test_unknown_node_type(): """ Test compatibility behavior for unknown node type. """ - registerNodeType(SampleNodeV1) - g = Graph('') + registerNodeDesc(SampleNodeV1) + g = Graph("") n = g.addNewNode("SampleNodeV1", input="/dev/null", paramA="foo") graphFile = os.path.join(tempfile.mkdtemp(), "test_unknown_node_type.mg") g.save(graphFile) internalFolder = n.internalFolder nodeName = n.name - unregisterNodeType(SampleNodeV1) + unregisterNodeDesc(SampleNodeV1) - # reload file + + # Reload file g = loadGraph(graphFile) os.remove(graphFile) assert len(g.nodes) == 1 n = g.node(nodeName) # SampleNodeV1 is now an unknown type - # check node instance type and compatibility issue type + # Check node instance type and compatibility issue type assert isinstance(n, CompatibilityNode) assert n.issue == CompatibilityIssue.UnknownNodeType - # check if attributes are properly restored + # Check if attributes are properly restored assert len(n.attributes) == 3 assert n.input.isInput assert n.output.isOutput - # check if internal folder + # Check if internal folder assert n.internalFolder == internalFolder - # upgrade can't be perform on unknown node types + # Upgrade can't be perform on unknown node types assert not n.canUpgrade with pytest.raises(NodeUpgradeError): g.upgradeNode(nodeName) @@ -205,20 +214,20 @@ def test_description_conflict(): """ Test compatibility behavior for conflicting node descriptions. """ - # copy registered node types to be able to restore them - originalNodeTypes = copy.copy(meshroom.core.nodesDesc) + # Copy registered node types to be able to restore them + originalNodeTypes = copy.deepcopy(pluginManager.getRegisteredNodePlugins()) nodeTypes = [SampleNodeV1, SampleNodeV2, SampleNodeV3, SampleNodeV4, SampleNodeV5] nodes = [] - g = Graph('') + g = Graph("") - # register and instantiate instances of all node types except last one + # Register and instantiate instances of all node types except last one for nt in nodeTypes[:-1]: - registerNodeType(nt) + registerNodeDesc(nt) n = g.addNewNode(nt.__name__) if nt == SampleNodeV4: - # initialize list attribute with values to create a conflict with V5 + # Initialize list attribute with values to create a conflict with V5 n.paramA.value = [{'a': 0, 'b': [1.0, 2.0]}] nodes.append(n) @@ -226,15 +235,15 @@ def test_description_conflict(): graphFile = os.path.join(tempfile.mkdtemp(), "test_description_conflict.mg") g.save(graphFile) - # reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances) + # Reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances) loadGraph(graphFile, strictCompatibility=True) - # offset node types register to create description conflicts - # each node type name now reference the next one's implementation + # Offset node types register to create description conflicts + # Each node type name now reference the next one's implementation for i, nt in enumerate(nodeTypes[:-1]): - meshroom.core.nodesDesc[nt.__name__] = nodeTypes[i+1] + pluginManager.getRegisteredNodePlugins()[nt.__name__] = NodePlugin(nodeTypes[i + 1]) - # reload file + # Reload file g = loadGraph(graphFile) os.remove(graphFile) @@ -246,7 +255,7 @@ def test_description_conflict(): assert isinstance(compatNode, CompatibilityNode) assert srcNode.internalFolder == compatNode.internalFolder - # case by case description conflict verification + # Case by case description conflict verification if isinstance(srcNode.nodeDesc, SampleNodeV1): # V1 => V2: 'input' has been renamed to 'in' assert len(compatNode.attributes) == 3 @@ -254,27 +263,29 @@ def test_description_conflict(): assert hasattr(compatNode, "input") assert not hasattr(compatNode, "in") - # perform upgrade + # Perform upgrade upgradedNode = g.upgradeNode(nodeName) - assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2) + assert isinstance(upgradedNode, Node) and \ + isinstance(upgradedNode.nodeDesc, SampleNodeV2) assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"] assert not hasattr(upgradedNode, "input") assert hasattr(upgradedNode, "in") - # check uid has changed (not the same set of attributes) + # Check UID has changed (not the same set of attributes) assert upgradedNode.internalFolder != srcNode.internalFolder elif isinstance(srcNode.nodeDesc, SampleNodeV2): - # V2 => V3: 'paramA' has been removed' + # V2 => V3: 'paramA' has been removed assert len(compatNode.attributes) == 3 assert hasattr(compatNode, "paramA") - # perform upgrade + # Perform upgrade upgradedNode = g.upgradeNode(nodeName) - assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3) + assert isinstance(upgradedNode, Node) and \ + isinstance(upgradedNode.nodeDesc, SampleNodeV3) assert not hasattr(upgradedNode, "paramA") - # check uid is identical (paramA not part of uid) + # Check UID is identical (paramA not part of UID) assert upgradedNode.internalFolder == srcNode.internalFolder elif isinstance(srcNode.nodeDesc, SampleNodeV3): @@ -282,9 +293,10 @@ def test_description_conflict(): assert len(compatNode.attributes) == 2 assert not hasattr(compatNode, "paramA") - # perform upgrade + # Perform upgrade upgradedNode = g.upgradeNode(nodeName) - assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4) + assert isinstance(upgradedNode, Node) and \ + isinstance(upgradedNode.nodeDesc, SampleNodeV4) assert hasattr(upgradedNode, "paramA") assert isinstance(upgradedNode.paramA.attributeDesc, desc.ListAttribute) @@ -298,32 +310,33 @@ def test_description_conflict(): groupAttribute = compatNode.paramA.attributeDesc.elementDesc assert isinstance(groupAttribute, desc.GroupAttribute) - # check that Compatibility node respect SampleGroupV1 description + # Check that Compatibility node respect SampleGroupV1 description for elt in groupAttribute.groupDesc: - assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__) + assert isinstance(elt, + next(a for a in SampleGroupV1 if a.name == elt.name).__class__) - # perform upgrade + # Perform upgrade upgradedNode = g.upgradeNode(nodeName) - assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5) + assert isinstance(upgradedNode, Node) and \ + isinstance(upgradedNode.nodeDesc, SampleNodeV5) assert hasattr(upgradedNode, "paramA") - # parameter was incompatible, value could not be restored + # Parameter was incompatible, value could not be restored assert upgradedNode.paramA.isDefault assert upgradedNode.internalFolder != srcNode.internalFolder else: raise ValueError("Unexpected node type: " + srcNode.nodeType) - # restore original node types - meshroom.core.nodesDesc = originalNodeTypes - + # Restore original node types + pluginManager._nodePlugins = originalNodeTypes def test_upgradeAllNodes(): - registerNodeType(SampleNodeV1) - registerNodeType(SampleNodeV2) - registerNodeType(SampleInputNodeV1) - registerNodeType(SampleInputNodeV2) + registerNodeDesc(SampleNodeV1) + registerNodeDesc(SampleNodeV2) + registerNodeDesc(SampleInputNodeV1) + registerNodeDesc(SampleInputNodeV2) - g = Graph('') + g = Graph("") n1 = g.addNewNode("SampleNodeV1") n2 = g.addNewNode("SampleNodeV2") n3 = g.addNewNode("SampleInputNodeV1") @@ -335,74 +348,78 @@ def test_upgradeAllNodes(): graphFile = os.path.join(tempfile.mkdtemp(), "test_description_conflict.mg") g.save(graphFile) - # make SampleNodeV2 and SampleInputNodeV2 an unknown type - unregisterNodeType(SampleNodeV2) - unregisterNodeType(SampleInputNodeV2) - # replace SampleNodeV1 by SampleNodeV2 and SampleInputNodeV1 by SampleInputNodeV2 - meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2 - meshroom.core.nodesDesc[SampleInputNodeV1.__name__] = SampleInputNodeV2 + # Replace SampleNodeV1 by SampleNodeV2 and SampleInputNodeV1 by SampleInputNodeV2 + pluginManager.getRegisteredNodePlugins()[SampleNodeV1.__name__] = \ + pluginManager.getRegisteredNodePlugin(SampleNodeV2.__name__) + pluginManager.getRegisteredNodePlugins()[SampleInputNodeV1.__name__] = \ + pluginManager.getRegisteredNodePlugin(SampleInputNodeV2.__name__) + + # Make SampleNodeV2 and SampleInputNodeV2 an unknown type + unregisterNodeDesc(SampleNodeV2) + unregisterNodeDesc(SampleInputNodeV2) - # reload file + # Reload file g = loadGraph(graphFile) os.remove(graphFile) - # both nodes are CompatibilityNodes + # Both nodes are CompatibilityNodes assert len(g.compatibilityNodes) == 4 assert g.node(n1Name).canUpgrade # description conflict assert not g.node(n2Name).canUpgrade # unknown type assert g.node(n3Name).canUpgrade # description conflict assert not g.node(n4Name).canUpgrade # unknown type - # upgrade all upgradable nodes + # Upgrade all upgradable nodes g.upgradeAllNodes() - # only the nodes with an unknown type have not been upgraded + # Only the nodes with an unknown type have not been upgraded assert len(g.compatibilityNodes) == 2 assert n2Name in g.compatibilityNodes.keys() assert n4Name in g.compatibilityNodes.keys() - unregisterNodeType(SampleNodeV1) - unregisterNodeType(SampleInputNodeV1) + unregisterNodeDesc(SampleNodeV1) + unregisterNodeDesc(SampleInputNodeV1) def test_conformUpgrade(): - registerNodeType(SampleNodeV5) - registerNodeType(SampleNodeV6) + registerNodeDesc(SampleNodeV5) + registerNodeDesc(SampleNodeV6) - g = Graph('') + g = Graph("") n1 = g.addNewNode("SampleNodeV5") - n1.paramA.value = [{'a': 0, 'b': [{'a': 0, 'b': [1.0, 2.0]}, {'a': 1, 'b': [1.0, 2.0]}]}] + n1.paramA.value = [{"a": 0, "b": [{"a": 0, "b": [1.0, 2.0]}, {"a": 1, "b": [1.0, 2.0]}]}] n1Name = n1.name graphFile = os.path.join(tempfile.mkdtemp(), "test_conform_upgrade.mg") g.save(graphFile) - # replace SampleNodeV5 by SampleNodeV6 - meshroom.core.nodesDesc[SampleNodeV5.__name__] = SampleNodeV6 + # Replace SampleNodeV5 by SampleNodeV6 + pluginManager.getRegisteredNodePlugins()[SampleNodeV5.__name__] = \ + pluginManager.getRegisteredNodePlugin(SampleNodeV6.__name__) - # reload file + # Reload file g = loadGraph(graphFile) os.remove(graphFile) - # node is a CompatibilityNode + # Node is a CompatibilityNode assert len(g.compatibilityNodes) == 1 assert g.node(n1Name).canUpgrade - # upgrade all upgradable nodes + # Upgrade all upgradable nodes g.upgradeAllNodes() - # only the node with an unknown type has not been upgraded + # Only the node with an unknown type has not been upgraded assert len(g.compatibilityNodes) == 0 upgradedNode = g.node(n1Name) - # check upgrade + # Check upgrade assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV6) - # check conformation + # Check conformation assert len(upgradedNode.paramA.value) == 1 - unregisterNodeType(SampleNodeV5) - unregisterNodeType(SampleNodeV6) + unregisterNodeDesc(SampleNodeV5) + unregisterNodeDesc(SampleNodeV6) class TestGraphLoadingWithStrictCompatibility: @@ -418,7 +435,6 @@ def test_failsOnUnknownNodeType(self, graphSavedOnDisk): def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk): - with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): graph: Graph = graphSavedOnDisk graph.addNewNode(SampleNodeV1.__name__) @@ -433,7 +449,6 @@ def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk): class TestGraphTemplateLoading: def test_failsOnUnknownNodeTypeError(self, graphSavedOnDisk): - with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): graph: Graph = graphSavedOnDisk graph.addNewNode(SampleNodeV1.__name__) @@ -482,7 +497,7 @@ def test_loadingConflictingNodeVersionCreatesCompatibilityNodes(self, graphSaved with overrideNodeTypeVersion(SampleNodeV1, "1.0"): node = graph.addNewNode(SampleNodeV1.__name__) graph.save() - + with overrideNodeTypeVersion(SampleNodeV1, "2.0"): otherGraph = Graph("") otherGraph.load(graph.filepath) @@ -496,7 +511,7 @@ def test_loadingUnspecifiedNodeVersionAssumesCurrentVersion(self, graphSavedOnDi with registeredNodeTypes([SampleNodeV1]): graph.addNewNode(SampleNodeV1.__name__) graph.save() - + with overrideNodeTypeVersion(SampleNodeV1, "2.0"): otherGraph = Graph("") otherGraph.load(graph.filepath) @@ -508,7 +523,8 @@ class UidTestingNodeV1(desc.Node): inputs = [ desc.File(name="input", label="Input", description="", value="", invalidate=True), ] - outputs = [desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}")] + outputs = [desc.File(name="output", label="Output", + description="", value="{nodeCacheFolder}")] class UidTestingNodeV2(desc.Node): @@ -554,7 +570,8 @@ class UidTestingNodeV3(desc.Node): description="", ), ] - outputs = [desc.File(name="output", label="Output", description="", value="{nodeCacheFolder}")] + outputs = [desc.File(name="output", label="Output", + description="", value="{nodeCacheFolder}")] class TestUidConflict: @@ -626,7 +643,8 @@ def checkNodeAConnectionsToNodeB(): assert len(loadedGraph.compatibilityNodes) == 0 - def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, graphSavedOnDisk): + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection( + self, graphSavedOnDisk): with registeredNodeTypes([UidTestingNodeV1, UidTestingNodeV2]): graph: Graph = graphSavedOnDisk nodeA = graph.addNewNode(UidTestingNodeV2.__name__) @@ -640,7 +658,8 @@ def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, loadedGraph = loadGraph(graph.filepath) assert len(loadedGraph.compatibilityNodes) == 1 - def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection(self, graphSavedOnDisk): + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection( + self,graphSavedOnDisk): with registeredNodeTypes([UidTestingNodeV2, UidTestingNodeV3]): graph: Graph = graphSavedOnDisk nodeA = graph.addNewNode(UidTestingNodeV2.__name__) diff --git a/tests/test_graph.py b/tests/test_graph.py index 003bc6ed2b..280cd46f7d 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,11 +2,11 @@ def test_depth(): - graph = Graph('Tests tasks depth') + graph = Graph("Tests tasks depth") - tA = graph.addNewNode('Ls', input='/tmp') - tB = graph.addNewNode('AppendText', inputText='echo B') - tC = graph.addNewNode('AppendText', inputText='echo C') + tA = graph.addNewNode("Ls", input="/tmp") + tB = graph.addNewNode("AppendText", inputText="echo B") + tC = graph.addNewNode("AppendText", inputText="echo C") graph.addEdges( (tA.output, tB.input), @@ -19,19 +19,19 @@ def test_depth(): def test_depth_diamond_graph(): - graph = Graph('Tests tasks depth') + graph = Graph("Tests tasks depth") - tA = graph.addNewNode('Ls', input='/tmp') - tB = graph.addNewNode('AppendText', inputText='echo B') - tC = graph.addNewNode('AppendText', inputText='echo C') - tD = graph.addNewNode('AppendFiles') + tA = graph.addNewNode("Ls", input="/tmp") + tB = graph.addNewNode("AppendText", inputText="echo B") + tC = graph.addNewNode("AppendText", inputText="echo C") + tD = graph.addNewNode("AppendFiles") graph.addEdges( (tA.output, tB.input), (tA.output, tC.input), (tB.output, tD.input), (tC.output, tD.input2), - ) + ) assert tA.depth == 0 assert tB.depth == 1 @@ -58,13 +58,13 @@ def test_depth_diamond_graph(): def test_depth_diamond_graph2(): - graph = Graph('Tests tasks depth') + graph = Graph("Tests tasks depth") - tA = graph.addNewNode('Ls', input='/tmp') - tB = graph.addNewNode('AppendText', inputText='echo B') - tC = graph.addNewNode('AppendText', inputText='echo C') - tD = graph.addNewNode('AppendText', inputText='echo D') - tE = graph.addNewNode('AppendFiles') + tA = graph.addNewNode("Ls", input="/tmp") + tB = graph.addNewNode("AppendText", inputText="echo B") + tC = graph.addNewNode("AppendText", inputText="echo C") + tD = graph.addNewNode("AppendText", inputText="echo D") + tE = graph.addNewNode("AppendFiles") # C # / \ # /---/---->\ @@ -81,7 +81,7 @@ def test_depth_diamond_graph2(): (tB.output, tE.input2), (tC.output, tE.input3), (tD.output, tE.input4), - ) + ) assert tA.depth == 0 assert tB.depth == 1 @@ -116,14 +116,13 @@ def test_depth_diamond_graph2(): def test_transitive_reduction(): + graph = Graph("Tests tasks depth") - graph = Graph('Tests tasks depth') - - tA = graph.addNewNode('Ls', input='/tmp') - tB = graph.addNewNode('AppendText', inputText='echo B') - tC = graph.addNewNode('AppendText', inputText='echo C') - tD = graph.addNewNode('AppendText', inputText='echo D') - tE = graph.addNewNode('AppendFiles') + tA = graph.addNewNode("Ls", input="/tmp") + tB = graph.addNewNode("AppendText", inputText="echo B") + tC = graph.addNewNode("AppendText", inputText="echo C") + tD = graph.addNewNode("AppendText", inputText="echo D") + tE = graph.addNewNode("AppendFiles") # C # / \ # /---/---->\ @@ -141,7 +140,7 @@ def test_transitive_reduction(): (tB.output, tE.input4), (tC.output, tE.input3), (tD.output, tE.input2), - ) + ) flowEdges = graph.flowEdges() flowEdgesRes = [(tB, tA), @@ -153,24 +152,24 @@ def test_transitive_reduction(): assert set(flowEdgesRes) == set(flowEdges) assert len(graph._nodesMinMaxDepths) == len(graph.nodes) - for node, (minDepth, maxDepth) in graph._nodesMinMaxDepths.items(): + for node, (_, maxDepth) in graph._nodesMinMaxDepths.items(): assert node.depth == maxDepth def test_graph_reverse_dfsOnDiscover(): - graph = Graph('Test dfsOnDiscover(reverse=True)') + graph = Graph("Test dfsOnDiscover(reverse=True)") # ------------\ # / ~ C - E - F # A - B # ~ D - A = graph.addNewNode('Ls', input='/tmp') - B = graph.addNewNode('AppendText', inputText=A.output) - C = graph.addNewNode('AppendText', inputText=B.output) - D = graph.addNewNode('AppendText', inputText=B.output) - E = graph.addNewNode('Ls', input=C.output) - F = graph.addNewNode('AppendText', input=A.output, inputText=E.output) + A = graph.addNewNode("Ls", input="/tmp") + B = graph.addNewNode("AppendText", inputText=A.output) + C = graph.addNewNode("AppendText", inputText=B.output) + D = graph.addNewNode("AppendText", inputText=B.output) + E = graph.addNewNode("Ls", input=C.output) + F = graph.addNewNode("AppendText", input=A.output, inputText=E.output) # Get all nodes from A (use set, order not guaranteed) nodes = graph.dfsOnDiscover(startNodes=[A], reverse=True)[0] @@ -179,7 +178,7 @@ def test_graph_reverse_dfsOnDiscover(): nodes = graph.dfsOnDiscover(startNodes=[B], reverse=True)[0] assert set(nodes) == {B, D, C, E, F} # Get all nodes of type AppendText from B - nodes = graph.dfsOnDiscover(startNodes=[B], filterTypes=['AppendText'], reverse=True)[0] + nodes = graph.dfsOnDiscover(startNodes=[B], filterTypes=["AppendText"], reverse=True)[0] assert set(nodes) == {B, D, C, F} # Get all nodes from C (order guaranteed) nodes = graph.dfsOnDiscover(startNodes=[C], reverse=True)[0] @@ -190,7 +189,7 @@ def test_graph_reverse_dfsOnDiscover(): def test_graph_dfsOnDiscover(): - graph = Graph('Test dfsOnDiscover(reverse=False)') + graph = Graph("Test dfsOnDiscover(reverse=False)") # ------------\ # / ~ C - E - F @@ -198,13 +197,13 @@ def test_graph_dfsOnDiscover(): # ~ D # G - G = graph.addNewNode('Ls', input='/tmp') - A = graph.addNewNode('Ls', input='/tmp') - B = graph.addNewNode('AppendText', inputText=A.output) - C = graph.addNewNode('AppendText', inputText=B.output) - D = graph.addNewNode('AppendText', input=G.output, inputText=B.output) - E = graph.addNewNode('Ls', input=C.output) - F = graph.addNewNode('AppendText', input=A.output, inputText=E.output) + G = graph.addNewNode("Ls", input="/tmp") + A = graph.addNewNode("Ls", input="/tmp") + B = graph.addNewNode("AppendText", inputText=A.output) + C = graph.addNewNode("AppendText", inputText=B.output) + D = graph.addNewNode("AppendText", input=G.output, inputText=B.output) + E = graph.addNewNode("Ls", input=C.output) + F = graph.addNewNode("AppendText", input=A.output, inputText=E.output) # Get all nodes from A (use set, order not guaranteed) nodes = graph.dfsOnDiscover(startNodes=[A], reverse=False)[0] @@ -219,7 +218,7 @@ def test_graph_dfsOnDiscover(): nodes = graph.dfsOnDiscover(startNodes=[F], reverse=False)[0] assert set(nodes) == {A, B, C, E, F} # Get all nodes of type AppendText from C - nodes = graph.dfsOnDiscover(startNodes=[C], filterTypes=['AppendText'], reverse=False)[0] + nodes = graph.dfsOnDiscover(startNodes=[C], filterTypes=["AppendText"], reverse=False)[0] assert set(nodes) == {B, C} # Get all nodes from D (order guaranteed) nodes = graph.dfsOnDiscover(startNodes=[D], longestPathFirst=True, reverse=False)[0] @@ -230,21 +229,21 @@ def test_graph_dfsOnDiscover(): def test_graph_nodes_sorting(): - graph = Graph('') + graph = Graph("") - ls0 = graph.addNewNode('Ls') - ls1 = graph.addNewNode('Ls') - ls2 = graph.addNewNode('Ls') + ls0 = graph.addNewNode("Ls") + ls1 = graph.addNewNode("Ls") + ls2 = graph.addNewNode("Ls") - assert graph.nodesOfType('Ls', sortedByIndex=True) == [ls0, ls1, ls2] + assert graph.nodesOfType("Ls", sortedByIndex=True) == [ls0, ls1, ls2] - graph = Graph('') + graph = Graph("") # 'Random' creation order (what happens when loading a file) - ls2 = graph.addNewNode('Ls', name='Ls_2') - ls0 = graph.addNewNode('Ls', name='Ls_0') - ls1 = graph.addNewNode('Ls', name='Ls_1') + ls2 = graph.addNewNode("Ls", name="Ls_2") + ls0 = graph.addNewNode("Ls", name="Ls_0") + ls1 = graph.addNewNode("Ls", name="Ls_1") - assert graph.nodesOfType('Ls', sortedByIndex=True) == [ls0, ls1, ls2] + assert graph.nodesOfType("Ls", sortedByIndex=True) == [ls0, ls1, ls2] def test_duplicate_nodes(): @@ -256,24 +255,25 @@ def test_duplicate_nodes(): # \ \ # ---------- n3 - g = Graph('') - n0 = g.addNewNode('Ls', input='/tmp') - n1 = g.addNewNode('Ls', input=n0.output) - n2 = g.addNewNode('Ls', input=n1.output) - n3 = g.addNewNode('AppendFiles', input=n1.output, input2=n2.output) + g = Graph("") + n0 = g.addNewNode("Ls", input="/tmp") + n1 = g.addNewNode("Ls", input=n0.output) + n2 = g.addNewNode("Ls", input=n1.output) + n3 = g.addNewNode("AppendFiles", input=n1.output, input2=n2.output) - # duplicate from n1 + # Duplicate from n1 nodes_to_duplicate, _ = g.dfsOnDiscover(startNodes=[n1], reverse=True, dependenciesOnly=True) nMap = g.duplicateNodes(srcNodes=nodes_to_duplicate) for s, duplicated in nMap.items(): for d in duplicated: assert s.nodeType == d.nodeType - # check number of duplicated nodes and that every parent node has been duplicated once - assert len(nMap) == 3 and all([len(nMap[i]) == 1 for i in nMap.keys()]) + # Check number of duplicated nodes and that every parent node has been duplicated once + assert len(nMap) == 3 and \ + all([len(nMap[i]) == 1 for i in nMap.keys()]) - # check connections - # access directly index 0 because we know there is a single duplicate for each parent node + # Check connections + # Access directly index 0 because we know there is a single duplicate for each parent node assert nMap[n1][0].input.getLinkParam() == n0.output assert nMap[n2][0].input.getLinkParam() == nMap[n1][0].output assert nMap[n3][0].input.getLinkParam() == nMap[n1][0].output diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py index b742069192..e9de490829 100644 --- a/tests/test_graphIO.py +++ b/tests/test_graphIO.py @@ -255,7 +255,8 @@ def test_listAttributeToListAttributeConnectionIsSerialized(self): otherGraph = Graph("") otherGraph._deserialize(graph.serializePartial([nodeA, nodeB])) - assert otherGraph.node(nodeB.name).listInput.linkParam == otherGraph.node(nodeA.name).listInput + assert otherGraph.node(nodeB.name).listInput.linkParam == \ + otherGraph.node(nodeA.name).listInput def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self): graph = Graph("") diff --git a/tests/test_invalidation.py b/tests/test_invalidation.py index 9dfe0879e7..f9ef972c86 100644 --- a/tests/test_invalidation.py +++ b/tests/test_invalidation.py @@ -1,28 +1,30 @@ #!/usr/bin/env python # coding:utf-8 from meshroom.core.graph import Graph -from meshroom.core import desc, registerNodeType +from meshroom.core import desc, pluginManager + +from .utils import registerNodeDesc class SampleNode(desc.Node): """ Sample Node for unit testing """ inputs = [ - desc.File(name='input', label='Input', description='', value='',), - desc.StringParam(name='paramA', label='ParamA', description='', value='', invalidate=False) # No impact on UID + desc.File(name="input", label="Input", description="", value="",), + desc.StringParam(name="paramA", label="ParamA", + description="", value="", + invalidate=False) # No impact on UID ] outputs = [ desc.File(name='output', label='Output', description='', value="{nodeCacheFolder}") ] - -registerNodeType(SampleNode) - +registerNodeDesc(SampleNode) # register standalone NodePlugin def test_output_invalidation(): - graph = Graph('') - n1 = graph.addNewNode('SampleNode', input='/tmp') - n2 = graph.addNewNode('SampleNode') - n3 = graph.addNewNode('SampleNode') + graph = Graph("") + n1 = graph.addNewNode("SampleNode", input="/tmp") + n2 = graph.addNewNode("SampleNode") + n3 = graph.addNewNode("SampleNode") graph.addEdges( (n1.output, n2.input), @@ -52,9 +54,9 @@ def test_inputLinkInvalidation(): """ Input links should not change the invalidation. """ - graph = Graph('') - n1 = graph.addNewNode('SampleNode') - n2 = graph.addNewNode('SampleNode') + graph = Graph("") + n1 = graph.addNewNode("SampleNode") + n2 = graph.addNewNode("SampleNode") graph.addEdges((n1.input, n2.input)) assert n1.input.uid() == n2.input.uid() diff --git a/tests/test_nodeAttributeChangedCallback.py b/tests/test_nodeAttributeChangedCallback.py index 1ab088ead7..1628d61248 100644 --- a/tests/test_nodeAttributeChangedCallback.py +++ b/tests/test_nodeAttributeChangedCallback.py @@ -1,9 +1,11 @@ # coding:utf-8 from meshroom.core.graph import Graph, loadGraph, executeGraph -from meshroom.core import desc, registerNodeType, unregisterNodeType +from meshroom.core import desc, pluginManager from meshroom.core.node import Node +from .utils import registerNodeDesc, unregisterNodeDesc + class NodeWithAttributeChangedCallback(desc.BaseNode): """ @@ -37,13 +39,14 @@ def processChunk(self, chunk): class TestNodeWithAttributeChangedCallback: + @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributeChangedCallback) + registerNodeDesc(NodeWithAttributeChangedCallback) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributeChangedCallback) + unregisterNodeDesc(NodeWithAttributeChangedCallback) def test_assignValueTriggersCallback(self): node = Node(NodeWithAttributeChangedCallback.__name__) @@ -68,13 +71,14 @@ def test_assignNonDefaultValueTriggersCallback(self): class TestAttributeCallbackTriggerInGraph: + @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributeChangedCallback) + registerNodeDesc(NodeWithAttributeChangedCallback) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributeChangedCallback) + unregisterNodeDesc(NodeWithAttributeChangedCallback) def test_connectionTriggersCallback(self): graph = Graph("") @@ -219,7 +223,7 @@ class NodeWithCompoundAttributes(desc.BaseNode): desc.IntParam( name="int", label="Int", description="", value=0, range=None ) - ], + ], ) ), desc.GroupAttribute( @@ -241,15 +245,16 @@ class NodeWithCompoundAttributes(desc.BaseNode): class TestAttributeCallbackBehaviorWithUpstreamCompoundAttributes: + @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributeChangedCallback) - registerNodeType(NodeWithCompoundAttributes) + registerNodeDesc(NodeWithAttributeChangedCallback) + registerNodeDesc(NodeWithCompoundAttributes) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributeChangedCallback) - unregisterNodeType(NodeWithCompoundAttributes) + unregisterNodeDesc(NodeWithAttributeChangedCallback) + unregisterNodeDesc(NodeWithCompoundAttributes) def test_connectionToListElement(self): graph = Graph("") @@ -313,7 +318,8 @@ def test_connectionToListElementInGroup(self): class NodeWithDynamicOutputValue(desc.BaseNode): """ - A Node containing an output attribute which value is computed dynamically during graph execution. + A Node containing an output attribute which value is computed dynamically + during graph execution. """ inputs = [ @@ -340,15 +346,18 @@ def processChunk(self, chunk): class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs: + # nodePluginAttributeChangedCallback = NodePlugin(NodeWithAttributeChangedCallback) + # nodePluginDynamicOutputValue = NodePlugin(NodeWithDynamicOutputValue) + @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributeChangedCallback) - registerNodeType(NodeWithDynamicOutputValue) + registerNodeDesc(NodeWithAttributeChangedCallback) + registerNodeDesc(NodeWithDynamicOutputValue) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributeChangedCallback) - unregisterNodeType(NodeWithDynamicOutputValue) + unregisterNodeDesc(NodeWithAttributeChangedCallback) + unregisterNodeDesc(NodeWithDynamicOutputValue) def test_connectingUncomputedDynamicOutputDoesNotTriggerDownstreamAttributeChangedCallback( self, @@ -390,7 +399,6 @@ def test_dynamicOutputValueComputeDoesNotTriggerDownstreamAttributeChangedCallba assert nodeB.input.value == 20 assert nodeB.affectedInput.value == 0 - def test_clearingDynamicOutputValueDoesNotTriggerDownstreamAttributeChangedCallback( self, graphSavedOnDisk ): @@ -434,11 +442,11 @@ def test_loadingGraphWithComputedDynamicOutputValueDoesNotTriggerDownstreamAttri class TestAttributeCallbackBehaviorOnGraphImport: @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributeChangedCallback) + registerNodeDesc(NodeWithAttributeChangedCallback) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributeChangedCallback) + unregisterNodeDesc(NodeWithAttributeChangedCallback) def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self): graph = Graph("") @@ -450,9 +458,8 @@ def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self): nodeA.input.value = 5 nodeB.affectedInput.value = 2 - + otherGraph = Graph("") otherGraph.importGraphContent(graph) assert otherGraph.node(nodeB.name).affectedInput.value == 2 - diff --git a/tests/test_nodeCallbacks.py b/tests/test_nodeCallbacks.py index 149151e803..9fbc676ca7 100644 --- a/tests/test_nodeCallbacks.py +++ b/tests/test_nodeCallbacks.py @@ -1,7 +1,9 @@ -from meshroom.core import desc, registerNodeType, unregisterNodeType +from meshroom.core import desc, pluginManager from meshroom.core.node import Node from meshroom.core.graph import Graph, loadGraph +from .utils import registerNodeDesc, unregisterNodeDesc + class NodeWithCreationCallback(desc.InputNode): """Node defining an 'onNodeCreated' callback, triggered a new node is added to a Graph.""" @@ -22,13 +24,14 @@ def onNodeCreated(cls, node: Node): class TestNodeCreationCallback: + @classmethod def setup_class(cls): - registerNodeType(NodeWithCreationCallback) + registerNodeDesc(NodeWithCreationCallback) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithCreationCallback) + unregisterNodeDesc(NodeWithCreationCallback) def test_notTriggeredOnNodeInstantiation(self): node = Node(NodeWithCreationCallback.__name__) diff --git a/tests/test_nodeCommandLineFormatting.py b/tests/test_nodeCommandLineFormatting.py index e7f2adf7a0..9c5ea69b05 100644 --- a/tests/test_nodeCommandLineFormatting.py +++ b/tests/test_nodeCommandLineFormatting.py @@ -2,13 +2,16 @@ # coding:utf-8 from meshroom.core.graph import Graph, loadGraph, executeGraph -from meshroom.core import desc, registerNodeType, unregisterNodeType +from meshroom.core import desc, pluginManager from meshroom.core.node import Node +from .utils import registerNodeDesc, unregisterNodeDesc + class NodeWithAttributesNeedingFormatting(desc.Node): """ - A node containing list, file, choice and group attributes in order to test the formatting of the command line. + A node containing list, file, choice and group attributes in order to test the + formatting of the command line. """ inputs = [ desc.ListAttribute( @@ -99,13 +102,14 @@ class NodeWithAttributesNeedingFormatting(desc.Node): ] class TestCommandLineFormatting: + @classmethod def setup_class(cls): - registerNodeType(NodeWithAttributesNeedingFormatting) + registerNodeDesc(NodeWithAttributesNeedingFormatting) @classmethod def teardown_class(cls): - unregisterNodeType(NodeWithAttributesNeedingFormatting) + unregisterNodeDesc(NodeWithAttributesNeedingFormatting) def test_formatting_listOfFiles(self): inputImages = ["/non/existing/fileA", "/non/existing/with space/fileB"] @@ -128,23 +132,28 @@ def test_formatting_listOfFiles(self): # Assert that extending values when the list is not empty is working node.images.extend(inputImages) - assert node.images.getValueStr() == '"single value with space" "{}" "{}"'.format(inputImages[0], - inputImages[1]) + assert node.images.getValueStr() == \ + '"single value with space" "{}" "{}"'.format(inputImages[0], + inputImages[1]) - # Values are not retrieved as strings in the command line, so quotes around them are not expected - assert node._cmdVars["imagesValue"] == 'single value with space {} {}'.format(inputImages[0], - inputImages[1]) + # Values are not retrieved as strings in the command line, so quotes around them are + # not expected + assert node._cmdVars["imagesValue"] == \ + 'single value with space {} {}'.format(inputImages[0], + inputImages[1]) def test_formatting_strings(self): graph = Graph("") node = graph.addNewNode("NodeWithAttributesNeedingFormatting") node._buildCmdVars() - # Assert an empty File attribute generates empty quotes when requesting its value as a string + # Assert an empty File attribute generates empty quotes when requesting its value as + # a string assert node.input.getValueStr() == '""' assert node._cmdVars["inputValue"] == "" - # Assert a Choice attribute with a non-empty default value is surrounded with quotes when requested as a string + # Assert a Choice attribute with a non-empty default value is surrounded with quotes + # when requested as a string assert node.method.getValueStr() == '"MethodC"' assert node._cmdVars["methodValue"] == "MethodC" @@ -154,14 +163,18 @@ def test_formatting_strings(self): # Assert that the list with one empty value generates empty quotes node.images.extend("") - assert node.images.getValueStr() == '""', "A list with one empty string should generate empty quotes" - assert node._cmdVars["imagesValue"] == "", "The value is always only the value, so empty here" + assert node.images.getValueStr() == '""', \ + "A list with one empty string should generate empty quotes" + assert node._cmdVars["imagesValue"] == "", \ + "The value is always only the value, so empty here" # Assert that a list with 2 empty strings generates quotes node.images.extend("") - assert node.images.getValueStr() == '"" ""', "A list with 2 empty strings should generate quotes" + assert node.images.getValueStr() == '"" ""', \ + "A list with 2 empty strings should generate quotes" assert node._cmdVars["imagesValue"] == ' ', \ - "The value is always only the value, so 2 empty strings with the space separator in the middle" + "The value is always only the value, so 2 empty strings with the " \ + "space separator in the middle" def test_formatting_groups(self): graph = Graph("") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a536a04341..ad5e44a5da 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,7 +44,7 @@ def test_pipeline(): if attr.isOutput and attr.enabled: otherAttr = otherNode.attribute(key) assert attr.uid() != otherAttr.uid() - + # Test serialization/deserialization on both graphs for graph in [graph1, graph2]: filename = tempfile.mktemp() diff --git a/tests/test_plugins.py b/tests/test_plugins.py new file mode 100644 index 0000000000..4ee2e0a4b9 --- /dev/null +++ b/tests/test_plugins.py @@ -0,0 +1,215 @@ +# coding:utf-8 + +from meshroom.core import desc, pluginManager, loadClassesNodes +from meshroom.core.plugins import NodePluginStatus, Plugin + +import os +import time + +class TestPluginWithValidNodesOnly: + plugin = None + + @classmethod + def setup_class(cls): + folder = os.path.join(os.path.dirname(__file__), "plugins", "meshroom") + package = "pluginA" + cls.plugin = Plugin(package, folder) + nodes = loadClassesNodes(folder, package) + for node in nodes: + cls.plugin.addNodePlugin(node) + pluginManager.addPlugin(cls.plugin) + + @classmethod + def teardown_class(cls): + for node in cls.plugin.nodes.values(): + pluginManager.unregisterNode(node) + cls.plugin = None + + def test_loadedPlugin(self): + # Assert that there are loaded plugins, and that "pluginA" is one of them + assert len(pluginManager.getPlugins()) >= 1 + plugin = pluginManager.getPlugin("pluginA") + assert plugin == self.plugin + assert str(plugin.path) == os.path.join(os.path.dirname(__file__), "plugins", "meshroom") + + # Assert that the nodes of pluginA have been successfully registered + assert len(pluginManager.getRegisteredNodePlugins()) >= 2 + for nodeName, nodePlugin in plugin.nodes.items(): + assert nodePlugin.status == NodePluginStatus.LOADED + assert pluginManager.isRegistered(nodeName) + + # Assert the template has been loaded + assert len(plugin.templates) == 1 + name = list(plugin.templates.keys())[0] + assert name == "sharedTemplate" + assert plugin.templates[name] == os.path.join(str(plugin.path), "sharedTemplate.mg") + + def test_unloadPlugin(self): + plugin = pluginManager.getPlugin("pluginA") + assert plugin == self.plugin + + # Unload the plugin without unregistering the nodes + pluginManager.removePlugin(plugin, unregisterNodePlugins=False) + + # Assert the plugin is not loaded anymore + assert pluginManager.getPlugin(plugin.name) is None + + # Assert the nodes are still registered and belong to an unloaded plugin + for nodeName, nodePlugin in plugin.nodes.items(): + assert nodePlugin.status == NodePluginStatus.LOADED + assert pluginManager.isRegistered(nodeName) + assert pluginManager.belongsToPlugin(nodeName) is None + + # Re-add the plugin + pluginManager.addPlugin(plugin, registerNodePlugins=False) + assert pluginManager.getPlugin(plugin.name) + + # Unload the plugin with a full unregistration of the nodes + pluginManager.removePlugin(plugin) + + # Assert the plugin is not loaded anymore + assert pluginManager.getPlugin(plugin.name) is None + + # Assert the nodes have been successfully unregistered + for nodeName, nodePlugin in plugin.nodes.items(): + assert nodePlugin.status == NodePluginStatus.NOT_LOADED + assert not pluginManager.isRegistered(nodeName) + + # Re-add the plugin and re-register the nodes + pluginManager.addPlugin(plugin) + assert pluginManager.getPlugin(plugin.name) + for nodeName, nodePlugin in plugin.nodes.items(): + assert nodePlugin.status == NodePluginStatus.LOADED + assert pluginManager.isRegistered(nodeName) + + def test_updateRegisteredNodes(self): + nbRegisteredNodes = len(pluginManager.getRegisteredNodePlugins()) + plugin = pluginManager.getPlugin("pluginA") + assert plugin == self.plugin + nodeA = pluginManager.getRegisteredNodePlugin("PluginANodeA") + nodeAName = nodeA.nodeDescriptor.__name__ + + # Unregister a node + assert nodeA + pluginManager.unregisterNode(nodeA) + + # Check that the node has been fully unregistered: + # - its status is "NOT_LOADED" + # - it is still part of pluginA + # - it is not in the list of registered plugins anymore (and returns None when requested) + assert nodeA.status == NodePluginStatus.NOT_LOADED + assert plugin.containsNodePlugin(nodeAName) + assert nodeA.plugin == plugin + + assert pluginManager.getRegisteredNodePlugin(nodeAName) is None + assert nodeAName not in pluginManager.getRegisteredNodePlugins() + assert len(pluginManager.getRegisteredNodePlugins()) == nbRegisteredNodes - 1 + + # Re-register the node + pluginManager.registerNode(nodeA) + + assert nodeA.status == NodePluginStatus.LOADED + assert pluginManager.getRegisteredNodePlugin(nodeAName) + assert len(pluginManager.getRegisteredNodePlugins()) == nbRegisteredNodes + + +class TestPluginWithInvalidNodes: + plugin = None + + @classmethod + def setup_class(cls): + folder = os.path.join(os.path.dirname(__file__), "plugins", "meshroom") + package = "pluginB" + cls.plugin = Plugin(package, folder) + nodes = loadClassesNodes(folder, package) + for node in nodes: + cls.plugin.addNodePlugin(node) + pluginManager.addPlugin(cls.plugin) + + @classmethod + def teardown_class(cls): + for node in cls.plugin.nodes.values(): + pluginManager.unregisterNode(node) + cls.plugin = None + + def test_loadedPlugin(self): + # Assert that there are loaded plugins, and that "pluginB" is one of them + assert len(pluginManager.getPlugins()) >= 1 + plugin = pluginManager.getPlugin("pluginB") + assert plugin == self.plugin + assert str(plugin.path) == os.path.join(os.path.dirname(__file__), "plugins", "meshroom") + + # Assert that PluginBNodeA is successfully registered + assert pluginManager.isRegistered("PluginBNodeA") + assert plugin.nodes["PluginBNodeA"].status == NodePluginStatus.LOADED + assert plugin.nodes["PluginBNodeA"].plugin == plugin + + # Assert that PluginBNodeB has not been registered (description error) + assert not pluginManager.isRegistered("PluginBNodeB") + assert plugin.nodes["PluginBNodeB"].status == NodePluginStatus.DESC_ERROR + assert plugin.nodes["PluginBNodeB"].plugin == plugin + + # Assert the template has been loaded + assert len(plugin.templates) == 1 + name = list(plugin.templates.keys())[0] + assert name == "sharedTemplate" + assert plugin.templates[name] == os.path.join(str(plugin.path), "sharedTemplate.mg") + + def test_reloadNodePlugin(self): + plugin = pluginManager.getPlugin("pluginB") + assert plugin == self.plugin + node = plugin.nodes["PluginBNodeB"] + nodeName = node.nodeDescriptor.__name__ + + # Check that the node has not been registered + assert node.status == NodePluginStatus.DESC_ERROR + assert not pluginManager.isRegistered(nodeName) + + # Check that the node cannot be registered + pluginManager.registerNode(node) + assert not pluginManager.isRegistered(nodeName) + + # Replace directly in the node file the line that fails the validation + # on the description with a line that will pass + originalFileContent = None + with open(node.path, "r") as f: + originalFileContent = f.read() + + replaceFileContent = originalFileContent.replace('"not an integer"', '1') + with open(node.path, "w") as f: + f.write(replaceFileContent) + + # Reload the node and assert it is valid + node.reload() + assert node.status == NodePluginStatus.NOT_LOADED + + # Attempt to register node plugin + pluginManager.registerNode(node) + assert pluginManager.isRegistered(nodeName) + + # Reload the node again without any change + node.reload() + assert pluginManager.isRegistered(nodeName) + + # Hack to ensure that the timestamp of the file will be different after being rewritten + # Without it, on some systems, the operation is too fast and the timestamp does not change, + # cause the test to fail + time.sleep(0.1) + + # Restore the node file to its original state (with a description error) + with open(node.path, "w") as f: + f.write(originalFileContent) + + timestampOr2 = os.path.getmtime(node.path) + print(f"New timestamp: {timestampOr2}") + print(os.stat(node.path)) + + # Reload the node and assert it is invalid while still registered + node.reload() + assert node.status == NodePluginStatus.DESC_ERROR + assert pluginManager.isRegistered(nodeName) + + # Unregister it + pluginManager.unregisterNode(node) + assert node.status == NodePluginStatus.DESC_ERROR # Not NOT_LOADED + assert not pluginManager.isRegistered(nodeName) diff --git a/tests/utils.py b/tests/utils.py index c279a0ad9c..93b4bf6d90 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,24 +1,26 @@ from contextlib import contextmanager from unittest.mock import patch -from typing import Type import meshroom -from meshroom.core import registerNodeType, unregisterNodeType -from meshroom.core import desc +from meshroom.core import desc, pluginManager +from meshroom.core.plugins import NodePlugin, NodePluginStatus @contextmanager -def registeredNodeTypes(nodeTypes: list[Type[desc.Node]]): +def registeredNodeTypes(nodeTypes: list[desc.Node]): + nodePluginsList = {} for nodeType in nodeTypes: - registerNodeType(nodeType) + nodePlugin = NodePlugin(nodeType) + pluginManager.registerNode(nodePlugin) + nodePluginsList[nodeType] = nodePlugin yield for nodeType in nodeTypes: - unregisterNodeType(nodeType) + pluginManager.unregisterNode(nodePluginsList[nodeType]) @contextmanager -def overrideNodeTypeVersion(nodeType: Type[desc.Node], version: str): - """Helper context manager to override the version of a given node type.""" +def overrideNodeTypeVersion(nodeType: desc.Node, version: str): + """ Helper context manager to override the version of a given node type. """ unpatchedFunc = meshroom.core.nodeVersion with patch.object( meshroom.core, @@ -26,3 +28,16 @@ def overrideNodeTypeVersion(nodeType: Type[desc.Node], version: str): side_effect=lambda type: version if type is nodeType else unpatchedFunc(type), ): yield + +def registerNodeDesc(nodeDesc: desc.Node): + name = nodeDesc.__name__ + if not pluginManager.isRegistered(name): + pluginManager._nodePlugins[name] = NodePlugin(nodeDesc) + pluginManager._nodePlugins[name].status = NodePluginStatus.LOADED + +def unregisterNodeDesc(nodeDesc: desc.Node): + name = nodeDesc.__name__ + if pluginManager.isRegistered(name): + plugin = pluginManager.getRegisteredNodePlugin(name) + plugin.status = NodePluginStatus.NOT_LOADED + del pluginManager._nodePlugins[name]