Skip to content

Commit 0286ca5

Browse files
authored
Merge pull request #653 from robertknight/rten-convert-refactor
Split some functionality out of the main `rten_convert.converter` module
2 parents 160eb2c + 1a44b4c commit 0286ca5

File tree

5 files changed

+426
-401
lines changed

5 files changed

+426
-401
lines changed

Diff for: rten-convert/rten_convert/attr_reader.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
from typing import Any, Callable, Literal
2+
3+
import numpy as np
4+
import onnx
5+
6+
from rten_convert.errors import ConversionError
7+
from rten_convert.graph import ConstantNode, Node
8+
from rten_convert.util import warn_once
9+
10+
11+
class AttributeReader:
12+
"""
13+
Utility for extracting attribute and input values from an ONNX operator.
14+
15+
This keeps track of which attributes have been read so that we can warn about
16+
any unhandled ones.
17+
"""
18+
19+
onnx_op: onnx.OperatorProto
20+
21+
add_node: Callable[[Node], int]
22+
"""
23+
Function that adds a new node to the graph and returns its ID.
24+
25+
This is used if a new constant node has to be generated to replace an
26+
operator attribute.
27+
"""
28+
29+
input_indexes: list[int | None]
30+
"""
31+
IDs of the operator's input nodes.
32+
33+
New inputs may be generated while reading an operator if it has an attribute
34+
that needs to be converted to a dynamic input.
35+
"""
36+
37+
_handled_attrs: set[str]
38+
"""Names of attributes that have been handled."""
39+
40+
def __init__(
41+
self,
42+
onnx_op: onnx.OperatorProto,
43+
input_indexes: list[int | None],
44+
add_node: Callable[[Node], int],
45+
):
46+
self.onnx_op = onnx_op
47+
48+
self.add_node = add_node
49+
self.input_indexes = input_indexes.copy()
50+
51+
self._handled_attrs = set()
52+
53+
def get_attr(self, name: str, expected_type: str, default):
54+
"""Get the value of an optional operator attribute."""
55+
56+
self._handled_attrs.add(name)
57+
58+
type_code = getattr(onnx.AttributeProto, expected_type.upper())
59+
for attr in self.onnx_op.attribute:
60+
if attr.name == name:
61+
if attr.type != type_code:
62+
raise ConversionError(
63+
f"Attribute {name} type does not match {expected_type}"
64+
)
65+
val = getattr(attr, _value_fields[type_code])
66+
67+
# String attribute values are stored as bytes, so we have to decode
68+
# them.
69+
if expected_type == "string":
70+
val = val.decode()
71+
72+
return val
73+
return default
74+
75+
def get_bool_attr(self, name: str, default: bool) -> bool:
76+
"""
77+
Get the value of an optional boolean operator attribute.
78+
79+
ONNX represents boolean attributes as "int" fields with values 0 or 1
80+
rather than a dedicated boolean type. This method converts these
81+
attributes to Python booleans.
82+
"""
83+
return bool(self.get_attr(name, "int", int(default)))
84+
85+
def get_enum_attr(self, name: str, enum: Any, default: str, fallback: Any = None):
86+
"""
87+
Get an optional attribute whose value is an enum variant.
88+
89+
The variant name is Pascal-Cased and looked up on the enum object.
90+
eg. `round_prefer_floor` => `RoundPreferFloor`. If the Pascal-Cased
91+
name matches a Python keyword, it is expected to be escaped, eg.
92+
`none` => `None_`.
93+
94+
If the attribute value does not match any enum value, this will raise if
95+
`fallback` is not specified, or emit a warning and use the value
96+
`fallback` otherwise. Use of `fallback` is appropriate if the
97+
substitution is unlikely to affect the resulting model's ability to run,
98+
but might impact accuracy modestly.
99+
"""
100+
101+
def convert_attr(val: str):
102+
pascal_case = _snake_case_to_pascal_case(val)
103+
104+
# Enum values that match Python keywords have a trailing underscore appended.
105+
escaped_pascal_case = pascal_case + "_"
106+
107+
try:
108+
return getattr(enum, pascal_case)
109+
except AttributeError:
110+
return getattr(enum, escaped_pascal_case)
111+
112+
val = self.get_attr(name, "string", default)
113+
try:
114+
return convert_attr(val)
115+
except AttributeError:
116+
if fallback:
117+
op = self.onnx_op.op_type
118+
warn_once(
119+
f'Replacing unsupported value "{val}" for "{name}" attr in {op} op with "{fallback}"'
120+
)
121+
return convert_attr(fallback)
122+
raise ConversionError(f'Unsupported value "{val}" for "{name}" attr')
123+
124+
def ignore_attr(self, name: str):
125+
"""
126+
Mark an attribute as ignored.
127+
128+
This is useful in cases where an attribute contains redundant information.
129+
"""
130+
self._handled_attrs.add(name)
131+
132+
def require_attr(self, name: str, expected_type: str):
133+
"""Get the value of a required operator attribute."""
134+
val = self.get_attr(name, expected_type, default=None)
135+
if val is None:
136+
raise ConversionError(f"Missing required attribute {name}")
137+
return val
138+
139+
def generate_input_from_attr(
140+
self, input_index: int, attr_name: str, attr_type: str
141+
):
142+
"""
143+
Generate a constant operator input from an attribute, if it exists.
144+
145+
Some operator inputs changed from attributes to inputs in different ONNX
146+
releases. This function checks to see if an operator has an attribute
147+
and synthesizes a constant input.
148+
149+
:param input_index: Index of the input that the attribute corresponds to
150+
:param attr_name: Name of the attribute
151+
:param attr_type: Expected type of the attribute
152+
"""
153+
154+
attr_val = self.get_attr(attr_name, attr_type, default=None)
155+
if attr_val is None:
156+
return
157+
158+
if input_index < len(self.input_indexes):
159+
raise ConversionError(
160+
f'Operator has both an attribute "{attr_name}" and corresponding input at index {input_index}'
161+
)
162+
163+
shape: list[int]
164+
match attr_type:
165+
case "int":
166+
shape = []
167+
data = np.array(attr_val).astype(np.int32)
168+
169+
case "float":
170+
shape = []
171+
data = np.array(attr_val).astype(np.float32)
172+
173+
case "ints":
174+
shape = [len(attr_val)]
175+
data = np.array([attr_val]).astype(np.int32)
176+
case _:
177+
raise ConversionError(
178+
f'Unable to generate input from "{attr_name}" attribute of type "{attr_type}"'
179+
)
180+
181+
generated_name = self.onnx_op.name + ":rten-" + attr_name
182+
const_node = ConstantNode(generated_name, shape, data)
183+
input_id = self.add_node(const_node)
184+
185+
while len(self.input_indexes) < input_index + 1:
186+
self.input_indexes.append(None)
187+
self.input_indexes[input_index] = input_id
188+
189+
def check_attr(
190+
self,
191+
name: str,
192+
expected_type,
193+
default,
194+
on_mismatch: Literal["raise", "warn"] = "raise",
195+
):
196+
"""
197+
Check if an operator has an unsupported non-default value for an attribute.
198+
199+
If `default` is a tuple, it specifies a set of acceptable defaults.
200+
201+
:param name: The name of the operator attribute
202+
:param default: The value which is equivalent to the default behavior
203+
:param on_mismatch:
204+
Whether a mismatch should be treated as a fatal error in model
205+
conversion or merely warn that this might cause a problem.
206+
"""
207+
208+
val = self.get_attr(name, expected_type, None)
209+
if val is None:
210+
return
211+
212+
if not isinstance(default, tuple):
213+
default = (default,)
214+
if val not in default:
215+
msg = f"Unsupported value {val} for attribute {name}. Default is {default}"
216+
if on_mismatch == "raise":
217+
raise ConversionError(msg)
218+
else:
219+
warn_once(msg)
220+
221+
def unhandled_attrs(self) -> list[onnx.AttributeProto]:
222+
"""Return a list of attributes which have not been read."""
223+
return [
224+
attr
225+
for attr in self.onnx_op.attribute
226+
if attr.name not in self._handled_attrs
227+
]
228+
229+
230+
def _snake_case_to_pascal_case(s: str) -> str:
231+
"""Transform a snake_case string to PascalCase."""
232+
return "".join([word[0].upper() + word[1:] for word in s.split("_")])
233+
234+
235+
# Mapping of ONNX attribute types to the field on an AttributeProto which
236+
# contains the value. Note that if you try to access the wrong field on an
237+
# AttributeProto, you get a default value instead of an exception.
238+
_value_fields = {
239+
onnx.AttributeProto.FLOAT: "f",
240+
onnx.AttributeProto.GRAPH: "g",
241+
onnx.AttributeProto.INT: "i",
242+
onnx.AttributeProto.INTS: "ints",
243+
onnx.AttributeProto.STRING: "s",
244+
onnx.AttributeProto.TENSOR: "t",
245+
}

0 commit comments

Comments
 (0)