Skip to content

Commit e18a862

Browse files
committed
Address PR review feedback on rename_values
Signed-off-by: enpasos <matthias.unverzagt@enpasos.com>
1 parent 70c44f1 commit e18a862

File tree

2 files changed

+50
-37
lines changed

2 files changed

+50
-37
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -363,63 +363,61 @@ def replace_all_uses_with(
363363

364364
def rename_values(
365365
values: _core.Value | Sequence[_core.Value],
366-
names: str | None | Sequence[str | None],
366+
names: str | Sequence[str],
367367
) -> None:
368368
"""Rename one or more values.
369369
370-
When multiple initializer-backed values are renamed together, temporary names are
371-
used so swaps and other permutations do not trip the initializer name guard.
370+
Initializer-backed values are removed from their graphs while renaming so swaps
371+
and other permutations do not trip the initializer name guard.
372372
373373
Args:
374374
values: The value or values to rename. Must be `_core.Value` instances or a
375375
sequence of `_core.Value` instances.
376376
names: The target name or names.
377377
378378
Raises:
379+
TypeError: If a value is not a :class:`~onnx_ir.Value` or a name is not a string.
379380
ValueError: If the number of values and names do not match, if one value is
380381
given conflicting target names, or if an initializer target would collide
381382
with an initializer outside the renamed set.
382383
"""
383384
if not isinstance(values, Sequence):
384385
values = (values,)
385-
if isinstance(names, str) or names is None:
386+
if isinstance(names, str) or not isinstance(names, Sequence):
386387
names = (names,)
387388
if len(values) != len(names):
388389
raise ValueError("The number of values and names must match.")
389390

390-
ordered_pairs: list[tuple[_core.Value, str | None]] = []
391-
target_by_value_id: dict[int, str | None] = {}
391+
ordered_pairs: list[tuple[_core.Value, str]] = []
392+
target_by_value: dict[_core.Value, str] = {}
392393
for value, name in zip(values, names):
393394
if not isinstance(value, _core.Value):
394395
raise TypeError(f"value must be a Value object, not {type(value)}")
395-
value_id = id(value)
396-
if value_id in target_by_value_id:
397-
if target_by_value_id[value_id] != name:
396+
if not isinstance(name, str):
397+
raise TypeError(f"name must be a string, not {type(name)}")
398+
if value in target_by_value:
399+
if target_by_value[value] != name:
398400
raise ValueError(
399401
f"Conflicting target names for value {value!r}: "
400-
f"{target_by_value_id[value_id]!r} vs {name!r}."
402+
f"{target_by_value[value]!r} vs {name!r}."
401403
)
402404
continue
403-
target_by_value_id[value_id] = name
405+
target_by_value[value] = name
404406
ordered_pairs.append((value, name))
405407

406-
initializer_pairs_by_graph: dict[_core.Graph, list[tuple[_core.Value, str | None]]] = {}
408+
initializer_pairs_by_graph: dict[_core.Graph, list[tuple[_core.Value, str]]] = {}
407409
for value, name in ordered_pairs:
408410
if not value.is_initializer():
409411
continue
410412
graph = value.graph
411413
assert isinstance(graph, _core.Graph), "Initializer values must belong to a graph"
412414
initializer_pairs_by_graph.setdefault(graph, []).append((value, name))
413415

416+
initializer_values_by_graph: dict[_core.Graph, tuple[_core.Value, ...]] = {}
414417
for graph, initializer_pairs in initializer_pairs_by_graph.items():
415-
renamed_initializer_ids = {id(value) for value, _ in initializer_pairs}
418+
renamed_initializers = {value for value, _ in initializer_pairs}
416419
seen_targets: dict[str, _core.Value] = {}
417420
for value, name in initializer_pairs:
418-
if name is None:
419-
raise ValueError(
420-
"Initializer value cannot have name set to None. "
421-
"Please pop() the value from initializers first to do so."
422-
)
423421
existing = seen_targets.get(name)
424422
if existing is not None and existing is not value:
425423
raise ValueError(
@@ -429,34 +427,26 @@ def rename_values(
429427
seen_targets[name] = value
430428
if name in graph.initializers:
431429
existing_initializer = graph.initializers[name]
432-
if (
433-
existing_initializer is not value
434-
and id(existing_initializer) not in renamed_initializer_ids
435-
):
430+
if existing_initializer is not value and existing_initializer not in renamed_initializers:
436431
raise ValueError(
437432
f"Cannot rename initializer '{value}' to '{name}': "
438433
"an initializer with that name already exists."
439434
)
440435

441-
used_names = set(graph.initializers)
442-
used_names.update(
443-
target_name for _, target_name in initializer_pairs if target_name is not None
444-
)
445-
tmp_index = 0
446-
for value, name in initializer_pairs:
447-
if value.name == name:
448-
continue
449-
tmp_name = f"__onnx_ir_tmp_name_{tmp_index}"
450-
while tmp_name in used_names:
451-
tmp_index += 1
452-
tmp_name = f"__onnx_ir_tmp_name_{tmp_index}"
453-
tmp_index += 1
454-
used_names.add(tmp_name)
455-
value.name = tmp_name
436+
initializer_values_by_graph[graph] = tuple(value for value, _ in initializer_pairs)
437+
438+
for graph, initializer_values in initializer_values_by_graph.items():
439+
for value in initializer_values:
440+
assert value.name is not None, "Initializer values must have names"
441+
graph.initializers.pop(value.name)
456442

457443
for value, name in ordered_pairs:
458444
value.name = name
459445

446+
for graph, initializer_values in initializer_values_by_graph.items():
447+
for value in initializer_values:
448+
graph.initializers.add(value)
449+
460450

461451
def create_value_mapping(
462452
graph: _core.Graph | _core.GraphView | _core.Function,

src/onnx_ir/_convenience/_convenience_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,29 @@ def test_rename_values_supports_initializer_swaps(self):
121121
self.assertIs(graph.initializers["const_1"], first)
122122
self.assertIs(graph.initializers["const_0"], second)
123123

124+
def test_rename_values_rejects_none_names(self):
125+
value = ir.Value(name="value")
126+
127+
with self.assertRaisesRegex(TypeError, "name must be a string"):
128+
_convenience.rename_values(value, None)
129+
130+
def test_rename_values_rejects_initializer_collisions_outside_rename_set(self):
131+
first = ir.Value(name="const_0", const_value=ir.tensor([1], name="const_0"))
132+
second = ir.Value(name="const_1", const_value=ir.tensor([2], name="const_1"))
133+
graph = ir.Graph(
134+
inputs=(),
135+
outputs=[first, second],
136+
nodes=(),
137+
initializers=[first, second],
138+
name="test_graph",
139+
)
140+
141+
with self.assertRaisesRegex(ValueError, "an initializer with that name already exists"):
142+
_convenience.rename_values(first, "const_1")
143+
144+
self.assertIs(graph.initializers["const_0"], first)
145+
self.assertIs(graph.initializers["const_1"], second)
146+
124147

125148
if __name__ == "__main__":
126149
unittest.main()

0 commit comments

Comments
 (0)