Skip to content

Commit dfa7544

Browse files
Support metadata_prop merge and version 25 in version converter (microsoft#2782)
Fix pytorch/pytorch#172784 --- This pull request adds support for ONNX opset version 25 in the version converter and introduces a new mechanism to copy node metadata during version conversions. It also includes comprehensive tests to ensure that metadata is properly transferred to new or replacement nodes created by adapters during the conversion process. **Version converter improvements:** * Increased the maximum supported ONNX opset version from 23 to 25 in `SUPPORTED_MAX_ONNX_OPSET` within `onnxscript/version_converter/_version_converter.py`. * Integrated a new `metadata_merger` utility and implemented a default metadata merger to ensure node metadata is copied during version conversion. Metadata is now merged from original nodes to all replacement nodes in the conversion process. [[1]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faR239-R244) [[2]](diffhunk://#diff-b6c70f90bafaee79b30e43c90bc0fd5192fb3de7ccc4cf9d48a209798dd775faR303) **Testing and validation:** * Added a new `VersionConverterMetadataMergeTest` class in `onnxscript/version_converter/_version_converter_test.py` to verify that metadata is copied correctly to replacement nodes and to all nodes created by adapters during conversion. * Updated the test suite to reflect the new maximum supported opset version and to ensure that conversion beyond version 25 is marked as expected to fail for future-proofing. --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 0ed5f23 commit dfa7544

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

onnxscript/version_converter/_version_converter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import onnx_ir.convenience as ir_convenience
1313

1414
import onnxscript.ir._tape as _tape
15+
import onnxscript.utils.metadata_merger as metadata_merger
1516
from onnxscript import ir
1617

1718
logger = logging.getLogger(__name__)
1819

1920

20-
SUPPORTED_MAX_ONNX_OPSET = 23
21+
SUPPORTED_MAX_ONNX_OPSET = 25
2122
SUPPORTED_MIN_ONNX_OPSET = 18
2223

2324

@@ -238,6 +239,12 @@ def groupnormalization_20_21(node: ir.Node, op):
238239
class _VersionConverter:
239240
def __init__(self, target_version: int):
240241
self._target_version = target_version
242+
# Default metadata merger: no merging should be needed; keep the first value.
243+
self._default_metadata_merger: metadata_merger.MetadataMerger = (
244+
metadata_merger.MetadataMerger(
245+
dict(),
246+
)
247+
)
241248

242249
def process_node(
243250
self, node: ir.Node, from_version: int, up_conversion: bool = True
@@ -293,6 +300,7 @@ def visit_node(
293300
for new_node in replacement.new_nodes:
294301
# TODO: control-flow
295302
new_node.version = to_version
303+
self._default_metadata_merger.copy_merged_metadata([node], replacement.new_nodes)
296304
self.replace_node(node, replacement, root)
297305

298306
def visit_graph(self, graph: ir.Graph) -> None:

onnxscript/version_converter/_version_converter_test.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,91 @@ def test_version_groupnorm_no_bias(self):
296296
self.assertEqual(model.graph.node(0).version, 20)
297297

298298

299-
class VersionConverter23to24Test(unittest.TestCase):
300-
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 23 not yet supported.")
299+
class VersionConverterMetadataMergeTest(unittest.TestCase):
300+
def test_metadata_is_copied_on_version_conversion(self):
301+
"""Test that metadata is copied from original node to replacement nodes during version conversion."""
302+
model = ir.from_onnx_text(
303+
"""
304+
<ir_version: 7, opset_import: [ "" : 18]>
305+
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
306+
{
307+
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
308+
reshape_x = Reshape (input_x, shape_a)
309+
dft = DFT <axis = 2, onesided = 1> (reshape_x)
310+
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
311+
output = Reshape (dft, shape_c)
312+
}
313+
"""
314+
)
315+
# Find the DFT node and add metadata to it
316+
dft_node = model.graph.node(2)
317+
self.assertEqual(dft_node.op_type, "DFT")
318+
dft_node.metadata_props["test_key"] = "test_value"
319+
dft_node.metadata_props["another_key"] = "another_value"
320+
321+
target_version = 25
322+
version_converter.convert_version(model, target_version=target_version)
323+
self.assertEqual(model.opset_imports[""], target_version)
324+
325+
# After conversion, DFT adapter adds a Constant node for axis and the DFT node is replaced
326+
# The replacement DFT node should have the metadata copied
327+
new_dft_node = model.graph.node(3)
328+
self.assertEqual(new_dft_node.op_type, "DFT")
329+
self.assertEqual(new_dft_node.version, 25)
330+
331+
# Verify metadata was copied to the new DFT node
332+
self.assertEqual(new_dft_node.metadata_props.get("test_key"), "test_value")
333+
self.assertEqual(new_dft_node.metadata_props.get("another_key"), "another_value")
334+
335+
def test_metadata_is_copied_to_multiple_replacement_nodes(self):
336+
"""Test that metadata is copied to all replacement nodes when an adapter creates multiple nodes."""
337+
model = ir.from_onnx_text(
338+
"""
339+
<ir_version: 7, opset_import: [ "" : 18]>
340+
agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output)
341+
{
342+
groupnorm = GroupNormalization <num_groups = 2> (input_x, scale, bias)
343+
shape_c = Constant<value: tensor = int64[4] {4, 512, 512}>()
344+
output = Reshape (groupnorm, shape_c)
345+
}
346+
"""
347+
)
348+
# Find the GroupNormalization node and add metadata to it
349+
groupnorm_node = model.graph.node(0)
350+
self.assertEqual(groupnorm_node.op_type, "GroupNormalization")
351+
groupnorm_node.metadata_props["source"] = "original_groupnorm"
352+
353+
target_version = 21
354+
version_converter.convert_version(model, target_version=target_version)
355+
self.assertEqual(model.opset_imports[""], target_version)
356+
357+
# GroupNormalization adapter creates multiple nodes (Reshape, Expand, etc.)
358+
# Verify that metadata was copied to the new nodes created by the adapter
359+
new_groupnorm_node = model.graph.node(9)
360+
self.assertEqual(new_groupnorm_node.op_type, "GroupNormalization")
361+
self.assertEqual(new_groupnorm_node.version, 21)
362+
363+
# Verify metadata was copied to the new GroupNormalization node
364+
self.assertEqual(new_groupnorm_node.metadata_props.get("source"), "original_groupnorm")
365+
366+
# Also check that intermediate nodes created by the adapter received the metadata
367+
# The adapter creates Reshape, Expand, Reshape nodes for scale and bias
368+
for i in range(9):
369+
node = model.graph.node(i)
370+
if node.version == 21 and node.op_type in ("Reshape", "Expand", "Constant"):
371+
self.assertEqual(
372+
node.metadata_props.get("source"),
373+
"original_groupnorm",
374+
f"Node {i} ({node.op_type}) should have metadata copied",
375+
)
376+
377+
378+
class VersionConverter25to26Test(unittest.TestCase):
379+
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 25 not yet supported.")
301380
def test_version_convert_compatible(self):
302381
model = ir.from_onnx_text(
303382
"""
304-
<ir_version: 7, opset_import: [ "" : 23]>
383+
<ir_version: 7, opset_import: [ "" : 25]>
305384
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
306385
{
307386
shape_a = Constant<value: tensor = int64[3] {4, 512, 512}>()
@@ -314,7 +393,7 @@ def test_version_convert_compatible(self):
314393
}
315394
"""
316395
)
317-
target_version = 24
396+
target_version = 26
318397
version_converter.convert_version(model, target_version=target_version)
319398

320399

0 commit comments

Comments
 (0)