Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/finn/source_code/finn.custom_op.fpgadataflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ RTLBackend
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.addclstoken
-----------------------------------------

.. automodule:: finn.custom_op.fpgadataflow.addclstoken
:members:
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.addstreams
----------------------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/finn/source_code/finn.custom_op.fpgadataflow.rtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ Custom Op - fpgadataflow.rtl
RTL Custom Op Nodes
===================

finn.custom\_op.fpgadataflow.rtl.addclstoken\_rtl
--------------------------------------------------

.. automodule:: finn.custom_op.fpgadataflow.rtl.addclstoken_rtl
:members:
:undoc-members:
:show-inheritance:

finn.custom\_op.fpgadataflow.convolutioninputgenerator\_rtl
------------------------------------------------------------

Expand Down
139 changes: 139 additions & 0 deletions finn-rtllib/addclstoken/hdl/addclstoken.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/****************************************************************************
* Copyright (C) 2026, Advanced Micro Devices, Inc.
* All rights reserved.
*
* SPDX-License-Identifier: BSD-3-Clause
*
* @brief Insert a constant class token into a folded token stream.
* @author Oliver Cassidy <oliver.cassidy@amd.com>
*
* @description
* Prepends a learned class token, supplied through cls_data, to each
* input sequence of patch tokens. The class token and patch tokens are
* transferred as SIMD-wide folds of ELEM_WIDTH-bit elements.
*
* Per sequence, the output stream is:
* 1. NUM_CHANNELS/SIMD folds from cls_data
* 2. NUM_TOKENS pass-through input tokens
* 3. PAD_TOKENS zero-valued tokens, when padding is enabled
***************************************************************************/

module addclstoken #(
parameter int unsigned NUM_TOKENS = 196,
parameter int unsigned NUM_CHANNELS = 192,
parameter int unsigned SIMD = 1,
parameter int unsigned ELEM_WIDTH = 8,
parameter int unsigned PAD_TOKENS = 0
)(
input logic clk,
input logic rst,

output logic irdy,
input logic ivld,
input logic [SIMD*ELEM_WIDTH-1:0] idat,

input logic ordy,
output logic ovld,
output logic [SIMD*ELEM_WIDTH-1:0] odat,

input logic [NUM_CHANNELS*ELEM_WIDTH-1:0] cls_data
);

localparam int unsigned FOLD_WIDTH = SIMD * ELEM_WIDTH;
localparam int unsigned FOLDS_PER_TOKEN = NUM_CHANNELS / SIMD;
localparam int unsigned TOTAL_INPUT_FOLDS = NUM_TOKENS * FOLDS_PER_TOKEN;
localparam int unsigned TOTAL_PAD_FOLDS = PAD_TOKENS * FOLDS_PER_TOKEN;
localparam int unsigned MAX_PHASE_FOLDS =
(TOTAL_INPUT_FOLDS > FOLDS_PER_TOKEN) ?
((TOTAL_INPUT_FOLDS > TOTAL_PAD_FOLDS) ?
TOTAL_INPUT_FOLDS : TOTAL_PAD_FOLDS) :
((FOLDS_PER_TOKEN > TOTAL_PAD_FOLDS) ?
FOLDS_PER_TOKEN : TOTAL_PAD_FOLDS);
localparam int unsigned CNT_WIDTH = (MAX_PHASE_FOLDS <= 1) ? 1 : $clog2(MAX_PHASE_FOLDS);

typedef enum logic [1:0] {
EMIT_CLS,
PASSTHROUGH,
EMIT_PAD
} state_t;

state_t state;
state_t next_state;
logic [CNT_WIDTH-1:0] fold_cnt;
logic fold_cnt_last;
logic out_transfer;

logic [CNT_WIDTH-1:0] cls_fold_cnt;
logic [FOLD_WIDTH-1:0] cls_fold;

assign cls_fold_cnt = (int'(fold_cnt) < FOLDS_PER_TOKEN) ? fold_cnt : '0;
assign cls_fold = cls_data[cls_fold_cnt * FOLD_WIDTH +: FOLD_WIDTH];
assign out_transfer = ovld & ordy;

always_comb begin
unique case (state)
EMIT_CLS: fold_cnt_last = (int'(fold_cnt) == FOLDS_PER_TOKEN - 1);
PASSTHROUGH: fold_cnt_last = (int'(fold_cnt) == TOTAL_INPUT_FOLDS - 1);
EMIT_PAD: fold_cnt_last = (int'(fold_cnt) == TOTAL_PAD_FOLDS - 1);
default: fold_cnt_last = 1'b1;
endcase
end

always_comb begin
irdy = 1'b0;
ovld = 1'b0;
odat = '0;

unique case (state)
EMIT_CLS: begin
ovld = 1'b1;
odat = cls_fold;
end
PASSTHROUGH: begin
irdy = ordy;
ovld = ivld;
odat = idat;
end
EMIT_PAD: begin
ovld = 1'b1;
end
default: begin
end
endcase
end

always_comb begin
next_state = state;
if (out_transfer && fold_cnt_last) begin
unique case (state)
EMIT_CLS: begin
next_state = PASSTHROUGH;
end
PASSTHROUGH: begin
next_state = (PAD_TOKENS == 0) ? EMIT_CLS : EMIT_PAD;
end
EMIT_PAD: begin
next_state = EMIT_CLS;
end
default: begin
next_state = EMIT_CLS;
end
endcase
end
end

always_ff @(posedge clk) begin
if (rst) begin
state <= EMIT_CLS;
fold_cnt <= '0;
end else if (out_transfer) begin
if (fold_cnt_last) begin
state <= next_state;
fold_cnt <= '0;
end else begin
fold_cnt <= fold_cnt + 1'b1;
end
end
end

endmodule
81 changes: 81 additions & 0 deletions finn-rtllib/addclstoken/hdl/addclstoken_template.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/******************************************************************************
* Copyright (C) 2026, Advanced Micro Devices, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
* THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
* OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/

module $TOP_MODULE_NAME$ #(
parameter FOLD_WIDTH = $FOLD_WIDTH$,
parameter AXI_WIDTH = ((FOLD_WIDTH + 7) / 8) * 8
)(
(* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *)
(* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *)
input ap_clk,
(* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *)
input ap_rst_n,

output in0_V_TREADY,
input in0_V_TVALID,
input [AXI_WIDTH-1:0] in0_V_TDATA,

input out0_V_TREADY,
output out0_V_TVALID,
output [AXI_WIDTH-1:0] out0_V_TDATA
);

localparam [$CLS_WIDTH$-1:0] CLS_DATA = $CLS_DATA$;

wire [FOLD_WIDTH-1:0] core_out;

assign out0_V_TDATA[FOLD_WIDTH-1:0] = core_out;

generate
if (AXI_WIDTH > FOLD_WIDTH) begin : gen_pad_tdata
assign out0_V_TDATA[AXI_WIDTH-1:FOLD_WIDTH] = {(AXI_WIDTH-FOLD_WIDTH){1'b0}};
end
endgenerate

addclstoken #(
.NUM_TOKENS($NUM_TOKENS$),
.NUM_CHANNELS($NUM_CHANNELS$),
.SIMD($SIMD$),
.ELEM_WIDTH($ELEM_WIDTH$),
.PAD_TOKENS($PAD_TOKENS$)
) impl (
.clk(ap_clk),
.rst(!ap_rst_n),
.irdy(in0_V_TREADY),
.ivld(in0_V_TVALID),
.idat(in0_V_TDATA[FOLD_WIDTH-1:0]),
.ordy(out0_V_TREADY),
.ovld(out0_V_TVALID),
.odat(core_out),
.cls_data(CLS_DATA)
);

endmodule
3 changes: 3 additions & 0 deletions src/finn/builder/build_dataflow_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ def apply_if_relevant(model, op_types, transform, desc=""):
)

# Streaming operations
model = apply_if_relevant(
model, ["Concat"], to_hw.InferAddCLSTokenLayer(), "CLS token insertion"
)
model = apply_if_relevant(model, ["Concat"], to_hw.InferConcatLayer(), "concat layers")
model = apply_if_relevant(model, ["Split"], to_hw.InferSplitLayer(), "split layers")

Expand Down
2 changes: 2 additions & 0 deletions src/finn/custom_op/fpgadataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def register_custom_op(cls):
# Import the submodule containing specializations of ElementwiseBinaryOperation
# Note: This will automatically register all decorated classes into this domain
import finn.custom_op.fpgadataflow.elementwise_binary
from finn.custom_op.fpgadataflow.addclstoken import AddCLSToken
from finn.custom_op.fpgadataflow.concat import StreamingConcat
from finn.custom_op.fpgadataflow.convolutioninputgenerator import (
ConvolutionInputGenerator,
Expand Down Expand Up @@ -91,6 +92,7 @@ def register_custom_op(cls):
custom_op["VVAU"] = VVAU
custom_op["StreamingDataflowPartition"] = StreamingDataflowPartition

custom_op["AddCLSToken"] = AddCLSToken
custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator
custom_op["Crop"] = Crop
custom_op["DuplicateStreams"] = DuplicateStreams
Expand Down
Loading
Loading