From 44e66e5a1fea66010c05371fbe15379c99f7c0cb Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Tue, 28 Apr 2026 17:31:31 +0100 Subject: [PATCH 1/7] AddCLSToken initial commit --- .../finn.custom_op.fpgadataflow.rst | 8 + .../finn.custom_op.fpgadataflow.rtl.rst | 8 + finn-rtllib/addclstoken/hdl/addclstoken.sv | 150 +++++++++ .../addclstoken/hdl/addclstoken_template.v | 81 +++++ src/finn/builder/build_dataflow_steps.py | 2 + src/finn/custom_op/fpgadataflow/__init__.py | 2 + .../custom_op/fpgadataflow/addclstoken.py | 171 ++++++++++ .../custom_op/fpgadataflow/rtl/__init__.py | 2 + .../fpgadataflow/rtl/addclstoken_rtl.py | 211 ++++++++++++ .../fpgadataflow/convert_to_hw_layers.py | 79 +++++ .../fpgadataflow/specialize_layers.py | 1 + src/finn/util/vivado.py | 44 ++- .../test_fpgadataflow_addclstoken.py | 299 ++++++++++++++++++ 13 files changed, 1050 insertions(+), 8 deletions(-) create mode 100644 finn-rtllib/addclstoken/hdl/addclstoken.sv create mode 100644 finn-rtllib/addclstoken/hdl/addclstoken_template.v create mode 100644 src/finn/custom_op/fpgadataflow/addclstoken.py create mode 100644 src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py create mode 100644 tests/fpgadataflow/test_fpgadataflow_addclstoken.py diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst index 25aafc324e..0688664bfe 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst @@ -39,6 +39,14 @@ RTLBackend :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.addclstoken +----------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.addclstoken + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.addstreams ---------------------------------------- diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst index 346eddb073..859a789f2f 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst @@ -5,6 +5,14 @@ Custom Op - fpgadataflow.rtl RTL Custom Op Nodes =================== +finn.custom\_op.fpgadataflow.rtl.addclstoken\_rtl +-------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.addclstoken_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.convolutioninputgenerator\_rtl ------------------------------------------------------------ diff --git a/finn-rtllib/addclstoken/hdl/addclstoken.sv b/finn-rtllib/addclstoken/hdl/addclstoken.sv new file mode 100644 index 0000000000..768b2a9a06 --- /dev/null +++ b/finn-rtllib/addclstoken/hdl/addclstoken.sv @@ -0,0 +1,150 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module addclstoken #( + parameter int unsigned NUM_TOKENS = 196, + parameter int unsigned NUM_CHANNELS = 192, + parameter int unsigned SIMD = 1, + parameter int unsigned ELEM_WIDTH = 8, + parameter int unsigned PAD_TOKENS = 0 +)( + input logic clk, + input logic rst, + + output logic irdy, + input logic ivld, + input logic [SIMD*ELEM_WIDTH-1:0] idat, + + input logic ordy, + output logic ovld, + output logic [SIMD*ELEM_WIDTH-1:0] odat, + + input logic [NUM_CHANNELS*ELEM_WIDTH-1:0] cls_data +); + + localparam int unsigned FOLD_WIDTH = SIMD * ELEM_WIDTH; + localparam int unsigned FOLDS_PER_TOKEN = NUM_CHANNELS / SIMD; + localparam int unsigned TOTAL_INPUT_FOLDS = NUM_TOKENS * FOLDS_PER_TOKEN; + localparam int unsigned TOTAL_PAD_FOLDS = PAD_TOKENS * FOLDS_PER_TOKEN; + localparam int unsigned MAX_PHASE_FOLDS = + (TOTAL_INPUT_FOLDS > FOLDS_PER_TOKEN) ? + ((TOTAL_INPUT_FOLDS > TOTAL_PAD_FOLDS) ? + TOTAL_INPUT_FOLDS : TOTAL_PAD_FOLDS) : + ((FOLDS_PER_TOKEN > TOTAL_PAD_FOLDS) ? + FOLDS_PER_TOKEN : TOTAL_PAD_FOLDS); + localparam int unsigned CNT_WIDTH = (MAX_PHASE_FOLDS <= 1) ? 1 : $clog2(MAX_PHASE_FOLDS); + + typedef enum logic [1:0] { + EMIT_CLS, + PASSTHROUGH, + EMIT_PAD + } state_t; + + state_t state; + state_t next_state; + logic [CNT_WIDTH-1:0] fold_cnt; + logic fold_cnt_last; + logic out_transfer; + + logic [CNT_WIDTH-1:0] cls_fold_cnt; + logic [FOLD_WIDTH-1:0] cls_fold; + + assign cls_fold_cnt = (int'(fold_cnt) < FOLDS_PER_TOKEN) ? fold_cnt : '0; + assign cls_fold = cls_data[cls_fold_cnt * FOLD_WIDTH +: FOLD_WIDTH]; + assign out_transfer = ovld & ordy; + + always_comb begin + unique case (state) + EMIT_CLS: fold_cnt_last = (int'(fold_cnt) == FOLDS_PER_TOKEN - 1); + PASSTHROUGH: fold_cnt_last = (int'(fold_cnt) == TOTAL_INPUT_FOLDS - 1); + EMIT_PAD: fold_cnt_last = (int'(fold_cnt) == TOTAL_PAD_FOLDS - 1); + default: fold_cnt_last = 1'b1; + endcase + end + + always_comb begin + irdy = 1'b0; + ovld = 1'b0; + odat = '0; + + unique case (state) + EMIT_CLS: begin + ovld = 1'b1; + odat = cls_fold; + end + PASSTHROUGH: begin + irdy = ordy; + ovld = ivld; + odat = idat; + end + EMIT_PAD: begin + ovld = 1'b1; + end + default: begin + end + endcase + end + + always_comb begin + next_state = state; + if (out_transfer && fold_cnt_last) begin + unique case (state) + EMIT_CLS: begin + next_state = PASSTHROUGH; + end + PASSTHROUGH: begin + next_state = (PAD_TOKENS == 0) ? EMIT_CLS : EMIT_PAD; + end + EMIT_PAD: begin + next_state = EMIT_CLS; + end + default: begin + next_state = EMIT_CLS; + end + endcase + end + end + + always_ff @(posedge clk) begin + if (rst) begin + state <= EMIT_CLS; + fold_cnt <= '0; + end else if (out_transfer) begin + if (fold_cnt_last) begin + state <= next_state; + fold_cnt <= '0; + end else begin + fold_cnt <= fold_cnt + 1'b1; + end + end + end + +endmodule diff --git a/finn-rtllib/addclstoken/hdl/addclstoken_template.v b/finn-rtllib/addclstoken/hdl/addclstoken_template.v new file mode 100644 index 0000000000..57bba51c8d --- /dev/null +++ b/finn-rtllib/addclstoken/hdl/addclstoken_template.v @@ -0,0 +1,81 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter FOLD_WIDTH = $FOLD_WIDTH$, + parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [AXI_WIDTH-1:0] in0_V_TDATA, + + input out_V_TREADY, + output out_V_TVALID, + output [AXI_WIDTH-1:0] out_V_TDATA +); + + localparam [$CLS_WIDTH$-1:0] CLS_DATA = $CLS_DATA$; + + wire [FOLD_WIDTH-1:0] core_out; + + assign out_V_TDATA[FOLD_WIDTH-1:0] = core_out; + + generate + if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata + assign out_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}}; + end + endgenerate + + addclstoken #( + .NUM_TOKENS($NUM_TOKENS$), + .NUM_CHANNELS($NUM_CHANNELS$), + .SIMD($SIMD$), + .ELEM_WIDTH($ELEM_WIDTH$), + .PAD_TOKENS($PAD_TOKENS$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .irdy(in0_V_TREADY), + .ivld(in0_V_TVALID), + .idat(in0_V_TDATA[FOLD_WIDTH-1:0]), + .ordy(out_V_TREADY), + .ovld(out_V_TVALID), + .odat(core_out), + .cls_data(CLS_DATA) + ); + +endmodule diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index ecc1d28c53..8c2f79c1d6 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -348,6 +348,8 @@ def step_convert_to_hw(model: ModelWrapper, cfg: DataflowBuildConfig): model = model.transform(to_hw.InferQuantizedMatrixVectorActivation()) # TopK to LabelSelect model = model.transform(to_hw.InferLabelSelectLayer()) + # sequence CLS token insertion + model = model.transform(to_hw.InferAddCLSTokenLayer()) # input quantization (if any) as standalone threshold model = model.transform(to_hw.InferThresholdingLayer()) # needed for convolutions -- TODO always exec? diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index aed2ab7fe1..c6e8dd1dcc 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -27,6 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from finn.custom_op.fpgadataflow.addclstoken import AddCLSToken from finn.custom_op.fpgadataflow.addstreams import AddStreams from finn.custom_op.fpgadataflow.channelwise_op import ChannelwiseOp from finn.custom_op.fpgadataflow.concat import StreamingConcat @@ -66,6 +67,7 @@ custom_op["StreamingDataflowPartition"] = StreamingDataflowPartition custom_op["AddStreams"] = AddStreams +custom_op["AddCLSToken"] = AddCLSToken custom_op["ChannelwiseOp"] = ChannelwiseOp custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator custom_op["DownSampler"] = DownSampler diff --git a/src/finn/custom_op/fpgadataflow/addclstoken.py b/src/finn/custom_op/fpgadataflow/addclstoken.py new file mode 100644 index 0000000000..35eae4bb29 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/addclstoken.py @@ -0,0 +1,171 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class AddCLSToken(HWCustomOp): + """Prepend a learned class token to a sequence of patch tokens.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "NumTokens": ("i", True, 0), + "NumChannels": ("i", True, 0), + "PadTokens": ("i", False, 0), + "SIMD": ("i", False, 1), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + } + ) + return my_attrs + + def get_normal_input_shape(self, ind=0): + num_channels = self.get_nodeattr("NumChannels") + if ind == 0: + return (1, self.get_nodeattr("NumTokens"), num_channels) + elif ind == 1: + return (1, 1, num_channels) + else: + raise Exception("AddCLSToken only has two inputs") + + def get_folded_input_shape(self, ind=0): + normal_shape = self.get_normal_input_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def get_normal_output_shape(self, ind=0): + num_tokens = self.get_nodeattr("NumTokens") + num_channels = self.get_nodeattr("NumChannels") + pad_tokens = self.get_nodeattr("PadTokens") + return (1, num_tokens + 1 + pad_tokens, num_channels) + + def get_folded_output_shape(self, ind=0): + normal_shape = self.get_normal_output_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape(0) + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpected input shape for patch tokens." + + exp_wshape = self.get_normal_input_shape(1) + wshape = tuple(model.get_tensor_shape(self.onnx_node.input[1])) + assert wshape == exp_wshape, "Unexpected input shape for CLS token." + + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype() + + idt = model.get_tensor_datatype(node.input[0]) + if idt is None: + idt = attr_idt + if idt is None: + raise Exception("AddCLSToken input datatype is not set") + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + cls_dt = model.get_tensor_datatype(node.input[1]) + if cls_dt is None: + model.set_tensor_datatype(node.input[1], idt) + else: + assert cls_dt == idt, "CLS token datatype must match input datatype." + + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(ind) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind != 0: + return 0 + return self.get_input_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_outstream_width(self, ind=0): + return self.get_output_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def execute_node(self, context, graph): + node = self.onnx_node + patches = context[node.input[0]] + cls_token = context[node.input[1]] + + result = np.concatenate([cls_token, patches], axis=1) + pad_tokens = self.get_nodeattr("PadTokens") + if pad_tokens > 0: + pad_shape = (1, pad_tokens, self.get_nodeattr("NumChannels")) + padding = np.zeros(pad_shape, dtype=result.dtype) + result = np.concatenate([result, padding], axis=1) + + oshape = self.get_normal_output_shape() + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape(oshape) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return int(128 + self.get_nodeattr("NumChannels")) + + def get_op_and_param_counts(self): + return {"param_cls_token": int(self.get_nodeattr("NumChannels"))} diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index 06067a4fca..26ed73e382 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from finn.custom_op.fpgadataflow.rtl.addclstoken_rtl import AddCLSToken_rtl from finn.custom_op.fpgadataflow.rtl.convolutioninputgenerator_rtl import ( ConvolutionInputGenerator_rtl, ) @@ -42,6 +43,7 @@ # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure +custom_op["AddCLSToken_rtl"] = AddCLSToken_rtl custom_op["ConvolutionInputGenerator_rtl"] = ConvolutionInputGenerator_rtl custom_op["FMPadding_rtl"] = FMPadding_rtl custom_op["StreamingDataWidthConverter_rtl"] = StreamingDataWidthConverter_rtl diff --git a/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py new file mode 100644 index 0000000000..53e6318f49 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py @@ -0,0 +1,211 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import os +import shutil +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.addclstoken import AddCLSToken +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.util.basic import get_rtlsim_trace_depth, make_build_dir +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + +try: + from pyverilator import PyVerilator +except ModuleNotFoundError: + PyVerilator = None + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/addclstoken/hdl" + + +class AddCLSToken_rtl(AddCLSToken, RTLBackend): + """RTL implementation of AddCLSToken.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(AddCLSToken.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def _pack_value(self, value, dtype): + bitwidth = dtype.bitwidth() + if dtype == DataType["BIPOLAR"]: + int_value = int((value + 1) // 2) + else: + if dtype.is_fixed_point(): + value = value / dtype.scale_factor() + int_value = int(value) + if int_value < 0: + int_value += 1 << bitwidth + return int_value & ((1 << bitwidth) - 1) + + def _pack_cls_token(self, model): + dtype = self.get_input_datatype() + bitwidth = dtype.bitwidth() + num_channels = self.get_nodeattr("NumChannels") + cls_token = model.get_initializer(self.onnx_node.input[1]) + if cls_token is None: + raise Exception("AddCLSToken RTL generation requires a constant CLS token input.") + + cls_token = np.asarray(cls_token, dtype=np.float32) + assert cls_token.shape == self.get_normal_input_shape( + 1 + ), "CLS token shape does not match AddCLSToken attributes." + assert np.vectorize(dtype.allowed)(cls_token).all(), ( + "CLS token values cannot be represented with %s" % dtype.name + ) + packed = 0 + for i, value in enumerate(cls_token.flatten()): + packed |= self._pack_value(value, dtype) << (i * bitwidth) + return "%d'h%x" % (num_channels * bitwidth, packed) + + def generate_hdl(self, model, fpgapart, clk): + simd = self.get_nodeattr("SIMD") + num_channels = self.get_nodeattr("NumChannels") + assert num_channels % simd == 0, "SIMD must divide NumChannels" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/addclstoken_template.v" + with open(template_path, "r") as f: + template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype().bitwidth() + fold_width = elem_width * simd + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "NUM_TOKENS": self.get_nodeattr("NumTokens"), + "NUM_CHANNELS": num_channels, + "SIMD": simd, + "ELEM_WIDTH": elem_width, + "PAD_TOKENS": self.get_nodeattr("PadTokens"), + "FOLD_WIDTH": fold_width, + "CLS_WIDTH": num_channels * elem_width, + "CLS_DATA": self._pack_cls_token(model), + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + shutil.copy(rtlsrc + "/addclstoken.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def prepare_rtlsim(self): + if PyVerilator is None: + raise ImportError("Installation of PyVerilator is required.") + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + verilog_files = [ + "addclstoken.sv", + self.get_nodeattr("gen_top_module") + ".v", + ] + sim = PyVerilator.build( + verilog_files, + build_dir=make_build_dir("pyverilator_" + self.onnx_node.name + "_"), + verilog_path=[code_gen_dir], + trace_depth=get_rtlsim_trace_depth(), + top_module_name=self.get_nodeattr("gen_top_module"), + ) + self.set_nodeattr("rtlsim_so", sim.lib._name) + return sim + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = [ + "addclstoken.sv", + self.get_nodeattr("gen_top_module") + ".v", + ] + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + AddCLSToken.execute_node(self, context, graph) + elif mode == "rtlsim": + node = self.onnx_node + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + exp_ishape = self.get_normal_input_shape(0) + exp_oshape = self.get_normal_output_shape() + + inp = context[node.input[0]] + assert str(inp.dtype) == "float32", "Input datatype is not float32" + assert inp.shape == exp_ishape, "Input shape does not match expected shape." + + folded_ishape = self.get_folded_input_shape(0) + np.save(os.path.join(code_gen_dir, "input_0.npy"), inp.reshape(folded_ishape).copy()) + + sim = self.get_rtlsim() + export_idt = self.get_input_datatype() + rtlsim_inp = npy_to_rtlsim_input( + os.path.join(code_gen_dir, "input_0.npy"), + export_idt, + self.get_instream_width(), + ) + self.reset_rtlsim(sim) + self.toggle_clk(sim) + rtlsim_output = self.rtlsim(sim, rtlsim_inp) + + odt = self.get_output_datatype() + out_npy = rtlsim_output_to_npy( + rtlsim_output, + os.path.join(code_gen_dir, "output.npy"), + odt, + self.get_folded_output_shape(), + self.get_outstream_width(), + odt.bitwidth(), + ) + context[node.output[0]] = np.asarray(out_npy, dtype=np.float32).reshape(exp_oshape) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index e14181b140..e486b19ce4 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1196,6 +1196,85 @@ def apply(self, model): return (model, graph_modified) +class InferAddCLSTokenLayer(Transformation): + """Convert Concat([cls_token, patches], axis=1) into AddCLSToken.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Concat": + continue + + axis = get_by_name(node.attribute, "axis") + if axis is None or len(node.input) != 2: + continue + + cls_name = node.input[0] + patch_name = node.input[1] + cls_init = model.get_initializer(cls_name) + if cls_init is None or model.get_initializer(patch_name) is not None: + continue + + cls_shape = model.get_tensor_shape(cls_name) + if cls_shape is None: + cls_shape = list(cls_init.shape) + patch_shape = model.get_tensor_shape(patch_name) + if cls_shape is None or patch_shape is None: + continue + if any(x is None for x in list(cls_shape) + list(patch_shape)): + continue + + rank = len(patch_shape) + concat_axis = axis.i if axis.i >= 0 else axis.i + rank + if rank != 3 or concat_axis != 1: + continue + + if len(cls_shape) != 3 or cls_shape[0] != 1 or cls_shape[1] != 1: + continue + if patch_shape[0] != 1 or cls_shape[2] != patch_shape[2]: + continue + + out_shape = model.get_tensor_shape(node.output[0]) + exp_oshape = [1, patch_shape[1] + 1, patch_shape[2]] + if out_shape is not None and list(out_shape) != exp_oshape: + continue + + idt = model.get_tensor_datatype(patch_name) + if idt is None or not idt.is_integer(): + continue + cls_dt = model.get_tensor_datatype(cls_name) + if cls_dt is None: + model.set_tensor_datatype(cls_name, idt) + elif cls_dt != idt: + continue + + new_node = helper.make_node( + "AddCLSToken", + [patch_name, cls_name], + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="AddCLSToken_" + node.name, + NumTokens=int(patch_shape[1]), + NumChannels=int(patch_shape[2]), + PadTokens=0, + SIMD=1, + inputDataType=idt.name, + outputDataType=idt.name, + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferStreamingEltwise(Transformation): """Convert eltwise Sub or Sub -> Abs to StreamingEltwise layer with SubEltwise or AbsDiffEltwise op.""" diff --git a/src/finn/transformation/fpgadataflow/specialize_layers.py b/src/finn/transformation/fpgadataflow/specialize_layers.py index dbcadd1df5..ac26028106 100644 --- a/src/finn/transformation/fpgadataflow/specialize_layers.py +++ b/src/finn/transformation/fpgadataflow/specialize_layers.py @@ -311,6 +311,7 @@ def apply(self, model): node.input, node.output, domain="finn.custom_op.fpgadataflow." + impl_style, + name=node.name, ) # add all attributes for attribute in node.attribute: diff --git a/src/finn/util/vivado.py b/src/finn/util/vivado.py index bc8ca40d88..14cdba54df 100644 --- a/src/finn/util/vivado.py +++ b/src/finn/util/vivado.py @@ -27,10 +27,27 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +import re from finn.util.basic import launch_process_helper, which +def _extract_util_from_report(vivado_proj_folder, row_name): + """Extract the Used column for a row in Vivado's utilization report.""" + + log_path = os.path.join(vivado_proj_folder, "vivado.log") + if not os.path.isfile(log_path): + return None + + row_pattern = re.compile(r"^\|\s*%s\s*\|\s*([0-9.]+)\s*\|" % re.escape(row_name)) + with open(log_path, "r") as f: + for line in f: + match = row_pattern.match(line) + if match is not None: + return float(match.group(1)) + return None + + def out_of_context_synth( verilog_dir, top_name, @@ -48,16 +65,17 @@ def out_of_context_synth( raise Exception("vivado is not in PATH, ensure settings64.sh is sourced.") omx_path = os.environ["OHMYXILINX"] script = "vivadocompile.sh" - # vivadocompile.sh - call_omx = "zsh %s/%s %s %s %s %f" % ( - omx_path, - script, + # vivadocompile.sh + # + call_omx = [ + "zsh", + os.path.join(omx_path, script), top_name, + "", clk_name, fpga_part, - float(clk_period_ns), - ) - call_omx = call_omx.split() + "%f" % float(clk_period_ns), + ] launch_process_helper(call_omx, proc_env=os.environ.copy(), cwd=verilog_dir) vivado_proj_folder = "%s/results_%s" % (verilog_dir, top_name) @@ -67,13 +85,23 @@ def out_of_context_synth( res_data = myfile.read().split("\n") ret = {} ret["vivado_proj_folder"] = vivado_proj_folder + util_report_rows = { + "DSP": "DSPs", + } for res_line in res_data: res_fields = res_line.split("=") print(res_fields) try: ret[res_fields[0]] = float(res_fields[1]) except ValueError: - ret[res_fields[0]] = 0 + util_value = None + if res_fields[0] in util_report_rows: + util_value = _extract_util_from_report( + vivado_proj_folder, util_report_rows[res_fields[0]] + ) + if util_value is None: + raise + ret[res_fields[0]] = util_value except IndexError: ret[res_fields[0]] = 0 if ret["WNS"] == 0: diff --git a/tests/fpgadataflow/test_fpgadataflow_addclstoken.py b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py new file mode 100644 index 0000000000..07caadc99f --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py @@ -0,0 +1,299 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +import os +from functools import partial +from onnx import TensorProto, helper, numpy_helper +from pathlib import Path +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferAddCLSTokenLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _make_graph(nodes, output_shape, cls_values, finn_dtype=DataType["INT8"]): + patch_shape = [1, 3, 4] + patches = helper.make_tensor_value_info("patches", TensorProto.FLOAT, patch_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape) + cls_init = numpy_helper.from_array(cls_values.astype(np.float32), name="cls") + graph = helper.make_graph(nodes, "addclstoken_test", [patches], [output], [cls_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + for tensor_name in ["patches", "cls", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_concat_model(): + cls_values = np.asarray([[[1, -2, 3, -4]]], dtype=np.float32) + concat = helper.make_node( + "Concat", + ["cls", "patches"], + ["out"], + axis=1, + name="concat_cls", + ) + model = _make_graph([concat], [1, 4, 4], cls_values) + return model, cls_values + + +def _make_addclstoken_model( + pad_tokens=0, + simd=1, + finn_dtype=DataType["INT8"], + cls_values=None, +): + if cls_values is None: + cls_values = np.asarray([[[1, -2, 3, -4]]], dtype=np.float32) + addcls = helper.make_node( + "AddCLSToken", + ["patches", "cls"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="AddCLSToken_0", + NumTokens=3, + NumChannels=4, + PadTokens=pad_tokens, + SIMD=simd, + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + ) + model = _make_graph([addcls], [1, 4 + pad_tokens, 4], cls_values, finn_dtype) + return model, cls_values + + +def _prepare_addclstoken_stitched_ip_model(simd=1, pad_tokens=0): + model, cls_values = _make_addclstoken_model(pad_tokens=pad_tokens, simd=simd) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model, cls_values + + +@pytest.mark.fpgadataflow +def test_convert_concat_to_addclstoken(): + model, cls_values = _make_concat_model() + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected = np.concatenate([cls_values, patches], axis=1) + + ret = execute_onnx(model, {"patches": patches}) + assert (ret["out"] == expected).all() + + model = model.transform(InferAddCLSTokenLayer()) + node = model.graph.node[0] + assert node.op_type == "AddCLSToken" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["patches", "cls"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 4, 4) + assert inst.get_exp_cycles() == 16 + + ret = execute_onnx(model, {"patches": patches}) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers("xc7z020clg400-1")) + assert model.graph.node[0].op_type == "AddCLSToken_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + assert model.graph.node[0].name == "AddCLSToken_concat_cls" + + +@pytest.mark.fpgadataflow +def test_addclstoken_python_execution_with_padding(): + model, cls_values = _make_addclstoken_model(pad_tokens=2) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected = np.concatenate( + [cls_values, patches, np.zeros((1, 2, 4), dtype=np.float32)], + axis=1, + ) + + ret = execute_onnx(model, {"patches": patches}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,cls_values,expected_cls_data", + [ + (DataType["INT8"], np.asarray([[[1, -2, 3, -4]]], dtype=np.float32), "32'hfc03fe01"), + (DataType["UINT4"], np.asarray([[[1, 2, 3, 4]]], dtype=np.float32), "16'h4321"), + (DataType["BIPOLAR"], np.asarray([[[1, -1, 1, -1]]], dtype=np.float32), "4'h5"), + ], +) +def test_addclstoken_rtl_codegen(tmp_path, monkeypatch, finn_dtype, cls_values, expected_cls_data): + if "FINN_ROOT" not in os.environ: + monkeypatch.setenv("FINN_ROOT", str(Path(__file__).resolve().parents[2])) + + model, _ = _make_addclstoken_model( + pad_tokens=1, + simd=2, + finn_dtype=finn_dtype, + cls_values=cls_values, + ) + model = model.transform(SpecializeLayers("xc7z020clg400-1")) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, "xc7z020clg400-1", 10) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == "AddCLSToken_0" + wrapper = tmp_path / (topname + ".v") + core = tmp_path / "addclstoken.sv" + assert wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + assert "parameter FOLD_WIDTH = %d" % (2 * finn_dtype.bitwidth()) in wrapper_text + assert ".SIMD(2)" in wrapper_text + assert ".PAD_TOKENS(1)" in wrapper_text + assert "CLS_DATA = %s" % expected_cls_data in wrapper_text + assert "= '0" not in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("addclstoken.sv" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_addclstoken_resource_estimation(): + model, _ = _make_addclstoken_model(pad_tokens=1, simd=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 132, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,pad_tokens", [(1, 0), (2, 1)]) +def test_addclstoken_rtlsim(simd, pad_tokens): + model, cls_values = _make_addclstoken_model(pad_tokens=pad_tokens, simd=simd) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected_values = [cls_values, patches] + if pad_tokens > 0: + expected_values.append(np.zeros((1, pad_tokens, 4), dtype=np.float32)) + expected = np.concatenate(expected_values, axis=1) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"patches": patches}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("AddCLSToken_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,pad_tokens", [(1, 0), (2, 1)]) +def test_addclstoken_stitched_ip_rtlsim(simd, pad_tokens): + model, cls_values = _prepare_addclstoken_stitched_ip_model( + simd=simd, + pad_tokens=pad_tokens, + ) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected_values = [cls_values, patches] + if pad_tokens > 0: + expected_values.append(np.zeros((1, pad_tokens, 4), dtype=np.float32)) + expected = np.concatenate(expected_values, axis=1) + + model.set_metadata_prop("exec_mode", "rtlsim") + model.set_metadata_prop("extra_verilator_args", str(["-Wno-TIMESCALEMOD"])) + + ret = execute_onnx(model, {"patches": patches}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_addclstoken_stitched_ip_synth_ooc(): + model, _ = _prepare_addclstoken_stitched_ip_model(simd=2, pad_tokens=1) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0 From a3eac3899098d61f2a6ec08019f468ed7c2d8a7a Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Wed, 29 Apr 2026 09:57:21 +0100 Subject: [PATCH 2/7] header --- finn-rtllib/addclstoken/hdl/addclstoken.sv | 37 ++++++++-------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/finn-rtllib/addclstoken/hdl/addclstoken.sv b/finn-rtllib/addclstoken/hdl/addclstoken.sv index 768b2a9a06..d5bbdc2188 100644 --- a/finn-rtllib/addclstoken/hdl/addclstoken.sv +++ b/finn-rtllib/addclstoken/hdl/addclstoken.sv @@ -1,33 +1,22 @@ -/****************************************************************************** +/**************************************************************************** * Copyright (C) 2026, Advanced Micro Devices, Inc. * All rights reserved. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: + * SPDX-License-Identifier: BSD-3-Clause * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. + * @brief Insert a constant class token into a folded token stream. + * @author Oliver Cassidy * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. + * @description + * Prepends a learned class token, supplied through cls_data, to each + * input sequence of patch tokens. The class token and patch tokens are + * transferred as SIMD-wide folds of ELEM_WIDTH-bit elements. * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - *****************************************************************************/ + * Per sequence, the output stream is: + * 1. NUM_CHANNELS/SIMD folds from cls_data + * 2. NUM_TOKENS pass-through input tokens + * 3. PAD_TOKENS zero-valued tokens, when padding is enabled + ***************************************************************************/ module addclstoken #( parameter int unsigned NUM_TOKENS = 196, From 598b572c9e5f6593a2e4e9d7629f02c9a24abe83 Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Wed, 29 Apr 2026 20:52:23 +0100 Subject: [PATCH 3/7] select token initial commit --- .../finn.custom_op.fpgadataflow.rst | 8 + .../finn.custom_op.fpgadataflow.rtl.rst | 8 + finn-rtllib/selecttoken/hdl/select_token.sv | 82 ++++++ .../selecttoken/hdl/select_token_template.v | 78 +++++ src/finn/builder/build_dataflow_steps.py | 1 + src/finn/custom_op/fpgadataflow/__init__.py | 4 +- .../custom_op/fpgadataflow/rtl/__init__.py | 2 + .../fpgadataflow/rtl/selecttoken_rtl.py | 133 +++++++++ .../custom_op/fpgadataflow/selecttoken.py | 155 ++++++++++ .../fpgadataflow/convert_to_hw_layers.py | 82 +++++- .../test_fpgadataflow_selecttoken.py | 267 ++++++++++++++++++ 11 files changed, 818 insertions(+), 2 deletions(-) create mode 100644 finn-rtllib/selecttoken/hdl/select_token.sv create mode 100644 finn-rtllib/selecttoken/hdl/select_token_template.v create mode 100644 src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py create mode 100644 src/finn/custom_op/fpgadataflow/selecttoken.py create mode 100644 tests/fpgadataflow/test_fpgadataflow_selecttoken.py diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst index 26a2073e4a..6cefa2f15d 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst @@ -119,6 +119,14 @@ finn.custom\_op.fpgadataflow.labelselect :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken +----------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.selecttoken + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.lookup ----------------------------------------------- diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst index 859a789f2f..26834ec610 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst @@ -45,6 +45,14 @@ finn.custom\_op.fpgadataflow.streamingdatawidthconverter\_rtl :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken\_rtl +--------------------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.selecttoken_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.streamingfifo\_rtl ------------------------------------------------- diff --git a/finn-rtllib/selecttoken/hdl/select_token.sv b/finn-rtllib/selecttoken/hdl/select_token.sv new file mode 100644 index 0000000000..fb4c3df800 --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token.sv @@ -0,0 +1,82 @@ +/**************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief Select one token from a folded token stream. + * @author Oliver Cassidy + * + * @description + * Consumes NUM_TOKENS token vectors fold-by-fold. Folds belonging to + * TOKEN_INDEX are forwarded to the output stream; all other folds are + * consumed and discarded. + ***************************************************************************/ + +module select_token #( + parameter int unsigned NUM_TOKENS = 197, + parameter int unsigned NUM_CHANNELS = 192, + parameter int unsigned SIMD = 1, + parameter int unsigned ELEM_WIDTH = 8, + parameter int unsigned TOKEN_INDEX = 0 +)( + input logic clk, + input logic rst, + + output logic irdy, + input logic ivld, + input logic [SIMD*ELEM_WIDTH-1:0] idat, + + input logic ordy, + output logic ovld, + output logic [SIMD*ELEM_WIDTH-1:0] odat +); + + localparam int unsigned FOLDS_PER_TOKEN = NUM_CHANNELS / SIMD; + localparam int unsigned TOKEN_CNT_WIDTH = (NUM_TOKENS <= 1) ? 1 : $clog2(NUM_TOKENS); + localparam int unsigned FOLD_CNT_WIDTH = + (FOLDS_PER_TOKEN <= 1) ? 1 : $clog2(FOLDS_PER_TOKEN); + + logic [TOKEN_CNT_WIDTH-1:0] token_cnt; + logic [FOLD_CNT_WIDTH-1:0] fold_cnt; + logic is_selected; + logic in_transfer; + logic fold_cnt_last; + logic token_cnt_last; + + assign is_selected = (int'(token_cnt) == TOKEN_INDEX); + assign in_transfer = irdy & ivld; + assign fold_cnt_last = (int'(fold_cnt) == FOLDS_PER_TOKEN - 1); + assign token_cnt_last = (int'(token_cnt) == NUM_TOKENS - 1); + + always_comb begin + irdy = 1'b1; + ovld = 1'b0; + odat = '0; + + if (is_selected) begin + irdy = ordy; + ovld = ivld; + odat = idat; + end + end + + always_ff @(posedge clk) begin + if (rst) begin + token_cnt <= '0; + fold_cnt <= '0; + end else if (in_transfer) begin + if (fold_cnt_last) begin + fold_cnt <= '0; + if (token_cnt_last) begin + token_cnt <= '0; + end else begin + token_cnt <= token_cnt + 1'b1; + end + end else begin + fold_cnt <= fold_cnt + 1'b1; + end + end + end + +endmodule diff --git a/finn-rtllib/selecttoken/hdl/select_token_template.v b/finn-rtllib/selecttoken/hdl/select_token_template.v new file mode 100644 index 0000000000..566fa63ac5 --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token_template.v @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter FOLD_WIDTH = $FOLD_WIDTH$, + parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [AXI_WIDTH-1:0] in0_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [AXI_WIDTH-1:0] out0_V_TDATA +); + + wire [FOLD_WIDTH-1:0] core_out; + + assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out; + + generate + if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}}; + end + endgenerate + + select_token #( + .NUM_TOKENS($NUM_TOKENS$), + .NUM_CHANNELS($NUM_CHANNELS$), + .SIMD($SIMD$), + .ELEM_WIDTH($ELEM_WIDTH$), + .TOKEN_INDEX($TOKEN_INDEX$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .irdy(in0_V_TREADY), + .ivld(in0_V_TVALID), + .idat(in0_V_TDATA[FOLD_WIDTH-1:0]), + .ordy(out0_V_TREADY), + .ovld(out0_V_TVALID), + .odat(core_out) + ); + +endmodule diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index ca15a01c07..e84b997b2e 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -539,6 +539,7 @@ def apply_if_relevant(model, op_types, transform, desc=""): ) # Lookup layers + model = apply_if_relevant(model, ["Gather"], to_hw.InferSelectTokenLayer(), "token selection") model = apply_if_relevant(model, ["Gather"], to_hw.InferLookupLayer(), "lookup layers") # Activation functions diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index c00a1d5054..4dc93e7dd6 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -71,6 +71,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.outer_shuffle import OuterShuffle from finn.custom_op.fpgadataflow.pool import Pool from finn.custom_op.fpgadataflow.requant import Requant +from finn.custom_op.fpgadataflow.selecttoken import SelectToken from finn.custom_op.fpgadataflow.shuffle import Shuffle from finn.custom_op.fpgadataflow.split import StreamingSplit from finn.custom_op.fpgadataflow.streamingdataflowpartition import ( @@ -105,10 +106,11 @@ def register_custom_op(cls): custom_op["Lookup"] = Lookup custom_op["OuterShuffle"] = OuterShuffle custom_op["Pool"] = Pool +custom_op["Requant"] = Requant +custom_op["SelectToken"] = SelectToken custom_op["Shuffle"] = Shuffle custom_op["StreamingConcat"] = StreamingConcat custom_op["StreamingSplit"] = StreamingSplit custom_op["StreamingDataWidthConverter"] = StreamingDataWidthConverter custom_op["UpsampleNearestNeighbour"] = UpsampleNearestNeighbour custom_op["HWSoftmax"] = HWSoftmax -custom_op["Requant"] = Requant diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index fd3df3fbb7..10deceb9c3 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -41,6 +41,7 @@ from finn.custom_op.fpgadataflow.rtl.layernorm_rtl import LayerNorm_rtl from finn.custom_op.fpgadataflow.rtl.matrixvectoractivation_rtl import MVAU_rtl from finn.custom_op.fpgadataflow.rtl.requant_rtl import Requant_rtl +from finn.custom_op.fpgadataflow.rtl.selecttoken_rtl import SelectToken_rtl from finn.custom_op.fpgadataflow.rtl.streamingdatawidthconverter_rtl import ( StreamingDataWidthConverter_rtl, ) @@ -62,6 +63,7 @@ custom_op["StreamingDataWidthConverter_rtl"] = StreamingDataWidthConverter_rtl custom_op["StreamingFIFO_rtl"] = StreamingFIFO_rtl custom_op["MVAU_rtl"] = MVAU_rtl +custom_op["SelectToken_rtl"] = SelectToken_rtl custom_op["VVAU_rtl"] = VVAU_rtl custom_op["Thresholding_rtl"] = Thresholding_rtl custom_op["InnerShuffle_rtl"] = InnerShuffle_rtl diff --git a/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py new file mode 100644 index 0000000000..c429f6e4ed --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py @@ -0,0 +1,133 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.selecttoken import SelectToken + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/selecttoken/hdl" + + +class SelectToken_rtl(SelectToken, RTLBackend): + """RTL implementation of SelectToken.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(SelectToken.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def generate_hdl(self, model, fpgapart, clk): + simd = self.get_nodeattr("SIMD") + num_channels = self.get_nodeattr("NumChannels") + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert num_channels % simd == 0, "SIMD must divide NumChannels" + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/select_token_template.v" + with open(template_path, "r") as f: + template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype().bitwidth() + fold_width = elem_width * simd + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "NUM_TOKENS": num_tokens, + "NUM_CHANNELS": num_channels, + "SIMD": simd, + "ELEM_WIDTH": elem_width, + "TOKEN_INDEX": token_index, + "FOLD_WIDTH": fold_width, + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + shutil.copy(rtlsrc + "/select_token.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + verilog_files = [ + rtllib_dir + "select_token.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + return verilog_files + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + SelectToken.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/selecttoken.py b/src/finn/custom_op/fpgadataflow/selecttoken.py new file mode 100644 index 0000000000..8139fbfbc8 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/selecttoken.py @@ -0,0 +1,155 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class SelectToken(HWCustomOp): + """Select one token vector from a sequence of token vectors.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "NumTokens": ("i", True, 0), + "NumChannels": ("i", True, 0), + "TokenIndex": ("i", True, 0), + "SIMD": ("i", False, 1), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + } + ) + return my_attrs + + def get_normal_input_shape(self, ind=0): + if ind != 0: + raise Exception("SelectToken only has one input") + return (1, self.get_nodeattr("NumTokens"), self.get_nodeattr("NumChannels")) + + def get_folded_input_shape(self, ind=0): + normal_shape = self.get_normal_input_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def get_normal_output_shape(self, ind=0): + return (1, self.get_nodeattr("NumChannels")) + + def get_folded_output_shape(self, ind=0): + normal_shape = self.get_normal_output_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpected input shape for token sequence." + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype() + + idt = model.get_tensor_datatype(node.input[0]) + if idt is None: + idt = attr_idt + if idt is None: + raise Exception("SelectToken input datatype is not set") + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + attr_odt = self.get_nodeattr("outputDataType") + if attr_odt != "" and DataType[attr_odt] != idt: + warnings.warn( + "outputDataType changing for %s: %s -> %s" + % (node.name, str(DataType[attr_odt]), str(idt)) + ) + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(ind) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind != 0: + return 0 + return self.get_input_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_outstream_width(self, ind=0): + return self.get_output_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return int(np.prod(self.get_folded_input_shape()[:-1])) + + def execute_node(self, context, graph): + node = self.onnx_node + inp = context[node.input[0]] + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token." + + result = inp[:, token_index, :] + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape( + self.get_normal_output_shape() + ) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return 200 diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 994905c9c6..2c73c88702 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1265,9 +1265,89 @@ def apply(self, model): return (model, graph_modified) +class InferSelectTokenLayer(Transformation): + """Convert scalar Gather(input, token_index, axis=1) into SelectToken.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Gather": + continue + + axis = get_by_name(node.attribute, "axis") + if axis is None or len(node.input) != 2: + continue + + seq_name = node.input[0] + idx_name = node.input[1] + idx_init = model.get_initializer(idx_name) + if idx_init is None or idx_init.size != 1: + continue + if model.get_initializer(seq_name) is not None: + continue + + seq_shape = model.get_tensor_shape(seq_name) + if seq_shape is None or any(x is None for x in seq_shape): + continue + + rank = len(seq_shape) + gather_axis = axis.i if axis.i >= 0 else axis.i + rank + if rank != 3 or gather_axis != 1: + continue + + token_index = int(idx_init.flatten()[0]) + num_tokens = int(seq_shape[1]) + if token_index < 0: + token_index += num_tokens + if token_index < 0 or token_index >= num_tokens: + continue + + out_shape = model.get_tensor_shape(node.output[0]) + exp_oshape = [int(seq_shape[0]), int(seq_shape[2])] + if out_shape is not None and list(out_shape) != exp_oshape: + continue + if seq_shape[0] != 1: + continue + + idt = model.get_tensor_datatype(seq_name) + if idt is None or not idt.is_integer(): + continue + odt = model.get_tensor_datatype(node.output[0]) + if odt is None: + odt = idt + elif odt != idt: + continue + + new_node = helper.make_node( + "SelectToken", + [seq_name], + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_" + node.name, + NumTokens=num_tokens, + NumChannels=int(seq_shape[2]), + TokenIndex=token_index, + SIMD=1, + inputDataType=idt.name, + outputDataType=odt.name, + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferSplitLayer(Transformation): """Convert suitable Split nodes (operating on last/-1 axis) - into StreamingConcat HW layers.""" + into StreamingSplit HW layers.""" def apply(self, model): graph = model.graph diff --git a/tests/fpgadataflow/test_fpgadataflow_selecttoken.py b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py new file mode 100644 index 0000000000..47709a7520 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py @@ -0,0 +1,267 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import TensorProto, helper, numpy_helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferSelectTokenLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _make_graph(nodes, output_shape, idx_values=None, finn_dtype=DataType["INT8"]): + tokens_shape = [1, 4, 4] + tokens = helper.make_tensor_value_info("tokens", TensorProto.FLOAT, tokens_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape) + initializers = [] + if idx_values is not None: + initializers.append(numpy_helper.from_array(idx_values, name="idx")) + graph = helper.make_graph(nodes, "selecttoken_test", [tokens], [output], initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + for tensor_name in ["tokens", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_gather_model(token_index=0): + idx_values = np.asarray(token_index, dtype=np.int64) + gather = helper.make_node( + "Gather", + ["tokens", "idx"], + ["out"], + axis=1, + name="gather_token", + ) + return _make_graph([gather], [1, 4], idx_values) + + +def _make_selecttoken_model(token_index=0, simd=1, finn_dtype=DataType["INT8"]): + select = helper.make_node( + "SelectToken", + ["tokens"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_0", + NumTokens=4, + NumChannels=4, + TokenIndex=token_index, + SIMD=simd, + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + ) + return _make_graph([select], [1, 4], None, finn_dtype) + + +def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model + + +@pytest.mark.fpgadataflow +def test_convert_gather_to_selecttoken(): + model = _make_gather_model(token_index=2) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, 2, :] + + ret = execute_onnx(model, {"tokens": tokens}) + assert (ret["out"] == expected).all() + + model = model.transform(InferSelectTokenLayer()) + node = model.graph.node[0] + assert node.op_type == "SelectToken" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["tokens"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 4) + assert inst.get_exp_cycles() == 16 + assert inst.get_nodeattr("TokenIndex") == 2 + + ret = execute_onnx(model, {"tokens": tokens}) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + assert model.graph.node[0].op_type == "SelectToken_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + assert model.graph.node[0].name == "SelectToken_gather_token" + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize("token_index", [0, 1, 3]) +def test_selecttoken_python_execution(token_index): + model = _make_selecttoken_model(token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + ret = execute_onnx(model, {"tokens": tokens}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,fold_width", + [(DataType["INT8"], 16), (DataType["UINT4"], 8), (DataType["BIPOLAR"], 2)], +) +def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width): + model = _make_selecttoken_model(token_index=3, simd=2, finn_dtype=finn_dtype) + model = model.transform(SpecializeLayers(FPGA_PART)) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == "SelectToken_0" + wrapper = tmp_path / (topname + ".v") + core = tmp_path / "select_token.sv" + assert wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + assert "parameter FOLD_WIDTH = %d" % fold_width in wrapper_text + assert ".SIMD(2)" in wrapper_text + assert ".TOKEN_INDEX(3)" in wrapper_text + assert "select_token #(" in wrapper_text + assert "out0_V_TVALID" in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("select_token.sv" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_selecttoken_resource_estimation(): + model = _make_selecttoken_model(token_index=1, simd=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 200, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_rtlsim(simd, token_index): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"tokens": tokens}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("SelectToken_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_stitched_ip_rtlsim(simd, token_index): + model = _prepare_selecttoken_stitched_ip_model(simd=simd, token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, {"tokens": tokens}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_selecttoken_stitched_ip_synth_ooc(): + model = _prepare_selecttoken_stitched_ip_model(simd=2, token_index=1) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0 From 658149983145fbbe53ccc8de6bf6c7846806d4eb Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Thu, 30 Apr 2026 16:58:07 +0100 Subject: [PATCH 4/7] Address AddCLSToken review comments --- docs/finn/components/rtl-swg.rst | 2 +- docs/finn/developers.rst | 2 +- docs/finn/source_code/finn.builder.rst | 2 +- docs/finn/source_code/finn.core.rst | 2 +- docs/finn/source_code/finn.rst | 2 +- .../fpgadataflow/rtl/addclstoken_rtl.py | 4 ++++ src/finn/custom_op/fpgadataflow/rtlbackend.py | 10 +++++++--- .../fpgadataflow/specialize_layers.py | 1 - .../test_fpgadataflow_addclstoken.py | 19 ++++++++++++------- 9 files changed, 28 insertions(+), 16 deletions(-) diff --git a/docs/finn/components/rtl-swg.rst b/docs/finn/components/rtl-swg.rst index e8db1d2fa7..8d48dc9d5a 100644 --- a/docs/finn/components/rtl-swg.rst +++ b/docs/finn/components/rtl-swg.rst @@ -96,7 +96,7 @@ Dynamic Mode The "default" style also supports a dynamic mode, which provides an interface to change feature map dimensions, stride, or dilation at run-time. See `this pull request `_ for more information. Folding -------- +======= The RTL SWG is supported by the basic automatic folding algorithm in FINN (:py:mod:`finn.transformation.fpgadataflow.set_folding.SetFolding`). Consider the following implications: diff --git a/docs/finn/developers.rst b/docs/finn/developers.rst index 985b86b279..a265c699c9 100644 --- a/docs/finn/developers.rst +++ b/docs/finn/developers.rst @@ -99,7 +99,7 @@ computer, and you should be able to launch the various .tcl scripts or .xpr proj Docker container as well. Linting -------- +======= We use a pre-commit hook to auto-format Python code and check for issues. See https://pre-commit.com/ for installation. Once you have pre-commit, you can install diff --git a/docs/finn/source_code/finn.builder.rst b/docs/finn/source_code/finn.builder.rst index e4dc810e81..caadf3f91f 100644 --- a/docs/finn/source_code/finn.builder.rst +++ b/docs/finn/source_code/finn.builder.rst @@ -3,7 +3,7 @@ Builder ******* Modules -~~~~~~~ +======= finn.builder.build\_dataflow ---------------------------- diff --git a/docs/finn/source_code/finn.core.rst b/docs/finn/source_code/finn.core.rst index 4f16b3ac74..28cb47eaf7 100644 --- a/docs/finn/source_code/finn.core.rst +++ b/docs/finn/source_code/finn.core.rst @@ -3,7 +3,7 @@ Core **** Modules -~~~~~~~ +======= qonnx.core.data\_layout ------------------------- diff --git a/docs/finn/source_code/finn.rst b/docs/finn/source_code/finn.rst index f67dd0fe9c..5547a46623 100644 --- a/docs/finn/source_code/finn.rst +++ b/docs/finn/source_code/finn.rst @@ -6,7 +6,7 @@ The FINN sources are divided into different modules. They are listed below. .. note:: **Some of these functions and modules are located in the `qonnx` repository.** Modules -~~~~~~~ +======= .. toctree:: :maxdepth: 1 diff --git a/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py index 7b3f810cad..8ca3daec88 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py @@ -135,6 +135,10 @@ def get_rtl_file_list(self, abspath=False): ] return verilog_files + def get_rtlsim_input_indices(self): + """Only patch tokens are streamed; CLS token data is embedded in generated RTL.""" + return [0] + def code_generation_ipi(self): code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") sourcefiles = self.get_rtl_file_list() diff --git a/src/finn/custom_op/fpgadataflow/rtlbackend.py b/src/finn/custom_op/fpgadataflow/rtlbackend.py index 642523f2db..2b8db0310e 100644 --- a/src/finn/custom_op/fpgadataflow/rtlbackend.py +++ b/src/finn/custom_op/fpgadataflow/rtlbackend.py @@ -85,6 +85,10 @@ def code_generation_ipi(self): def code_generation_ipgen(self, model, fpgapart, clk): self.generate_hdl(model, fpgapart, clk) + def get_rtlsim_input_indices(self): + """Return ONNX input indices that are driven as RTLSim input streams.""" + return range(len(self.onnx_node.input)) + def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") @@ -92,10 +96,10 @@ def execute_node(self, context, graph): if mode == "rtlsim": node = self.onnx_node inputs = {} - for i, inp in enumerate(node.input): + for i in self.get_rtlsim_input_indices(): + inp = node.input[i] nbits = self.get_instream_width(i) - if nbits == 0: - continue + assert nbits > 0, "RTLSim input stream %d has zero width." % i exp_ishape = tuple(self.get_normal_input_shape(i)) folded_ishape = self.get_folded_input_shape(i) inp_val = context[inp] diff --git a/src/finn/transformation/fpgadataflow/specialize_layers.py b/src/finn/transformation/fpgadataflow/specialize_layers.py index b2a8629789..dcd2472e0a 100644 --- a/src/finn/transformation/fpgadataflow/specialize_layers.py +++ b/src/finn/transformation/fpgadataflow/specialize_layers.py @@ -389,7 +389,6 @@ def apply(self, model): node.input, node.output, domain="finn.custom_op.fpgadataflow." + impl_style, - name=node.name, ) # add all attributes for attribute in node.attribute: diff --git a/tests/fpgadataflow/test_fpgadataflow_addclstoken.py b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py index 766d783271..7e57c3ef0e 100644 --- a/tests/fpgadataflow/test_fpgadataflow_addclstoken.py +++ b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py @@ -120,13 +120,17 @@ def _prepare_addclstoken_stitched_ip_model(simd=1, pad_tokens=0): return model, cls_values +def _make_input_dict(model, patches): + return {model.graph.input[0].name: patches} + + @pytest.mark.fpgadataflow def test_convert_concat_to_addclstoken(): model, cls_values = _make_concat_model() patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) expected = np.concatenate([cls_values, patches], axis=1) - ret = execute_onnx(model, {"patches": patches}) + ret = execute_onnx(model, _make_input_dict(model, patches)) assert (ret["out"] == expected).all() model = model.transform(InferAddCLSTokenLayer()) @@ -139,13 +143,13 @@ def test_convert_concat_to_addclstoken(): assert inst.get_normal_output_shape() == (1, 4, 4) assert inst.get_exp_cycles() == 16 - ret = execute_onnx(model, {"patches": patches}) + ret = execute_onnx(model, _make_input_dict(model, patches)) assert (ret["out"] == expected).all() model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) assert model.graph.node[0].op_type == "AddCLSToken_rtl" assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" - assert model.graph.node[0].name == "AddCLSToken_concat_cls" @pytest.mark.fpgadataflow @@ -157,7 +161,7 @@ def test_addclstoken_python_execution_with_padding(): axis=1, ) - ret = execute_onnx(model, {"patches": patches}) + ret = execute_onnx(model, _make_input_dict(model, patches)) assert (ret["out"] == expected).all() @@ -178,6 +182,7 @@ def test_addclstoken_rtl_codegen(tmp_path, finn_dtype, cls_values, expected_cls_ cls_values=cls_values, ) model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) node = model.graph.node[0] inst = getCustomOp(node) @@ -185,7 +190,7 @@ def test_addclstoken_rtl_codegen(tmp_path, finn_dtype, cls_values, expected_cls_ inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) topname = inst.get_nodeattr("gen_top_module") - assert topname == "AddCLSToken_0" + assert topname == node.name wrapper = tmp_path / (topname + ".v") core = tmp_path / "addclstoken.sv" assert wrapper.is_file() @@ -244,7 +249,7 @@ def test_addclstoken_rtlsim(simd, pad_tokens): model = model.transform(SetExecMode("rtlsim")) model = model.transform(PrepareRTLSim()) - ret = execute_onnx(model, {"patches": patches}) + ret = execute_onnx(model, _make_input_dict(model, patches)) assert (ret["out"] == expected).all() node = model.get_nodes_by_op_type("AddCLSToken_rtl")[0] @@ -273,7 +278,7 @@ def test_addclstoken_stitched_ip_rtlsim(simd, pad_tokens): model.set_metadata_prop("exec_mode", "rtlsim") - ret = execute_onnx(model, {"patches": patches}) + ret = execute_onnx(model, _make_input_dict(model, patches)) assert (ret["out"] == expected).all() From ae8f3e2b072e45a9e1002d6d2f9ded792c6c23d7 Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Fri, 1 May 2026 08:44:53 +0100 Subject: [PATCH 5/7] Address SelectToken follow-ups after AddCLSToken merge --- .../test_fpgadataflow_selecttoken.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/fpgadataflow/test_fpgadataflow_selecttoken.py b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py index 47709a7520..29c6323ac8 100644 --- a/tests/fpgadataflow/test_fpgadataflow_selecttoken.py +++ b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py @@ -113,13 +113,17 @@ def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0): return model +def _make_input_dict(model, tokens): + return {model.graph.input[0].name: tokens} + + @pytest.mark.fpgadataflow def test_convert_gather_to_selecttoken(): model = _make_gather_model(token_index=2) tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) expected = tokens[:, 2, :] - ret = execute_onnx(model, {"tokens": tokens}) + ret = execute_onnx(model, _make_input_dict(model, tokens)) assert (ret["out"] == expected).all() model = model.transform(InferSelectTokenLayer()) @@ -133,13 +137,13 @@ def test_convert_gather_to_selecttoken(): assert inst.get_exp_cycles() == 16 assert inst.get_nodeattr("TokenIndex") == 2 - ret = execute_onnx(model, {"tokens": tokens}) + ret = execute_onnx(model, _make_input_dict(model, tokens)) assert (ret["out"] == expected).all() model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) assert model.graph.node[0].op_type == "SelectToken_rtl" assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" - assert model.graph.node[0].name == "SelectToken_gather_token" @pytest.mark.fpgadataflow @@ -149,7 +153,7 @@ def test_selecttoken_python_execution(token_index): tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) expected = tokens[:, token_index, :] - ret = execute_onnx(model, {"tokens": tokens}) + ret = execute_onnx(model, _make_input_dict(model, tokens)) assert (ret["out"] == expected).all() @@ -161,6 +165,7 @@ def test_selecttoken_python_execution(token_index): def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width): model = _make_selecttoken_model(token_index=3, simd=2, finn_dtype=finn_dtype) model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) node = model.graph.node[0] inst = getCustomOp(node) @@ -168,7 +173,7 @@ def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width): inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) topname = inst.get_nodeattr("gen_top_module") - assert topname == "SelectToken_0" + assert topname == node.name wrapper = tmp_path / (topname + ".v") core = tmp_path / "select_token.sv" assert wrapper.is_file() @@ -223,7 +228,7 @@ def test_selecttoken_rtlsim(simd, token_index): model = model.transform(SetExecMode("rtlsim")) model = model.transform(PrepareRTLSim()) - ret = execute_onnx(model, {"tokens": tokens}) + ret = execute_onnx(model, _make_input_dict(model, tokens)) assert (ret["out"] == expected).all() node = model.get_nodes_by_op_type("SelectToken_rtl")[0] @@ -246,7 +251,7 @@ def test_selecttoken_stitched_ip_rtlsim(simd, token_index): model.set_metadata_prop("exec_mode", "rtlsim") - ret = execute_onnx(model, {"tokens": tokens}) + ret = execute_onnx(model, _make_input_dict(model, tokens)) assert (ret["out"] == expected).all() From 050b3c38b37bd438e81e2d1cf27a649354ebe11c Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Tue, 5 May 2026 18:07:38 +0100 Subject: [PATCH 6/7] Add fpgadataflow Where RTL op --- .../finn.custom_op.fpgadataflow.rst | 8 + .../finn.custom_op.fpgadataflow.rtl.rst | 8 + finn-rtllib/where/hdl/where.sv | 389 +++++++++++ finn-rtllib/where/hdl/where_core_template.sv | 100 +++ finn-rtllib/where/hdl/where_template.v | 91 +++ src/finn/builder/build_dataflow_steps.py | 1 + src/finn/custom_op/fpgadataflow/__init__.py | 2 + .../custom_op/fpgadataflow/rtl/__init__.py | 2 + .../custom_op/fpgadataflow/rtl/where_rtl.py | 156 +++++ src/finn/custom_op/fpgadataflow/where.py | 227 +++++++ .../fpgadataflow/convert_to_hw_layers.py | 94 ++- tests/fpgadataflow/test_fpgadataflow_where.py | 619 ++++++++++++++++++ 12 files changed, 1696 insertions(+), 1 deletion(-) create mode 100644 finn-rtllib/where/hdl/where.sv create mode 100644 finn-rtllib/where/hdl/where_core_template.sv create mode 100644 finn-rtllib/where/hdl/where_template.v create mode 100644 src/finn/custom_op/fpgadataflow/rtl/where_rtl.py create mode 100644 src/finn/custom_op/fpgadataflow/where.py create mode 100644 tests/fpgadataflow/test_fpgadataflow_where.py diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst index 6cefa2f15d..2a6e716031 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst @@ -127,6 +127,14 @@ finn.custom\_op.fpgadataflow.selecttoken :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.where +----------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.where + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.lookup ----------------------------------------------- diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst index 26834ec610..1ad68f9818 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst @@ -53,6 +53,14 @@ finn.custom\_op.fpgadataflow.selecttoken\_rtl :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.where\_rtl +--------------------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.where_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.streamingfifo\_rtl ------------------------------------------------- diff --git a/finn-rtllib/where/hdl/where.sv b/finn-rtllib/where/hdl/where.sv new file mode 100644 index 0000000000..f50a92af9e --- /dev/null +++ b/finn-rtllib/where/hdl/where.sv @@ -0,0 +1,389 @@ +/**************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief ONNX Where stream operator with multidirectional broadcasting. + * + * @description + * The three input tensors are consumed once per frame into local word + * memories. The output tensor is then emitted in row-major folded order. + * This frame-buffered schedule supports full ONNX multidirectional + * broadcasting, including reuse across non-contiguous output positions. + ***************************************************************************/ + +`default_nettype none + +module where_broadcast #( + int unsigned DATA_WIDTH = 32, + int unsigned PE = 1, + int unsigned NDIMS = 2, + int unsigned COND_NDIMS = NDIMS, + int unsigned X_NDIMS = NDIMS, + int unsigned Y_NDIMS = NDIMS, + + parameter int unsigned OUT_SHAPE[NDIMS] = '{ default: 1 }, + parameter int unsigned COND_SHAPE[COND_NDIMS] = '{ default: 1 }, + parameter int unsigned X_SHAPE[X_NDIMS] = '{ default: 1 }, + parameter int unsigned Y_SHAPE[Y_NDIMS] = '{ default: 1 }, + + localparam int unsigned OUTER_DIMS = (NDIMS > 1)? NDIMS-1 : 1, + localparam int unsigned COND_PE = (COND_SHAPE[COND_NDIMS-1] == 1)? 1 : PE, + localparam int unsigned X_PE = (X_SHAPE[X_NDIMS-1] == 1)? 1 : PE, + localparam int unsigned Y_PE = (Y_SHAPE[Y_NDIMS-1] == 1)? 1 : PE +)( + // Global Control + input wire logic clk, + input wire logic rst, + + // Condition Stream - folded according to COND_SHAPE + input wire logic [COND_PE-1:0] cdat, + input wire logic cvld, + output logic crdy, + + // X Stream - folded according to X_SHAPE + input wire logic [X_PE-1:0][DATA_WIDTH-1:0] xdat, + input wire logic xvld, + output logic xrdy, + + // Y Stream - folded according to Y_SHAPE + input wire logic [Y_PE-1:0][DATA_WIDTH-1:0] ydat, + input wire logic yvld, + output logic yrdy, + + // Output Stream - folded according to OUT_SHAPE and PE + output logic [PE-1:0][DATA_WIDTH-1:0] odat, + output logic ovld, + input wire logic ordy +); + + typedef int unsigned outer_idx_t[OUTER_DIMS]; + typedef logic [COND_PE-1:0] cond_word_t; + typedef logic [X_PE-1:0][DATA_WIDTH-1:0] x_word_t; + typedef logic [Y_PE-1:0][DATA_WIDTH-1:0] y_word_t; + typedef logic [PE-1:0][DATA_WIDTH-1:0] out_word_t; + + function automatic int unsigned out_outer_elems(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < NDIMS; i++) + r *= OUT_SHAPE[i]; + return r; + endfunction : out_outer_elems + + function automatic int unsigned cond_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < COND_NDIMS; i++) + r *= COND_SHAPE[i]; + if(COND_SHAPE[COND_NDIMS-1] != 1) + r *= COND_SHAPE[COND_NDIMS-1] / PE; + return r; + endfunction : cond_word_count + + function automatic int unsigned x_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < X_NDIMS; i++) + r *= X_SHAPE[i]; + if(X_SHAPE[X_NDIMS-1] != 1) + r *= X_SHAPE[X_NDIMS-1] / PE; + return r; + endfunction : x_word_count + + function automatic int unsigned y_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < Y_NDIMS; i++) + r *= Y_SHAPE[i]; + if(Y_SHAPE[Y_NDIMS-1] != 1) + r *= Y_SHAPE[Y_NDIMS-1] / PE; + return r; + endfunction : y_word_count + + function automatic int unsigned out_word_count(); + return out_outer_elems() * (OUT_SHAPE[NDIMS-1] / PE); + endfunction : out_word_count + + function automatic int unsigned cond_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(COND_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return COND_SHAPE[source_axis]; + endfunction : cond_dim + + function automatic int unsigned x_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(X_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return X_SHAPE[source_axis]; + endfunction : x_dim + + function automatic int unsigned y_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(Y_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return Y_SHAPE[source_axis]; + endfunction : y_dim + + function automatic int unsigned cond_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(COND_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= COND_SHAPE[source_axis]; + if(COND_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(COND_SHAPE[COND_NDIMS-1] != 1) + r = r * (COND_SHAPE[COND_NDIMS-1] / PE) + out_fold; + return r; + endfunction : cond_word_addr + + function automatic int unsigned x_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(X_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= X_SHAPE[source_axis]; + if(X_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(X_SHAPE[X_NDIMS-1] != 1) + r = r * (X_SHAPE[X_NDIMS-1] / PE) + out_fold; + return r; + endfunction : x_word_addr + + function automatic int unsigned y_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(Y_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= Y_SHAPE[source_axis]; + if(Y_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(Y_SHAPE[Y_NDIMS-1] != 1) + r = r * (Y_SHAPE[Y_NDIMS-1] / PE) + out_fold; + return r; + endfunction : y_word_addr + + localparam int unsigned OUT_FOLDS = OUT_SHAPE[NDIMS-1] / PE; + localparam int unsigned OUT_WORDS = out_word_count(); + localparam int unsigned COND_WORDS = cond_word_count(); + localparam int unsigned X_WORDS = x_word_count(); + localparam int unsigned Y_WORDS = y_word_count(); + + initial begin + automatic int unsigned max_dim; + automatic int unsigned cd; + automatic int unsigned xd; + automatic int unsigned yd; + + if(DATA_WIDTH < 1) begin + $error("%m: DATA_WIDTH must be positive."); + $finish; + end + if(PE < 1) begin + $error("%m: PE must be positive."); + $finish; + end + if(NDIMS < 1) begin + $error("%m: NDIMS must be positive."); + $finish; + end + if(COND_NDIMS < 1 || COND_NDIMS > NDIMS) begin + $error("%m: COND_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if(X_NDIMS < 1 || X_NDIMS > NDIMS) begin + $error("%m: X_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if(Y_NDIMS < 1 || Y_NDIMS > NDIMS) begin + $error("%m: Y_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if((OUT_SHAPE[NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the output innermost dimension."); + $finish; + end + for(int unsigned i = 0; i < NDIMS; i++) begin + cd = cond_dim(i); + xd = x_dim(i); + yd = y_dim(i); + max_dim = cd; + + if(cd < 1 || xd < 1 || yd < 1 || OUT_SHAPE[i] < 1) begin + $error("%m: shape dimensions must be positive."); + $finish; + end + if(xd != 1 && max_dim != 1 && xd != max_dim) begin + $error("%m: X_SHAPE is not broadcast-compatible."); + $finish; + end + if(xd != 1) max_dim = xd; + if(yd != 1 && max_dim != 1 && yd != max_dim) begin + $error("%m: Y_SHAPE is not broadcast-compatible."); + $finish; + end + if(yd != 1) max_dim = yd; + if(cd != 1 && cd != max_dim) begin + $error("%m: COND_SHAPE is not broadcast-compatible."); + $finish; + end + if(OUT_SHAPE[i] != max_dim) begin + $error("%m: OUT_SHAPE is not the multidirectional broadcast result."); + $finish; + end + end + if(COND_SHAPE[COND_NDIMS-1] != 1 && (COND_SHAPE[COND_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the condition innermost dimension when not broadcast."); + $finish; + end + if(X_SHAPE[X_NDIMS-1] != 1 && (X_SHAPE[X_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the X innermost dimension when not broadcast."); + $finish; + end + if(Y_SHAPE[Y_NDIMS-1] != 1 && (Y_SHAPE[Y_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the Y innermost dimension when not broadcast."); + $finish; + end + end + + //------------------------------------------------------------------------ + // Frame Input Buffers + cond_word_t Cmem[COND_WORDS]; + x_word_t Xmem[X_WORDS]; + y_word_t Ymem[Y_WORDS]; + + int unsigned CWr = 0; + int unsigned XWr = 0; + int unsigned YWr = 0; + logic CLoaded = 0; + logic XLoaded = 0; + logic YLoaded = 0; + logic Emit = 0; + + assign crdy = !Emit && !CLoaded; + assign xrdy = !Emit && !XLoaded; + assign yrdy = !Emit && !YLoaded; + + uwire c_fire = cvld && crdy; + uwire x_fire = xvld && xrdy; + uwire y_fire = yvld && yrdy; + uwire emit_fire = Emit && ordy; + + uwire c_loaded_now = CLoaded || (c_fire && CWr == COND_WORDS-1); + uwire x_loaded_now = XLoaded || (x_fire && XWr == X_WORDS-1); + uwire y_loaded_now = YLoaded || (y_fire && YWr == Y_WORDS-1); + + //------------------------------------------------------------------------ + // Output Indexing + outer_idx_t OutIdx = '{ default: 0 }; + int unsigned OutFold = 0; + + uwire out_last_fold = (OutFold == OUT_FOLDS-1); + logic out_last_outer; + always_comb begin + out_last_outer = 1; + for(int unsigned i = 0; i+1 < NDIMS; i++) + out_last_outer &= (OutIdx[i] == OUT_SHAPE[i]-1); + end + uwire out_last = out_last_fold && out_last_outer; + uwire frame_done = emit_fire && out_last; + + always_ff @(posedge clk) begin + if(rst) begin + CWr <= 0; + XWr <= 0; + YWr <= 0; + CLoaded <= 0; + XLoaded <= 0; + YLoaded <= 0; + Emit <= 0; + OutIdx <= '{ default: 0 }; + OutFold <= 0; + end + else begin + if(frame_done) begin + CWr <= 0; + XWr <= 0; + YWr <= 0; + CLoaded <= 0; + XLoaded <= 0; + YLoaded <= 0; + Emit <= 0; + OutIdx <= '{ default: 0 }; + OutFold <= 0; + end + else begin + if(c_fire) begin + Cmem[CWr] <= cdat; + CLoaded <= (CWr == COND_WORDS-1); + if(CWr != COND_WORDS-1) CWr <= CWr + 1; + end + if(x_fire) begin + Xmem[XWr] <= xdat; + XLoaded <= (XWr == X_WORDS-1); + if(XWr != X_WORDS-1) XWr <= XWr + 1; + end + if(y_fire) begin + Ymem[YWr] <= ydat; + YLoaded <= (YWr == Y_WORDS-1); + if(YWr != Y_WORDS-1) YWr <= YWr + 1; + end + if(!Emit && c_loaded_now && x_loaded_now && y_loaded_now) + Emit <= 1; + else if(emit_fire) begin + if(out_last_fold) begin + automatic bit carry = 1; + OutFold <= 0; + for(int i = int'(NDIMS)-2; i >= 0; i--) begin + if(carry) begin + if(OutIdx[i] == OUT_SHAPE[i]-1) begin + OutIdx[i] <= 0; + end + else begin + OutIdx[i] <= OutIdx[i] + 1; + carry = 0; + end + end + end + end + else + OutFold <= OutFold + 1; + end + end + end + end + + //------------------------------------------------------------------------ + // Broadcast Selection + uwire logic [31:0] c_addr = cond_word_addr(OutIdx, OutFold); + uwire logic [31:0] x_addr = x_word_addr(OutIdx, OutFold); + uwire logic [31:0] y_addr = y_word_addr(OutIdx, OutFold); + uwire cond_word_t c_word = Cmem[c_addr]; + uwire x_word_t x_word = Xmem[x_addr]; + uwire y_word_t y_word = Ymem[y_addr]; + + out_word_t selected; + for(genvar lane = 0; lane < PE; lane++) begin : genSelect + uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? c_word[0] : c_word[lane]; + uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? x_word[0] : x_word[lane]; + uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? y_word[0] : y_word[lane]; + assign selected[lane] = c? x : y; + end : genSelect + + assign odat = selected; + assign ovld = Emit; + +endmodule : where_broadcast + +`default_nettype wire diff --git a/finn-rtllib/where/hdl/where_core_template.sv b/finn-rtllib/where/hdl/where_core_template.sv new file mode 100644 index 0000000000..381b1db9a5 --- /dev/null +++ b/finn-rtllib/where/hdl/where_core_template.sv @@ -0,0 +1,100 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$_core #( + parameter COND_WIDTH = $COND_WIDTH$, + parameter X_WIDTH = $X_WIDTH$, + parameter Y_WIDTH = $Y_WIDTH$, + parameter OUT_WIDTH = $OUT_WIDTH$, + parameter COND_AXI_WIDTH = ((COND_WIDTH + 7) / 8) * 8, + parameter X_AXI_WIDTH = ((X_WIDTH + 7) / 8) * 8, + parameter Y_AXI_WIDTH = ((Y_WIDTH + 7) / 8) * 8, + parameter OUT_AXI_WIDTH = ((OUT_WIDTH + 7) / 8) * 8 +)( + input ap_clk, + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [COND_AXI_WIDTH-1:0] in0_V_TDATA, + + output in1_V_TREADY, + input in1_V_TVALID, + input [X_AXI_WIDTH-1:0] in1_V_TDATA, + + output in2_V_TREADY, + input in2_V_TVALID, + input [Y_AXI_WIDTH-1:0] in2_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [OUT_AXI_WIDTH-1:0] out0_V_TDATA +); + + wire [OUT_WIDTH-1:0] core_out; + + assign out0_V_TDATA[OUT_WIDTH-1:0] = core_out; + + generate + if (OUT_AXI_WIDTH > OUT_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[OUT_AXI_WIDTH-1:OUT_WIDTH] = {(OUT_AXI_WIDTH-OUT_WIDTH){1'b0}}; + end + endgenerate + + where_broadcast #( + .DATA_WIDTH($DATA_WIDTH$), + .PE($PE$), + .NDIMS($NDIMS$), + .COND_NDIMS($COND_NDIMS$), + .X_NDIMS($X_NDIMS$), + .Y_NDIMS($Y_NDIMS$), + .OUT_SHAPE($OUT_SHAPE$), + .COND_SHAPE($COND_SHAPE$), + .X_SHAPE($X_SHAPE$), + .Y_SHAPE($Y_SHAPE$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .cdat(in0_V_TDATA[COND_WIDTH-1:0]), + .cvld(in0_V_TVALID), + .crdy(in0_V_TREADY), + .xdat(in1_V_TDATA[X_WIDTH-1:0]), + .xvld(in1_V_TVALID), + .xrdy(in1_V_TREADY), + .ydat(in2_V_TDATA[Y_WIDTH-1:0]), + .yvld(in2_V_TVALID), + .yrdy(in2_V_TREADY), + .odat(core_out), + .ovld(out0_V_TVALID), + .ordy(out0_V_TREADY) + ); + +endmodule diff --git a/finn-rtllib/where/hdl/where_template.v b/finn-rtllib/where/hdl/where_template.v new file mode 100644 index 0000000000..3479b88bcf --- /dev/null +++ b/finn-rtllib/where/hdl/where_template.v @@ -0,0 +1,91 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter COND_WIDTH = $COND_WIDTH$, + parameter X_WIDTH = $X_WIDTH$, + parameter Y_WIDTH = $Y_WIDTH$, + parameter OUT_WIDTH = $OUT_WIDTH$, + parameter COND_AXI_WIDTH = ((COND_WIDTH + 7) / 8) * 8, + parameter X_AXI_WIDTH = ((X_WIDTH + 7) / 8) * 8, + parameter Y_AXI_WIDTH = ((Y_WIDTH + 7) / 8) * 8, + parameter OUT_AXI_WIDTH = ((OUT_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:in1_V:in2_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [COND_AXI_WIDTH-1:0] in0_V_TDATA, + + output in1_V_TREADY, + input in1_V_TVALID, + input [X_AXI_WIDTH-1:0] in1_V_TDATA, + + output in2_V_TREADY, + input in2_V_TVALID, + input [Y_AXI_WIDTH-1:0] in2_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [OUT_AXI_WIDTH-1:0] out0_V_TDATA +); + + $TOP_MODULE_NAME$_core #( + .COND_WIDTH(COND_WIDTH), + .X_WIDTH(X_WIDTH), + .Y_WIDTH(Y_WIDTH), + .OUT_WIDTH(OUT_WIDTH), + .COND_AXI_WIDTH(COND_AXI_WIDTH), + .X_AXI_WIDTH(X_AXI_WIDTH), + .Y_AXI_WIDTH(Y_AXI_WIDTH), + .OUT_AXI_WIDTH(OUT_AXI_WIDTH) + ) impl ( + .ap_clk(ap_clk), + .ap_rst_n(ap_rst_n), + .in0_V_TREADY(in0_V_TREADY), + .in0_V_TVALID(in0_V_TVALID), + .in0_V_TDATA(in0_V_TDATA), + .in1_V_TREADY(in1_V_TREADY), + .in1_V_TVALID(in1_V_TVALID), + .in1_V_TDATA(in1_V_TDATA), + .in2_V_TREADY(in2_V_TREADY), + .in2_V_TVALID(in2_V_TVALID), + .in2_V_TDATA(in2_V_TDATA), + .out0_V_TREADY(out0_V_TREADY), + .out0_V_TVALID(out0_V_TVALID), + .out0_V_TDATA(out0_V_TDATA) + ); + +endmodule diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index e84b997b2e..9c1b235f4a 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -526,6 +526,7 @@ def apply_if_relevant(model, op_types, transform, desc=""): to_hw.InferElementwiseBinaryOperation(), "elementwise binary operations", ) + model = apply_if_relevant(model, ["Where"], to_hw.InferWhereLayer(), "where selection") model = apply_if_relevant( model, ["Relu"], to_hw.InferReLUAsElementwiseMax(), "ReLU as elementwise max" ) diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 4dc93e7dd6..4a04853d39 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -84,6 +84,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.thresholding import Thresholding from finn.custom_op.fpgadataflow.upsampler import UpsampleNearestNeighbour from finn.custom_op.fpgadataflow.vectorvectoractivation import VVAU +from finn.custom_op.fpgadataflow.where import Where # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure @@ -114,3 +115,4 @@ def register_custom_op(cls): custom_op["StreamingDataWidthConverter"] = StreamingDataWidthConverter custom_op["UpsampleNearestNeighbour"] = UpsampleNearestNeighbour custom_op["HWSoftmax"] = HWSoftmax +custom_op["Where"] = Where diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index 10deceb9c3..fb8b76d6d3 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -48,6 +48,7 @@ from finn.custom_op.fpgadataflow.rtl.streamingfifo_rtl import StreamingFIFO_rtl from finn.custom_op.fpgadataflow.rtl.thresholding_rtl import Thresholding_rtl from finn.custom_op.fpgadataflow.rtl.vectorvectoractivation_rtl import VVAU_rtl +from finn.custom_op.fpgadataflow.rtl.where_rtl import Where_rtl custom_op = dict() @@ -68,5 +69,6 @@ custom_op["Thresholding_rtl"] = Thresholding_rtl custom_op["InnerShuffle_rtl"] = InnerShuffle_rtl custom_op["Requant_rtl"] = Requant_rtl +custom_op["Where_rtl"] = Where_rtl custom_op["FINNLoop"] = FINNLoop diff --git a/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py new file mode 100644 index 0000000000..7497b44399 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py @@ -0,0 +1,156 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.where import Where + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/where/hdl" + + +class Where_rtl(Where, RTLBackend): + """RTL implementation of Where.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(Where.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def _shape_literal(self, shape): + rtl_shape = self._rtl_shape(shape) + return "'{ " + ", ".join(str(int(x)) for x in rtl_shape) + " }" + + def generate_hdl(self, model, fpgapart, clk): + pe = self._output_stream_pe() + out_shape = self.get_normal_output_shape() + cond_shape = self.get_normal_input_shape(0) + x_shape = self.get_normal_input_shape(1) + y_shape = self.get_normal_input_shape(2) + out_rtl_shape = self._rtl_shape(out_shape) + cond_rtl_shape = self._rtl_shape(cond_shape) + x_rtl_shape = self._rtl_shape(x_shape) + y_rtl_shape = self._rtl_shape(y_shape) + assert out_rtl_shape[-1] % pe == 0, "PE must divide the output innermost dimension" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/where_template.v" + with open(template_path, "r") as f: + template = f.read() + core_template_path = rtlsrc + "/where_core_template.sv" + with open(core_template_path, "r") as f: + core_template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype(1).bitwidth() + cond_width = self.get_instream_width(0) + x_width = self.get_instream_width(1) + y_width = self.get_instream_width(2) + out_width = self.get_outstream_width(0) + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "PE": pe, + "DATA_WIDTH": elem_width, + "NDIMS": len(out_rtl_shape), + "COND_NDIMS": len(cond_rtl_shape), + "X_NDIMS": len(x_rtl_shape), + "Y_NDIMS": len(y_rtl_shape), + "OUT_SHAPE": self._shape_literal(out_shape), + "COND_SHAPE": self._shape_literal(cond_shape), + "X_SHAPE": self._shape_literal(x_shape), + "Y_SHAPE": self._shape_literal(y_shape), + "COND_WIDTH": cond_width, + "X_WIDTH": x_width, + "Y_WIDTH": y_width, + "OUT_WIDTH": out_width, + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + core_template = core_template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + with open(os.path.join(code_gen_dir, topname + "_core.sv"), "w") as f: + f.write(core_template) + shutil.copy(rtlsrc + "/where.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + return [ + rtllib_dir + "where.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + "_core.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + Where.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/where.py b/src/finn/custom_op/fpgadataflow/where.py new file mode 100644 index 0000000000..b83dabbd3d --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/where.py @@ -0,0 +1,227 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class Where(HWCustomOp): + """Elementwise ONNX Where with multidirectional broadcasting.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "Shape": ("ints", True, []), + "CondShape": ("ints", False, []), + "XShape": ("ints", False, []), + "YShape": ("ints", False, []), + "CondRank": ("i", False, -1), + "XRank": ("i", False, -1), + "YRank": ("i", False, -1), + "PE": ("i", False, 1), + "conditionDataType": ("s", False, "BINARY"), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + "inFIFODepths": ("ints", False, [2, 2, 2]), + "outFIFODepths": ("ints", False, [2]), + } + ) + return my_attrs + + def _shape(self): + return tuple(self.get_nodeattr("Shape")) + + def _input_shape(self, ind): + if ind == 0: + attr_name, rank_name = "CondShape", "CondRank" + elif ind == 1: + attr_name, rank_name = "XShape", "XRank" + elif ind == 2: + attr_name, rank_name = "YShape", "YRank" + else: + raise Exception("Where has exactly three inputs") + + rank = self.get_nodeattr(rank_name) + shape = tuple(self.get_nodeattr(attr_name)) + if rank >= 0: + assert len(shape) == rank, "%s length must match %s" % (attr_name, rank_name) + return shape + if len(shape) != 0: + return shape + return self._shape() + + def _rtl_shape(self, shape): + if len(shape) == 0: + return (1,) + return tuple(shape) + + def _input_stream_pe(self, ind): + shape = self._rtl_shape(self.get_normal_input_shape(ind)) + if shape[-1] == 1: + return 1 + return self._output_stream_pe() + + def _output_stream_pe(self): + shape = self._rtl_shape(self.get_normal_output_shape()) + if shape[-1] == 1: + return 1 + return self.get_nodeattr("PE") + + def _folded_shape(self, shape, stream_pe): + rtl_shape = self._rtl_shape(shape) + *outer, channels = rtl_shape + assert channels % stream_pe == 0, "Stream PE must divide the innermost dimension" + return tuple(outer + [channels // stream_pe, stream_pe]) + + def get_normal_input_shape(self, ind=0): + if ind not in [0, 1, 2]: + raise Exception("Where has exactly three inputs") + return self._input_shape(ind) + + def get_folded_input_shape(self, ind=0): + return self._folded_shape(self.get_normal_input_shape(ind), self._input_stream_pe(ind)) + + def get_normal_output_shape(self, ind=0): + if ind != 0: + raise Exception("Where has exactly one output") + return self._shape() + + def get_folded_output_shape(self, ind=0): + return self._folded_shape(self.get_normal_output_shape(ind), self._output_stream_pe()) + + def make_shape_compatible_op(self, model): + for i, inp in enumerate(self.onnx_node.input): + ishape = tuple(model.get_tensor_shape(inp)) + assert ishape == self.get_normal_input_shape(i), ( + "Unexpected input shape for Where input %d." % i + ) + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + + cond_dt = model.get_tensor_datatype(node.input[0]) + if cond_dt is None: + cond_dt = self.get_condition_datatype() + model.set_tensor_datatype(node.input[0], cond_dt) + if cond_dt != DataType["BINARY"]: + raise Exception("Where condition datatype must be BINARY") + self.set_nodeattr("conditionDataType", cond_dt.name) + + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype(1) + + x_dt = model.get_tensor_datatype(node.input[1]) + y_dt = model.get_tensor_datatype(node.input[2]) + idt = x_dt if x_dt is not None else attr_idt + if idt is None: + raise Exception("Where input datatype is not set") + if y_dt is None: + model.set_tensor_datatype(node.input[2], idt) + elif y_dt != idt: + raise Exception("Where X and Y datatypes must match") + if x_dt is None: + model.set_tensor_datatype(node.input[1], idt) + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + attr_odt = self.get_nodeattr("outputDataType") + if attr_odt != "" and DataType[attr_odt] != idt: + warnings.warn( + "outputDataType changing for %s: %s -> %s" + % (node.name, str(DataType[attr_odt]), str(idt)) + ) + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_condition_datatype(self): + return DataType[self.get_nodeattr("conditionDataType")] + + def get_input_datatype(self, ind=0): + if ind == 0: + return self.get_condition_datatype() + if ind in [1, 2]: + return DataType[self.get_nodeattr("inputDataType")] + raise Exception("Where has exactly three inputs") + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(1) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind == 0: + return self._input_stream_pe(ind) + if ind in [1, 2]: + return self.get_input_datatype(ind).bitwidth() * self._input_stream_pe(ind) + return 0 + + def get_outstream_width(self, ind=0): + return self.get_output_datatype(ind).bitwidth() * self._output_stream_pe() + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return self.get_number_output_values() + + def execute_node(self, context, graph): + node = self.onnx_node + cond = context[node.input[0]] + xval = context[node.input[1]] + yval = context[node.input[2]] + + result = np.where(cond.astype(bool), xval, yval) + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape( + self.get_normal_output_shape() + ) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return int(64 + self.get_nodeattr("PE") * self.get_output_datatype().bitwidth()) + + def get_op_and_param_counts(self): + return {"op_where": int(np.prod(self.get_normal_output_shape()))} diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 2c73c88702..277d3362ea 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -30,7 +30,7 @@ import numpy as np import qonnx.core.data_layout as DataLayout import warnings -from onnx import NodeProto, TensorProto, helper +from onnx import AttributeProto, NodeProto, TensorProto, helper from qonnx.core.datatype import DataType # QONNX wrapper to ONNX model graphs @@ -1493,6 +1493,98 @@ def apply(self, model): return (model, graph_modified) +class InferWhereLayer(Transformation): + """Convert ONNX Where(condition, X, Y) into a streaming Where layer.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Where" or node.domain not in ["", "ai.onnx"]: + continue + if len(node.input) != 3: + continue + + cond_name, x_name, y_name = node.input + if any(model.get_initializer(inp) is not None for inp in node.input): + continue + + cond_shape = model.get_tensor_shape(cond_name) + x_shape = model.get_tensor_shape(x_name) + y_shape = model.get_tensor_shape(y_name) + out_shape = model.get_tensor_shape(node.output[0]) + if any(s is None for s in [cond_shape, x_shape, y_shape, out_shape]): + continue + if any(x is None for x in list(cond_shape) + list(x_shape) + list(y_shape)): + continue + try: + broadcast_shape = np.broadcast_shapes( + tuple(cond_shape), tuple(x_shape), tuple(y_shape) + ) + except ValueError: + continue + if list(out_shape) != [int(x) for x in broadcast_shape]: + continue + x_dt = model.get_tensor_datatype(x_name) + y_dt = model.get_tensor_datatype(y_name) + if x_dt is None or y_dt is None or x_dt != y_dt: + continue + supported_dt = ( + x_dt.is_integer() + or x_dt.is_fixed_point() + or x_dt in [DataType["FLOAT32"], DataType["FLOAT16"]] + ) + if not supported_dt: + continue + out_dt = model.get_tensor_datatype(node.output[0]) + if out_dt is not None and out_dt != x_dt: + continue + + cond_dt = model.get_tensor_datatype(cond_name) + if cond_dt is None: + model.set_tensor_datatype(cond_name, DataType["BINARY"]) + cond_dt = DataType["BINARY"] + if cond_dt != DataType["BINARY"]: + continue + + new_node = helper.make_node( + "Where", + node.input, + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="Where_" + node.name, + CondRank=len(cond_shape), + XRank=len(x_shape), + YRank=len(y_shape), + PE=1, + conditionDataType=cond_dt.name, + inputDataType=x_dt.name, + outputDataType=x_dt.name, + inFIFODepths=[2, 2, 2], + outFIFODepths=[2], + ) + for attr_name, attr_value in [ + ("Shape", [int(x) for x in broadcast_shape]), + ("CondShape", [int(x) for x in cond_shape]), + ("XShape", [int(x) for x in x_shape]), + ("YShape", [int(x) for x in y_shape]), + ]: + new_node.attribute.append( + helper.make_attribute(attr_name, attr_value, attr_type=AttributeProto.INTS) + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferStreamingEltwise(Transformation): """ DEPRECATED: This transformation is deprecated and now redirects to diff --git a/tests/fpgadataflow/test_fpgadataflow_where.py b/tests/fpgadataflow/test_fpgadataflow_where.py new file mode 100644 index 0000000000..4aa9b53bf7 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_where.py @@ -0,0 +1,619 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import AttributeProto, TensorProto, helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferWhereLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _numel(shape): + return int(np.prod(shape)) if len(shape) > 0 else 1 + + +def _make_graph( + nodes, + shape=None, + finn_dtype=DataType["INT8"], + cond_is_bool=False, + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + out_shape = shape if out_shape is None else out_shape + cond_proto = TensorProto.BOOL if cond_is_bool else TensorProto.FLOAT + cond = helper.make_tensor_value_info("cond", cond_proto, cond_shape) + xval = helper.make_tensor_value_info("xval", TensorProto.FLOAT, x_shape) + yval = helper.make_tensor_value_info("yval", TensorProto.FLOAT, y_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape) + graph = helper.make_graph(nodes, "where_test", [cond, xval, yval], [output]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + if not cond_is_bool: + model.set_tensor_datatype("cond", DataType["BINARY"]) + for tensor_name in ["xval", "yval", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_onnx_where_model( + shape=None, + finn_dtype=DataType["INT8"], + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + where = helper.make_node("Where", ["cond", "xval", "yval"], ["out"], name="where_select") + return _make_graph( + [where], + shape, + finn_dtype=finn_dtype, + cond_is_bool=True, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + + +def _make_where_model( + shape=None, + pe=1, + finn_dtype=DataType["INT8"], + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + out_shape = shape if out_shape is None else out_shape + where = helper.make_node( + "Where", + ["cond", "xval", "yval"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="Where_0", + CondRank=len(cond_shape), + XRank=len(x_shape), + YRank=len(y_shape), + PE=pe, + conditionDataType="BINARY", + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + inFIFODepths=[2, 2, 2], + outFIFODepths=[2], + ) + for attr_name, attr_value in [ + ("Shape", out_shape), + ("CondShape", cond_shape), + ("XShape", x_shape), + ("YShape", y_shape), + ]: + where.attribute.append( + helper.make_attribute(attr_name, attr_value, attr_type=AttributeProto.INTS) + ) + return _make_graph( + [where], + shape, + finn_dtype=finn_dtype, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + + +def _prepare_where_stitched_ip_model( + pe=1, + shape=None, + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + model = _make_where_model( + pe=pe, + shape=shape, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model + + +def _make_inputs(shape=None, cond_shape=None, x_shape=None, y_shape=None): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + cond = (np.arange(_numel(cond_shape), dtype=np.float32) % 2).reshape(cond_shape) + xval = np.arange(_numel(x_shape), dtype=np.float32).reshape(x_shape) + yval = (100 + np.arange(_numel(y_shape), dtype=np.float32)).reshape(y_shape) + return cond, xval, yval + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_to_where(): + model = _make_onnx_where_model() + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + assert node.op_type == "Where" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["cond", "xval", "yval"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 2, 4) + assert inst.get_exp_cycles() == 8 + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + assert model.graph.node[0].op_type == "Where_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_broadcast_to_where(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_onnx_where_model( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape + ) + cond, xval, yval = _make_inputs( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape + ) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + assert node.op_type == "Where" + assert node.domain == "finn.custom_op.fpgadataflow" + + inst = getCustomOp(node) + assert inst.get_normal_input_shape(0) == tuple(cond_shape) + assert inst.get_normal_input_shape(1) == tuple(x_shape) + assert inst.get_normal_input_shape(2) == tuple(y_shape) + assert inst.get_normal_output_shape() == tuple(out_shape) + assert inst.get_folded_input_shape(0) == (3, 1, 1) + assert inst.get_folded_input_shape(1) == (4, 1) + assert inst.get_folded_input_shape(2) == (2, 1, 1, 1) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_scalar_broadcast_to_where(): + cond_shape = [] + x_shape = [2, 3] + y_shape = [1, 3] + out_shape = [2, 3] + model = _make_onnx_where_model( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape + ) + cond, xval, yval = _make_inputs( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape + ) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + inst = getCustomOp(node) + + assert inst.get_normal_input_shape(0) == tuple(cond_shape) + assert inst.get_normal_input_shape(1) == tuple(x_shape) + assert inst.get_normal_input_shape(2) == tuple(y_shape) + assert inst.get_normal_output_shape() == tuple(out_shape) + assert inst.get_folded_input_shape(0) == (1, 1) + assert inst.get_nodeattr("CondRank") == 0 + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_float32_to_where(): + model = _make_onnx_where_model(finn_dtype=DataType["FLOAT32"]) + cond, xval, yval = _make_inputs() + xval = xval + 0.25 + yval = yval + 0.5 + expected = np.where(cond.astype(bool), xval, yval) + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + inst = getCustomOp(node) + + assert inst.get_input_datatype(1) == DataType["FLOAT32"] + assert inst.get_output_datatype() == DataType["FLOAT32"] + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype", + [DataType["INT8"], DataType["UINT4"], DataType["BIPOLAR"], DataType["FLOAT32"]], +) +def test_where_python_execution(finn_dtype): + model = _make_where_model(finn_dtype=finn_dtype) + cond, xval, yval = _make_inputs() + if finn_dtype == DataType["BIPOLAR"]: + xval = np.where(xval % 2 == 0, -1, 1).astype(np.float32) + yval = -xval + elif finn_dtype == DataType["UINT4"]: + yval = (15 - xval).astype(np.float32) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_where_python_execution_broadcast(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_where_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape + ) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,fold_width", + [ + (DataType["INT8"], 16), + (DataType["UINT4"], 8), + (DataType["BIPOLAR"], 2), + (DataType["FLOAT32"], 64), + ], +) +def test_where_rtl_codegen(tmp_path, finn_dtype, fold_width): + model = _make_where_model(pe=2, finn_dtype=finn_dtype) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == node.name + wrapper = tmp_path / (topname + ".v") + core_wrapper = tmp_path / (topname + "_core.sv") + core = tmp_path / "where.sv" + assert wrapper.is_file() + assert core_wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + core_wrapper_text = core_wrapper.read_text() + assert "parameter COND_WIDTH = 2" in wrapper_text + assert "parameter X_WIDTH = %d" % fold_width in wrapper_text + assert "parameter Y_WIDTH = %d" % fold_width in wrapper_text + assert "parameter OUT_WIDTH = %d" % fold_width in wrapper_text + assert ".DATA_WIDTH(%d)" % finn_dtype.bitwidth() in core_wrapper_text + assert ".PE(2)" in core_wrapper_text + assert ".NDIMS(3)" in core_wrapper_text + assert "in2_V_TDATA" in wrapper_text + assert "out0_V_TVALID" in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("where.sv" in cmd for cmd in ipi_cmds) + assert any(topname + "_core.sv" in cmd for cmd in ipi_cmds) + assert any(topname + ".v" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_where_rtl_codegen_broadcast(tmp_path): + model = _make_where_model( + pe=2, + cond_shape=[3, 1], + x_shape=[4], + y_shape=[2, 1, 1], + out_shape=[2, 3, 4], + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + wrapper_text = (tmp_path / (topname + ".v")).read_text() + core_wrapper_text = (tmp_path / (topname + "_core.sv")).read_text() + assert "parameter COND_WIDTH = 1" in wrapper_text + assert "parameter X_WIDTH = 16" in wrapper_text + assert "parameter Y_WIDTH = 8" in wrapper_text + assert "parameter OUT_WIDTH = 16" in wrapper_text + assert ".NDIMS(3)" in core_wrapper_text + assert ".COND_NDIMS(2)" in core_wrapper_text + assert ".X_NDIMS(1)" in core_wrapper_text + assert ".Y_NDIMS(3)" in core_wrapper_text + assert ".OUT_SHAPE('{ 2, 3, 4 })" in core_wrapper_text + assert ".COND_SHAPE('{ 3, 1 })" in core_wrapper_text + assert ".X_SHAPE('{ 4 })" in core_wrapper_text + assert ".Y_SHAPE('{ 2, 1, 1 })" in core_wrapper_text + + +@pytest.mark.fpgadataflow +def test_where_rtl_codegen_scalar_broadcast(tmp_path): + model = _make_where_model( + pe=3, + cond_shape=[], + x_shape=[2, 3], + y_shape=[1, 3], + out_shape=[2, 3], + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + wrapper_text = (tmp_path / (topname + ".v")).read_text() + core_wrapper_text = (tmp_path / (topname + "_core.sv")).read_text() + assert "parameter COND_WIDTH = 1" in wrapper_text + assert "parameter X_WIDTH = 24" in wrapper_text + assert "parameter Y_WIDTH = 24" in wrapper_text + assert "parameter OUT_WIDTH = 24" in wrapper_text + assert ".NDIMS(2)" in core_wrapper_text + assert ".COND_NDIMS(1)" in core_wrapper_text + assert ".COND_SHAPE('{ 1 })" in core_wrapper_text + + +@pytest.mark.fpgadataflow +def test_where_resource_estimation(): + model = _make_where_model(pe=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 80, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("pe", [1, 2]) +def test_where_rtlsim(pe): + model = _make_where_model(pe=pe) + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("Where_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_rtlsim_broadcast(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_where_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape + ) + expected = np.where(cond.astype(bool), xval, yval) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("Where_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=15) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("pe", [1, 2]) +def test_where_stitched_ip_rtlsim(pe): + model = _prepare_where_stitched_ip_model(pe=pe) + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_stitched_ip_rtlsim_broadcast(): + cond_shape = [1, 3, 1] + x_shape = [1, 1, 4] + y_shape = [1, 2, 1, 1] + out_shape = [1, 2, 3, 4] + model = _prepare_where_stitched_ip_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape + ) + expected = np.where(cond.astype(bool), xval, yval) + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_stitched_ip_synth_ooc(): + model = _prepare_where_stitched_ip_model(pe=2) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0 From 2b2800745fb530325f7d497c465ad84741454bd9 Mon Sep 17 00:00:00 2001 From: ollycassidy13 Date: Fri, 8 May 2026 11:59:14 +0100 Subject: [PATCH 7/7] rtl and precommit update --- finn-rtllib/where/hdl/where.sv | 247 ++++++++++++------ finn-rtllib/where/hdl/where_core_template.sv | 5 +- .../custom_op/fpgadataflow/rtl/where_rtl.py | 1 + src/finn/custom_op/fpgadataflow/where.py | 15 +- tests/fpgadataflow/test_fpgadataflow_where.py | 23 +- 5 files changed, 186 insertions(+), 105 deletions(-) diff --git a/finn-rtllib/where/hdl/where.sv b/finn-rtllib/where/hdl/where.sv index f50a92af9e..61244bbc75 100644 --- a/finn-rtllib/where/hdl/where.sv +++ b/finn-rtllib/where/hdl/where.sv @@ -4,18 +4,33 @@ * * SPDX-License-Identifier: BSD-3-Clause * + * @author Oliver Cassidy + * * @brief ONNX Where stream operator with multidirectional broadcasting. * * @description - * The three input tensors are consumed once per frame into local word - * memories. The output tensor is then emitted in row-major folded order. - * This frame-buffered schedule supports full ONNX multidirectional - * broadcasting, including reuse across non-contiguous output positions. + * This module implements the ONNX expression: + * + * OUT = COND ? X : Y + * + * after applying ONNX multidirectional broadcasting across COND, X and Y. + * Each input stream carries one complete tensor frame folded by its own + * innermost dimension. All three frames are first buffered, then output + * words are read in row-major folded order and selected lane by lane. + * + * COND stream ---> C frame buffer ---\ + * X stream ------> X frame buffer ----+--> registered read --> select --> OUT stream + * Y stream ------> Y frame buffer ---/ + * + * The frame-buffered schedule is required for broadcast reuse across + * non-contiguous output positions. The read data and selected output are + * registered so the memory output does not feed the AXI/stream output + * combinatorially. ***************************************************************************/ `default_nettype none -module where_broadcast #( +module where #( int unsigned DATA_WIDTH = 32, int unsigned PE = 1, int unsigned NDIMS = 2, @@ -27,6 +42,7 @@ module where_broadcast #( parameter int unsigned COND_SHAPE[COND_NDIMS] = '{ default: 1 }, parameter int unsigned X_SHAPE[X_NDIMS] = '{ default: 1 }, parameter int unsigned Y_SHAPE[Y_NDIMS] = '{ default: 1 }, + parameter RAM_STYLE = "auto", localparam int unsigned OUTER_DIMS = (NDIMS > 1)? NDIMS-1 : 1, localparam int unsigned COND_PE = (COND_SHAPE[COND_NDIMS-1] == 1)? 1 : PE, @@ -58,7 +74,7 @@ module where_broadcast #( input wire logic ordy ); - typedef int unsigned outer_idx_t[OUTER_DIMS]; + typedef logic [31:0] outer_idx_t[OUTER_DIMS]; typedef logic [COND_PE-1:0] cond_word_t; typedef logic [X_PE-1:0][DATA_WIDTH-1:0] x_word_t; typedef logic [Y_PE-1:0][DATA_WIDTH-1:0] y_word_t; @@ -179,6 +195,15 @@ module where_broadcast #( localparam int unsigned COND_WORDS = cond_word_count(); localparam int unsigned X_WORDS = x_word_count(); localparam int unsigned Y_WORDS = y_word_count(); + localparam int unsigned COND_ADDR_WIDTH = (COND_WORDS > 1)? $clog2(COND_WORDS) : 1; + localparam int unsigned X_ADDR_WIDTH = (X_WORDS > 1)? $clog2(X_WORDS) : 1; + localparam int unsigned Y_ADDR_WIDTH = (Y_WORDS > 1)? $clog2(Y_WORDS) : 1; + localparam int unsigned OUT_FOLD_WIDTH = (OUT_FOLDS > 1)? $clog2(OUT_FOLDS) : 1; + + typedef logic [COND_ADDR_WIDTH-1:0] cond_addr_t; + typedef logic [X_ADDR_WIDTH-1:0] x_addr_t; + typedef logic [Y_ADDR_WIDTH-1:0] y_addr_t; + typedef logic [OUT_FOLD_WIDTH-1:0] out_fold_t; initial begin automatic int unsigned max_dim; @@ -257,37 +282,74 @@ module where_broadcast #( end end - //------------------------------------------------------------------------ + //======================================================================= // Frame Input Buffers + (* RAM_STYLE = RAM_STYLE *) cond_word_t Cmem[COND_WORDS]; + (* RAM_STYLE = RAM_STYLE *) x_word_t Xmem[X_WORDS]; + (* RAM_STYLE = RAM_STYLE *) y_word_t Ymem[Y_WORDS]; - int unsigned CWr = 0; - int unsigned XWr = 0; - int unsigned YWr = 0; + cond_addr_t CWr = 0; + x_addr_t XWr = 0; + y_addr_t YWr = 0; logic CLoaded = 0; logic XLoaded = 0; logic YLoaded = 0; - logic Emit = 0; + logic Reading = 0; + logic ReadValid = 0; + logic OValid = 0; - assign crdy = !Emit && !CLoaded; - assign xrdy = !Emit && !XLoaded; - assign yrdy = !Emit && !YLoaded; + uwire frame_busy = Reading || ReadValid || OValid; + assign crdy = !frame_busy && !CLoaded; + assign xrdy = !frame_busy && !XLoaded; + assign yrdy = !frame_busy && !YLoaded; uwire c_fire = cvld && crdy; uwire x_fire = xvld && xrdy; uwire y_fire = yvld && yrdy; - uwire emit_fire = Emit && ordy; + uwire output_fire = OValid && ordy; uwire c_loaded_now = CLoaded || (c_fire && CWr == COND_WORDS-1); uwire x_loaded_now = XLoaded || (x_fire && XWr == X_WORDS-1); uwire y_loaded_now = YLoaded || (y_fire && YWr == Y_WORDS-1); + uwire start_reading = !frame_busy && c_loaded_now && x_loaded_now && y_loaded_now; + + uwire frame_done = output_fire && !Reading && !ReadValid; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + CWr <= 0; + XWr <= 0; + YWr <= 0; + CLoaded <= 0; + XLoaded <= 0; + YLoaded <= 0; + end + else begin + if(c_fire) begin + Cmem[CWr] <= cdat; + CLoaded <= (CWr == COND_WORDS-1); + if(CWr != COND_WORDS-1) CWr <= CWr + 1; + end + if(x_fire) begin + Xmem[XWr] <= xdat; + XLoaded <= (XWr == X_WORDS-1); + if(XWr != X_WORDS-1) XWr <= XWr + 1; + end + if(y_fire) begin + Ymem[YWr] <= ydat; + YLoaded <= (YWr == Y_WORDS-1); + if(YWr != Y_WORDS-1) YWr <= YWr + 1; + end + end + end - //------------------------------------------------------------------------ + //======================================================================= // Output Indexing outer_idx_t OutIdx = '{ default: 0 }; - int unsigned OutFold = 0; + out_fold_t OutFold = 0; uwire out_last_fold = (OutFold == OUT_FOLDS-1); logic out_last_outer; @@ -297,93 +359,108 @@ module where_broadcast #( out_last_outer &= (OutIdx[i] == OUT_SHAPE[i]-1); end uwire out_last = out_last_fold && out_last_outer; - uwire frame_done = emit_fire && out_last; + + uwire output_ready = !OValid || ordy; + uwire read_ready = !ReadValid || output_ready; + uwire read_issue = Reading && read_ready; always_ff @(posedge clk) begin - if(rst) begin - CWr <= 0; - XWr <= 0; - YWr <= 0; - CLoaded <= 0; - XLoaded <= 0; - YLoaded <= 0; - Emit <= 0; + if(rst || frame_done) begin + Reading <= 0; + end + else begin + if(start_reading) + Reading <= 1; + else if(read_issue && out_last) + Reading <= 0; + end + end + + always_ff @(posedge clk) begin + if(rst || frame_done || start_reading) begin OutIdx <= '{ default: 0 }; OutFold <= 0; end - else begin - if(frame_done) begin - CWr <= 0; - XWr <= 0; - YWr <= 0; - CLoaded <= 0; - XLoaded <= 0; - YLoaded <= 0; - Emit <= 0; - OutIdx <= '{ default: 0 }; + else if(read_issue && !out_last) begin + if(out_last_fold) begin + automatic bit carry = 1; OutFold <= 0; - end - else begin - if(c_fire) begin - Cmem[CWr] <= cdat; - CLoaded <= (CWr == COND_WORDS-1); - if(CWr != COND_WORDS-1) CWr <= CWr + 1; - end - if(x_fire) begin - Xmem[XWr] <= xdat; - XLoaded <= (XWr == X_WORDS-1); - if(XWr != X_WORDS-1) XWr <= XWr + 1; - end - if(y_fire) begin - Ymem[YWr] <= ydat; - YLoaded <= (YWr == Y_WORDS-1); - if(YWr != Y_WORDS-1) YWr <= YWr + 1; - end - if(!Emit && c_loaded_now && x_loaded_now && y_loaded_now) - Emit <= 1; - else if(emit_fire) begin - if(out_last_fold) begin - automatic bit carry = 1; - OutFold <= 0; - for(int i = int'(NDIMS)-2; i >= 0; i--) begin - if(carry) begin - if(OutIdx[i] == OUT_SHAPE[i]-1) begin - OutIdx[i] <= 0; - end - else begin - OutIdx[i] <= OutIdx[i] + 1; - carry = 0; - end - end + for(int i = int'(NDIMS)-2; i >= 0; i--) begin + if(carry) begin + if(OutIdx[i] == OUT_SHAPE[i]-1) begin + OutIdx[i] <= 0; + end + else begin + OutIdx[i] <= OutIdx[i] + 1; + carry = 0; end end - else - OutFold <= OutFold + 1; end end + else + OutFold <= OutFold + 1; end end - //------------------------------------------------------------------------ - // Broadcast Selection - uwire logic [31:0] c_addr = cond_word_addr(OutIdx, OutFold); - uwire logic [31:0] x_addr = x_word_addr(OutIdx, OutFold); - uwire logic [31:0] y_addr = y_word_addr(OutIdx, OutFold); - uwire cond_word_t c_word = Cmem[c_addr]; - uwire x_word_t x_word = Xmem[x_addr]; - uwire y_word_t y_word = Ymem[y_addr]; + //======================================================================= + // Registered Broadcast Reads + uwire cond_addr_t c_addr = cond_addr_t'(cond_word_addr(OutIdx, OutFold)); + uwire x_addr_t x_addr = x_addr_t'(x_word_addr(OutIdx, OutFold)); + uwire y_addr_t y_addr = y_addr_t'(y_word_addr(OutIdx, OutFold)); + + cond_word_t CWord = 'x; + x_word_t XWord = 'x; + y_word_t YWord = 'x; + always_ff @(posedge clk) begin + if(rst || frame_done) begin + ReadValid <= 0; + CWord <= 'x; + XWord <= 'x; + YWord <= 'x; + end + else if(read_ready) begin + ReadValid <= read_issue; + if(read_issue) begin + CWord <= Cmem[c_addr]; + XWord <= Xmem[x_addr]; + YWord <= Ymem[y_addr]; + end + else begin + CWord <= 'x; + XWord <= 'x; + YWord <= 'x; + end + end + end + + //======================================================================= + // Broadcast Selection out_word_t selected; for(genvar lane = 0; lane < PE; lane++) begin : genSelect - uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? c_word[0] : c_word[lane]; - uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? x_word[0] : x_word[lane]; - uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? y_word[0] : y_word[lane]; + uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? CWord[0] : CWord[lane]; + uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? XWord[0] : XWord[lane]; + uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? YWord[0] : YWord[lane]; assign selected[lane] = c? x : y; end : genSelect - assign odat = selected; - assign ovld = Emit; + out_word_t ODat = 'x; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + OValid <= 0; + ODat <= 'x; + end + else if(output_ready) begin + OValid <= ReadValid; + if(ReadValid) ODat <= selected; + else ODat <= 'x; + end + end + + assign odat = ODat; + assign ovld = OValid; -endmodule : where_broadcast +endmodule : where `default_nettype wire diff --git a/finn-rtllib/where/hdl/where_core_template.sv b/finn-rtllib/where/hdl/where_core_template.sv index 381b1db9a5..cd1f931fec 100644 --- a/finn-rtllib/where/hdl/where_core_template.sv +++ b/finn-rtllib/where/hdl/where_core_template.sv @@ -69,7 +69,7 @@ module $TOP_MODULE_NAME$_core #( end endgenerate - where_broadcast #( + where #( .DATA_WIDTH($DATA_WIDTH$), .PE($PE$), .NDIMS($NDIMS$), @@ -79,7 +79,8 @@ module $TOP_MODULE_NAME$_core #( .OUT_SHAPE($OUT_SHAPE$), .COND_SHAPE($COND_SHAPE$), .X_SHAPE($X_SHAPE$), - .Y_SHAPE($Y_SHAPE$) + .Y_SHAPE($Y_SHAPE$), + .RAM_STYLE($RAM_STYLE$) ) impl ( .clk(ap_clk), .rst(!ap_rst_n), diff --git a/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py index 7497b44399..0a7c30050e 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py @@ -97,6 +97,7 @@ def generate_hdl(self, model, fpgapart, clk): "X_WIDTH": x_width, "Y_WIDTH": y_width, "OUT_WIDTH": out_width, + "RAM_STYLE": '"{}"'.format(self.get_nodeattr("ram_style")), } for key, value in code_gen_dict.items(): diff --git a/src/finn/custom_op/fpgadataflow/where.py b/src/finn/custom_op/fpgadataflow/where.py index b83dabbd3d..ea8c982e6c 100644 --- a/src/finn/custom_op/fpgadataflow/where.py +++ b/src/finn/custom_op/fpgadataflow/where.py @@ -54,6 +54,12 @@ def get_nodeattr_types(self): "conditionDataType": ("s", False, "BINARY"), "inputDataType": ("s", True, ""), "outputDataType": ("s", False, ""), + "ram_style": ( + "s", + False, + "auto", + {"auto", "block", "distributed", "ultra"}, + ), "inFIFODepths": ("ints", False, [2, 2, 2]), "outFIFODepths": ("ints", False, [2]), } @@ -76,7 +82,10 @@ def _input_shape(self, ind): rank = self.get_nodeattr(rank_name) shape = tuple(self.get_nodeattr(attr_name)) if rank >= 0: - assert len(shape) == rank, "%s length must match %s" % (attr_name, rank_name) + assert len(shape) == rank, "%s length must match %s" % ( + attr_name, + rank_name, + ) return shape if len(shape) != 0: return shape @@ -204,7 +213,9 @@ def get_number_output_values(self): return int(np.prod(self.get_folded_output_shape()[:-1])) def get_exp_cycles(self): - return self.get_number_output_values() + input_cycles = max(int(np.prod(self.get_folded_input_shape(ind)[:-1])) for ind in range(3)) + output_cycles = self.get_number_output_values() + return input_cycles + output_cycles + 4 def execute_node(self, context, graph): node = self.onnx_node diff --git a/tests/fpgadataflow/test_fpgadataflow_where.py b/tests/fpgadataflow/test_fpgadataflow_where.py index 4aa9b53bf7..3269a04061 100644 --- a/tests/fpgadataflow/test_fpgadataflow_where.py +++ b/tests/fpgadataflow/test_fpgadataflow_where.py @@ -222,7 +222,7 @@ def test_convert_onnx_where_to_where(): inst = getCustomOp(node) assert inst.get_normal_output_shape() == (1, 2, 4) - assert inst.get_exp_cycles() == 8 + assert inst.get_exp_cycles() == 20 ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) assert (ret["out"] == expected).all() @@ -242,9 +242,7 @@ def test_convert_onnx_where_broadcast_to_where(): model = _make_onnx_where_model( cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape ) - cond, xval, yval = _make_inputs( - cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape - ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) expected = np.where(cond.astype(bool), xval, yval) ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) @@ -278,9 +276,7 @@ def test_convert_onnx_where_scalar_broadcast_to_where(): model = _make_onnx_where_model( cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape ) - cond, xval, yval = _make_inputs( - cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape - ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) expected = np.where(cond.astype(bool), xval, yval) ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) @@ -354,9 +350,7 @@ def test_where_python_execution_broadcast(): y_shape=y_shape, out_shape=out_shape, ) - cond, xval, yval = _make_inputs( - cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape - ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) expected = np.where(cond.astype(bool), xval, yval) ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) @@ -400,6 +394,7 @@ def test_where_rtl_codegen(tmp_path, finn_dtype, fold_width): assert ".DATA_WIDTH(%d)" % finn_dtype.bitwidth() in core_wrapper_text assert ".PE(2)" in core_wrapper_text assert ".NDIMS(3)" in core_wrapper_text + assert '.RAM_STYLE("auto")' in core_wrapper_text assert "in2_V_TDATA" in wrapper_text assert "out0_V_TVALID" in wrapper_text @@ -538,9 +533,7 @@ def test_where_rtlsim_broadcast(): y_shape=y_shape, out_shape=out_shape, ) - cond, xval, yval = _make_inputs( - cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape - ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) expected = np.where(cond.astype(bool), xval, yval) model = model.transform(SpecializeLayers(FPGA_PART)) @@ -591,9 +584,7 @@ def test_where_stitched_ip_rtlsim_broadcast(): y_shape=y_shape, out_shape=out_shape, ) - cond, xval, yval = _make_inputs( - cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape - ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) expected = np.where(cond.astype(bool), xval, yval) model.set_metadata_prop("exec_mode", "rtlsim")