diff --git a/torch_dag/core/unstructured_to_structured.py b/torch_dag/core/unstructured_to_structured.py index 3cba33e..360fe67 100644 --- a/torch_dag/core/unstructured_to_structured.py +++ b/torch_dag/core/unstructured_to_structured.py @@ -236,7 +236,7 @@ def build_dag_by_tracing(self) -> dag_module.DagModule: for node in graph.nodes: num_predecessors = len(node._input_nodes) if num_predecessors == 0 and node.target in model.state_dict(): # free parameter node - module = structured_modules.ParameterModule(param=torch.nn.Parameter(model.state_dict()[node.target])) + module = structured_modules.ParameterModule(param=torch.nn.Parameter(model.get_parameter(node.target), requires_grad = model.get_parameter(node.target).requires_grad)) vertex = dag.add_vertex( name=node.name, module=module, @@ -395,6 +395,9 @@ def convert_node(self, node: torch.fx.node.Node, modules_dict, state_dict): elif node.target == torch.square: return structured_modules.PowerModule(pow=2) + elif node.target == torch.add: + return structured_modules.AddModule() + elif node.target == torch.addcmul: return structured_modules.AddcmulModule()