Skip to content

Commit 00534c8

Browse files
hsharma35facebook-github-bot
authored andcommitted
Fix output spec + insert clone for constant_prop_pass. (pytorch#11209)
Summary: Pull Request resolved: pytorch#11209 In case where a constant propagated node is returned by the exported program, `_validate()` fails with `SpecViolationError` with signature: ``` User output <SOME_CONSTANT_PROP_NODE> is not in the correct order or is not found in the exported program's user_output list. ``` This diff does two things: 1. Update output spec when propagated constants are output of the program. 2. Insert clone op on the constant prop tensor before sending them to output node. This avoid memory planning related errors. Reviewed By: angelayi Differential Revision: D75473310
1 parent 1bc36c7 commit 00534c8

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,37 @@ def create_constant_nodes_and_return_specs(
295295
return name_to_spec_dict
296296

297297

298+
def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
299+
"""
300+
Update the output node and output specs in the exported program.
301+
In case a constant node is used as output, we replace it with a clone of the constant node.
302+
"""
303+
# Dict [node.name -> InputSpec]
304+
updated_constant_placeholders = get_constant_placeholder_dict(exported_program)
305+
output = exported_program.graph.find_nodes(op="output")[0]
306+
output_nodes = cast(list[torch.fx.Node], list(output.args[0]))
307+
output_specs = exported_program.graph_signature.output_specs
308+
assert len(output_nodes) == len(output_specs)
309+
310+
for i in range(len(output_specs)):
311+
out_node = output_nodes[i]
312+
if out_node not in updated_constant_placeholders:
313+
continue
314+
315+
with exported_program.graph.inserting_after(out_node):
316+
new_node = exported_program.graph.call_function(
317+
exir_ops.edge.aten.clone.default, (out_node,)
318+
)
319+
assert "val" in out_node.meta
320+
new_node.meta["val"] = out_node.meta["val"]
321+
output_nodes[i] = new_node
322+
323+
# Update the constant-propagated output node.
324+
output_specs[i].arg = TensorArgument(name=output_nodes[i].name)
325+
326+
output.args = (output_nodes,)
327+
328+
298329
def constant_prop_pass(
299330
exported_program: ExportedProgram,
300331
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
@@ -341,12 +372,12 @@ def constant_prop_pass(
341372

342373
# Generate new input spec.
343374
new_input_specs = []
344-
for node in exported_program.graph.nodes:
345-
if node.op != "placeholder":
346-
continue
375+
for node in exported_program.graph.find_nodes(op="placeholder"):
347376
new_input_specs.append(name_to_spec_dict[node.name])
348377
exported_program.graph_signature.input_specs = new_input_specs
349378

379+
_update_output_node_and_specs(exported_program)
380+
350381
# Cleanup the graph.
351382
exported_program.graph.eliminate_dead_code()
352383
exported_program.graph_module.recompile()

exir/tests/test_passes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,34 @@ def forward(self, x):
10261026
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
10271027
).run(gm.code)
10281028

1029+
def test_constant_prop_for_output(self) -> None:
1030+
class Add(torch.nn.Module):
1031+
def forward(self) -> torch.Tensor:
1032+
return torch.add(torch.tensor(3), torch.tensor(5))
1033+
1034+
add = Add()
1035+
1036+
edge = to_edge(
1037+
export(add, (), strict=True),
1038+
compile_config=EdgeCompileConfig(_skip_dim_order=False),
1039+
)
1040+
# Check there is a lifted tensor followed by a to_copy node
1041+
FileCheck().check("c_lifted_tensor_0").check("c_lifted_tensor_1").run(
1042+
edge.exported_program().graph_module.code
1043+
)
1044+
1045+
edge._edge_programs["forward"] = constant_prop_pass(
1046+
edge.exported_program("forward")
1047+
)
1048+
1049+
# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.
1050+
FileCheck().check_not("c_lifted_tensor_").check("_prop_tensor_constant").run(
1051+
edge.exported_program().graph_module.code
1052+
)
1053+
# Validate that the program successfully passes validation to executorch:
1054+
edge.exported_program()._validate()
1055+
edge.to_executorch()
1056+
10291057
def test_constant_prop_pass_for_add(self) -> None:
10301058
class Add(torch.nn.Module):
10311059
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)