diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 236a42acdb..b74bbf538d 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -58,6 +58,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( ConvolutionInputGenerator, ) +from finn.custom_op.fpgadataflow.crop import Crop from finn.custom_op.fpgadataflow.duplicatestreams import DuplicateStreams from finn.custom_op.fpgadataflow.fmpadding import FMPadding from finn.custom_op.fpgadataflow.fmpadding_pixel import FMPadding_Pixel @@ -95,6 +96,7 @@ def register_custom_op(cls): custom_op["AddStreams"] = AddStreams custom_op["ChannelwiseOp"] = ChannelwiseOp custom_op["ConvolutionInputGenerator"] = ConvolutionInputGenerator +custom_op["Crop"] = Crop custom_op["DuplicateStreams"] = DuplicateStreams custom_op["FMPadding"] = FMPadding custom_op["FMPadding_Pixel"] = FMPadding_Pixel diff --git a/src/finn/custom_op/fpgadataflow/crop.py b/src/finn/custom_op/fpgadataflow/crop.py new file mode 100644 index 0000000000..13effc61fe --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/crop.py @@ -0,0 +1,141 @@ +################################################################################### +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# Copyright for portions of this file is held by AMD and Microsoft under +# MIT license as part of project Brainsmith. +# All other copyright is held by AMD and is provided under BSD-3-Clause license. +# +################################################################################### + +import numpy as np +import warnings +from qonnx.core.datatype import DataType + +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +class Crop(HWCustomOp): + """Abstraction layer for Crop layers.""" + + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + my_attrs = { + "DataType": ("s", True, ""), + "ImgDim": ("ints", True, []), # [h, w] + "NumChannels": ("i", True, 0), + "CropNorth": ("i", True, []), + "CropSouth": ("i", True, []), + "CropWest": ("i", True, []), + "CropEast": ("i", True, []), + "SIMD": ("i", False, 1), + "numInputVectors": ("ints", False, []), + } + my_attrs.update(super().get_nodeattr_types()) + return my_attrs + + def get_normal_input_shape(self, ind=0): + num_vec = self.get_nodeattr("numInputVectors") + h, w = self.get_nodeattr("ImgDim") + if h == 0: + img_dim = [w] + else: + img_dim = [h, w] + ch = self.get_nodeattr("NumChannels") + return num_vec + img_dim + [ch] if num_vec != [0] else img_dim + [ch] + + def get_normal_output_shape(self, ind=0): + num_vec = self.get_nodeattr("numInputVectors") + height, width = self.get_nodeattr("ImgDim") + ch = self.get_nodeattr("NumChannels") + crop_north = self.get_nodeattr("CropNorth") + crop_east = self.get_nodeattr("CropEast") + crop_west = self.get_nodeattr("CropWest") + crop_south = self.get_nodeattr("CropSouth") + owidth = width - (crop_west + crop_east) + oheight = height - (crop_north + crop_south) + if oheight == 0: + o_img_dim = [owidth] + else: + o_img_dim = [oheight, owidth] + return num_vec + o_img_dim + [ch] if num_vec != [0] else o_img_dim + [ch] + + def execute_node(self, context, graph): + node = self.onnx_node + h, w = self.get_nodeattr("ImgDim") + crop_north = self.get_nodeattr("CropNorth") + crop_east = self.get_nodeattr("CropEast") + crop_west = self.get_nodeattr("CropWest") + crop_south = self.get_nodeattr("CropSouth") + inp = context[node.input[0]] + if len(inp.shape) == 3: + cropped_slice = inp[crop_north : h - crop_south, crop_west : w - crop_east, :] + elif len(inp.shape) == 2: + cropped_slice = inp[crop_west : w - crop_east, :] + elif len(inp.shape) == 4: + cropped_slice = inp[:, crop_north : h - crop_south, crop_west : w - crop_east, :] + else: + raise Exception("Crop execute node currently only supports 2D - 4D input tensors.") + assert cropped_slice.shape == tuple(self.get_normal_output_shape()) + context[node.output[0]] = cropped_slice + + def get_input_datatype(self, ind=0): + return DataType[self.get_nodeattr("DataType")] + + def infer_node_datatype(self, model): + node = self.onnx_node + dt = model.get_tensor_datatype(node.input[0]) + if dt != self.get_input_datatype(): + warn_str = ( + f"data_type changing for {node.name}: {str(self.get_input_datatype())} -> {str(dt)}" + ) + warnings.warn(warn_str) + self.set_nodeattr("DataType", dt.name) + + def get_instream_width(self, ind=0): + ibits = self.get_input_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + return ibits * simd + + def get_outstream_width(self, ind=0): + obits = self.get_output_datatype().bitwidth() + simd = self.get_nodeattr("SIMD") + return obits * simd + + def get_output_datatype(self, ind=0): + return DataType[self.get_nodeattr("DataType")] + + def get_folded_output_shape(self, ind=0): + normal_oshape = list(self.get_normal_output_shape()) + simd = self.get_nodeattr("SIMD") + assert normal_oshape[-1] % simd == 0, "Innermost dimension must be divisible by SIMD" + fold = int(normal_oshape[-1] / simd) + folded_oshape = normal_oshape[:-1] + [fold, simd] + return tuple(folded_oshape) + + def get_folded_input_shape(self, ind=0): + normal_ishape = list(self.get_normal_input_shape()) + simd = self.get_nodeattr("SIMD") + assert normal_ishape[-1] % simd == 0, "Innermost dimension must be divisible by SIMD" + fold = int(normal_ishape[-1] / simd) + folded_ishape = normal_ishape[:-1] + [fold, simd] + return tuple(folded_ishape) + + def get_exp_cycles(self): + simd = self.get_nodeattr("SIMD") + num_vec = self.get_nodeattr("numInputVectors") + height, width = self.get_nodeattr("ImgDim") + ch = self.get_nodeattr("NumChannels") + if height == 0: + # pretend that height is 1 for code generation + height = 1 + + return ( + np.prod(num_vec) * height * width * (ch // simd) + if num_vec != [0] + else height * width * (ch // simd) + ) diff --git a/src/finn/custom_op/fpgadataflow/hls/__init__.py b/src/finn/custom_op/fpgadataflow/hls/__init__.py index ff97ebe136..f63e082f04 100644 --- a/src/finn/custom_op/fpgadataflow/hls/__init__.py +++ b/src/finn/custom_op/fpgadataflow/hls/__init__.py @@ -57,6 +57,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.hls.channelwise_op_hls import ChannelwiseOp_hls from finn.custom_op.fpgadataflow.hls.checksum_hls import CheckSum_hls from finn.custom_op.fpgadataflow.hls.concat_hls import StreamingConcat_hls +from finn.custom_op.fpgadataflow.hls.crop_hls import Crop_hls from finn.custom_op.fpgadataflow.hls.duplicatestreams_hls import DuplicateStreams_hls from finn.custom_op.fpgadataflow.hls.fmpadding_pixel_hls import FMPadding_Pixel_hls from finn.custom_op.fpgadataflow.hls.globalaccpool_hls import GlobalAccPool_hls @@ -82,6 +83,7 @@ def register_custom_op(cls): custom_op["AddStreams_hls"] = AddStreams_hls custom_op["ChannelwiseOp_hls"] = ChannelwiseOp_hls custom_op["CheckSum_hls"] = CheckSum_hls +custom_op["Crop_hls"] = Crop_hls custom_op["DuplicateStreams_hls"] = DuplicateStreams_hls custom_op["FMPadding_Pixel_hls"] = FMPadding_Pixel_hls custom_op["GlobalAccPool_hls"] = GlobalAccPool_hls diff --git a/src/finn/custom_op/fpgadataflow/hls/crop_hls.py b/src/finn/custom_op/fpgadataflow/hls/crop_hls.py new file mode 100644 index 0000000000..14a60de42f --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/hls/crop_hls.py @@ -0,0 +1,89 @@ +################################################################################### +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# Copyright for portions of this file is held by AMD and Microsoft under +# MIT license as part of project Brainsmith. +# All other copyright is held by AMD and is provided under BSD-3-Clause license. +# +################################################################################### + +from finn.custom_op.fpgadataflow.crop import Crop +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend + + +class Crop_hls(Crop, HLSBackend): + def __init__(self, onnx_node, **kwargs): + super().__init__(onnx_node, **kwargs) + + def get_nodeattr_types(self): + return Crop.get_nodeattr_types(self) | HLSBackend.get_nodeattr_types(self) + + def global_includes(self): + self.code_gen_dict["$GLOBALS$"] = [ + '#include "crop.hpp"', + ] + + def defines(self, var): + simd = self.get_nodeattr("SIMD") + dtype = self.get_input_datatype() + height, width = self.get_nodeattr("ImgDim") + if height == 0: + # pretend that height is 1 for code generation + height = 1 + ch = self.get_nodeattr("NumChannels") + self.code_gen_dict["$DEFINES$"] = [ + f""" + constexpr unsigned SIMD = {simd}; + constexpr unsigned H = {height}; + constexpr unsigned W = {width}; + constexpr unsigned CF = {ch // simd}; + constexpr unsigned CROP_N = {self.get_nodeattr("CropNorth")}; + constexpr unsigned CROP_E = {self.get_nodeattr("CropEast")}; + constexpr unsigned CROP_S = {self.get_nodeattr("CropSouth")}; + constexpr unsigned CROP_W = {self.get_nodeattr("CropWest")}; + using TV = hls::vector<{dtype.get_hls_datatype_str()}, SIMD>; + """ + ] + + def docompute(self): + self.code_gen_dict["$DOCOMPUTE$"] = [ + """ + hls::stream src0; + hls::stream dst0; + #pragma HLS stream variable=src0 depth=2 + #pragma HLS stream variable=dst0 depth=2 + + move(in0_V, src0); + crop< H, W, CF, CROP_N, CROP_E, CROP_S, CROP_W, TV>(src0, dst0); + move(dst0, out0_V); + """ + ] + + def blackboxfunction(self): + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + f""" + void {self.onnx_node.name} ( + hls::stream &in0_V, + hls::stream &out0_V + ) + """ + ] + + def pragmas(self): + self.code_gen_dict["$PRAGMAS$"] = [ + """ + #pragma HLS interface AXIS port=in0_V + #pragma HLS interface AXIS port=out0_V + #pragma HLS aggregate variable=in0_V compact=bit + #pragma HLS aggregate variable=out0_V compact=bit + + #pragma HLS interface ap_ctrl_none port=return + #pragma HLS dataflow disable_start_propagation + """ + ] + + def execute_node(self, context, graph): + HLSBackend.execute_node(self, context, graph) diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 3668f491a8..fac85a5bfe 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -2097,3 +2097,134 @@ def apply(self, model): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) return (model, graph_modified) + + +def elements_are_consecutive(indices): + if indices.size == 1: + return True + else: + indices.sort() + return np.all(np.diff(indices) == 1) + + +class InferCrop(Transformation): + """ + Find gather layers that can be converted into a Crop layer + and replace them with a Crop layer + """ + + def __init__(self): + super().__init__() + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Gather": + # ensure that the indices input is an initializer + if model.get_initializer(n.input[1]) is None: + continue + + # ensure that the axis is among the two innermost dimensions + input_shape = model.get_tensor_shape(n.input[0]) + assert ( + len(input_shape) > 1 + ), "Input shape needs to be at least 2D to be converted to Crop." + + max_index = len(input_shape) - 1 + axis = get_by_name(n.attribute, "axis").i + if len(input_shape) >= 3: + assert axis in [ + max_index - 1, + max_index - 2, + ], "Crop Operates on height and width of the input, assuming (N)HWC layout." + else: + assert ( + axis == max_index - 1 + ), "Crop Operates on width of the input, for 2D input assuming WC layout." + is_vertical = axis == max_index # otherwise horizontal + assert is_vertical is False, "This operator does not current support vertical crops" + + # assume that the indices input is an int64 scalar or array + indices = model.get_initializer(n.input[1]) + assert indices.dtype == np.int64, "Indices must be int64" + # Handle both scalar (0-d) and array cases + if indices.ndim == 0: + # Single scalar index - always consecutive + indices_to_check = np.array([indices.item()]) + else: + indices_to_check = indices + assert elements_are_consecutive(indices_to_check), "Indices must be consecutive" + + idt0 = model.get_tensor_datatype(n.input[0]) + + crop_north = 0 + crop_east = 0 + crop_west = 0 + crop_south = 0 + num_inp_vec = [0] + + if len(input_shape) >= 3: + height_ind = len(input_shape) - 3 + width_ind = len(input_shape) - 2 + channels_ind = len(input_shape) - 1 + + height = input_shape[height_ind] + width = input_shape[width_ind] + channels = input_shape[channels_ind] + # save other dimensions in numInpVectors + if len(input_shape) > 3: + num_inp_vec = list(input_shape[:height_ind]) + + crop_min = int(np.min(indices_to_check)) + crop_max = input_shape[axis] - int(np.max(indices_to_check)) - 1 + + if axis == height_ind: + crop_north = crop_min + crop_south = crop_max + elif axis == width_ind: + crop_west = crop_min + crop_east = crop_max + + elif len(input_shape) == 2: + # if there are only two dimensions, assume + height = 0 + width_ind = len(input_shape) - 2 + channels_ind = len(input_shape) - 1 + width = input_shape[width_ind] + channels = input_shape[channels_ind] + + # axis is on width dimension + crop_west = int(np.min(indices_to_check)) + crop_east = input_shape[axis] - int(np.max(indices_to_check)) - 1 + + # create and insert new node + new_node = helper.make_node( + "Crop", + [n.input[0]], # input tensor(s) + [n.output[0]], # output tensor(s) + domain="finn.custom_op.fpgadataflow", + backend="fpgadataflow", + DataType=idt0.name, + name="Crop" + n.name, + SIMD=1, + ImgDim=[height, width], + NumChannels=channels, + CropNorth=crop_north, + CropEast=crop_east, + CropWest=crop_west, + CropSouth=crop_south, + numInputVectors=num_inp_vec, + cpp_interface="hls_vector", + hls_style="freerunning", + ) + graph.node.insert(node_ind, new_node) + graph.node.remove(n) + graph_modified = True + + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + return (model, graph_modified) diff --git a/tests/fpgadataflow/test_fpgadataflow_crop.py b/tests/fpgadataflow/test_fpgadataflow_crop.py new file mode 100644 index 0000000000..32dcd30474 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_crop.py @@ -0,0 +1,132 @@ +############################################################################ +# Copyright (C) 2025, Advanced Micro Devices, Inc. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# Copyright for portions of this file is held by AMD and Microsoft under +# MIT license as part of project Brainsmith. +# All other copyright is held by AMD and is provided under BSD-3-Clause license. +# +# Note: This test was originally written by Josh Monson and was adjusted. +# +############################################################################ + +import pytest + +import numpy as np +from onnx import TensorProto, helper +from qonnx.core.datatype import DataType +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.transformation.general import GiveUniqueNodeNames +from qonnx.transformation.infer_shapes import InferShapes +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +import finn.core.onnx_exec as oxe +import finn.transformation.fpgadataflow.convert_to_hw_layers as to_hw +from finn.analysis.fpgadataflow.exp_cycles_per_layer import exp_cycles_per_layer +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers + +test_fpga_part: str = "xczu7ev-ffvc1156-2-e" +target_clk_ns = 5 + + +def make_gather_model(indices, ishape, axis): + # Define the input tensor + data = helper.make_tensor_value_info("data", TensorProto.FLOAT, ishape) + + # Define the output tensor and leave shape undefined to be inferred later + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, None) + + indices = helper.make_tensor("indices", TensorProto.INT64, [len(indices)], indices) + + gather_node = helper.make_node( + "Gather", inputs=["data", "indices"], outputs=["output"], axis=axis + ) + + # Create the graph + graph = helper.make_graph( + nodes=[gather_node], + name="GatherGraph", + inputs=[data], + outputs=[output], + initializer=[ + indices, + ], + ) + + # Create the QONNX model + model = qonnx_make_model(graph, producer_name="gather-model") + model = ModelWrapper(model, fix_missing_initializer_valueinfo=True) + model = model.transform(InferShapes()) + + return model + + +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +@pytest.mark.parametrize( + "ishape_axis_indices", + [ + ([1, 16, 48], 1, [0]), + ([1, 16, 48], 1, [4, 5, 6]), + ([32, 48], 0, [15]), + ([1, 16, 48, 16], 2, [1]), + ], +) +@pytest.mark.parametrize("simd", [1, 8, 16]) +@pytest.mark.parametrize("idt", [DataType["INT8"], DataType["FLOAT32"]]) +@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) +def test_fpgadataflow_gather_crop(ishape_axis_indices, simd, idt, exec_mode): + ishape, axis, indices = ishape_axis_indices + indices = np.array(indices) + model = make_gather_model(indices, ishape, axis=axis) + model.set_tensor_datatype(model.graph.input[0].name, idt) + + # reference calculation + input = gen_finn_dt_tensor(idt, ishape) + input_t = {model.graph.input[0].name: input} + + y_ref = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + + model = model.transform(to_hw.InferCrop()) + + input_t = {model.graph.input[0].name: input} + y_hw = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + + assert (y_ref == y_hw).all() + + model = model.transform(SpecializeLayers(test_fpga_part)) + assert model.graph.node[0].op_type == "Crop_hls" + getCustomOp(model.graph.node[0]).set_nodeattr("SIMD", simd) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(SetExecMode(exec_mode)) + + if exec_mode == "cppsim": + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + elif exec_mode == "rtlsim": + model = model.transform(PrepareIP(test_fpga_part, target_clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + input_t = {model.graph.input[0].name: input} + + y_sim = oxe.execute_onnx(model, input_t)[model.graph.output[0].name] + + assert (y_ref == y_sim).all() + + if exec_mode == "rtlsim": + cycles_rtlsim = getCustomOp(model.graph.node[0]).get_nodeattr("cycles_rtlsim") + exp_cycles_dict = model.analysis(exp_cycles_per_layer) + exp_cycles = exp_cycles_dict[model.graph.node[0].name] + assert np.isclose(exp_cycles, cycles_rtlsim, atol=10) + assert exp_cycles != 0