Skip to content

Commit 2c868b8

Browse files
authored
Merge pull request #649 from robertknight/constant-op-value-fields
Support `value_{int, ints, float, floats}` attributes in `Constant` op
2 parents d5d6f04 + d42a6a7 commit 2c868b8

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

Diff for: rten-convert/rten_convert/converter.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def snake_case_to_pascal_case(s: str) -> str:
210210

211211
class ONNXOperatorReader:
212212
"""
213-
Utiliy for extracting attribute and input values from an ONNX operator.
213+
Utility for extracting attribute and input values from an ONNX operator.
214214
215215
This keeps track of which attributes have been read so that we can warn about
216216
any unhandled ones.
@@ -501,10 +501,29 @@ def noop_add_node(node: Node) -> int:
501501

502502
output_name = onnx_op.output[0]
503503

504-
tensor = ONNXOperatorReader(
505-
onnx_op, input_indexes=[], add_node=noop_add_node
506-
).require_attr("value", "tensor")
507-
const_node = constant_node_from_onnx_initializer(tensor, output_name)
504+
attrs = ONNXOperatorReader(onnx_op, input_indexes=[], add_node=noop_add_node)
505+
if (tensor := attrs.get_attr("value", "tensor", None)) is not None:
506+
const_node = constant_node_from_onnx_initializer(tensor, output_name)
507+
else:
508+
if (int_ := attrs.get_attr("value_int", "int", None)) is not None:
509+
shape = []
510+
data = np.array(int_).astype(np.int32)
511+
elif (ints := attrs.get_attr("value_ints", "ints", None)) is not None:
512+
shape = [len(ints)]
513+
data = np.array(ints).astype(np.int32)
514+
elif (float_ := attrs.get_attr("value_float", "float", None)) is not None:
515+
shape = []
516+
data = np.array(float_).astype(np.float32)
517+
elif (floats := attrs.get_attr("value_floats", "floats", None)) is not None:
518+
shape = [len(floats)]
519+
data = np.array(floats).astype(np.float32)
520+
else:
521+
# Unsupported attributes: value_string, value_strings
522+
raise Exception(
523+
f'Unable to get value from "Constant" operator "{onnx_op.name}"'
524+
)
525+
const_node = ConstantNode("dummy_name", shape, data)
526+
508527
const_node.name = output_name
509528

510529
return const_node

0 commit comments

Comments
 (0)