Skip to content

Commit b9e35c4

Browse files
Merge pull request #56 from TCLResearchEurope/fix-requires-grad
Fix requires grad
2 parents d6ce90f + 098b531 commit b9e35c4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch_dag/core/unstructured_to_structured.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def build_dag_by_tracing(self) -> dag_module.DagModule:
236236
for node in graph.nodes:
237237
num_predecessors = len(node._input_nodes)
238238
if num_predecessors == 0 and node.target in model.state_dict(): # free parameter node
239-
module = structured_modules.ParameterModule(param=torch.nn.Parameter(model.state_dict()[node.target]))
239+
module = structured_modules.ParameterModule(param=torch.nn.Parameter(model.get_parameter(node.target), requires_grad = model.get_parameter(node.target).requires_grad))
240240
vertex = dag.add_vertex(
241241
name=node.name,
242242
module=module,
@@ -395,6 +395,9 @@ def convert_node(self, node: torch.fx.node.Node, modules_dict, state_dict):
395395
elif node.target == torch.square:
396396
return structured_modules.PowerModule(pow=2)
397397

398+
elif node.target == torch.add:
399+
return structured_modules.AddModule()
400+
398401
elif node.target == torch.addcmul:
399402
return structured_modules.AddcmulModule()
400403

0 commit comments

Comments
 (0)