Skip to content

Commit 926397c

Browse files
justinchubyCopilot
andauthored
Add methods to resize inputs and outputs in Node class (onnx#244)
This pull request introduces new methods to dynamically resize the inputs and outputs of a `Node` in the ONNX IR core, making graph manipulation more flexible and robust. It updates the API to support these operations and adds comprehensive tests to ensure correct behavior, including edge cases and error handling. ### API Enhancements for Node Inputs and Outputs * Added `resize_inputs` method to `Node`, allowing the number of inputs to be increased (by adding `None` values) or decreased (removing extra inputs and cleaning up their uses). The setter error message for `inputs` was also updated to reference the new method. * Added `resize_outputs` method to `Node`, enabling dynamic resizing of outputs. Outputs can be increased (new `Value` objects created) or decreased, but removal is only allowed if the outputs have no uses; otherwise, a `ValueError` is raised. The setter error message for `outputs` was updated accordingly. ### Comprehensive Testing * Added extensive unit tests for `resize_inputs`, covering increasing, decreasing, unchanged size, zeroing, growing from zero, and preservation of `None` inputs. * Added thorough unit tests for `resize_outputs`, including increasing, decreasing, unchanged size, zeroing, growing from zero, and error handling when attempting to remove outputs that are still in use. ### Why not create `pop` and `add` to the inputs/outputs object Currently inputs and outputs are implemented as `tuple`, which means there are a lot of methods associated to it that we get for free. Since there is not a `usertuple` class in python, defining an immutable sequence that supports all of the tuple methods may be complex (have to implement the full Sequence interface). And since the current interface says Node.inputs, Node.outputs return Sequence interfaces, having selected mutable methods on them (pop, add) makes it confusing. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 135eccc commit 926397c

File tree

3 files changed

+292
-5
lines changed

3 files changed

+292
-5
lines changed

src/onnx_ir/_core.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,9 +1843,35 @@ def inputs(self) -> Sequence[Value | None]:
18431843
@inputs.setter
18441844
def inputs(self, _: Any) -> None:
18451845
raise AttributeError(
1846-
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
1846+
"Node.inputs cannot be assigned to. Please use 'resize_inputs' and "
1847+
"'replace_input_with' instead."
18471848
)
18481849

1850+
def resize_inputs(self, new_size: int, /) -> None:
1851+
"""Resize the inputs of the node.
1852+
1853+
If the new size is greater than the current size, new inputs are added as None.
1854+
If the new size is less than the current size, the extra inputs are removed.
1855+
1856+
After ``inputs`` is resized, you can use :meth:`replace_input_with` to set the new inputs.
1857+
1858+
.. versionadded:: 0.1.13
1859+
1860+
Args:
1861+
new_size: The new number of inputs.
1862+
"""
1863+
current_size = len(self._inputs)
1864+
if new_size == current_size:
1865+
return
1866+
if new_size < current_size:
1867+
# Remove extra inputs
1868+
for i in range(new_size, current_size):
1869+
self.replace_input_with(i, None)
1870+
self._inputs = self._inputs[:new_size]
1871+
else:
1872+
# Add new inputs as None
1873+
self._inputs = self._inputs + (None,) * (new_size - current_size)
1874+
18491875
def predecessors(self) -> Sequence[Node]:
18501876
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
18511877
# Use the ordered nature of a dictionary to deduplicate the nodes
@@ -1920,15 +1946,54 @@ def append(self, /, nodes: Node | Iterable[Node]) -> None:
19201946
def outputs(self) -> Sequence[Value]:
19211947
"""The output values of the node.
19221948
1923-
The outputs are immutable. To change the outputs, create a new node and
1924-
replace the inputs of the using nodes of this node's outputs by calling
1925-
:meth:`replace_input_with` on the using nodes of this node's outputs.
1949+
The outputs are always attached to this node once initialized (immutable),
1950+
except that the list can be resized to remove or add outputs.
1951+
1952+
Use :meth:`resize_outputs` to change the number of outputs of the node.
19261953
"""
19271954
return self._outputs
19281955

19291956
@outputs.setter
19301957
def outputs(self, _: Sequence[Value]) -> None:
1931-
raise AttributeError("outputs is immutable. Please create a new node instead.")
1958+
raise AttributeError(
1959+
"Node.outputs cannot be assigned to. Please use 'resize_outputs' or create a new node instead."
1960+
)
1961+
1962+
def resize_outputs(self, new_size: int, /) -> None:
1963+
"""Resize the outputs of the node.
1964+
1965+
If the new size is greater than the current size, new output values are created.
1966+
If the new size is less than the current size, the extra output values are removed.
1967+
The removed output values must not have any uses.
1968+
1969+
.. versionadded:: 0.1.13
1970+
1971+
Args:
1972+
new_size: The new number of outputs.
1973+
1974+
Raises:
1975+
ValueError: If the new size is less than the current size and
1976+
the removed outputs have uses.
1977+
"""
1978+
current_size = len(self._outputs)
1979+
if new_size == current_size:
1980+
return
1981+
if new_size < current_size:
1982+
# Check that the removed outputs have no uses
1983+
for output in self._outputs[new_size:]:
1984+
if output.uses():
1985+
raise ValueError(
1986+
f"Cannot remove output {output} because it has uses: {output.uses()}"
1987+
)
1988+
for output in self._outputs[new_size:]:
1989+
# Detach the output from this node
1990+
output._producer = None # pylint: disable=protected-access
1991+
output._index = -1 # pylint: disable=protected-access
1992+
self._outputs = self._outputs[:new_size]
1993+
else:
1994+
# Create new outputs
1995+
new_outputs = [Value(self, index=i) for i in range(current_size, new_size)]
1996+
self._outputs = self._outputs + tuple(new_outputs)
19321997

19331998
@property
19341999
def attributes(self) -> _graph_containers.Attributes:

src/onnx_ir/_core_test.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,6 +1507,207 @@ def test_attributes_get_tensors(self):
15071507
node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1]
15081508
)
15091509

1510+
def test_resize_inputs_increase_size(self):
1511+
"""Test that resize_inputs increases the number of inputs by adding None values."""
1512+
v0 = _core.Value(name="v0")
1513+
v1 = _core.Value(name="v1")
1514+
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)
1515+
1516+
self.assertEqual(len(node.inputs), 2)
1517+
self.assertIs(node.inputs[0], v0)
1518+
self.assertIs(node.inputs[1], v1)
1519+
1520+
# Resize to 4 inputs
1521+
node.resize_inputs(4)
1522+
1523+
self.assertEqual(len(node.inputs), 4)
1524+
self.assertIs(node.inputs[0], v0)
1525+
self.assertIs(node.inputs[1], v1)
1526+
self.assertIsNone(node.inputs[2])
1527+
self.assertIsNone(node.inputs[3])
1528+
1529+
def test_resize_inputs_decrease_size(self):
1530+
"""Test that resize_inputs decreases the number of inputs and removes uses."""
1531+
v0 = _core.Value(name="v0")
1532+
v1 = _core.Value(name="v1")
1533+
v2 = _core.Value(name="v2")
1534+
node = _core.Node("", "TestOp", inputs=(v0, v1, v2), num_outputs=1)
1535+
1536+
self.assertEqual(len(node.inputs), 3)
1537+
# Check that node is in v2's uses
1538+
self.assertEqual(len(v2.uses()), 1)
1539+
self.assertIn(_core.Usage(node, 2), v2.uses())
1540+
1541+
# Resize to 2 inputs (remove v2)
1542+
node.resize_inputs(2)
1543+
1544+
self.assertEqual(len(node.inputs), 2)
1545+
self.assertIs(node.inputs[0], v0)
1546+
self.assertIs(node.inputs[1], v1)
1547+
# Check that node is no longer in v2's uses
1548+
self.assertEqual(len(v2.uses()), 0)
1549+
1550+
def test_resize_inputs_same_size(self):
1551+
"""Test that resize_inputs does nothing when size is unchanged."""
1552+
v0 = _core.Value(name="v0")
1553+
v1 = _core.Value(name="v1")
1554+
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)
1555+
1556+
# Resize to same size
1557+
node.resize_inputs(2)
1558+
1559+
self.assertEqual(len(node.inputs), 2)
1560+
self.assertIs(node.inputs[0], v0)
1561+
self.assertIs(node.inputs[1], v1)
1562+
1563+
def test_resize_inputs_to_zero(self):
1564+
"""Test that resize_inputs can reduce inputs to zero."""
1565+
v0 = _core.Value(name="v0")
1566+
v1 = _core.Value(name="v1")
1567+
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)
1568+
1569+
node.resize_inputs(0)
1570+
1571+
self.assertEqual(len(node.inputs), 0)
1572+
self.assertEqual(node.inputs, ())
1573+
# Check that uses are removed
1574+
self.assertEqual(len(v0.uses()), 0)
1575+
self.assertEqual(len(v1.uses()), 0)
1576+
1577+
def test_resize_inputs_from_zero(self):
1578+
"""Test that resize_inputs can increase from zero inputs."""
1579+
node = _core.Node("", "TestOp", inputs=(), num_outputs=1)
1580+
1581+
self.assertEqual(len(node.inputs), 0)
1582+
1583+
node.resize_inputs(3)
1584+
1585+
self.assertEqual(len(node.inputs), 3)
1586+
self.assertIsNone(node.inputs[0])
1587+
self.assertIsNone(node.inputs[1])
1588+
self.assertIsNone(node.inputs[2])
1589+
1590+
def test_resize_inputs_preserves_none_inputs(self):
1591+
"""Test that resize_inputs preserves None inputs when decreasing size."""
1592+
v0 = _core.Value(name="v0")
1593+
node = _core.Node("", "TestOp", inputs=(v0, None, None), num_outputs=1)
1594+
1595+
node.resize_inputs(2)
1596+
1597+
self.assertEqual(len(node.inputs), 2)
1598+
self.assertIs(node.inputs[0], v0)
1599+
self.assertIsNone(node.inputs[1])
1600+
1601+
def test_resize_outputs_increase_size(self):
1602+
"""Test that resize_outputs increases the number of outputs."""
1603+
v0 = _core.Value(name="v0")
1604+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)
1605+
1606+
self.assertEqual(len(node.outputs), 2)
1607+
old_output_0 = node.outputs[0]
1608+
old_output_1 = node.outputs[1]
1609+
1610+
# Resize to 4 outputs
1611+
node.resize_outputs(4)
1612+
1613+
self.assertEqual(len(node.outputs), 4)
1614+
# Verify old outputs are preserved
1615+
self.assertIs(node.outputs[0], old_output_0)
1616+
self.assertIs(node.outputs[1], old_output_1)
1617+
# Verify new outputs are created
1618+
self.assertIsNotNone(node.outputs[2])
1619+
self.assertIsNotNone(node.outputs[3])
1620+
# Verify new outputs have correct producer and index
1621+
self.assertIs(node.outputs[2].producer(), node)
1622+
self.assertIs(node.outputs[3].producer(), node)
1623+
self.assertEqual(node.outputs[2].index(), 2)
1624+
self.assertEqual(node.outputs[3].index(), 3)
1625+
1626+
def test_resize_outputs_decrease_size(self):
1627+
"""Test that resize_outputs decreases the number of outputs when they have no uses."""
1628+
v0 = _core.Value(name="v0")
1629+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3)
1630+
1631+
self.assertEqual(len(node.outputs), 3)
1632+
old_output_0 = node.outputs[0]
1633+
1634+
# Resize to 1 output
1635+
node.resize_outputs(1)
1636+
1637+
self.assertEqual(len(node.outputs), 1)
1638+
self.assertIs(node.outputs[0], old_output_0)
1639+
1640+
def test_resize_outputs_decrease_size_raises_when_output_has_uses(self):
1641+
"""Test that resize_outputs raises ValueError when removing outputs with uses."""
1642+
v0 = _core.Value(name="v0")
1643+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3)
1644+
# Create a consumer for the third output
1645+
_consumer = _core.Node("", "Consumer", inputs=(node.outputs[2],), num_outputs=1)
1646+
1647+
self.assertEqual(len(node.outputs[2].uses()), 1)
1648+
1649+
# Try to resize to 2 outputs (remove the third one)
1650+
with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"):
1651+
node.resize_outputs(2)
1652+
1653+
# Verify outputs are unchanged
1654+
self.assertEqual(len(node.outputs), 3)
1655+
1656+
def test_resize_outputs_same_size(self):
1657+
"""Test that resize_outputs does nothing when size is unchanged."""
1658+
v0 = _core.Value(name="v0")
1659+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)
1660+
1661+
old_outputs = node.outputs
1662+
1663+
# Resize to same size
1664+
node.resize_outputs(2)
1665+
1666+
self.assertEqual(len(node.outputs), 2)
1667+
self.assertIs(node.outputs[0], old_outputs[0])
1668+
self.assertIs(node.outputs[1], old_outputs[1])
1669+
1670+
def test_resize_outputs_to_zero(self):
1671+
"""Test that resize_outputs can reduce outputs to zero."""
1672+
v0 = _core.Value(name="v0")
1673+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)
1674+
1675+
node.resize_outputs(0)
1676+
1677+
self.assertEqual(len(node.outputs), 0)
1678+
self.assertEqual(node.outputs, ())
1679+
1680+
def test_resize_outputs_from_zero(self):
1681+
"""Test that resize_outputs can increase from zero outputs."""
1682+
v0 = _core.Value(name="v0")
1683+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=0)
1684+
1685+
self.assertEqual(len(node.outputs), 0)
1686+
1687+
node.resize_outputs(2)
1688+
1689+
self.assertEqual(len(node.outputs), 2)
1690+
self.assertIsNotNone(node.outputs[0])
1691+
self.assertIsNotNone(node.outputs[1])
1692+
self.assertIs(node.outputs[0].producer(), node)
1693+
self.assertIs(node.outputs[1].producer(), node)
1694+
self.assertEqual(node.outputs[0].index(), 0)
1695+
self.assertEqual(node.outputs[1].index(), 1)
1696+
1697+
def test_resize_outputs_decrease_with_middle_output_having_uses(self):
1698+
"""Test that resize_outputs raises when removing a middle output with uses."""
1699+
v0 = _core.Value(name="v0")
1700+
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=4)
1701+
# Create a consumer for the second output (index 1)
1702+
_consumer = _core.Node("", "Consumer", inputs=(node.outputs[1],), num_outputs=1)
1703+
1704+
# Try to resize to 1 output (remove outputs at indices 1, 2, 3)
1705+
with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"):
1706+
node.resize_outputs(1)
1707+
1708+
# Verify outputs are unchanged
1709+
self.assertEqual(len(node.outputs), 4)
1710+
15101711
# TODO(justinchuby): Test all methods
15111712

15121713

src/onnx_ir/passes/common/unused_removal.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ def is_used_output(i: int) -> bool:
6464
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
6565
out.name = ""
6666

67+
# Remove trailing outputs with empty names by counting backwards
68+
new_output_count = len(node.outputs)
69+
for i in reversed(range(len(node.outputs))):
70+
if not node.outputs[i].name:
71+
new_output_count -= 1
72+
else:
73+
break
74+
node.resize_outputs(new_output_count)
75+
76+
77+
def _remove_trailing_empty_inputs(node: ir.Node) -> None:
78+
# Remove trailing None inputs
79+
new_input_count = len(node.inputs)
80+
for i in reversed(range(len(node.inputs))):
81+
if node.inputs[i] is None:
82+
new_input_count -= 1
83+
else:
84+
break
85+
node.resize_inputs(new_input_count)
86+
6787

6888
def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
6989
graph_outputs = frozenset(function_or_graph.outputs)
@@ -79,6 +99,7 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph
7999
function_or_graph.remove(node, safe=True)
80100
count += 1
81101
else:
102+
_remove_trailing_empty_inputs(node)
82103
if onnx_opset_version is not None:
83104
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
84105
for attr in node.attributes.values():

0 commit comments

Comments
 (0)