diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst index 84e9633304..2a6e716031 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst @@ -39,6 +39,14 @@ RTLBackend :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.addclstoken +----------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.addclstoken + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.addstreams ---------------------------------------- @@ -111,6 +119,22 @@ finn.custom\_op.fpgadataflow.labelselect :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken +----------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.selecttoken + :members: + :undoc-members: + :show-inheritance: + +finn.custom\_op.fpgadataflow.where +----------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.where + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.lookup ----------------------------------------------- diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst index 346eddb073..1ad68f9818 100644 --- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst +++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst @@ -5,6 +5,14 @@ Custom Op - fpgadataflow.rtl RTL Custom Op Nodes =================== +finn.custom\_op.fpgadataflow.rtl.addclstoken\_rtl +-------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.addclstoken_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.convolutioninputgenerator\_rtl ------------------------------------------------------------ @@ -37,6 +45,22 @@ finn.custom\_op.fpgadataflow.streamingdatawidthconverter\_rtl :undoc-members: :show-inheritance: +finn.custom\_op.fpgadataflow.selecttoken\_rtl +--------------------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.selecttoken_rtl + :members: + :undoc-members: + :show-inheritance: + +finn.custom\_op.fpgadataflow.where\_rtl +--------------------------------------------------------------- + +.. automodule:: finn.custom_op.fpgadataflow.rtl.where_rtl + :members: + :undoc-members: + :show-inheritance: + finn.custom\_op.fpgadataflow.streamingfifo\_rtl ------------------------------------------------- diff --git a/finn-rtllib/addclstoken/hdl/addclstoken.sv b/finn-rtllib/addclstoken/hdl/addclstoken.sv new file mode 100644 index 0000000000..d5bbdc2188 --- /dev/null +++ b/finn-rtllib/addclstoken/hdl/addclstoken.sv @@ -0,0 +1,139 @@ +/**************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief Insert a constant class token into a folded token stream. + * @author Oliver Cassidy + * + * @description + * Prepends a learned class token, supplied through cls_data, to each + * input sequence of patch tokens. The class token and patch tokens are + * transferred as SIMD-wide folds of ELEM_WIDTH-bit elements. + * + * Per sequence, the output stream is: + * 1. NUM_CHANNELS/SIMD folds from cls_data + * 2. NUM_TOKENS pass-through input tokens + * 3. PAD_TOKENS zero-valued tokens, when padding is enabled + ***************************************************************************/ + +module addclstoken #( + parameter int unsigned NUM_TOKENS = 196, + parameter int unsigned NUM_CHANNELS = 192, + parameter int unsigned SIMD = 1, + parameter int unsigned ELEM_WIDTH = 8, + parameter int unsigned PAD_TOKENS = 0 +)( + input logic clk, + input logic rst, + + output logic irdy, + input logic ivld, + input logic [SIMD*ELEM_WIDTH-1:0] idat, + + input logic ordy, + output logic ovld, + output logic [SIMD*ELEM_WIDTH-1:0] odat, + + input logic [NUM_CHANNELS*ELEM_WIDTH-1:0] cls_data +); + + localparam int unsigned FOLD_WIDTH = SIMD * ELEM_WIDTH; + localparam int unsigned FOLDS_PER_TOKEN = NUM_CHANNELS / SIMD; + localparam int unsigned TOTAL_INPUT_FOLDS = NUM_TOKENS * FOLDS_PER_TOKEN; + localparam int unsigned TOTAL_PAD_FOLDS = PAD_TOKENS * FOLDS_PER_TOKEN; + localparam int unsigned MAX_PHASE_FOLDS = + (TOTAL_INPUT_FOLDS > FOLDS_PER_TOKEN) ? + ((TOTAL_INPUT_FOLDS > TOTAL_PAD_FOLDS) ? + TOTAL_INPUT_FOLDS : TOTAL_PAD_FOLDS) : + ((FOLDS_PER_TOKEN > TOTAL_PAD_FOLDS) ? + FOLDS_PER_TOKEN : TOTAL_PAD_FOLDS); + localparam int unsigned CNT_WIDTH = (MAX_PHASE_FOLDS <= 1) ? 1 : $clog2(MAX_PHASE_FOLDS); + + typedef enum logic [1:0] { + EMIT_CLS, + PASSTHROUGH, + EMIT_PAD + } state_t; + + state_t state; + state_t next_state; + logic [CNT_WIDTH-1:0] fold_cnt; + logic fold_cnt_last; + logic out_transfer; + + logic [CNT_WIDTH-1:0] cls_fold_cnt; + logic [FOLD_WIDTH-1:0] cls_fold; + + assign cls_fold_cnt = (int'(fold_cnt) < FOLDS_PER_TOKEN) ? fold_cnt : '0; + assign cls_fold = cls_data[cls_fold_cnt * FOLD_WIDTH +: FOLD_WIDTH]; + assign out_transfer = ovld & ordy; + + always_comb begin + unique case (state) + EMIT_CLS: fold_cnt_last = (int'(fold_cnt) == FOLDS_PER_TOKEN - 1); + PASSTHROUGH: fold_cnt_last = (int'(fold_cnt) == TOTAL_INPUT_FOLDS - 1); + EMIT_PAD: fold_cnt_last = (int'(fold_cnt) == TOTAL_PAD_FOLDS - 1); + default: fold_cnt_last = 1'b1; + endcase + end + + always_comb begin + irdy = 1'b0; + ovld = 1'b0; + odat = '0; + + unique case (state) + EMIT_CLS: begin + ovld = 1'b1; + odat = cls_fold; + end + PASSTHROUGH: begin + irdy = ordy; + ovld = ivld; + odat = idat; + end + EMIT_PAD: begin + ovld = 1'b1; + end + default: begin + end + endcase + end + + always_comb begin + next_state = state; + if (out_transfer && fold_cnt_last) begin + unique case (state) + EMIT_CLS: begin + next_state = PASSTHROUGH; + end + PASSTHROUGH: begin + next_state = (PAD_TOKENS == 0) ? EMIT_CLS : EMIT_PAD; + end + EMIT_PAD: begin + next_state = EMIT_CLS; + end + default: begin + next_state = EMIT_CLS; + end + endcase + end + end + + always_ff @(posedge clk) begin + if (rst) begin + state <= EMIT_CLS; + fold_cnt <= '0; + end else if (out_transfer) begin + if (fold_cnt_last) begin + state <= next_state; + fold_cnt <= '0; + end else begin + fold_cnt <= fold_cnt + 1'b1; + end + end + end + +endmodule diff --git a/finn-rtllib/addclstoken/hdl/addclstoken_template.v b/finn-rtllib/addclstoken/hdl/addclstoken_template.v new file mode 100644 index 0000000000..d38dd72bed --- /dev/null +++ b/finn-rtllib/addclstoken/hdl/addclstoken_template.v @@ -0,0 +1,81 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter FOLD_WIDTH = $FOLD_WIDTH$, + parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [AXI_WIDTH-1:0] in0_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [AXI_WIDTH-1:0] out0_V_TDATA +); + + localparam [$CLS_WIDTH$-1:0] CLS_DATA = $CLS_DATA$; + + wire [FOLD_WIDTH-1:0] core_out; + + assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out; + + generate + if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}}; + end + endgenerate + + addclstoken #( + .NUM_TOKENS($NUM_TOKENS$), + .NUM_CHANNELS($NUM_CHANNELS$), + .SIMD($SIMD$), + .ELEM_WIDTH($ELEM_WIDTH$), + .PAD_TOKENS($PAD_TOKENS$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .irdy(in0_V_TREADY), + .ivld(in0_V_TVALID), + .idat(in0_V_TDATA[FOLD_WIDTH-1:0]), + .ordy(out0_V_TREADY), + .ovld(out0_V_TVALID), + .odat(core_out), + .cls_data(CLS_DATA) + ); + +endmodule diff --git a/finn-rtllib/selecttoken/hdl/select_token.sv b/finn-rtllib/selecttoken/hdl/select_token.sv new file mode 100644 index 0000000000..fb4c3df800 --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token.sv @@ -0,0 +1,82 @@ +/**************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief Select one token from a folded token stream. + * @author Oliver Cassidy + * + * @description + * Consumes NUM_TOKENS token vectors fold-by-fold. Folds belonging to + * TOKEN_INDEX are forwarded to the output stream; all other folds are + * consumed and discarded. + ***************************************************************************/ + +module select_token #( + parameter int unsigned NUM_TOKENS = 197, + parameter int unsigned NUM_CHANNELS = 192, + parameter int unsigned SIMD = 1, + parameter int unsigned ELEM_WIDTH = 8, + parameter int unsigned TOKEN_INDEX = 0 +)( + input logic clk, + input logic rst, + + output logic irdy, + input logic ivld, + input logic [SIMD*ELEM_WIDTH-1:0] idat, + + input logic ordy, + output logic ovld, + output logic [SIMD*ELEM_WIDTH-1:0] odat +); + + localparam int unsigned FOLDS_PER_TOKEN = NUM_CHANNELS / SIMD; + localparam int unsigned TOKEN_CNT_WIDTH = (NUM_TOKENS <= 1) ? 1 : $clog2(NUM_TOKENS); + localparam int unsigned FOLD_CNT_WIDTH = + (FOLDS_PER_TOKEN <= 1) ? 1 : $clog2(FOLDS_PER_TOKEN); + + logic [TOKEN_CNT_WIDTH-1:0] token_cnt; + logic [FOLD_CNT_WIDTH-1:0] fold_cnt; + logic is_selected; + logic in_transfer; + logic fold_cnt_last; + logic token_cnt_last; + + assign is_selected = (int'(token_cnt) == TOKEN_INDEX); + assign in_transfer = irdy & ivld; + assign fold_cnt_last = (int'(fold_cnt) == FOLDS_PER_TOKEN - 1); + assign token_cnt_last = (int'(token_cnt) == NUM_TOKENS - 1); + + always_comb begin + irdy = 1'b1; + ovld = 1'b0; + odat = '0; + + if (is_selected) begin + irdy = ordy; + ovld = ivld; + odat = idat; + end + end + + always_ff @(posedge clk) begin + if (rst) begin + token_cnt <= '0; + fold_cnt <= '0; + end else if (in_transfer) begin + if (fold_cnt_last) begin + fold_cnt <= '0; + if (token_cnt_last) begin + token_cnt <= '0; + end else begin + token_cnt <= token_cnt + 1'b1; + end + end else begin + fold_cnt <= fold_cnt + 1'b1; + end + end + end + +endmodule diff --git a/finn-rtllib/selecttoken/hdl/select_token_template.v b/finn-rtllib/selecttoken/hdl/select_token_template.v new file mode 100644 index 0000000000..566fa63ac5 --- /dev/null +++ b/finn-rtllib/selecttoken/hdl/select_token_template.v @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter FOLD_WIDTH = $FOLD_WIDTH$, + parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [AXI_WIDTH-1:0] in0_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [AXI_WIDTH-1:0] out0_V_TDATA +); + + wire [FOLD_WIDTH-1:0] core_out; + + assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out; + + generate + if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}}; + end + endgenerate + + select_token #( + .NUM_TOKENS($NUM_TOKENS$), + .NUM_CHANNELS($NUM_CHANNELS$), + .SIMD($SIMD$), + .ELEM_WIDTH($ELEM_WIDTH$), + .TOKEN_INDEX($TOKEN_INDEX$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .irdy(in0_V_TREADY), + .ivld(in0_V_TVALID), + .idat(in0_V_TDATA[FOLD_WIDTH-1:0]), + .ordy(out0_V_TREADY), + .ovld(out0_V_TVALID), + .odat(core_out) + ); + +endmodule diff --git a/finn-rtllib/where/hdl/where.sv b/finn-rtllib/where/hdl/where.sv new file mode 100644 index 0000000000..61244bbc75 --- /dev/null +++ b/finn-rtllib/where/hdl/where.sv @@ -0,0 +1,466 @@ +/**************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Oliver Cassidy + * + * @brief ONNX Where stream operator with multidirectional broadcasting. + * + * @description + * This module implements the ONNX expression: + * + * OUT = COND ? X : Y + * + * after applying ONNX multidirectional broadcasting across COND, X and Y. + * Each input stream carries one complete tensor frame folded by its own + * innermost dimension. All three frames are first buffered, then output + * words are read in row-major folded order and selected lane by lane. + * + * COND stream ---> C frame buffer ---\ + * X stream ------> X frame buffer ----+--> registered read --> select --> OUT stream + * Y stream ------> Y frame buffer ---/ + * + * The frame-buffered schedule is required for broadcast reuse across + * non-contiguous output positions. The read data and selected output are + * registered so the memory output does not feed the AXI/stream output + * combinatorially. + ***************************************************************************/ + +`default_nettype none + +module where #( + int unsigned DATA_WIDTH = 32, + int unsigned PE = 1, + int unsigned NDIMS = 2, + int unsigned COND_NDIMS = NDIMS, + int unsigned X_NDIMS = NDIMS, + int unsigned Y_NDIMS = NDIMS, + + parameter int unsigned OUT_SHAPE[NDIMS] = '{ default: 1 }, + parameter int unsigned COND_SHAPE[COND_NDIMS] = '{ default: 1 }, + parameter int unsigned X_SHAPE[X_NDIMS] = '{ default: 1 }, + parameter int unsigned Y_SHAPE[Y_NDIMS] = '{ default: 1 }, + parameter RAM_STYLE = "auto", + + localparam int unsigned OUTER_DIMS = (NDIMS > 1)? NDIMS-1 : 1, + localparam int unsigned COND_PE = (COND_SHAPE[COND_NDIMS-1] == 1)? 1 : PE, + localparam int unsigned X_PE = (X_SHAPE[X_NDIMS-1] == 1)? 1 : PE, + localparam int unsigned Y_PE = (Y_SHAPE[Y_NDIMS-1] == 1)? 1 : PE +)( + // Global Control + input wire logic clk, + input wire logic rst, + + // Condition Stream - folded according to COND_SHAPE + input wire logic [COND_PE-1:0] cdat, + input wire logic cvld, + output logic crdy, + + // X Stream - folded according to X_SHAPE + input wire logic [X_PE-1:0][DATA_WIDTH-1:0] xdat, + input wire logic xvld, + output logic xrdy, + + // Y Stream - folded according to Y_SHAPE + input wire logic [Y_PE-1:0][DATA_WIDTH-1:0] ydat, + input wire logic yvld, + output logic yrdy, + + // Output Stream - folded according to OUT_SHAPE and PE + output logic [PE-1:0][DATA_WIDTH-1:0] odat, + output logic ovld, + input wire logic ordy +); + + typedef logic [31:0] outer_idx_t[OUTER_DIMS]; + typedef logic [COND_PE-1:0] cond_word_t; + typedef logic [X_PE-1:0][DATA_WIDTH-1:0] x_word_t; + typedef logic [Y_PE-1:0][DATA_WIDTH-1:0] y_word_t; + typedef logic [PE-1:0][DATA_WIDTH-1:0] out_word_t; + + function automatic int unsigned out_outer_elems(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < NDIMS; i++) + r *= OUT_SHAPE[i]; + return r; + endfunction : out_outer_elems + + function automatic int unsigned cond_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < COND_NDIMS; i++) + r *= COND_SHAPE[i]; + if(COND_SHAPE[COND_NDIMS-1] != 1) + r *= COND_SHAPE[COND_NDIMS-1] / PE; + return r; + endfunction : cond_word_count + + function automatic int unsigned x_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < X_NDIMS; i++) + r *= X_SHAPE[i]; + if(X_SHAPE[X_NDIMS-1] != 1) + r *= X_SHAPE[X_NDIMS-1] / PE; + return r; + endfunction : x_word_count + + function automatic int unsigned y_word_count(); + automatic int unsigned r = 1; + for(int unsigned i = 0; i+1 < Y_NDIMS; i++) + r *= Y_SHAPE[i]; + if(Y_SHAPE[Y_NDIMS-1] != 1) + r *= Y_SHAPE[Y_NDIMS-1] / PE; + return r; + endfunction : y_word_count + + function automatic int unsigned out_word_count(); + return out_outer_elems() * (OUT_SHAPE[NDIMS-1] / PE); + endfunction : out_word_count + + function automatic int unsigned cond_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(COND_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return COND_SHAPE[source_axis]; + endfunction : cond_dim + + function automatic int unsigned x_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(X_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return X_SHAPE[source_axis]; + endfunction : x_dim + + function automatic int unsigned y_dim(input int unsigned axis); + automatic int signed source_axis = int'(axis) + int'(Y_NDIMS) - int'(NDIMS); + if(source_axis < 0) + return 1; + return Y_SHAPE[source_axis]; + endfunction : y_dim + + function automatic int unsigned cond_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(COND_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= COND_SHAPE[source_axis]; + if(COND_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(COND_SHAPE[COND_NDIMS-1] != 1) + r = r * (COND_SHAPE[COND_NDIMS-1] / PE) + out_fold; + return r; + endfunction : cond_word_addr + + function automatic int unsigned x_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(X_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= X_SHAPE[source_axis]; + if(X_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(X_SHAPE[X_NDIMS-1] != 1) + r = r * (X_SHAPE[X_NDIMS-1] / PE) + out_fold; + return r; + endfunction : x_word_addr + + function automatic int unsigned y_word_addr( + input outer_idx_t out_idx, + input int unsigned out_fold + ); + automatic int unsigned r = 0; + for(int unsigned i = 0; i+1 < NDIMS; i++) begin + automatic int signed source_axis = int'(i) + int'(Y_NDIMS) - int'(NDIMS); + if(source_axis >= 0) begin + r *= Y_SHAPE[source_axis]; + if(Y_SHAPE[source_axis] != 1) r += out_idx[i]; + end + end + if(Y_SHAPE[Y_NDIMS-1] != 1) + r = r * (Y_SHAPE[Y_NDIMS-1] / PE) + out_fold; + return r; + endfunction : y_word_addr + + localparam int unsigned OUT_FOLDS = OUT_SHAPE[NDIMS-1] / PE; + localparam int unsigned OUT_WORDS = out_word_count(); + localparam int unsigned COND_WORDS = cond_word_count(); + localparam int unsigned X_WORDS = x_word_count(); + localparam int unsigned Y_WORDS = y_word_count(); + localparam int unsigned COND_ADDR_WIDTH = (COND_WORDS > 1)? $clog2(COND_WORDS) : 1; + localparam int unsigned X_ADDR_WIDTH = (X_WORDS > 1)? $clog2(X_WORDS) : 1; + localparam int unsigned Y_ADDR_WIDTH = (Y_WORDS > 1)? $clog2(Y_WORDS) : 1; + localparam int unsigned OUT_FOLD_WIDTH = (OUT_FOLDS > 1)? $clog2(OUT_FOLDS) : 1; + + typedef logic [COND_ADDR_WIDTH-1:0] cond_addr_t; + typedef logic [X_ADDR_WIDTH-1:0] x_addr_t; + typedef logic [Y_ADDR_WIDTH-1:0] y_addr_t; + typedef logic [OUT_FOLD_WIDTH-1:0] out_fold_t; + + initial begin + automatic int unsigned max_dim; + automatic int unsigned cd; + automatic int unsigned xd; + automatic int unsigned yd; + + if(DATA_WIDTH < 1) begin + $error("%m: DATA_WIDTH must be positive."); + $finish; + end + if(PE < 1) begin + $error("%m: PE must be positive."); + $finish; + end + if(NDIMS < 1) begin + $error("%m: NDIMS must be positive."); + $finish; + end + if(COND_NDIMS < 1 || COND_NDIMS > NDIMS) begin + $error("%m: COND_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if(X_NDIMS < 1 || X_NDIMS > NDIMS) begin + $error("%m: X_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if(Y_NDIMS < 1 || Y_NDIMS > NDIMS) begin + $error("%m: Y_NDIMS must be in the range 1..NDIMS."); + $finish; + end + if((OUT_SHAPE[NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the output innermost dimension."); + $finish; + end + for(int unsigned i = 0; i < NDIMS; i++) begin + cd = cond_dim(i); + xd = x_dim(i); + yd = y_dim(i); + max_dim = cd; + + if(cd < 1 || xd < 1 || yd < 1 || OUT_SHAPE[i] < 1) begin + $error("%m: shape dimensions must be positive."); + $finish; + end + if(xd != 1 && max_dim != 1 && xd != max_dim) begin + $error("%m: X_SHAPE is not broadcast-compatible."); + $finish; + end + if(xd != 1) max_dim = xd; + if(yd != 1 && max_dim != 1 && yd != max_dim) begin + $error("%m: Y_SHAPE is not broadcast-compatible."); + $finish; + end + if(yd != 1) max_dim = yd; + if(cd != 1 && cd != max_dim) begin + $error("%m: COND_SHAPE is not broadcast-compatible."); + $finish; + end + if(OUT_SHAPE[i] != max_dim) begin + $error("%m: OUT_SHAPE is not the multidirectional broadcast result."); + $finish; + end + end + if(COND_SHAPE[COND_NDIMS-1] != 1 && (COND_SHAPE[COND_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the condition innermost dimension when not broadcast."); + $finish; + end + if(X_SHAPE[X_NDIMS-1] != 1 && (X_SHAPE[X_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the X innermost dimension when not broadcast."); + $finish; + end + if(Y_SHAPE[Y_NDIMS-1] != 1 && (Y_SHAPE[Y_NDIMS-1] % PE) != 0) begin + $error("%m: PE must divide the Y innermost dimension when not broadcast."); + $finish; + end + end + + //======================================================================= + // Frame Input Buffers + (* RAM_STYLE = RAM_STYLE *) + cond_word_t Cmem[COND_WORDS]; + (* RAM_STYLE = RAM_STYLE *) + x_word_t Xmem[X_WORDS]; + (* RAM_STYLE = RAM_STYLE *) + y_word_t Ymem[Y_WORDS]; + + cond_addr_t CWr = 0; + x_addr_t XWr = 0; + y_addr_t YWr = 0; + logic CLoaded = 0; + logic XLoaded = 0; + logic YLoaded = 0; + logic Reading = 0; + logic ReadValid = 0; + logic OValid = 0; + + uwire frame_busy = Reading || ReadValid || OValid; + assign crdy = !frame_busy && !CLoaded; + assign xrdy = !frame_busy && !XLoaded; + assign yrdy = !frame_busy && !YLoaded; + + uwire c_fire = cvld && crdy; + uwire x_fire = xvld && xrdy; + uwire y_fire = yvld && yrdy; + uwire output_fire = OValid && ordy; + + uwire c_loaded_now = CLoaded || (c_fire && CWr == COND_WORDS-1); + uwire x_loaded_now = XLoaded || (x_fire && XWr == X_WORDS-1); + uwire y_loaded_now = YLoaded || (y_fire && YWr == Y_WORDS-1); + uwire start_reading = !frame_busy && c_loaded_now && x_loaded_now && y_loaded_now; + + uwire frame_done = output_fire && !Reading && !ReadValid; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + CWr <= 0; + XWr <= 0; + YWr <= 0; + CLoaded <= 0; + XLoaded <= 0; + YLoaded <= 0; + end + else begin + if(c_fire) begin + Cmem[CWr] <= cdat; + CLoaded <= (CWr == COND_WORDS-1); + if(CWr != COND_WORDS-1) CWr <= CWr + 1; + end + if(x_fire) begin + Xmem[XWr] <= xdat; + XLoaded <= (XWr == X_WORDS-1); + if(XWr != X_WORDS-1) XWr <= XWr + 1; + end + if(y_fire) begin + Ymem[YWr] <= ydat; + YLoaded <= (YWr == Y_WORDS-1); + if(YWr != Y_WORDS-1) YWr <= YWr + 1; + end + end + end + + //======================================================================= + // Output Indexing + outer_idx_t OutIdx = '{ default: 0 }; + out_fold_t OutFold = 0; + + uwire out_last_fold = (OutFold == OUT_FOLDS-1); + logic out_last_outer; + always_comb begin + out_last_outer = 1; + for(int unsigned i = 0; i+1 < NDIMS; i++) + out_last_outer &= (OutIdx[i] == OUT_SHAPE[i]-1); + end + uwire out_last = out_last_fold && out_last_outer; + + uwire output_ready = !OValid || ordy; + uwire read_ready = !ReadValid || output_ready; + uwire read_issue = Reading && read_ready; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + Reading <= 0; + end + else begin + if(start_reading) + Reading <= 1; + else if(read_issue && out_last) + Reading <= 0; + end + end + + always_ff @(posedge clk) begin + if(rst || frame_done || start_reading) begin + OutIdx <= '{ default: 0 }; + OutFold <= 0; + end + else if(read_issue && !out_last) begin + if(out_last_fold) begin + automatic bit carry = 1; + OutFold <= 0; + for(int i = int'(NDIMS)-2; i >= 0; i--) begin + if(carry) begin + if(OutIdx[i] == OUT_SHAPE[i]-1) begin + OutIdx[i] <= 0; + end + else begin + OutIdx[i] <= OutIdx[i] + 1; + carry = 0; + end + end + end + end + else + OutFold <= OutFold + 1; + end + end + + //======================================================================= + // Registered Broadcast Reads + uwire cond_addr_t c_addr = cond_addr_t'(cond_word_addr(OutIdx, OutFold)); + uwire x_addr_t x_addr = x_addr_t'(x_word_addr(OutIdx, OutFold)); + uwire y_addr_t y_addr = y_addr_t'(y_word_addr(OutIdx, OutFold)); + + cond_word_t CWord = 'x; + x_word_t XWord = 'x; + y_word_t YWord = 'x; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + ReadValid <= 0; + CWord <= 'x; + XWord <= 'x; + YWord <= 'x; + end + else if(read_ready) begin + ReadValid <= read_issue; + if(read_issue) begin + CWord <= Cmem[c_addr]; + XWord <= Xmem[x_addr]; + YWord <= Ymem[y_addr]; + end + else begin + CWord <= 'x; + XWord <= 'x; + YWord <= 'x; + end + end + end + + //======================================================================= + // Broadcast Selection + out_word_t selected; + for(genvar lane = 0; lane < PE; lane++) begin : genSelect + uwire c = (COND_SHAPE[COND_NDIMS-1] == 1)? CWord[0] : CWord[lane]; + uwire [DATA_WIDTH-1:0] x = (X_SHAPE[X_NDIMS-1] == 1)? XWord[0] : XWord[lane]; + uwire [DATA_WIDTH-1:0] y = (Y_SHAPE[Y_NDIMS-1] == 1)? YWord[0] : YWord[lane]; + assign selected[lane] = c? x : y; + end : genSelect + + out_word_t ODat = 'x; + + always_ff @(posedge clk) begin + if(rst || frame_done) begin + OValid <= 0; + ODat <= 'x; + end + else if(output_ready) begin + OValid <= ReadValid; + if(ReadValid) ODat <= selected; + else ODat <= 'x; + end + end + + assign odat = ODat; + assign ovld = OValid; + +endmodule : where + +`default_nettype wire diff --git a/finn-rtllib/where/hdl/where_core_template.sv b/finn-rtllib/where/hdl/where_core_template.sv new file mode 100644 index 0000000000..cd1f931fec --- /dev/null +++ b/finn-rtllib/where/hdl/where_core_template.sv @@ -0,0 +1,101 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$_core #( + parameter COND_WIDTH = $COND_WIDTH$, + parameter X_WIDTH = $X_WIDTH$, + parameter Y_WIDTH = $Y_WIDTH$, + parameter OUT_WIDTH = $OUT_WIDTH$, + parameter COND_AXI_WIDTH = ((COND_WIDTH + 7) / 8) * 8, + parameter X_AXI_WIDTH = ((X_WIDTH + 7) / 8) * 8, + parameter Y_AXI_WIDTH = ((Y_WIDTH + 7) / 8) * 8, + parameter OUT_AXI_WIDTH = ((OUT_WIDTH + 7) / 8) * 8 +)( + input ap_clk, + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [COND_AXI_WIDTH-1:0] in0_V_TDATA, + + output in1_V_TREADY, + input in1_V_TVALID, + input [X_AXI_WIDTH-1:0] in1_V_TDATA, + + output in2_V_TREADY, + input in2_V_TVALID, + input [Y_AXI_WIDTH-1:0] in2_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [OUT_AXI_WIDTH-1:0] out0_V_TDATA +); + + wire [OUT_WIDTH-1:0] core_out; + + assign out0_V_TDATA[OUT_WIDTH-1:0] = core_out; + + generate + if (OUT_AXI_WIDTH > OUT_WIDTH) begin : gen_pad_tdata + assign out0_V_TDATA[OUT_AXI_WIDTH-1:OUT_WIDTH] = {(OUT_AXI_WIDTH-OUT_WIDTH){1'b0}}; + end + endgenerate + + where #( + .DATA_WIDTH($DATA_WIDTH$), + .PE($PE$), + .NDIMS($NDIMS$), + .COND_NDIMS($COND_NDIMS$), + .X_NDIMS($X_NDIMS$), + .Y_NDIMS($Y_NDIMS$), + .OUT_SHAPE($OUT_SHAPE$), + .COND_SHAPE($COND_SHAPE$), + .X_SHAPE($X_SHAPE$), + .Y_SHAPE($Y_SHAPE$), + .RAM_STYLE($RAM_STYLE$) + ) impl ( + .clk(ap_clk), + .rst(!ap_rst_n), + .cdat(in0_V_TDATA[COND_WIDTH-1:0]), + .cvld(in0_V_TVALID), + .crdy(in0_V_TREADY), + .xdat(in1_V_TDATA[X_WIDTH-1:0]), + .xvld(in1_V_TVALID), + .xrdy(in1_V_TREADY), + .ydat(in2_V_TDATA[Y_WIDTH-1:0]), + .yvld(in2_V_TVALID), + .yrdy(in2_V_TREADY), + .odat(core_out), + .ovld(out0_V_TVALID), + .ordy(out0_V_TREADY) + ); + +endmodule diff --git a/finn-rtllib/where/hdl/where_template.v b/finn-rtllib/where/hdl/where_template.v new file mode 100644 index 0000000000..3479b88bcf --- /dev/null +++ b/finn-rtllib/where/hdl/where_template.v @@ -0,0 +1,91 @@ +/****************************************************************************** + * Copyright (C) 2026, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *****************************************************************************/ + +module $TOP_MODULE_NAME$ #( + parameter COND_WIDTH = $COND_WIDTH$, + parameter X_WIDTH = $X_WIDTH$, + parameter Y_WIDTH = $Y_WIDTH$, + parameter OUT_WIDTH = $OUT_WIDTH$, + parameter COND_AXI_WIDTH = ((COND_WIDTH + 7) / 8) * 8, + parameter X_AXI_WIDTH = ((X_WIDTH + 7) / 8) * 8, + parameter Y_AXI_WIDTH = ((Y_WIDTH + 7) / 8) * 8, + parameter OUT_AXI_WIDTH = ((OUT_WIDTH + 7) / 8) * 8 +)( + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:in1_V:in2_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + output in0_V_TREADY, + input in0_V_TVALID, + input [COND_AXI_WIDTH-1:0] in0_V_TDATA, + + output in1_V_TREADY, + input in1_V_TVALID, + input [X_AXI_WIDTH-1:0] in1_V_TDATA, + + output in2_V_TREADY, + input in2_V_TVALID, + input [Y_AXI_WIDTH-1:0] in2_V_TDATA, + + input out0_V_TREADY, + output out0_V_TVALID, + output [OUT_AXI_WIDTH-1:0] out0_V_TDATA +); + + $TOP_MODULE_NAME$_core #( + .COND_WIDTH(COND_WIDTH), + .X_WIDTH(X_WIDTH), + .Y_WIDTH(Y_WIDTH), + .OUT_WIDTH(OUT_WIDTH), + .COND_AXI_WIDTH(COND_AXI_WIDTH), + .X_AXI_WIDTH(X_AXI_WIDTH), + .Y_AXI_WIDTH(Y_AXI_WIDTH), + .OUT_AXI_WIDTH(OUT_AXI_WIDTH) + ) impl ( + .ap_clk(ap_clk), + .ap_rst_n(ap_rst_n), + .in0_V_TREADY(in0_V_TREADY), + .in0_V_TVALID(in0_V_TVALID), + .in0_V_TDATA(in0_V_TDATA), + .in1_V_TREADY(in1_V_TREADY), + .in1_V_TVALID(in1_V_TVALID), + .in1_V_TDATA(in1_V_TDATA), + .in2_V_TREADY(in2_V_TREADY), + .in2_V_TVALID(in2_V_TVALID), + .in2_V_TDATA(in2_V_TDATA), + .out0_V_TREADY(out0_V_TREADY), + .out0_V_TVALID(out0_V_TVALID), + .out0_V_TDATA(out0_V_TDATA) + ); + +endmodule diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py index 7c5d27dfb9..9c1b235f4a 100644 --- a/src/finn/builder/build_dataflow_steps.py +++ b/src/finn/builder/build_dataflow_steps.py @@ -500,6 +500,9 @@ def apply_if_relevant(model, op_types, transform, desc=""): ) # Streaming operations + model = apply_if_relevant( + model, ["Concat"], to_hw.InferAddCLSTokenLayer(), "CLS token insertion" + ) model = apply_if_relevant(model, ["Concat"], to_hw.InferConcatLayer(), "concat layers") model = apply_if_relevant(model, ["Split"], to_hw.InferSplitLayer(), "split layers") @@ -523,6 +526,7 @@ def apply_if_relevant(model, op_types, transform, desc=""): to_hw.InferElementwiseBinaryOperation(), "elementwise binary operations", ) + model = apply_if_relevant(model, ["Where"], to_hw.InferWhereLayer(), "where selection") model = apply_if_relevant( model, ["Relu"], to_hw.InferReLUAsElementwiseMax(), "ReLU as elementwise max" ) @@ -536,6 +540,7 @@ def apply_if_relevant(model, op_types, transform, desc=""): ) # Lookup layers + model = apply_if_relevant(model, ["Gather"], to_hw.InferSelectTokenLayer(), "token selection") model = apply_if_relevant(model, ["Gather"], to_hw.InferLookupLayer(), "lookup layers") # Activation functions diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index f05198837b..4a04853d39 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -52,6 +52,7 @@ def register_custom_op(cls): # Import the submodule containing specializations of ElementwiseBinaryOperation # Note: This will automatically register all decorated classes into this domain import finn.custom_op.fpgadataflow.elementwise_binary +from finn.custom_op.fpgadataflow.addclstoken import AddCLSToken from finn.custom_op.fpgadataflow.concat import StreamingConcat from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( ConvolutionInputGenerator, @@ -70,6 +71,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.outer_shuffle import OuterShuffle from finn.custom_op.fpgadataflow.pool import Pool from finn.custom_op.fpgadataflow.requant import Requant +from finn.custom_op.fpgadataflow.selecttoken import SelectToken from finn.custom_op.fpgadataflow.shuffle import Shuffle from finn.custom_op.fpgadataflow.split import StreamingSplit from finn.custom_op.fpgadataflow.streamingdataflowpartition import ( @@ -82,6 +84,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.thresholding import Thresholding from finn.custom_op.fpgadataflow.upsampler import UpsampleNearestNeighbour from finn.custom_op.fpgadataflow.vectorvectoractivation import VVAU +from finn.custom_op.fpgadataflow.where import Where # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure @@ -91,6 +94,7 @@ def register_custom_op(cls): custom_op["VVAU"] = VVAU custom_op["StreamingDataflowPartition"] = StreamingDataflowPartition +custom_op["AddCLSToken"] = AddCLSToken custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator custom_op["Crop"] = Crop custom_op["DuplicateStreams"] = DuplicateStreams @@ -103,10 +107,12 @@ def register_custom_op(cls): custom_op["Lookup"] = Lookup custom_op["OuterShuffle"] = OuterShuffle custom_op["Pool"] = Pool +custom_op["Requant"] = Requant +custom_op["SelectToken"] = SelectToken custom_op["Shuffle"] = Shuffle custom_op["StreamingConcat"] = StreamingConcat custom_op["StreamingSplit"] = StreamingSplit custom_op["StreamingDataWidthConverter"] = StreamingDataWidthConverter custom_op["UpsampleNearestNeighbour"] = UpsampleNearestNeighbour custom_op["HWSoftmax"] = HWSoftmax -custom_op["Requant"] = Requant +custom_op["Where"] = Where diff --git a/src/finn/custom_op/fpgadataflow/addclstoken.py b/src/finn/custom_op/fpgadataflow/addclstoken.py new file mode 100644 index 0000000000..35eae4bb29 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/addclstoken.py @@ -0,0 +1,171 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class AddCLSToken(HWCustomOp): + """Prepend a learned class token to a sequence of patch tokens.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "NumTokens": ("i", True, 0), + "NumChannels": ("i", True, 0), + "PadTokens": ("i", False, 0), + "SIMD": ("i", False, 1), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + } + ) + return my_attrs + + def get_normal_input_shape(self, ind=0): + num_channels = self.get_nodeattr("NumChannels") + if ind == 0: + return (1, self.get_nodeattr("NumTokens"), num_channels) + elif ind == 1: + return (1, 1, num_channels) + else: + raise Exception("AddCLSToken only has two inputs") + + def get_folded_input_shape(self, ind=0): + normal_shape = self.get_normal_input_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def get_normal_output_shape(self, ind=0): + num_tokens = self.get_nodeattr("NumTokens") + num_channels = self.get_nodeattr("NumChannels") + pad_tokens = self.get_nodeattr("PadTokens") + return (1, num_tokens + 1 + pad_tokens, num_channels) + + def get_folded_output_shape(self, ind=0): + normal_shape = self.get_normal_output_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape(0) + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpected input shape for patch tokens." + + exp_wshape = self.get_normal_input_shape(1) + wshape = tuple(model.get_tensor_shape(self.onnx_node.input[1])) + assert wshape == exp_wshape, "Unexpected input shape for CLS token." + + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype() + + idt = model.get_tensor_datatype(node.input[0]) + if idt is None: + idt = attr_idt + if idt is None: + raise Exception("AddCLSToken input datatype is not set") + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + cls_dt = model.get_tensor_datatype(node.input[1]) + if cls_dt is None: + model.set_tensor_datatype(node.input[1], idt) + else: + assert cls_dt == idt, "CLS token datatype must match input datatype." + + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(ind) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind != 0: + return 0 + return self.get_input_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_outstream_width(self, ind=0): + return self.get_output_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def execute_node(self, context, graph): + node = self.onnx_node + patches = context[node.input[0]] + cls_token = context[node.input[1]] + + result = np.concatenate([cls_token, patches], axis=1) + pad_tokens = self.get_nodeattr("PadTokens") + if pad_tokens > 0: + pad_shape = (1, pad_tokens, self.get_nodeattr("NumChannels")) + padding = np.zeros(pad_shape, dtype=result.dtype) + result = np.concatenate([result, padding], axis=1) + + oshape = self.get_normal_output_shape() + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape(oshape) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return int(128 + self.get_nodeattr("NumChannels")) + + def get_op_and_param_counts(self): + return {"param_cls_token": int(self.get_nodeattr("NumChannels"))} diff --git a/src/finn/custom_op/fpgadataflow/rtl/__init__.py b/src/finn/custom_op/fpgadataflow/rtl/__init__.py index 520fcdcd12..fb8b76d6d3 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/__init__.py +++ b/src/finn/custom_op/fpgadataflow/rtl/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from finn.custom_op.fpgadataflow.rtl.addclstoken_rtl import AddCLSToken_rtl from finn.custom_op.fpgadataflow.rtl.convolutioninputgenerator_rtl import ( ConvolutionInputGenerator_rtl, ) @@ -40,17 +41,20 @@ from finn.custom_op.fpgadataflow.rtl.layernorm_rtl import LayerNorm_rtl from finn.custom_op.fpgadataflow.rtl.matrixvectoractivation_rtl import MVAU_rtl from finn.custom_op.fpgadataflow.rtl.requant_rtl import Requant_rtl +from finn.custom_op.fpgadataflow.rtl.selecttoken_rtl import SelectToken_rtl from finn.custom_op.fpgadataflow.rtl.streamingdatawidthconverter_rtl import ( StreamingDataWidthConverter_rtl, ) from finn.custom_op.fpgadataflow.rtl.streamingfifo_rtl import StreamingFIFO_rtl from finn.custom_op.fpgadataflow.rtl.thresholding_rtl import Thresholding_rtl from finn.custom_op.fpgadataflow.rtl.vectorvectoractivation_rtl import VVAU_rtl +from finn.custom_op.fpgadataflow.rtl.where_rtl import Where_rtl custom_op = dict() # make sure new HLSCustomOp subclasses are imported here so that they get # registered and plug in correctly into the infrastructure +custom_op["AddCLSToken_rtl"] = AddCLSToken_rtl custom_op["ConvolutionInputGenerator_rtl"] = ConvolutionInputGenerator_rtl custom_op["ElementwiseAdd_rtl"] = ElementwiseAdd_rtl custom_op["ElementwiseSub_rtl"] = ElementwiseSub_rtl @@ -60,9 +64,11 @@ custom_op["StreamingDataWidthConverter_rtl"] = StreamingDataWidthConverter_rtl custom_op["StreamingFIFO_rtl"] = StreamingFIFO_rtl custom_op["MVAU_rtl"] = MVAU_rtl +custom_op["SelectToken_rtl"] = SelectToken_rtl custom_op["VVAU_rtl"] = VVAU_rtl custom_op["Thresholding_rtl"] = Thresholding_rtl custom_op["InnerShuffle_rtl"] = InnerShuffle_rtl custom_op["Requant_rtl"] = Requant_rtl +custom_op["Where_rtl"] = Where_rtl custom_op["FINNLoop"] = FINNLoop diff --git a/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py new file mode 100644 index 0000000000..8ca3daec88 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/addclstoken_rtl.py @@ -0,0 +1,168 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import os +import shutil +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.addclstoken import AddCLSToken +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/addclstoken/hdl" + + +class AddCLSToken_rtl(AddCLSToken, RTLBackend): + """RTL implementation of AddCLSToken.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(AddCLSToken.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def _pack_value(self, value, dtype): + bitwidth = dtype.bitwidth() + if dtype == DataType["BIPOLAR"]: + int_value = int((value + 1) // 2) + else: + if dtype.is_fixed_point(): + value = value / dtype.scale_factor() + int_value = int(value) + if int_value < 0: + int_value += 1 << bitwidth + return int_value & ((1 << bitwidth) - 1) + + def _pack_cls_token(self, model): + dtype = self.get_input_datatype() + bitwidth = dtype.bitwidth() + num_channels = self.get_nodeattr("NumChannels") + cls_token = model.get_initializer(self.onnx_node.input[1]) + if cls_token is None: + raise Exception("AddCLSToken RTL generation requires a constant CLS token input.") + + cls_token = np.asarray(cls_token, dtype=np.float32) + assert cls_token.shape == self.get_normal_input_shape( + 1 + ), "CLS token shape does not match AddCLSToken attributes." + assert np.vectorize(dtype.allowed)(cls_token).all(), ( + "CLS token values cannot be represented with %s" % dtype.name + ) + packed = 0 + for i, value in enumerate(cls_token.flatten()): + packed |= self._pack_value(value, dtype) << (i * bitwidth) + return "%d'h%x" % (num_channels * bitwidth, packed) + + def generate_hdl(self, model, fpgapart, clk): + simd = self.get_nodeattr("SIMD") + num_channels = self.get_nodeattr("NumChannels") + assert num_channels % simd == 0, "SIMD must divide NumChannels" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/addclstoken_template.v" + with open(template_path, "r") as f: + template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype().bitwidth() + fold_width = elem_width * simd + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "NUM_TOKENS": self.get_nodeattr("NumTokens"), + "NUM_CHANNELS": num_channels, + "SIMD": simd, + "ELEM_WIDTH": elem_width, + "PAD_TOKENS": self.get_nodeattr("PadTokens"), + "FOLD_WIDTH": fold_width, + "CLS_WIDTH": num_channels * elem_width, + "CLS_DATA": self._pack_cls_token(model), + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + shutil.copy(rtlsrc + "/addclstoken.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + verilog_files = [ + rtllib_dir + "addclstoken.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + return verilog_files + + def get_rtlsim_input_indices(self): + """Only patch tokens are streamed; CLS token data is embedded in generated RTL.""" + return [0] + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + AddCLSToken.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py new file mode 100644 index 0000000000..c429f6e4ed --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/selecttoken_rtl.py @@ -0,0 +1,133 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.selecttoken import SelectToken + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/selecttoken/hdl" + + +class SelectToken_rtl(SelectToken, RTLBackend): + """RTL implementation of SelectToken.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(SelectToken.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def generate_hdl(self, model, fpgapart, clk): + simd = self.get_nodeattr("SIMD") + num_channels = self.get_nodeattr("NumChannels") + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert num_channels % simd == 0, "SIMD must divide NumChannels" + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/select_token_template.v" + with open(template_path, "r") as f: + template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype().bitwidth() + fold_width = elem_width * simd + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "NUM_TOKENS": num_tokens, + "NUM_CHANNELS": num_channels, + "SIMD": simd, + "ELEM_WIDTH": elem_width, + "TOKEN_INDEX": token_index, + "FOLD_WIDTH": fold_width, + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + shutil.copy(rtlsrc + "/select_token.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + verilog_files = [ + rtllib_dir + "select_token.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + return verilog_files + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + SelectToken.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py new file mode 100644 index 0000000000..0a7c30050e --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/rtl/where_rtl.py @@ -0,0 +1,157 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend +from finn.custom_op.fpgadataflow.where import Where + + +def _rtlsrc_dir(): + return os.environ["FINN_ROOT"] + "/finn-rtllib/where/hdl" + + +class Where_rtl(Where, RTLBackend): + """RTL implementation of Where.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = {} + my_attrs.update(Where.get_nodeattr_types(self)) + my_attrs.update(RTLBackend.get_nodeattr_types(self)) + return my_attrs + + def _shape_literal(self, shape): + rtl_shape = self._rtl_shape(shape) + return "'{ " + ", ".join(str(int(x)) for x in rtl_shape) + " }" + + def generate_hdl(self, model, fpgapart, clk): + pe = self._output_stream_pe() + out_shape = self.get_normal_output_shape() + cond_shape = self.get_normal_input_shape(0) + x_shape = self.get_normal_input_shape(1) + y_shape = self.get_normal_input_shape(2) + out_rtl_shape = self._rtl_shape(out_shape) + cond_rtl_shape = self._rtl_shape(cond_shape) + x_rtl_shape = self._rtl_shape(x_shape) + y_rtl_shape = self._rtl_shape(y_shape) + assert out_rtl_shape[-1] % pe == 0, "PE must divide the output innermost dimension" + + rtlsrc = _rtlsrc_dir() + template_path = rtlsrc + "/where_template.v" + with open(template_path, "r") as f: + template = f.read() + core_template_path = rtlsrc + "/where_core_template.sv" + with open(core_template_path, "r") as f: + core_template = f.read() + + topname = self.get_verilog_top_module_name() + self.set_nodeattr("gen_top_module", topname) + + elem_width = self.get_input_datatype(1).bitwidth() + cond_width = self.get_instream_width(0) + x_width = self.get_instream_width(1) + y_width = self.get_instream_width(2) + out_width = self.get_outstream_width(0) + code_gen_dict = { + "TOP_MODULE_NAME": topname, + "PE": pe, + "DATA_WIDTH": elem_width, + "NDIMS": len(out_rtl_shape), + "COND_NDIMS": len(cond_rtl_shape), + "X_NDIMS": len(x_rtl_shape), + "Y_NDIMS": len(y_rtl_shape), + "OUT_SHAPE": self._shape_literal(out_shape), + "COND_SHAPE": self._shape_literal(cond_shape), + "X_SHAPE": self._shape_literal(x_shape), + "Y_SHAPE": self._shape_literal(y_shape), + "COND_WIDTH": cond_width, + "X_WIDTH": x_width, + "Y_WIDTH": y_width, + "OUT_WIDTH": out_width, + "RAM_STYLE": '"{}"'.format(self.get_nodeattr("ram_style")), + } + + for key, value in code_gen_dict.items(): + template = template.replace("$%s$" % key, str(value)) + core_template = core_template.replace("$%s$" % key, str(value)) + + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + with open(os.path.join(code_gen_dir, topname + ".v"), "w") as f: + f.write(template) + with open(os.path.join(code_gen_dir, topname + "_core.sv"), "w") as f: + f.write(core_template) + shutil.copy(rtlsrc + "/where.sv", code_gen_dir) + + self.set_nodeattr("ipgen_path", code_gen_dir) + self.set_nodeattr("ip_path", code_gen_dir) + + def get_rtl_file_list(self, abspath=False): + if abspath: + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" + rtllib_dir = _rtlsrc_dir() + "/" + else: + code_gen_dir = "" + rtllib_dir = "" + + return [ + rtllib_dir + "where.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + "_core.sv", + code_gen_dir + self.get_nodeattr("gen_top_module") + ".v", + ] + + def code_generation_ipi(self): + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + sourcefiles = self.get_rtl_file_list() + sourcefiles = [os.path.join(code_gen_dir, f) for f in sourcefiles] + + cmd = [] + for f in sourcefiles: + cmd += ["add_files -norecurse %s" % f] + cmd += [ + "create_bd_cell -type module -reference %s %s" + % (self.get_nodeattr("gen_top_module"), self.onnx_node.name) + ] + return cmd + + def execute_node(self, context, graph): + mode = self.get_nodeattr("exec_mode") + if mode == "cppsim": + Where.execute_node(self, context, graph) + elif mode == "rtlsim": + RTLBackend.execute_node(self, context, graph) + else: + raise Exception( + """Invalid value for attribute exec_mode! Is currently set to: {} + has to be set to one of the following values ("cppsim", "rtlsim")""".format( + mode + ) + ) diff --git a/src/finn/custom_op/fpgadataflow/rtlbackend.py b/src/finn/custom_op/fpgadataflow/rtlbackend.py index 8635a96550..2b8db0310e 100644 --- a/src/finn/custom_op/fpgadataflow/rtlbackend.py +++ b/src/finn/custom_op/fpgadataflow/rtlbackend.py @@ -85,6 +85,10 @@ def code_generation_ipi(self): def code_generation_ipgen(self, model, fpgapart, clk): self.generate_hdl(model, fpgapart, clk) + def get_rtlsim_input_indices(self): + """Return ONNX input indices that are driven as RTLSim input streams.""" + return range(len(self.onnx_node.input)) + def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") @@ -92,7 +96,10 @@ def execute_node(self, context, graph): if mode == "rtlsim": node = self.onnx_node inputs = {} - for i, inp in enumerate(node.input): + for i in self.get_rtlsim_input_indices(): + inp = node.input[i] + nbits = self.get_instream_width(i) + assert nbits > 0, "RTLSim input stream %d has zero width." % i exp_ishape = tuple(self.get_normal_input_shape(i)) folded_ishape = self.get_folded_input_shape(i) inp_val = context[inp] @@ -102,7 +109,6 @@ def execute_node(self, context, graph): reshaped_input = inp_val.reshape(folded_ishape) np.save(os.path.join(code_gen_dir, "input_%s.npy" % i), reshaped_input) - nbits = self.get_instream_width(i) rtlsim_inp = npy_to_rtlsim_input( "{}/input_{}.npy".format(code_gen_dir, i), export_idt, nbits ) diff --git a/src/finn/custom_op/fpgadataflow/selecttoken.py b/src/finn/custom_op/fpgadataflow/selecttoken.py new file mode 100644 index 0000000000..8139fbfbc8 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/selecttoken.py @@ -0,0 +1,155 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class SelectToken(HWCustomOp): + """Select one token vector from a sequence of token vectors.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "NumTokens": ("i", True, 0), + "NumChannels": ("i", True, 0), + "TokenIndex": ("i", True, 0), + "SIMD": ("i", False, 1), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + } + ) + return my_attrs + + def get_normal_input_shape(self, ind=0): + if ind != 0: + raise Exception("SelectToken only has one input") + return (1, self.get_nodeattr("NumTokens"), self.get_nodeattr("NumChannels")) + + def get_folded_input_shape(self, ind=0): + normal_shape = self.get_normal_input_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def get_normal_output_shape(self, ind=0): + return (1, self.get_nodeattr("NumChannels")) + + def get_folded_output_shape(self, ind=0): + normal_shape = self.get_normal_output_shape(ind) + simd = self.get_nodeattr("SIMD") + num_channels = normal_shape[-1] + assert num_channels % simd == 0, "SIMD must divide NumChannels" + return normal_shape[:-1] + (num_channels // simd, simd) + + def make_shape_compatible_op(self, model): + exp_ishape = self.get_normal_input_shape() + ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0])) + assert ishape == exp_ishape, "Unexpected input shape for token sequence." + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype() + + idt = model.get_tensor_datatype(node.input[0]) + if idt is None: + idt = attr_idt + if idt is None: + raise Exception("SelectToken input datatype is not set") + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + attr_odt = self.get_nodeattr("outputDataType") + if attr_odt != "" and DataType[attr_odt] != idt: + warnings.warn( + "outputDataType changing for %s: %s -> %s" + % (node.name, str(DataType[attr_odt]), str(idt)) + ) + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("inputDataType")] + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(ind) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind != 0: + return 0 + return self.get_input_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_outstream_width(self, ind=0): + return self.get_output_datatype().bitwidth() * self.get_nodeattr("SIMD") + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + return int(np.prod(self.get_folded_input_shape()[:-1])) + + def execute_node(self, context, graph): + node = self.onnx_node + inp = context[node.input[0]] + token_index = self.get_nodeattr("TokenIndex") + num_tokens = self.get_nodeattr("NumTokens") + if token_index < 0: + token_index += num_tokens + assert 0 <= token_index < num_tokens, "TokenIndex must select an existing token." + + result = inp[:, token_index, :] + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape( + self.get_normal_output_shape() + ) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return 200 diff --git a/src/finn/custom_op/fpgadataflow/where.py b/src/finn/custom_op/fpgadataflow/where.py new file mode 100644 index 0000000000..ea8c982e6c --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/where.py @@ -0,0 +1,238 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class Where(HWCustomOp): + """Elementwise ONNX Where with multidirectional broadcasting.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = super().get_nodeattr_types() + my_attrs.update( + { + "Shape": ("ints", True, []), + "CondShape": ("ints", False, []), + "XShape": ("ints", False, []), + "YShape": ("ints", False, []), + "CondRank": ("i", False, -1), + "XRank": ("i", False, -1), + "YRank": ("i", False, -1), + "PE": ("i", False, 1), + "conditionDataType": ("s", False, "BINARY"), + "inputDataType": ("s", True, ""), + "outputDataType": ("s", False, ""), + "ram_style": ( + "s", + False, + "auto", + {"auto", "block", "distributed", "ultra"}, + ), + "inFIFODepths": ("ints", False, [2, 2, 2]), + "outFIFODepths": ("ints", False, [2]), + } + ) + return my_attrs + + def _shape(self): + return tuple(self.get_nodeattr("Shape")) + + def _input_shape(self, ind): + if ind == 0: + attr_name, rank_name = "CondShape", "CondRank" + elif ind == 1: + attr_name, rank_name = "XShape", "XRank" + elif ind == 2: + attr_name, rank_name = "YShape", "YRank" + else: + raise Exception("Where has exactly three inputs") + + rank = self.get_nodeattr(rank_name) + shape = tuple(self.get_nodeattr(attr_name)) + if rank >= 0: + assert len(shape) == rank, "%s length must match %s" % ( + attr_name, + rank_name, + ) + return shape + if len(shape) != 0: + return shape + return self._shape() + + def _rtl_shape(self, shape): + if len(shape) == 0: + return (1,) + return tuple(shape) + + def _input_stream_pe(self, ind): + shape = self._rtl_shape(self.get_normal_input_shape(ind)) + if shape[-1] == 1: + return 1 + return self._output_stream_pe() + + def _output_stream_pe(self): + shape = self._rtl_shape(self.get_normal_output_shape()) + if shape[-1] == 1: + return 1 + return self.get_nodeattr("PE") + + def _folded_shape(self, shape, stream_pe): + rtl_shape = self._rtl_shape(shape) + *outer, channels = rtl_shape + assert channels % stream_pe == 0, "Stream PE must divide the innermost dimension" + return tuple(outer + [channels // stream_pe, stream_pe]) + + def get_normal_input_shape(self, ind=0): + if ind not in [0, 1, 2]: + raise Exception("Where has exactly three inputs") + return self._input_shape(ind) + + def get_folded_input_shape(self, ind=0): + return self._folded_shape(self.get_normal_input_shape(ind), self._input_stream_pe(ind)) + + def get_normal_output_shape(self, ind=0): + if ind != 0: + raise Exception("Where has exactly one output") + return self._shape() + + def get_folded_output_shape(self, ind=0): + return self._folded_shape(self.get_normal_output_shape(ind), self._output_stream_pe()) + + def make_shape_compatible_op(self, model): + for i, inp in enumerate(self.onnx_node.input): + ishape = tuple(model.get_tensor_shape(inp)) + assert ishape == self.get_normal_input_shape(i), ( + "Unexpected input shape for Where input %d." % i + ) + return super().make_const_shape_op(self.get_normal_output_shape()) + + def infer_node_datatype(self, model): + node = self.onnx_node + + cond_dt = model.get_tensor_datatype(node.input[0]) + if cond_dt is None: + cond_dt = self.get_condition_datatype() + model.set_tensor_datatype(node.input[0], cond_dt) + if cond_dt != DataType["BINARY"]: + raise Exception("Where condition datatype must be BINARY") + self.set_nodeattr("conditionDataType", cond_dt.name) + + attr_idt = None + if self.get_nodeattr("inputDataType") != "": + attr_idt = self.get_input_datatype(1) + + x_dt = model.get_tensor_datatype(node.input[1]) + y_dt = model.get_tensor_datatype(node.input[2]) + idt = x_dt if x_dt is not None else attr_idt + if idt is None: + raise Exception("Where input datatype is not set") + if y_dt is None: + model.set_tensor_datatype(node.input[2], idt) + elif y_dt != idt: + raise Exception("Where X and Y datatypes must match") + if x_dt is None: + model.set_tensor_datatype(node.input[1], idt) + + if attr_idt is not None and attr_idt != idt: + warnings.warn( + "inputDataType changing for %s: %s -> %s" % (node.name, str(attr_idt), str(idt)) + ) + self.set_nodeattr("inputDataType", idt.name) + + attr_odt = self.get_nodeattr("outputDataType") + if attr_odt != "" and DataType[attr_odt] != idt: + warnings.warn( + "outputDataType changing for %s: %s -> %s" + % (node.name, str(DataType[attr_odt]), str(idt)) + ) + self.set_nodeattr("outputDataType", idt.name) + model.set_tensor_datatype(node.output[0], idt) + + def verify_node(self): + pass + + def get_condition_datatype(self): + return DataType[self.get_nodeattr("conditionDataType")] + + def get_input_datatype(self, ind=0): + if ind == 0: + return self.get_condition_datatype() + if ind in [1, 2]: + return DataType[self.get_nodeattr("inputDataType")] + raise Exception("Where has exactly three inputs") + + def get_output_datatype(self, ind=0): + odt = self.get_nodeattr("outputDataType") + if odt == "": + return self.get_input_datatype(1) + return DataType[odt] + + def get_instream_width(self, ind=0): + if ind == 0: + return self._input_stream_pe(ind) + if ind in [1, 2]: + return self.get_input_datatype(ind).bitwidth() * self._input_stream_pe(ind) + return 0 + + def get_outstream_width(self, ind=0): + return self.get_output_datatype(ind).bitwidth() * self._output_stream_pe() + + def get_number_output_values(self): + return int(np.prod(self.get_folded_output_shape()[:-1])) + + def get_exp_cycles(self): + input_cycles = max(int(np.prod(self.get_folded_input_shape(ind)[:-1])) for ind in range(3)) + output_cycles = self.get_number_output_values() + return input_cycles + output_cycles + 4 + + def execute_node(self, context, graph): + node = self.onnx_node + cond = context[node.input[0]] + xval = context[node.input[1]] + yval = context[node.input[2]] + + result = np.where(cond.astype(bool), xval, yval) + context[node.output[0]] = np.asarray(result, dtype=np.float32).reshape( + self.get_normal_output_shape() + ) + + def bram_estimation(self): + return 0 + + def lut_estimation(self): + return int(64 + self.get_nodeattr("PE") * self.get_output_datatype().bitwidth()) + + def get_op_and_param_counts(self): + return {"op_where": int(np.prod(self.get_normal_output_shape()))} diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index f7b7beee14..277d3362ea 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -30,7 +30,7 @@ import numpy as np import qonnx.core.data_layout as DataLayout import warnings -from onnx import NodeProto, TensorProto, helper +from onnx import AttributeProto, NodeProto, TensorProto, helper from qonnx.core.datatype import DataType # QONNX wrapper to ONNX model graphs @@ -1265,9 +1265,89 @@ def apply(self, model): return (model, graph_modified) +class InferSelectTokenLayer(Transformation): + """Convert scalar Gather(input, token_index, axis=1) into SelectToken.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Gather": + continue + + axis = get_by_name(node.attribute, "axis") + if axis is None or len(node.input) != 2: + continue + + seq_name = node.input[0] + idx_name = node.input[1] + idx_init = model.get_initializer(idx_name) + if idx_init is None or idx_init.size != 1: + continue + if model.get_initializer(seq_name) is not None: + continue + + seq_shape = model.get_tensor_shape(seq_name) + if seq_shape is None or any(x is None for x in seq_shape): + continue + + rank = len(seq_shape) + gather_axis = axis.i if axis.i >= 0 else axis.i + rank + if rank != 3 or gather_axis != 1: + continue + + token_index = int(idx_init.flatten()[0]) + num_tokens = int(seq_shape[1]) + if token_index < 0: + token_index += num_tokens + if token_index < 0 or token_index >= num_tokens: + continue + + out_shape = model.get_tensor_shape(node.output[0]) + exp_oshape = [int(seq_shape[0]), int(seq_shape[2])] + if out_shape is not None and list(out_shape) != exp_oshape: + continue + if seq_shape[0] != 1: + continue + + idt = model.get_tensor_datatype(seq_name) + if idt is None or not idt.is_integer(): + continue + odt = model.get_tensor_datatype(node.output[0]) + if odt is None: + odt = idt + elif odt != idt: + continue + + new_node = helper.make_node( + "SelectToken", + [seq_name], + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_" + node.name, + NumTokens=num_tokens, + NumChannels=int(seq_shape[2]), + TokenIndex=token_index, + SIMD=1, + inputDataType=idt.name, + outputDataType=odt.name, + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferSplitLayer(Transformation): """Convert suitable Split nodes (operating on last/-1 axis) - into StreamingConcat HW layers.""" + into StreamingSplit HW layers.""" def apply(self, model): graph = model.graph @@ -1334,6 +1414,177 @@ def apply(self, model): return (model, graph_modified) +class InferAddCLSTokenLayer(Transformation): + """Convert Concat([cls_token, patches], axis=1) into AddCLSToken.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Concat": + continue + + axis = get_by_name(node.attribute, "axis") + if axis is None or len(node.input) != 2: + continue + + cls_name = node.input[0] + patch_name = node.input[1] + cls_init = model.get_initializer(cls_name) + if cls_init is None or model.get_initializer(patch_name) is not None: + continue + + cls_shape = model.get_tensor_shape(cls_name) + if cls_shape is None: + cls_shape = list(cls_init.shape) + patch_shape = model.get_tensor_shape(patch_name) + if cls_shape is None or patch_shape is None: + continue + if any(x is None for x in list(cls_shape) + list(patch_shape)): + continue + + rank = len(patch_shape) + concat_axis = axis.i if axis.i >= 0 else axis.i + rank + if rank != 3 or concat_axis != 1: + continue + + if len(cls_shape) != 3 or cls_shape[0] != 1 or cls_shape[1] != 1: + continue + if patch_shape[0] != 1 or cls_shape[2] != patch_shape[2]: + continue + + out_shape = model.get_tensor_shape(node.output[0]) + exp_oshape = [1, patch_shape[1] + 1, patch_shape[2]] + if out_shape is not None and list(out_shape) != exp_oshape: + continue + + idt = model.get_tensor_datatype(patch_name) + if idt is None or not idt.is_integer(): + continue + cls_dt = model.get_tensor_datatype(cls_name) + if cls_dt is None: + model.set_tensor_datatype(cls_name, idt) + elif cls_dt != idt: + continue + + new_node = helper.make_node( + "AddCLSToken", + [patch_name, cls_name], + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="AddCLSToken_" + node.name, + NumTokens=int(patch_shape[1]), + NumChannels=int(patch_shape[2]), + PadTokens=0, + SIMD=1, + inputDataType=idt.name, + outputDataType=idt.name, + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + +class InferWhereLayer(Transformation): + """Convert ONNX Where(condition, X, Y) into a streaming Where layer.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for node in graph.node: + node_ind += 1 + if node.op_type != "Where" or node.domain not in ["", "ai.onnx"]: + continue + if len(node.input) != 3: + continue + + cond_name, x_name, y_name = node.input + if any(model.get_initializer(inp) is not None for inp in node.input): + continue + + cond_shape = model.get_tensor_shape(cond_name) + x_shape = model.get_tensor_shape(x_name) + y_shape = model.get_tensor_shape(y_name) + out_shape = model.get_tensor_shape(node.output[0]) + if any(s is None for s in [cond_shape, x_shape, y_shape, out_shape]): + continue + if any(x is None for x in list(cond_shape) + list(x_shape) + list(y_shape)): + continue + try: + broadcast_shape = np.broadcast_shapes( + tuple(cond_shape), tuple(x_shape), tuple(y_shape) + ) + except ValueError: + continue + if list(out_shape) != [int(x) for x in broadcast_shape]: + continue + x_dt = model.get_tensor_datatype(x_name) + y_dt = model.get_tensor_datatype(y_name) + if x_dt is None or y_dt is None or x_dt != y_dt: + continue + supported_dt = ( + x_dt.is_integer() + or x_dt.is_fixed_point() + or x_dt in [DataType["FLOAT32"], DataType["FLOAT16"]] + ) + if not supported_dt: + continue + out_dt = model.get_tensor_datatype(node.output[0]) + if out_dt is not None and out_dt != x_dt: + continue + + cond_dt = model.get_tensor_datatype(cond_name) + if cond_dt is None: + model.set_tensor_datatype(cond_name, DataType["BINARY"]) + cond_dt = DataType["BINARY"] + if cond_dt != DataType["BINARY"]: + continue + + new_node = helper.make_node( + "Where", + node.input, + node.output, + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="Where_" + node.name, + CondRank=len(cond_shape), + XRank=len(x_shape), + YRank=len(y_shape), + PE=1, + conditionDataType=cond_dt.name, + inputDataType=x_dt.name, + outputDataType=x_dt.name, + inFIFODepths=[2, 2, 2], + outFIFODepths=[2], + ) + for attr_name, attr_value in [ + ("Shape", [int(x) for x in broadcast_shape]), + ("CondShape", [int(x) for x in cond_shape]), + ("XShape", [int(x) for x in x_shape]), + ("YShape", [int(x) for x in y_shape]), + ]: + new_node.attribute.append( + helper.make_attribute(attr_name, attr_value, attr_type=AttributeProto.INTS) + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(node) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) + + class InferStreamingEltwise(Transformation): """ DEPRECATED: This transformation is deprecated and now redirects to diff --git a/tests/fpgadataflow/test_fpgadataflow_addclstoken.py b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py new file mode 100644 index 0000000000..7e57c3ef0e --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_addclstoken.py @@ -0,0 +1,299 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import TensorProto, helper, numpy_helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferAddCLSTokenLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _make_graph(nodes, output_shape, cls_values, finn_dtype=DataType["INT8"]): + patch_shape = [1, 3, 4] + patches = helper.make_tensor_value_info("patches", TensorProto.FLOAT, patch_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape) + cls_init = numpy_helper.from_array(cls_values.astype(np.float32), name="cls") + graph = helper.make_graph(nodes, "addclstoken_test", [patches], [output], [cls_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + for tensor_name in ["patches", "cls", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_concat_model(): + cls_values = np.asarray([[[1, -2, 3, -4]]], dtype=np.float32) + concat = helper.make_node( + "Concat", + ["cls", "patches"], + ["out"], + axis=1, + name="concat_cls", + ) + model = _make_graph([concat], [1, 4, 4], cls_values) + return model, cls_values + + +def _make_addclstoken_model( + pad_tokens=0, + simd=1, + finn_dtype=DataType["INT8"], + cls_values=None, +): + if cls_values is None: + cls_values = np.asarray([[[1, -2, 3, -4]]], dtype=np.float32) + addcls = helper.make_node( + "AddCLSToken", + ["patches", "cls"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="AddCLSToken_0", + NumTokens=3, + NumChannels=4, + PadTokens=pad_tokens, + SIMD=simd, + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + ) + model = _make_graph([addcls], [1, 4 + pad_tokens, 4], cls_values, finn_dtype) + return model, cls_values + + +def _prepare_addclstoken_stitched_ip_model(simd=1, pad_tokens=0): + model, cls_values = _make_addclstoken_model(pad_tokens=pad_tokens, simd=simd) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model, cls_values + + +def _make_input_dict(model, patches): + return {model.graph.input[0].name: patches} + + +@pytest.mark.fpgadataflow +def test_convert_concat_to_addclstoken(): + model, cls_values = _make_concat_model() + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected = np.concatenate([cls_values, patches], axis=1) + + ret = execute_onnx(model, _make_input_dict(model, patches)) + assert (ret["out"] == expected).all() + + model = model.transform(InferAddCLSTokenLayer()) + node = model.graph.node[0] + assert node.op_type == "AddCLSToken" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["patches", "cls"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 4, 4) + assert inst.get_exp_cycles() == 16 + + ret = execute_onnx(model, _make_input_dict(model, patches)) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + assert model.graph.node[0].op_type == "AddCLSToken_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + + +@pytest.mark.fpgadataflow +def test_addclstoken_python_execution_with_padding(): + model, cls_values = _make_addclstoken_model(pad_tokens=2) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected = np.concatenate( + [cls_values, patches, np.zeros((1, 2, 4), dtype=np.float32)], + axis=1, + ) + + ret = execute_onnx(model, _make_input_dict(model, patches)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,cls_values,expected_cls_data", + [ + (DataType["INT8"], np.asarray([[[1, -2, 3, -4]]], dtype=np.float32), "32'hfc03fe01"), + (DataType["UINT4"], np.asarray([[[1, 2, 3, 4]]], dtype=np.float32), "16'h4321"), + (DataType["BIPOLAR"], np.asarray([[[1, -1, 1, -1]]], dtype=np.float32), "4'h5"), + ], +) +def test_addclstoken_rtl_codegen(tmp_path, finn_dtype, cls_values, expected_cls_data): + model, _ = _make_addclstoken_model( + pad_tokens=1, + simd=2, + finn_dtype=finn_dtype, + cls_values=cls_values, + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == node.name + wrapper = tmp_path / (topname + ".v") + core = tmp_path / "addclstoken.sv" + assert wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + assert "parameter FOLD_WIDTH = %d" % (2 * finn_dtype.bitwidth()) in wrapper_text + assert ".SIMD(2)" in wrapper_text + assert ".PAD_TOKENS(1)" in wrapper_text + assert "CLS_DATA = %s" % expected_cls_data in wrapper_text + assert "out0_V_TVALID" in wrapper_text + assert "= '0" not in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("addclstoken.sv" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_addclstoken_resource_estimation(): + model, _ = _make_addclstoken_model(pad_tokens=1, simd=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 132, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,pad_tokens", [(1, 0), (2, 1)]) +def test_addclstoken_rtlsim(simd, pad_tokens): + model, cls_values = _make_addclstoken_model(pad_tokens=pad_tokens, simd=simd) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected_values = [cls_values, patches] + if pad_tokens > 0: + expected_values.append(np.zeros((1, pad_tokens, 4), dtype=np.float32)) + expected = np.concatenate(expected_values, axis=1) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, _make_input_dict(model, patches)) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("AddCLSToken_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,pad_tokens", [(1, 0), (2, 1)]) +def test_addclstoken_stitched_ip_rtlsim(simd, pad_tokens): + model, cls_values = _prepare_addclstoken_stitched_ip_model( + simd=simd, + pad_tokens=pad_tokens, + ) + patches = np.arange(12, dtype=np.float32).reshape(1, 3, 4) + expected_values = [cls_values, patches] + if pad_tokens > 0: + expected_values.append(np.zeros((1, pad_tokens, 4), dtype=np.float32)) + expected = np.concatenate(expected_values, axis=1) + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, _make_input_dict(model, patches)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_addclstoken_stitched_ip_synth_ooc(): + model, _ = _prepare_addclstoken_stitched_ip_model(simd=2, pad_tokens=1) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0 diff --git a/tests/fpgadataflow/test_fpgadataflow_selecttoken.py b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py new file mode 100644 index 0000000000..29c6323ac8 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_selecttoken.py @@ -0,0 +1,272 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import TensorProto, helper, numpy_helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferSelectTokenLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _make_graph(nodes, output_shape, idx_values=None, finn_dtype=DataType["INT8"]): + tokens_shape = [1, 4, 4] + tokens = helper.make_tensor_value_info("tokens", TensorProto.FLOAT, tokens_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, output_shape) + initializers = [] + if idx_values is not None: + initializers.append(numpy_helper.from_array(idx_values, name="idx")) + graph = helper.make_graph(nodes, "selecttoken_test", [tokens], [output], initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + for tensor_name in ["tokens", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_gather_model(token_index=0): + idx_values = np.asarray(token_index, dtype=np.int64) + gather = helper.make_node( + "Gather", + ["tokens", "idx"], + ["out"], + axis=1, + name="gather_token", + ) + return _make_graph([gather], [1, 4], idx_values) + + +def _make_selecttoken_model(token_index=0, simd=1, finn_dtype=DataType["INT8"]): + select = helper.make_node( + "SelectToken", + ["tokens"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="SelectToken_0", + NumTokens=4, + NumChannels=4, + TokenIndex=token_index, + SIMD=simd, + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + ) + return _make_graph([select], [1, 4], None, finn_dtype) + + +def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model + + +def _make_input_dict(model, tokens): + return {model.graph.input[0].name: tokens} + + +@pytest.mark.fpgadataflow +def test_convert_gather_to_selecttoken(): + model = _make_gather_model(token_index=2) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, 2, :] + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + model = model.transform(InferSelectTokenLayer()) + node = model.graph.node[0] + assert node.op_type == "SelectToken" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["tokens"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 4) + assert inst.get_exp_cycles() == 16 + assert inst.get_nodeattr("TokenIndex") == 2 + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + assert model.graph.node[0].op_type == "SelectToken_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize("token_index", [0, 1, 3]) +def test_selecttoken_python_execution(token_index): + model = _make_selecttoken_model(token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,fold_width", + [(DataType["INT8"], 16), (DataType["UINT4"], 8), (DataType["BIPOLAR"], 2)], +) +def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width): + model = _make_selecttoken_model(token_index=3, simd=2, finn_dtype=finn_dtype) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == node.name + wrapper = tmp_path / (topname + ".v") + core = tmp_path / "select_token.sv" + assert wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + assert "parameter FOLD_WIDTH = %d" % fold_width in wrapper_text + assert ".SIMD(2)" in wrapper_text + assert ".TOKEN_INDEX(3)" in wrapper_text + assert "select_token #(" in wrapper_text + assert "out0_V_TVALID" in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("select_token.sv" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_selecttoken_resource_estimation(): + model = _make_selecttoken_model(token_index=1, simd=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 200, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_rtlsim(simd, token_index): + model = _make_selecttoken_model(token_index=token_index, simd=simd) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("SelectToken_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("simd,token_index", [(1, 0), (2, 3)]) +def test_selecttoken_stitched_ip_rtlsim(simd, token_index): + model = _prepare_selecttoken_stitched_ip_model(simd=simd, token_index=token_index) + tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4) + expected = tokens[:, token_index, :] + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, _make_input_dict(model, tokens)) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_selecttoken_stitched_ip_synth_ooc(): + model = _prepare_selecttoken_stitched_ip_model(simd=2, token_index=1) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0 diff --git a/tests/fpgadataflow/test_fpgadataflow_where.py b/tests/fpgadataflow/test_fpgadataflow_where.py new file mode 100644 index 0000000000..3269a04061 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_where.py @@ -0,0 +1,610 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import numpy as np +from functools import partial +from onnx import AttributeProto, TensorProto, helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames + +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.analysis.fpgadataflow.res_estimation import ( + res_estimation, + res_estimation_complete, +) +from finn.core.onnx_exec import execute_onnx +from finn.transformation.fpgadataflow.convert_to_hw_layers import InferWhereLayer +from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers +from finn.transformation.fpgadataflow.synth_ooc import SynthOutOfContext + +FPGA_PART = "xc7z020clg400-1" +CLK_NS = 10 + + +def _numel(shape): + return int(np.prod(shape)) if len(shape) > 0 else 1 + + +def _make_graph( + nodes, + shape=None, + finn_dtype=DataType["INT8"], + cond_is_bool=False, + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + out_shape = shape if out_shape is None else out_shape + cond_proto = TensorProto.BOOL if cond_is_bool else TensorProto.FLOAT + cond = helper.make_tensor_value_info("cond", cond_proto, cond_shape) + xval = helper.make_tensor_value_info("xval", TensorProto.FLOAT, x_shape) + yval = helper.make_tensor_value_info("yval", TensorProto.FLOAT, y_shape) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape) + graph = helper.make_graph(nodes, "where_test", [cond, xval, yval], [output]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + model = ModelWrapper(model) + if not cond_is_bool: + model.set_tensor_datatype("cond", DataType["BINARY"]) + for tensor_name in ["xval", "yval", "out"]: + model.set_tensor_datatype(tensor_name, finn_dtype) + return model + + +def _make_onnx_where_model( + shape=None, + finn_dtype=DataType["INT8"], + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + where = helper.make_node("Where", ["cond", "xval", "yval"], ["out"], name="where_select") + return _make_graph( + [where], + shape, + finn_dtype=finn_dtype, + cond_is_bool=True, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + + +def _make_where_model( + shape=None, + pe=1, + finn_dtype=DataType["INT8"], + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + out_shape = shape if out_shape is None else out_shape + where = helper.make_node( + "Where", + ["cond", "xval", "yval"], + ["out"], + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + name="Where_0", + CondRank=len(cond_shape), + XRank=len(x_shape), + YRank=len(y_shape), + PE=pe, + conditionDataType="BINARY", + inputDataType=finn_dtype.name, + outputDataType=finn_dtype.name, + inFIFODepths=[2, 2, 2], + outFIFODepths=[2], + ) + for attr_name, attr_value in [ + ("Shape", out_shape), + ("CondShape", cond_shape), + ("XShape", x_shape), + ("YShape", y_shape), + ]: + where.attribute.append( + helper.make_attribute(attr_name, attr_value, attr_type=AttributeProto.INTS) + ) + return _make_graph( + [where], + shape, + finn_dtype=finn_dtype, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + + +def _prepare_where_stitched_ip_model( + pe=1, + shape=None, + cond_shape=None, + x_shape=None, + y_shape=None, + out_shape=None, +): + model = _make_where_model( + pe=pe, + shape=shape, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(InsertFIFO(create_shallow_fifos=True)) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(FPGA_PART, CLK_NS, vitis=False)) + return model + + +def _make_inputs(shape=None, cond_shape=None, x_shape=None, y_shape=None): + if shape is None: + shape = [1, 2, 4] + cond_shape = shape if cond_shape is None else cond_shape + x_shape = shape if x_shape is None else x_shape + y_shape = shape if y_shape is None else y_shape + cond = (np.arange(_numel(cond_shape), dtype=np.float32) % 2).reshape(cond_shape) + xval = np.arange(_numel(x_shape), dtype=np.float32).reshape(x_shape) + yval = (100 + np.arange(_numel(y_shape), dtype=np.float32)).reshape(y_shape) + return cond, xval, yval + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_to_where(): + model = _make_onnx_where_model() + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + assert node.op_type == "Where" + assert node.domain == "finn.custom_op.fpgadataflow" + assert list(node.input) == ["cond", "xval", "yval"] + + inst = getCustomOp(node) + assert inst.get_normal_output_shape() == (1, 2, 4) + assert inst.get_exp_cycles() == 20 + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + assert model.graph.node[0].op_type == "Where_rtl" + assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl" + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_broadcast_to_where(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_onnx_where_model( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape + ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + assert node.op_type == "Where" + assert node.domain == "finn.custom_op.fpgadataflow" + + inst = getCustomOp(node) + assert inst.get_normal_input_shape(0) == tuple(cond_shape) + assert inst.get_normal_input_shape(1) == tuple(x_shape) + assert inst.get_normal_input_shape(2) == tuple(y_shape) + assert inst.get_normal_output_shape() == tuple(out_shape) + assert inst.get_folded_input_shape(0) == (3, 1, 1) + assert inst.get_folded_input_shape(1) == (4, 1) + assert inst.get_folded_input_shape(2) == (2, 1, 1, 1) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_scalar_broadcast_to_where(): + cond_shape = [] + x_shape = [2, 3] + y_shape = [1, 3] + out_shape = [2, 3] + model = _make_onnx_where_model( + cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape, out_shape=out_shape + ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond.astype(bool), "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + inst = getCustomOp(node) + + assert inst.get_normal_input_shape(0) == tuple(cond_shape) + assert inst.get_normal_input_shape(1) == tuple(x_shape) + assert inst.get_normal_input_shape(2) == tuple(y_shape) + assert inst.get_normal_output_shape() == tuple(out_shape) + assert inst.get_folded_input_shape(0) == (1, 1) + assert inst.get_nodeattr("CondRank") == 0 + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_convert_onnx_where_float32_to_where(): + model = _make_onnx_where_model(finn_dtype=DataType["FLOAT32"]) + cond, xval, yval = _make_inputs() + xval = xval + 0.25 + yval = yval + 0.5 + expected = np.where(cond.astype(bool), xval, yval) + + model.set_tensor_datatype("cond", DataType["BINARY"]) + model = model.transform(InferWhereLayer()) + node = model.graph.node[0] + inst = getCustomOp(node) + + assert inst.get_input_datatype(1) == DataType["FLOAT32"] + assert inst.get_output_datatype() == DataType["FLOAT32"] + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype", + [DataType["INT8"], DataType["UINT4"], DataType["BIPOLAR"], DataType["FLOAT32"]], +) +def test_where_python_execution(finn_dtype): + model = _make_where_model(finn_dtype=finn_dtype) + cond, xval, yval = _make_inputs() + if finn_dtype == DataType["BIPOLAR"]: + xval = np.where(xval % 2 == 0, -1, 1).astype(np.float32) + yval = -xval + elif finn_dtype == DataType["UINT4"]: + yval = (15 - xval).astype(np.float32) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +def test_where_python_execution_broadcast(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_where_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) + expected = np.where(cond.astype(bool), xval, yval) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.parametrize( + "finn_dtype,fold_width", + [ + (DataType["INT8"], 16), + (DataType["UINT4"], 8), + (DataType["BIPOLAR"], 2), + (DataType["FLOAT32"], 64), + ], +) +def test_where_rtl_codegen(tmp_path, finn_dtype, fold_width): + model = _make_where_model(pe=2, finn_dtype=finn_dtype) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + assert topname == node.name + wrapper = tmp_path / (topname + ".v") + core_wrapper = tmp_path / (topname + "_core.sv") + core = tmp_path / "where.sv" + assert wrapper.is_file() + assert core_wrapper.is_file() + assert core.is_file() + wrapper_text = wrapper.read_text() + core_wrapper_text = core_wrapper.read_text() + assert "parameter COND_WIDTH = 2" in wrapper_text + assert "parameter X_WIDTH = %d" % fold_width in wrapper_text + assert "parameter Y_WIDTH = %d" % fold_width in wrapper_text + assert "parameter OUT_WIDTH = %d" % fold_width in wrapper_text + assert ".DATA_WIDTH(%d)" % finn_dtype.bitwidth() in core_wrapper_text + assert ".PE(2)" in core_wrapper_text + assert ".NDIMS(3)" in core_wrapper_text + assert '.RAM_STYLE("auto")' in core_wrapper_text + assert "in2_V_TDATA" in wrapper_text + assert "out0_V_TVALID" in wrapper_text + + ipi_cmds = inst.code_generation_ipi() + assert any("where.sv" in cmd for cmd in ipi_cmds) + assert any(topname + "_core.sv" in cmd for cmd in ipi_cmds) + assert any(topname + ".v" in cmd for cmd in ipi_cmds) + assert any("create_bd_cell" in cmd and topname in cmd for cmd in ipi_cmds) + + +@pytest.mark.fpgadataflow +def test_where_rtl_codegen_broadcast(tmp_path): + model = _make_where_model( + pe=2, + cond_shape=[3, 1], + x_shape=[4], + y_shape=[2, 1, 1], + out_shape=[2, 3, 4], + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + wrapper_text = (tmp_path / (topname + ".v")).read_text() + core_wrapper_text = (tmp_path / (topname + "_core.sv")).read_text() + assert "parameter COND_WIDTH = 1" in wrapper_text + assert "parameter X_WIDTH = 16" in wrapper_text + assert "parameter Y_WIDTH = 8" in wrapper_text + assert "parameter OUT_WIDTH = 16" in wrapper_text + assert ".NDIMS(3)" in core_wrapper_text + assert ".COND_NDIMS(2)" in core_wrapper_text + assert ".X_NDIMS(1)" in core_wrapper_text + assert ".Y_NDIMS(3)" in core_wrapper_text + assert ".OUT_SHAPE('{ 2, 3, 4 })" in core_wrapper_text + assert ".COND_SHAPE('{ 3, 1 })" in core_wrapper_text + assert ".X_SHAPE('{ 4 })" in core_wrapper_text + assert ".Y_SHAPE('{ 2, 1, 1 })" in core_wrapper_text + + +@pytest.mark.fpgadataflow +def test_where_rtl_codegen_scalar_broadcast(tmp_path): + model = _make_where_model( + pe=3, + cond_shape=[], + x_shape=[2, 3], + y_shape=[1, 3], + out_shape=[2, 3], + ) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + node = model.graph.node[0] + inst = getCustomOp(node) + inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path)) + inst.code_generation_ipgen(model, FPGA_PART, CLK_NS) + + topname = inst.get_nodeattr("gen_top_module") + wrapper_text = (tmp_path / (topname + ".v")).read_text() + core_wrapper_text = (tmp_path / (topname + "_core.sv")).read_text() + assert "parameter COND_WIDTH = 1" in wrapper_text + assert "parameter X_WIDTH = 24" in wrapper_text + assert "parameter Y_WIDTH = 24" in wrapper_text + assert "parameter OUT_WIDTH = 24" in wrapper_text + assert ".NDIMS(2)" in core_wrapper_text + assert ".COND_NDIMS(1)" in core_wrapper_text + assert ".COND_SHAPE('{ 1 })" in core_wrapper_text + + +@pytest.mark.fpgadataflow +def test_where_resource_estimation(): + model = _make_where_model(pe=2) + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + + expected = { + "BRAM_18K": 0, + "BRAM_efficiency": 1, + "LUT": 80, + "URAM": 0, + "URAM_efficiency": 1, + "DSP": 0, + } + resources = model.analysis(partial(res_estimation, fpgapart=FPGA_PART)) + assert len(resources) == 1 + assert list(resources.values())[0] == expected + + complete_resources = model.analysis(partial(res_estimation_complete, fpgapart=FPGA_PART)) + assert len(complete_resources) == 1 + assert list(complete_resources.values())[0] == [expected] + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("pe", [1, 2]) +def test_where_rtlsim(pe): + model = _make_where_model(pe=pe) + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("Where_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_rtlsim_broadcast(): + cond_shape = [3, 1] + x_shape = [4] + y_shape = [2, 1, 1] + out_shape = [2, 3, 4] + model = _make_where_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) + expected = np.where(cond.astype(bool), xval, yval) + + model = model.transform(SpecializeLayers(FPGA_PART)) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP(FPGA_PART, CLK_NS)) + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareRTLSim()) + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + node = model.get_nodes_by_op_type("Where_rtl")[0] + inst = getCustomOp(node) + cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[node.name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=15) + assert exp_cycles != 0 + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +@pytest.mark.parametrize("pe", [1, 2]) +def test_where_stitched_ip_rtlsim(pe): + model = _prepare_where_stitched_ip_model(pe=pe) + cond, xval, yval = _make_inputs() + expected = np.where(cond.astype(bool), xval, yval) + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_stitched_ip_rtlsim_broadcast(): + cond_shape = [1, 3, 1] + x_shape = [1, 1, 4] + y_shape = [1, 2, 1, 1] + out_shape = [1, 2, 3, 4] + model = _prepare_where_stitched_ip_model( + pe=2, + cond_shape=cond_shape, + x_shape=x_shape, + y_shape=y_shape, + out_shape=out_shape, + ) + cond, xval, yval = _make_inputs(cond_shape=cond_shape, x_shape=x_shape, y_shape=y_shape) + expected = np.where(cond.astype(bool), xval, yval) + + model.set_metadata_prop("exec_mode", "rtlsim") + + ret = execute_onnx(model, {"cond": cond, "xval": xval, "yval": yval}) + assert (ret["out"] == expected).all() + + +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_where_stitched_ip_synth_ooc(): + model = _prepare_where_stitched_ip_model(pe=2) + model = model.transform(SynthOutOfContext(FPGA_PART, CLK_NS)) + ret = model.get_metadata_prop("res_total_ooc_synth") + assert ret is not None + ret = eval(ret) + + assert ret["LUT"] > 0 + assert ret["FF"] > 0 + assert ret["DSP"] == 0 + assert ret["BRAM"] == 0 + assert ret["WNS"] >= 0