|
| 1 | +from typing import Any, Callable, Literal |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import onnx |
| 5 | + |
| 6 | +from rten_convert.errors import ConversionError |
| 7 | +from rten_convert.graph import ConstantNode, Node |
| 8 | +from rten_convert.util import warn_once |
| 9 | + |
| 10 | + |
| 11 | +class AttributeReader: |
| 12 | + """ |
| 13 | + Utility for extracting attribute and input values from an ONNX operator. |
| 14 | +
|
| 15 | + This keeps track of which attributes have been read so that we can warn about |
| 16 | + any unhandled ones. |
| 17 | + """ |
| 18 | + |
| 19 | + onnx_op: onnx.OperatorProto |
| 20 | + |
| 21 | + add_node: Callable[[Node], int] |
| 22 | + """ |
| 23 | + Function that adds a new node to the graph and returns its ID. |
| 24 | +
|
| 25 | + This is used if a new constant node has to be generated to replace an |
| 26 | + operator attribute. |
| 27 | + """ |
| 28 | + |
| 29 | + input_indexes: list[int | None] |
| 30 | + """ |
| 31 | + IDs of the operator's input nodes. |
| 32 | +
|
| 33 | + New inputs may be generated while reading an operator if it has an attribute |
| 34 | + that needs to be converted to a dynamic input. |
| 35 | + """ |
| 36 | + |
| 37 | + _handled_attrs: set[str] |
| 38 | + """Names of attributes that have been handled.""" |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + onnx_op: onnx.OperatorProto, |
| 43 | + input_indexes: list[int | None], |
| 44 | + add_node: Callable[[Node], int], |
| 45 | + ): |
| 46 | + self.onnx_op = onnx_op |
| 47 | + |
| 48 | + self.add_node = add_node |
| 49 | + self.input_indexes = input_indexes.copy() |
| 50 | + |
| 51 | + self._handled_attrs = set() |
| 52 | + |
| 53 | + def get_attr(self, name: str, expected_type: str, default): |
| 54 | + """Get the value of an optional operator attribute.""" |
| 55 | + |
| 56 | + self._handled_attrs.add(name) |
| 57 | + |
| 58 | + type_code = getattr(onnx.AttributeProto, expected_type.upper()) |
| 59 | + for attr in self.onnx_op.attribute: |
| 60 | + if attr.name == name: |
| 61 | + if attr.type != type_code: |
| 62 | + raise ConversionError( |
| 63 | + f"Attribute {name} type does not match {expected_type}" |
| 64 | + ) |
| 65 | + val = getattr(attr, _value_fields[type_code]) |
| 66 | + |
| 67 | + # String attribute values are stored as bytes, so we have to decode |
| 68 | + # them. |
| 69 | + if expected_type == "string": |
| 70 | + val = val.decode() |
| 71 | + |
| 72 | + return val |
| 73 | + return default |
| 74 | + |
| 75 | + def get_bool_attr(self, name: str, default: bool) -> bool: |
| 76 | + """ |
| 77 | + Get the value of an optional boolean operator attribute. |
| 78 | +
|
| 79 | + ONNX represents boolean attributes as "int" fields with values 0 or 1 |
| 80 | + rather than a dedicated boolean type. This method converts these |
| 81 | + attributes to Python booleans. |
| 82 | + """ |
| 83 | + return bool(self.get_attr(name, "int", int(default))) |
| 84 | + |
| 85 | + def get_enum_attr(self, name: str, enum: Any, default: str, fallback: Any = None): |
| 86 | + """ |
| 87 | + Get an optional attribute whose value is an enum variant. |
| 88 | +
|
| 89 | + The variant name is Pascal-Cased and looked up on the enum object. |
| 90 | + eg. `round_prefer_floor` => `RoundPreferFloor`. If the Pascal-Cased |
| 91 | + name matches a Python keyword, it is expected to be escaped, eg. |
| 92 | + `none` => `None_`. |
| 93 | +
|
| 94 | + If the attribute value does not match any enum value, this will raise if |
| 95 | + `fallback` is not specified, or emit a warning and use the value |
| 96 | + `fallback` otherwise. Use of `fallback` is appropriate if the |
| 97 | + substitution is unlikely to affect the resulting model's ability to run, |
| 98 | + but might impact accuracy modestly. |
| 99 | + """ |
| 100 | + |
| 101 | + def convert_attr(val: str): |
| 102 | + pascal_case = _snake_case_to_pascal_case(val) |
| 103 | + |
| 104 | + # Enum values that match Python keywords have a trailing underscore appended. |
| 105 | + escaped_pascal_case = pascal_case + "_" |
| 106 | + |
| 107 | + try: |
| 108 | + return getattr(enum, pascal_case) |
| 109 | + except AttributeError: |
| 110 | + return getattr(enum, escaped_pascal_case) |
| 111 | + |
| 112 | + val = self.get_attr(name, "string", default) |
| 113 | + try: |
| 114 | + return convert_attr(val) |
| 115 | + except AttributeError: |
| 116 | + if fallback: |
| 117 | + op = self.onnx_op.op_type |
| 118 | + warn_once( |
| 119 | + f'Replacing unsupported value "{val}" for "{name}" attr in {op} op with "{fallback}"' |
| 120 | + ) |
| 121 | + return convert_attr(fallback) |
| 122 | + raise ConversionError(f'Unsupported value "{val}" for "{name}" attr') |
| 123 | + |
| 124 | + def ignore_attr(self, name: str): |
| 125 | + """ |
| 126 | + Mark an attribute as ignored. |
| 127 | +
|
| 128 | + This is useful in cases where an attribute contains redundant information. |
| 129 | + """ |
| 130 | + self._handled_attrs.add(name) |
| 131 | + |
| 132 | + def require_attr(self, name: str, expected_type: str): |
| 133 | + """Get the value of a required operator attribute.""" |
| 134 | + val = self.get_attr(name, expected_type, default=None) |
| 135 | + if val is None: |
| 136 | + raise ConversionError(f"Missing required attribute {name}") |
| 137 | + return val |
| 138 | + |
| 139 | + def generate_input_from_attr( |
| 140 | + self, input_index: int, attr_name: str, attr_type: str |
| 141 | + ): |
| 142 | + """ |
| 143 | + Generate a constant operator input from an attribute, if it exists. |
| 144 | +
|
| 145 | + Some operator inputs changed from attributes to inputs in different ONNX |
| 146 | + releases. This function checks to see if an operator has an attribute |
| 147 | + and synthesizes a constant input. |
| 148 | +
|
| 149 | + :param input_index: Index of the input that the attribute corresponds to |
| 150 | + :param attr_name: Name of the attribute |
| 151 | + :param attr_type: Expected type of the attribute |
| 152 | + """ |
| 153 | + |
| 154 | + attr_val = self.get_attr(attr_name, attr_type, default=None) |
| 155 | + if attr_val is None: |
| 156 | + return |
| 157 | + |
| 158 | + if input_index < len(self.input_indexes): |
| 159 | + raise ConversionError( |
| 160 | + f'Operator has both an attribute "{attr_name}" and corresponding input at index {input_index}' |
| 161 | + ) |
| 162 | + |
| 163 | + shape: list[int] |
| 164 | + match attr_type: |
| 165 | + case "int": |
| 166 | + shape = [] |
| 167 | + data = np.array(attr_val).astype(np.int32) |
| 168 | + |
| 169 | + case "float": |
| 170 | + shape = [] |
| 171 | + data = np.array(attr_val).astype(np.float32) |
| 172 | + |
| 173 | + case "ints": |
| 174 | + shape = [len(attr_val)] |
| 175 | + data = np.array([attr_val]).astype(np.int32) |
| 176 | + case _: |
| 177 | + raise ConversionError( |
| 178 | + f'Unable to generate input from "{attr_name}" attribute of type "{attr_type}"' |
| 179 | + ) |
| 180 | + |
| 181 | + generated_name = self.onnx_op.name + ":rten-" + attr_name |
| 182 | + const_node = ConstantNode(generated_name, shape, data) |
| 183 | + input_id = self.add_node(const_node) |
| 184 | + |
| 185 | + while len(self.input_indexes) < input_index + 1: |
| 186 | + self.input_indexes.append(None) |
| 187 | + self.input_indexes[input_index] = input_id |
| 188 | + |
| 189 | + def check_attr( |
| 190 | + self, |
| 191 | + name: str, |
| 192 | + expected_type, |
| 193 | + default, |
| 194 | + on_mismatch: Literal["raise", "warn"] = "raise", |
| 195 | + ): |
| 196 | + """ |
| 197 | + Check if an operator has an unsupported non-default value for an attribute. |
| 198 | +
|
| 199 | + If `default` is a tuple, it specifies a set of acceptable defaults. |
| 200 | +
|
| 201 | + :param name: The name of the operator attribute |
| 202 | + :param default: The value which is equivalent to the default behavior |
| 203 | + :param on_mismatch: |
| 204 | + Whether a mismatch should be treated as a fatal error in model |
| 205 | + conversion or merely warn that this might cause a problem. |
| 206 | + """ |
| 207 | + |
| 208 | + val = self.get_attr(name, expected_type, None) |
| 209 | + if val is None: |
| 210 | + return |
| 211 | + |
| 212 | + if not isinstance(default, tuple): |
| 213 | + default = (default,) |
| 214 | + if val not in default: |
| 215 | + msg = f"Unsupported value {val} for attribute {name}. Default is {default}" |
| 216 | + if on_mismatch == "raise": |
| 217 | + raise ConversionError(msg) |
| 218 | + else: |
| 219 | + warn_once(msg) |
| 220 | + |
| 221 | + def unhandled_attrs(self) -> list[onnx.AttributeProto]: |
| 222 | + """Return a list of attributes which have not been read.""" |
| 223 | + return [ |
| 224 | + attr |
| 225 | + for attr in self.onnx_op.attribute |
| 226 | + if attr.name not in self._handled_attrs |
| 227 | + ] |
| 228 | + |
| 229 | + |
| 230 | +def _snake_case_to_pascal_case(s: str) -> str: |
| 231 | + """Transform a snake_case string to PascalCase.""" |
| 232 | + return "".join([word[0].upper() + word[1:] for word in s.split("_")]) |
| 233 | + |
| 234 | + |
| 235 | +# Mapping of ONNX attribute types to the field on an AttributeProto which |
| 236 | +# contains the value. Note that if you try to access the wrong field on an |
| 237 | +# AttributeProto, you get a default value instead of an exception. |
| 238 | +_value_fields = { |
| 239 | + onnx.AttributeProto.FLOAT: "f", |
| 240 | + onnx.AttributeProto.GRAPH: "g", |
| 241 | + onnx.AttributeProto.INT: "i", |
| 242 | + onnx.AttributeProto.INTS: "ints", |
| 243 | + onnx.AttributeProto.STRING: "s", |
| 244 | + onnx.AttributeProto.TENSOR: "t", |
| 245 | +} |
0 commit comments