@@ -155,81 +155,41 @@ def prepare_codegen_rtl_values(self, model):
155155 their key value(s) in the RTL template files"""
156156 code_gen_dict = {}
157157
158- thresholds = model .get_initializer (self .onnx_node .input [1 ])
158+ t_path = self .get_nodeattr ("code_gen_dir_ipgen" )
159+
160+ self .generate_params (model , t_path )
161+
159162 bias = self .get_nodeattr ("ActVal" ) # activation bias value
160163 output_data_type = self .get_nodeattr ("outputDataType" ) # output precision
161164 input_data_type = self .get_nodeattr ("inputDataType" ) # input/threshold precision
162165 o_bitwidth = DataType [output_data_type ].bitwidth ()
163-
164- t_path = self .get_nodeattr ("code_gen_dir_ipgen" )
165- if self .get_nodeattr ("runtime_writeable_weights" ) == 1 :
166- thresh_file_name = f"{ t_path } /memblock.dat"
167- self .make_weight_file (thresholds , "decoupled" , thresh_file_name )
166+ pe = self .get_nodeattr ("PE" )
167+ num_channels = self .get_nodeattr ("NumChannels" ) # number of channels
168168
169169 # The RTL expects 2^N-1 thresholds, but narrow range quantization will result in
170170 # one less threshold, prepending a dummy threshold (minimal possible value determined by
171171 # input data type) and decrease the bias by 1.
172- # Additionally, increase number of threshold steps to reflect new shape
173172 expected_thresholds = 2 ** o_bitwidth - 1
174173 n_thres_steps = self .get_nodeattr ("numSteps" )
175174 wdt = self .get_input_datatype (1 )
176175 if expected_thresholds != n_thres_steps :
177176 if DataType [output_data_type ].signed ():
178- min_val = wdt .min ()
179- thresholds = np .insert (thresholds , 0 , min_val , axis = 1 )
180177 bias = bias - 1
181- # TODO: temporary fix for unsigned narrow quantization
182178 else :
183179 max_val = wdt .max ()
184- if max_val > DataType [input_data_type ].max ():
185- thresholds = np .insert (thresholds , len (thresholds [0 ]), max_val , axis = 1 )
186- else :
180+ if max_val <= DataType [input_data_type ].max ():
187181 max_val = max_val + 1
188182 # increase wdt
189183 if not wdt .signed ():
190184 wdt = DataType .get_smallest_possible (max_val )
191185 else :
192186 wdt = DataType .get_smallest_possible (- max_val - 1 )
193- thresholds = np .insert (thresholds , len (thresholds [0 ]), max_val , axis = 1 )
194- n_thres_steps += 1
195-
196- # add dummy dimension as final dimension (that's what gets packed with next call)
197- t_expand = np .expand_dims (thresholds , axis = - 1 )
198- bw_hexdigit = roundup_to_integer_multiple (wdt .bitwidth (), 4 )
199- t_packed = pack_innermost_dim_as_hex_string (
200- t_expand ,
201- wdt ,
202- bw_hexdigit ,
203- prefix = "" ,
204- )
205-
206- pe = self .get_nodeattr ("PE" )
207- num_channels = self .get_nodeattr ("NumChannels" ) # number of channels
208187
209- # If a single threshold value is found, broadcast the value
210- if t_packed . shape [ 0 ] == 1 :
211- t_packed = np . broadcast_to ( t_packed , ( pe , expected_thresholds ))
188+ # If a single threshold value is found, set num_channels to PE
189+ thresholds = model . get_initializer ( self . onnx_node . input [ 1 ])
190+ if thresholds . shape [ 0 ] == 1 :
212191 num_channels = pe
213192
214- channel_fold = int (num_channels / pe )
215-
216- for stage in range (o_bitwidth ):
217- sn = o_bitwidth - stage - 1
218- for pe_value in range (pe ):
219- thresh_file = t_path + "/%s_threshs_%s_%s.dat" % (
220- self .onnx_node .name ,
221- pe_value ,
222- stage ,
223- )
224- threshs = np .zeros ([channel_fold * (2 ** stage )], dtype = "object" )
225- for ch in range (channel_fold ):
226- for i in range (2 ** stage ):
227- threshs [(ch << stage ) + i ] = t_packed [ch * pe + pe_value ][
228- (i << (o_bitwidth - stage )) + 2 ** sn - 1
229- ]
230- with open (thresh_file , "w" ) as f :
231- for val in threshs :
232- f .write (val + "\n " )
233193 code_gen_dict ["$THRESHOLDS_PATH$" ] = ['"./%s_"' % self .onnx_node .name ]
234194
235195 # Identify the module name
@@ -433,6 +393,14 @@ def get_verilog_top_module_intf_names(self):
433393
434394 return intf_names
435395
396+ def generate_params (self , model , path ):
397+ thresholds = model .get_initializer (self .onnx_node .input [1 ])
398+ rt_weights = self .get_nodeattr ("runtime_writeable_weights" )
399+ file_name = "{}/memblock.dat" .format (path )
400+ if rt_weights :
401+ self .make_weight_file (thresholds , "decoupled_runtime" , file_name )
402+ self .make_weight_file (thresholds , "internal_embedded" , file_name )
403+
436404 def make_weight_file (self , weights , weight_file_mode , weight_file_name ):
437405 """Produce a file containing given weights (thresholds) in appropriate
438406 format for this layer. This file can be used for either synthesis or
@@ -444,14 +412,20 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
444412 * weight_file_name : filename for the weight file to be generated
445413
446414 """
415+ path = os .path .dirname (weight_file_name )
416+ if not path :
417+ path = os .getcwd ()
447418 thresholds = weights
448419 pe = self .get_nodeattr ("PE" )
449- ch = self .get_nodeattr ("NumChannels" )
420+ num_channels = self .get_nodeattr ("NumChannels" ) # number of channels
450421 output_data_type = self .get_nodeattr ("outputDataType" ) # output precision
451422 o_bitwidth = DataType [output_data_type ].bitwidth ()
423+ input_data_type = self .get_nodeattr ("inputDataType" ) # input/threshold precision
424+
452425 # The RTL expects 2^N-1 thresholds, but narrow range quantization will result in
453426 # one less threshold, prepending a dummy threshold (minimal possible value determined by
454- # input data type) and decrease the bias by 1.
427+ # input data type)
428+ # and decrease the bias by 1 (needs to be done in code generation when bias is set).
455429 # Additionally, increase number of threshold steps to reflect new shape
456430 expected_thresholds = 2 ** o_bitwidth - 1
457431 n_thres_steps = self .get_nodeattr ("numSteps" )
@@ -463,7 +437,7 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
463437 # TODO: temporary fix for unsigned narrow quantization
464438 else :
465439 max_val = wdt .max ()
466- if max_val > self . get_input_datatype ( 0 ) .max ():
440+ if max_val > DataType [ input_data_type ] .max ():
467441 thresholds = np .insert (thresholds , len (thresholds [0 ]), max_val , axis = 1 )
468442 else :
469443 max_val = max_val + 1
@@ -475,35 +449,68 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
475449 thresholds = np .insert (thresholds , len (thresholds [0 ]), max_val , axis = 1 )
476450 n_thres_steps += 1
477451
478- # If a single threshold value is found, broadcast the value
479- if thresholds .shape [0 ] == 1 :
480- thresholds = np .broadcast_to (thresholds , (pe , expected_thresholds ))
481- ch = pe
482-
483- width_padded = roundup_to_integer_multiple (thresholds .shape [1 ], 2 ** o_bitwidth )
484- thresh_padded = np .zeros ((thresholds .shape [0 ], width_padded ))
485- thresh_padded [: thresholds .shape [0 ], :n_thres_steps ] = thresholds
486- thresh_stream = []
487- bw_hexdigit = roundup_to_integer_multiple (wdt .bitwidth (), 32 )
488- padding = np .zeros (width_padded , dtype = np .int32 )
489-
490- chan_ind = 0
491- cf = ch // pe
492- for fold in range (cf ):
493- for c in range (2 ** (pe - 1 ).bit_length ()):
494- if (c == 0 or c % pe != 0 ) and c < pe :
495- for t in thresh_padded [chan_ind ]:
496- t_packed = pack_innermost_dim_as_hex_string (
497- [t ], wdt , bw_hexdigit , prefix = ""
498- ).item ()
499- thresh_stream .append (t_packed )
500- chan_ind += 1
501- else :
502- for z in padding :
503- t_packed = pack_innermost_dim_as_hex_string (
504- [z ], wdt , bw_hexdigit , prefix = ""
505- ).item ()
506- thresh_stream .append (t_packed )
507- with open (weight_file_name , "w" ) as f :
508- for val in thresh_stream :
509- f .write (val + "\n " )
452+ if weight_file_mode == "decoupled_runtime" :
453+ # If a single threshold value is found, broadcast the value
454+ if thresholds .shape [0 ] == 1 :
455+ thresholds = np .broadcast_to (thresholds , (pe , expected_thresholds ))
456+ num_channels = pe
457+ width_padded = roundup_to_integer_multiple (thresholds .shape [1 ], 2 ** o_bitwidth )
458+ thresh_padded = np .zeros ((thresholds .shape [0 ], width_padded ))
459+ thresh_padded [: thresholds .shape [0 ], :n_thres_steps ] = thresholds
460+ thresh_stream = []
461+ bw_hexdigit = roundup_to_integer_multiple (wdt .bitwidth (), 32 )
462+ padding = np .zeros (width_padded , dtype = np .int32 )
463+
464+ chan_ind = 0
465+ cf = num_channels // pe
466+ for fold in range (cf ):
467+ for c in range (2 ** (pe - 1 ).bit_length ()):
468+ if (c == 0 or c % pe != 0 ) and c < pe :
469+ for t in thresh_padded [chan_ind ]:
470+ t_packed = pack_innermost_dim_as_hex_string (
471+ [t ], wdt , bw_hexdigit , prefix = ""
472+ ).item ()
473+ thresh_stream .append (t_packed )
474+ chan_ind += 1
475+ else :
476+ for z in padding :
477+ t_packed = pack_innermost_dim_as_hex_string (
478+ [z ], wdt , bw_hexdigit , prefix = ""
479+ ).item ()
480+ thresh_stream .append (t_packed )
481+ with open (weight_file_name , "w" ) as f :
482+ for val in thresh_stream :
483+ f .write (val + "\n " )
484+ elif weight_file_mode == "internal_embedded" :
485+ # add dummy dimension as final dimension (that's what gets packed with next call)
486+ t_expand = np .expand_dims (thresholds , axis = - 1 )
487+ bw_hexdigit = roundup_to_integer_multiple (wdt .bitwidth (), 4 )
488+ t_packed = pack_innermost_dim_as_hex_string (
489+ t_expand ,
490+ wdt ,
491+ bw_hexdigit ,
492+ prefix = "" ,
493+ )
494+ # If a single threshold value is found, broadcast the value
495+ if t_packed .shape [0 ] == 1 :
496+ t_packed = np .broadcast_to (t_packed , (pe , expected_thresholds ))
497+ num_channels = pe
498+ channel_fold = int (num_channels / pe )
499+
500+ for stage in range (o_bitwidth ):
501+ sn = o_bitwidth - stage - 1
502+ for pe_value in range (pe ):
503+ thresh_file = path + "/%s_threshs_%s_%s.dat" % (
504+ self .onnx_node .name ,
505+ pe_value ,
506+ stage ,
507+ )
508+ threshs = np .zeros ([channel_fold * (2 ** stage )], dtype = "object" )
509+ for ch in range (channel_fold ):
510+ for i in range (2 ** stage ):
511+ threshs [(ch << stage ) + i ] = t_packed [ch * pe + pe_value ][
512+ (i << (o_bitwidth - stage )) + 2 ** sn - 1
513+ ]
514+ with open (thresh_file , "w" ) as f :
515+ for val in threshs :
516+ f .write (val + "\n " )
0 commit comments