Skip to content
90 changes: 90 additions & 0 deletions src/onnx_ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
__all__ = [
"convert_attribute",
"convert_attributes",
"rename_values",
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
Expand Down Expand Up @@ -360,6 +361,95 @@ def replace_all_uses_with(
value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)


def rename_values(
values: _core.Value | Sequence[_core.Value],
names: str | Sequence[str],
) -> None:
"""Rename one or more values.

Initializer-backed values are removed from their graphs while renaming so swaps
and other permutations do not trip the initializer name guard.

Args:
values: The value or values to rename. Must be `_core.Value` instances or a
sequence of `_core.Value` instances.
names: The target name or names.

Raises:
TypeError: If a value is not a :class:`~onnx_ir.Value` or a name is not a string.
ValueError: If the number of values and names do not match, if one value is
given conflicting target names, or if an initializer target would collide
with an initializer outside the renamed set.
"""
if not isinstance(values, Sequence):
values = (values,)
if isinstance(names, str) or not isinstance(names, Sequence):
names = (names,)
if len(values) != len(names):
raise ValueError("The number of values and names must match.")

ordered_pairs: list[tuple[_core.Value, str]] = []
target_by_value: dict[_core.Value, str] = {}
for value, name in zip(values, names):
if not isinstance(value, _core.Value):
raise TypeError(f"value must be a Value object, not {type(value)}")
Comment thread
enpasos marked this conversation as resolved.
if not isinstance(name, str):
raise TypeError(f"name must be a string, not {type(name)}")
if value in target_by_value:
if target_by_value[value] != name:
raise ValueError(
f"Conflicting target names for value {value!r}: "
f"{target_by_value[value]!r} vs {name!r}."
)
continue
target_by_value[value] = name
ordered_pairs.append((value, name))

initializer_pairs_by_graph: dict[_core.Graph, list[tuple[_core.Value, str]]] = {}
for value, name in ordered_pairs:
if not value.is_initializer():
continue
graph = value.graph
assert isinstance(graph, _core.Graph), "Initializer values must belong to a graph"
initializer_pairs_by_graph.setdefault(graph, []).append((value, name))

initializer_values_by_graph: dict[_core.Graph, tuple[_core.Value, ...]] = {}
for graph, initializer_pairs in initializer_pairs_by_graph.items():
renamed_initializers = {value for value, _ in initializer_pairs}
seen_targets: dict[str, _core.Value] = {}
for value, name in initializer_pairs:
if name == "":
raise ValueError("Initializer value name cannot be an empty string.")
existing = seen_targets.get(name)
if existing is not None and existing is not value:
raise ValueError(
f"Cannot rename initializer '{value}' to '{name}': "
Comment thread
enpasos marked this conversation as resolved.
"another initializer in the rename set already targets that name."
)
seen_targets[name] = value
if name in graph.initializers:
existing_initializer = graph.initializers[name]
if existing_initializer is not value and existing_initializer not in renamed_initializers:
raise ValueError(
f"Cannot rename initializer '{value}' to '{name}': "
"an initializer with that name already exists."
)

initializer_values_by_graph[graph] = tuple(value for value, _ in initializer_pairs)

for graph, initializer_values in initializer_values_by_graph.items():
for value in initializer_values:
assert value.name is not None, "Initializer values must have names"
graph.initializers.pop(value.name)

for value, name in ordered_pairs:
value.name = name

for graph, initializer_values in initializer_values_by_graph.items():
for value in initializer_values:
graph.initializers.add(value)


def create_value_mapping(
graph: _core.Graph | _core.GraphView | _core.Function,
*,
Expand Down
64 changes: 64 additions & 0 deletions src/onnx_ir/_convenience/_convenience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,69 @@ def test_constant_value(self, _: str, attr: ir.Attr, expected: np.ndarray):
self.assertEqual(node.outputs[0].type, ir.TensorType(result_2.dtype))


class RenameValuesTest(unittest.TestCase):
def test_rename_values_supports_initializer_swaps(self):
first = ir.Value(name="const_0", const_value=ir.tensor([1], name="const_0"))
second = ir.Value(name="const_1", const_value=ir.tensor([2], name="const_1"))
graph = ir.Graph(
inputs=(),
outputs=[first, second],
nodes=(),
initializers=[first, second],
name="test_graph",
)

_convenience.rename_values((first, second), ("const_1", "const_0"))

self.assertEqual(first.name, "const_1")
self.assertEqual(second.name, "const_0")
self.assertEqual(first.const_value.name, "const_1")
self.assertEqual(second.const_value.name, "const_0")
self.assertEqual(set(graph.initializers), {"const_0", "const_1"})
self.assertIs(graph.initializers["const_1"], first)
self.assertIs(graph.initializers["const_0"], second)

def test_rename_values_rejects_none_names(self):
value = ir.Value(name="value")

with self.assertRaisesRegex(TypeError, "name must be a string"):
_convenience.rename_values(value, None)

def test_rename_values_rejects_initializer_collisions_outside_rename_set(self):
first = ir.Value(name="const_0", const_value=ir.tensor([1], name="const_0"))
second = ir.Value(name="const_1", const_value=ir.tensor([2], name="const_1"))
graph = ir.Graph(
inputs=(),
outputs=[first, second],
nodes=(),
initializers=[first, second],
name="test_graph",
)

with self.assertRaisesRegex(ValueError, "an initializer with that name already exists"):
_convenience.rename_values(first, "const_1")

self.assertIs(graph.initializers["const_0"], first)
self.assertIs(graph.initializers["const_1"], second)

def test_rename_values_rejects_empty_initializer_name_without_mutating_graph(self):
value = ir.Value(name="const_0", const_value=ir.tensor([1], name="const_0"))
graph = ir.Graph(
inputs=(),
outputs=[value],
nodes=(),
initializers=[value],
name="test_graph",
)

with self.assertRaisesRegex(ValueError, "empty string"):
_convenience.rename_values(value, "")

self.assertEqual(value.name, "const_0")
self.assertEqual(value.const_value.name, "const_0")
self.assertEqual(list(graph.initializers), ["const_0"])
self.assertIs(graph.initializers["const_0"], value)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions src/onnx_ir/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"create_value_mapping",
"extract",
"get_const_tensor",
"rename_values",
"replace_all_uses_with",
"replace_nodes_and_values",
]
Expand All @@ -19,6 +20,7 @@
convert_attributes,
create_value_mapping,
get_const_tensor,
rename_values,
replace_all_uses_with,
replace_nodes_and_values,
)
Expand Down
2 changes: 1 addition & 1 deletion src/onnx_ir/passes/common/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def enter_graph(graph_like) -> None:

if isinstance(graph_like, ir.Graph):
# For graphs, also fix initializers
for initializer in graph_like.initializers.values():
for initializer in tuple(graph_like.initializers.values()):
if self._process_value(
initializer, scoped_used_value_names[-1], seen_values, value_counter
):
Expand Down
26 changes: 26 additions & 0 deletions src/onnx_ir/passes/common/naming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,32 @@ def test_graph_inputs_outputs_have_precedence(self):
self.assertNotEqual(add_node.outputs[0].name, "important_input")
self.assertTrue(add_node.outputs[0].name.startswith("important_input_"))

def test_initializer_collision_does_not_mutate_dict_during_iteration(self):
"""Test NameFixPass handles collisions with initializer names safely."""
input_value = ir.val(
"input", shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.FLOAT)
)
initializer = ir.Value(name="weights", const_value=ir.tensor([1.0], name="weights"))
graph = ir.Graph(
inputs=[input_value],
outputs=[input_value],
nodes=(),
initializers=[initializer],
name="test_graph",
)
model = ir.Model(graph, ir_version=10)

input_value.name = "weights"

result = naming.NameFixPass()(model)

self.assertTrue(result.modified)
self.assertEqual(input_value.name, "weights")
self.assertEqual(initializer.name, "weights_1")
self.assertEqual(initializer.const_value.name, "weights_1")
self.assertEqual(list(graph.initializers), ["weights_1"])
self.assertIs(graph.initializers["weights_1"], initializer)


if __name__ == "__main__":
unittest.main()
Loading