Skip to content

Commit cb502d1

Browse files
authored
Merge pull request #1593 from ollycassidy13/feature/applyconfig-error-reporting
Report Invalid Custom Op Configs
2 parents 9fe9922 + 11f7e8c commit cb502d1

2 files changed

Lines changed: 60 additions & 9 deletions

File tree

src/finn/transformation/general.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
# Protobuf onnx graph node type
1919
from onnx import AttributeProto, NodeProto, mapping # noqa
20-
from qonnx.custom_op.registry import getCustomOp
20+
from qonnx.custom_op.registry import getCustomOp, is_custom_op
2121
from qonnx.transformation.base import Transformation
2222

2323

@@ -42,6 +42,7 @@ def __init__(self, config, node_filter=lambda x: True):
4242
self.node_filter = node_filter
4343
self.used_configurations = ["Defaults"]
4444
self.missing_configurations = []
45+
self.ignored_non_custom_configurations = []
4546

4647
def configure_network(self, graph_proto, model_config, subgraph_hier):
4748
# Configure network - graph_proto can be a GraphProto or ModelWrapper
@@ -66,10 +67,7 @@ def configure_network(self, graph_proto, model_config, subgraph_hier):
6667
self.missing_configurations += [node.name]
6768
node_config = {}
6869

69-
if node_config:
70-
self.used_configurations += [config_key]
71-
72-
try:
70+
if is_custom_op(node.domain, node.op_type):
7371
inst = getCustomOp(node)
7472

7573
if "Defaults" in model_config.keys():
@@ -92,9 +90,11 @@ def configure_network(self, graph_proto, model_config, subgraph_hier):
9290
# set node attributes from specified configuration
9391
for attr_name, value in node_config.items():
9492
inst.set_nodeattr(attr_name, value)
95-
except Exception:
96-
# Node is not a custom op, but it might have subgraphs
97-
pass
93+
94+
if node_config:
95+
self.used_configurations += [config_key]
96+
elif node_config:
97+
self.ignored_non_custom_configurations += [(config_key, node.op_type)]
9898

9999
# Recursively handle nested subgraphs
100100
for attr in node.attribute:
@@ -125,9 +125,29 @@ def apply(self, model):
125125
if len(unique_missing) > 0:
126126
warnings.warn("\nNo HW configuration for nodes: " + ", ".join(unique_missing))
127127

128+
# Check for matched configs that couldn't be applied because they were
129+
# specified for standard ONNX nodes instead of custom ops.
130+
unique_non_custom = list(dict.fromkeys(self.ignored_non_custom_configurations))
131+
if len(unique_non_custom) > 0:
132+
formatted_non_custom = [
133+
"{} ({})".format(config_key, op_type) for config_key, op_type in unique_non_custom
134+
]
135+
warnings.warn(
136+
"\nHW configurations for non-custom nodes were ignored: "
137+
+ ", ".join(formatted_non_custom)
138+
+ ". Configs can only be applied to custom ops."
139+
)
140+
128141
# Check for unused configs (top-level configs that weren't applied)
142+
ignored_configurations = [
143+
config_key for config_key, _ in self.ignored_non_custom_configurations
144+
]
129145
unused_configs = [
130-
x for x in model_config if x not in self.used_configurations and x != "Defaults"
146+
x
147+
for x in model_config
148+
if x not in self.used_configurations
149+
and x not in ignored_configurations
150+
and x != "Defaults"
131151
]
132152
if len(unused_configs) > 0:
133153
warnings.warn("\nUnused HW configurations: " + ", ".join(unused_configs))

tests/util/test_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,34 @@ def modify_all_im2col_nodes(graph_proto):
193193

194194
if os.path.exists(config_json_file):
195195
os.remove(config_json_file)
196+
197+
198+
@pytest.mark.util
199+
def test_apply_config_reports_invalid_custom_op_attr():
200+
"""Test invalid custom-op configs fail instead of being silently ignored."""
201+
202+
model, _ = make_im2col_test_model()
203+
im2col_node = model.graph.node[0]
204+
205+
with pytest.raises(Exception):
206+
model.transform(ApplyConfig({im2col_node.name: {"not_an_im2col_attr": 1}}))
207+
208+
209+
@pytest.mark.util
210+
def test_apply_config_warns_for_non_custom_op_config():
211+
"""Test explicit configs for standard ONNX nodes are reported."""
212+
213+
model, _ = make_im2col_test_model()
214+
if_node = model.graph.node[1]
215+
config = {if_node.name: {"kernel_size": [1, 1]}}
216+
217+
def node_filter(node):
218+
return node.name == if_node.name
219+
220+
with pytest.warns(UserWarning) as warn_records:
221+
model.transform(ApplyConfig(config, node_filter=node_filter))
222+
223+
warning_messages = [str(record.message) for record in warn_records]
224+
assert any("Configs can only be applied to custom ops" in msg for msg in warning_messages)
225+
assert any("{} ({})".format(if_node.name, if_node.op_type) in msg for msg in warning_messages)
226+
assert not any("Unused HW configurations" in msg for msg in warning_messages)

0 commit comments

Comments
 (0)