2727# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2828# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
30+ import math
3031import numpy as np
32+ import warnings
3133from qonnx .core .datatype import DataType
3234from qonnx .util .basic import roundup_to_integer_multiple
3335
3638
3739class StreamingConcat (HWCustomOp ):
3840 """Abstraction layer for HW implementation of Concat.
39- Only supports concatenating along the last axis."""
41+ Only supports concatenating along the last (channel) axis."""
4042
4143 def __init__ (self , onnx_node , ** kwargs ):
4244 super ().__init__ (onnx_node , ** kwargs )
4345
4446 def get_nodeattr_types (self ):
4547 my_attrs = {
48+ "SIMD" : ("i" , True , 0 ),
4649 # number of elements from each stream to concat
47- "ElemsPerStream " : ("ints" , True , []),
48- # FINN DataTypes for inputs; output datatype inferred from input
49- "inputDataType " : ("s " , True , "" ),
50+ "ChannelsPerStream " : ("ints" , True , []),
51+ # FINN DataTypes for inputs; output datatype inferred from inputs
52+ "inputDataTypes " : ("strings " , True , [ "" ] ),
5053 # number of input vectors for non-concat axes, examples:
5154 # [1] is a single vector (like a FC layer with batch=1)
5255 # [4] is four vectors (like a FC layer with batch=4)
@@ -57,29 +60,36 @@ def get_nodeattr_types(self):
5760 return my_attrs
5861
5962 def get_n_inputs (self ):
60- return len (self .get_nodeattr ("ElemsPerStream " ))
63+ return len (self .get_nodeattr ("ChannelsPerStream " ))
6164
6265 def get_total_elems (self ):
63- elems_per_stream = self .get_nodeattr ("ElemsPerStream " )
66+ elems_per_stream = self .get_nodeattr ("ChannelsPerStream " )
6467 return int (np .sum (elems_per_stream ))
6568
6669 def get_normal_input_shape (self , ind = 0 ):
67- elems_per_stream = self .get_nodeattr ("ElemsPerStream " )
70+ elems_per_stream = self .get_nodeattr ("ChannelsPerStream " )
6871 elems = elems_per_stream [ind ]
6972 vecs = list (self .get_nodeattr ("numInputVectors" ))
7073 ishape = tuple (vecs + [elems ])
7174 return ishape
7275
7376 def get_folded_input_shape (self , ind = 0 ):
74- return self .get_normal_input_shape (ind )
77+ simd = self .get_nodeattr ("SIMD" )
78+ folds = self .get_nodeattr ("ChannelsPerStream" )[ind ] // simd
79+ vecs = list (self .get_nodeattr ("numInputVectors" ))
80+ return tuple (vecs + [folds , simd ])
7581
7682 def get_normal_output_shape (self , ind = 0 ):
7783 total_elems = self .get_total_elems ()
7884 vecs = list (self .get_nodeattr ("numInputVectors" ))
7985 return tuple (vecs + [total_elems ])
8086
8187 def get_folded_output_shape (self , ind = 0 ):
82- return self .get_normal_output_shape ()
88+ total_elems = self .get_total_elems ()
89+ simd = self .get_nodeattr ("SIMD" )
90+ folds = total_elems // simd
91+ vecs = list (self .get_nodeattr ("numInputVectors" ))
92+ return tuple (vecs + [folds , simd ])
8393
8494 def make_shape_compatible_op (self , model ):
8595 # check all input shapes
@@ -94,7 +104,16 @@ def infer_node_datatype(self, model):
94104 # check all input datatypes
95105 for i , inp in enumerate (self .onnx_node .input ):
96106 idt = model .get_tensor_datatype (inp )
97- assert idt == self .get_input_datatype ()
107+ if idt != self .get_input_datatype (i ):
108+ warn_str = "inputDataType changing for %s: %s -> %s " % (
109+ self .onnx_node .name ,
110+ str (self .get_input_datatype (i )),
111+ str (idt ),
112+ )
113+ warnings .warn (warn_str )
114+ old_datatypes_attr = self .get_nodeattr ("inputDataTypes" )
115+ old_datatypes_attr [i ] = idt .name
116+ self .set_nodeattr ("inputDataTypes" , old_datatypes_attr )
98117 odt = self .get_output_datatype ()
99118 model .set_tensor_datatype (self .onnx_node .output [0 ], odt )
100119
@@ -103,21 +122,37 @@ def verify_node(self):
103122
104123 def get_input_datatype (self , ind = 0 ):
105124 # input dt identical for all inputs
106- return DataType [self .get_nodeattr ("inputDataType" ) ]
125+ return DataType [self .get_nodeattr ("inputDataTypes" )[ ind ] ]
107126
108127 def get_output_datatype (self , ind = 0 ):
109- return self .get_input_datatype ()
128+ # infer output datatype from declared inputDataTypes
129+ min_input = 0
130+ max_input = 0
131+ for i in range (len (self .get_nodeattr ("inputDataTypes" ))):
132+ idt = self .get_input_datatype (i )
133+ if idt .min () < min_input :
134+ min_input = idt .min ()
135+ if idt .max () > max_input :
136+ max_input = idt .max ()
137+ # if the input range is always greater than 0, then acc_max <= 2^P - 1
138+ if min_input >= 0 :
139+ out_bit_width = math .ceil (np .log2 (max_input + 1 ))
140+ odt = DataType [f"UINT{ out_bit_width } " ]
141+ # if the input range is signed, then acc_min >= -2^{P-1} and acc_max <=
142+ # 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
143+ else :
144+ max_abs_input = max (- min_input , 1 + max_input )
145+ out_bit_width = math .ceil (np .log2 (max_abs_input ) + 1 )
146+ odt = DataType [f"INT{ out_bit_width } " ]
147+ return odt
110148
111149 def get_instream_width (self , ind = 0 ):
112- elems_per_stream = self .get_nodeattr ("ElemsPerStream" )
113- elems = elems_per_stream [ind ]
114- ibits = self .get_input_datatype ().bitwidth ()
115- return elems * ibits
150+ ibits = self .get_input_datatype (ind ).bitwidth ()
151+ return ibits * self .get_nodeattr ("SIMD" )
116152
117153 def get_outstream_width (self , ind = 0 ):
118154 obits = self .get_output_datatype ().bitwidth ()
119- total_elems = self .get_total_elems ()
120- out_width = total_elems * obits
155+ out_width = obits * self .get_nodeattr ("SIMD" )
121156 return out_width
122157
123158 def get_number_output_values (self ):
0 commit comments