Skip to content

Commit bf7e352

Browse files
committed
Update loop rolling to custom op registration changes
1 parent 4988deb commit bf7e352

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/finn/transformation/fpgadataflow/loop_rolling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from qonnx.core.modelwrapper import ModelWrapper
99
from qonnx.transformation.base import Transformation
1010
from qonnx.transformation.fold_constants import FoldConstants
11+
from qonnx.custom_op.registry import is_custom_op
1112
from typing import List, Tuple
1213

1314
from finn.util import onnxscript_helpers as osh
1415

15-
1616
def get_constant_from_value(value):
1717
"""
1818
Get the constant value of a tensor.
@@ -567,10 +567,12 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]:
567567
from finn.util.basic import getHWCustomOp
568568

569569
for loop_node in model_wrapper.get_nodes_by_op_type("FINNLoop"):
570-
loop_body_graph = get_by_name(loop_node.attribute, "body").g
571-
for node in loop_body_graph.node:
570+
loop_body = getHWCustomOp(loop_node).get_nodeattr("body")
571+
for node in loop_body.graph.node:
572+
if not is_custom_op(node.domain):
573+
continue
572574
try:
573-
inst = getHWCustomOp(node, model_wrapper)
575+
inst = getHWCustomOp(node)
574576
inst.adapt_for_loop_body(LoopBody.signature)
575577
except (KeyError, AttributeError):
576578
# Operator doesn't need adaptation or doesn't support it

src/finn/util/onnxscript_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RewriterContext,
1818
pattern_builder,
1919
)
20-
from qonnx.util.basic import is_finn_op
20+
from qonnx.custom_op.registry import is_custom_op
2121
from typing import List, Optional
2222

2323

@@ -323,10 +323,10 @@ def is_fpgadataflow_onnxir_node(node):
323323
"""Returns True if given node is fpgadataflow node. Otherwise False."""
324324
is_node = False
325325
if node is not None:
326-
if is_finn_op(node.domain):
326+
if is_custom_op(node.domain):
327327
if "backend" in node.attributes:
328328
backend_value = node.attributes["backend"].as_string()
329-
if backend_value == "fpgadataflow":
329+
if backend_value in ["fpgadataflow", "hls", "rtl"]:
330330
is_node = True
331331

332332
return is_node

0 commit comments

Comments
 (0)