Skip to content

Fix output spec + insert clone for constant_prop_pass. #11209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,37 @@ def create_constant_nodes_and_return_specs(
return name_to_spec_dict


def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
"""
Update the output node and output specs in the exported program.
In case a constant node is used as output, we replace it with a clone of the constant node.
"""
# Dict [node.name -> InputSpec]
updated_constant_placeholders = get_constant_placeholder_dict(exported_program)
output = exported_program.graph.find_nodes(op="output")[0]
output_nodes = cast(list[torch.fx.Node], list(output.args[0]))
output_specs = exported_program.graph_signature.output_specs
assert len(output_nodes) == len(output_specs)

for i in range(len(output_specs)):
out_node = output_nodes[i]
if out_node not in updated_constant_placeholders:
continue

with exported_program.graph.inserting_after(out_node):
new_node = exported_program.graph.call_function(
exir_ops.edge.aten.clone.default, (out_node,)
)
assert "val" in out_node.meta
new_node.meta["val"] = out_node.meta["val"]
output_nodes[i] = new_node

# Update the constant-propagated output node.
output_specs[i].arg = TensorArgument(name=output_nodes[i].name)

output.args = (output_nodes,)


def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
Expand Down Expand Up @@ -341,12 +372,12 @@ def constant_prop_pass(

# Generate new input spec.
new_input_specs = []
for node in exported_program.graph.nodes:
if node.op != "placeholder":
continue
for node in exported_program.graph.find_nodes(op="placeholder"):
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs

_update_output_node_and_specs(exported_program)

# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
exported_program.graph_module.recompile()
Expand Down
28 changes: 28 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,34 @@ def forward(self, x):
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
).run(gm.code)

def test_constant_prop_for_output(self) -> None:
class Add(torch.nn.Module):
def forward(self) -> torch.Tensor:
return torch.add(torch.tensor(3), torch.tensor(5))

add = Add()

edge = to_edge(
export(add, (), strict=True),
compile_config=EdgeCompileConfig(_skip_dim_order=False),
)
# Check there is a lifted tensor followed by a to_copy node
FileCheck().check("c_lifted_tensor_0").check("c_lifted_tensor_1").run(
edge.exported_program().graph_module.code
)

edge._edge_programs["forward"] = constant_prop_pass(
edge.exported_program("forward")
)

# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.
FileCheck().check_not("c_lifted_tensor_").check("_prop_tensor_constant").run(
edge.exported_program().graph_module.code
)
# Validate that the program successfully passes validation to executorch:
edge.exported_program()._validate()
edge.to_executorch()

def test_constant_prop_pass_for_add(self) -> None:
class Add(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading