Skip to content

Commit 383cb03

Browse files
Copilotgramalingamjustinchubytitaiwangms
authored
Create a identity-elimination pass (onnx#121)
This PR implements a new `IdentityEliminationPass` that removes redundant Identity nodes from ONNX graphs according to the rules specified in the issue. ## Implementation The pass handles three distinct scenarios for Identity nodes of the form `y = Identity(x)`: 1. **Case 1**: When `y` is not a graph output → Replace all uses of `y` with `x` and remove the node 2. **Case 2**: When `y` is a graph output but `x` is not a graph input → Eliminate the node but rename `x` to `y` 3. **Case 3**: When both `y` is a graph output and `x` is a graph input → Keep the node (cannot eliminate) ## Example ```python import onnx_ir as ir from onnx_ir.passes.common import IdentityEliminationPass # Before: input1 -> Identity -> Add -> Identity -> output # After: input1 -> Add -> output (with proper renaming) pass_instance = IdentityEliminationPass() result = pass_instance(model) ``` ## Key Features - **Comprehensive elimination**: Handles all three scenarios correctly - **Subgraph support**: Processes Identity nodes in functions and node attributes - **Chain elimination**: Can eliminate chains of Identity nodes in a single pass - **Safe operation**: Uses `graph.remove(node, safe=True)` to ensure graph consistency - **Edge case handling**: Properly handles invalid nodes, None inputs, and empty graphs ## Testing Added comprehensive test suite with 9 test cases covering: - All three elimination scenarios - Multiple Identity nodes in complex graphs - Chain of Identity nodes - Edge cases (invalid nodes, None inputs, empty graphs) - Integration with PassManager All tests pass and the implementation follows established patterns in the codebase. Fixes onnx#120. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
1 parent 2481fdd commit 383cb03

File tree

3 files changed

+731
-0
lines changed

3 files changed

+731
-0
lines changed

src/onnx_ir/passes/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"ClearMetadataAndDocStringPass",
88
"CommonSubexpressionEliminationPass",
99
"DeduplicateInitializersPass",
10+
"IdentityEliminationPass",
1011
"InlinePass",
1112
"LiftConstantsToInitializersPass",
1213
"LiftSubgraphInitializersToMainGraphPass",
@@ -30,6 +31,9 @@
3031
LiftSubgraphInitializersToMainGraphPass,
3132
RemoveInitializersFromInputsPass,
3233
)
34+
from onnx_ir.passes.common.identity_elimination import (
35+
IdentityEliminationPass,
36+
)
3337
from onnx_ir.passes.common.initializer_deduplication import (
3438
DeduplicateInitializersPass,
3539
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Identity elimination pass for removing redundant Identity nodes."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"IdentityEliminationPass",
9+
]
10+
11+
import logging
12+
13+
import onnx_ir as ir
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class IdentityEliminationPass(ir.passes.InPlacePass):
19+
"""Pass for eliminating redundant Identity nodes.
20+
21+
This pass removes Identity nodes according to the following rules:
22+
1. For any node of the form `y = Identity(x)`, where `y` is not an output
23+
of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24+
2. If `y` is an output of a graph, and `x` is not an input of any graph,
25+
we can still do the elimination, but the value `x` should be renamed to be `y`.
26+
3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
27+
the node. It should be retained.
28+
"""
29+
30+
def call(self, model: ir.Model) -> ir.passes.PassResult:
31+
"""Main entry point for the identity elimination pass."""
32+
modified = False
33+
34+
# Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
35+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
36+
if self._try_eliminate_identity_node(node):
37+
modified = True
38+
39+
# Process nodes in functions
40+
for function in model.functions.values():
41+
for node in ir.traversal.RecursiveGraphIterator(function):
42+
if self._try_eliminate_identity_node(node):
43+
modified = True
44+
45+
if modified:
46+
logger.info("Identity elimination pass modified the model")
47+
48+
return ir.passes.PassResult(model, modified=modified)
49+
50+
def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
51+
"""Try to eliminate a single identity node. Returns True if modified."""
52+
if node.op_type != "Identity" or node.domain != "":
53+
return False
54+
55+
if len(node.inputs) != 1 or len(node.outputs) != 1:
56+
# Invalid Identity node, skip
57+
return False
58+
59+
input_value = node.inputs[0]
60+
output_value = node.outputs[0]
61+
62+
if input_value is None:
63+
# Cannot eliminate if input is None
64+
return False
65+
66+
# Get the graph that contains this node
67+
graph_like = node.graph
68+
assert graph_like is not None, "Node must be in a graph"
69+
70+
output_is_graph_output = output_value.is_graph_output()
71+
input_is_graph_input = input_value.is_graph_input()
72+
73+
# Case 3: Both output is graph output and input is graph input - keep the node
74+
if output_is_graph_output and input_is_graph_input:
75+
return False
76+
77+
# Case 1 & 2 (merged): Eliminate the identity node
78+
# Replace all uses of output with input
79+
ir.convenience.replace_all_uses_with(output_value, input_value)
80+
81+
# If output is a graph output, we need to rename input and update graph outputs
82+
if output_is_graph_output:
83+
# Store the original output name
84+
original_output_name = output_value.name
85+
86+
# Update the input value to have the output's name
87+
input_value.name = original_output_name
88+
89+
# Update graph outputs to point to the input value
90+
for idx, graph_output in enumerate(graph_like.outputs):
91+
if graph_output is output_value:
92+
graph_like.outputs[idx] = input_value
93+
94+
# Remove the identity node
95+
graph_like.remove(node, safe=True)
96+
logger.debug("Eliminated identity node: %s", node)
97+
return True

0 commit comments

Comments
 (0)