Skip to content

Commit aa0fc75

Browse files
authored
Merge pull request #1361 from Xilinx/refactor/thresh_param_gen
Restructure parameter generation for RTL Thresholding
2 parents 9d54173 + 171c7d9 commit aa0fc75

1 file changed

Lines changed: 92 additions & 85 deletions

File tree

src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py

Lines changed: 92 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -155,81 +155,41 @@ def prepare_codegen_rtl_values(self, model):
155155
their key value(s) in the RTL template files"""
156156
code_gen_dict = {}
157157

158-
thresholds = model.get_initializer(self.onnx_node.input[1])
158+
t_path = self.get_nodeattr("code_gen_dir_ipgen")
159+
160+
self.generate_params(model, t_path)
161+
159162
bias = self.get_nodeattr("ActVal") # activation bias value
160163
output_data_type = self.get_nodeattr("outputDataType") # output precision
161164
input_data_type = self.get_nodeattr("inputDataType") # input/threshold precision
162165
o_bitwidth = DataType[output_data_type].bitwidth()
163-
164-
t_path = self.get_nodeattr("code_gen_dir_ipgen")
165-
if self.get_nodeattr("runtime_writeable_weights") == 1:
166-
thresh_file_name = f"{t_path}/memblock.dat"
167-
self.make_weight_file(thresholds, "decoupled", thresh_file_name)
166+
pe = self.get_nodeattr("PE")
167+
num_channels = self.get_nodeattr("NumChannels") # number of channels
168168

169169
# The RTL expects 2^N-1 thresholds, but narrow range quantization will result in
170170
# one less threshold, prepending a dummy threshold (minimal possible value determined by
171171
# input data type) and decrease the bias by 1.
172-
# Additionally, increase number of threshold steps to reflect new shape
173172
expected_thresholds = 2**o_bitwidth - 1
174173
n_thres_steps = self.get_nodeattr("numSteps")
175174
wdt = self.get_input_datatype(1)
176175
if expected_thresholds != n_thres_steps:
177176
if DataType[output_data_type].signed():
178-
min_val = wdt.min()
179-
thresholds = np.insert(thresholds, 0, min_val, axis=1)
180177
bias = bias - 1
181-
# TODO: temporary fix for unsigned narrow quantization
182178
else:
183179
max_val = wdt.max()
184-
if max_val > DataType[input_data_type].max():
185-
thresholds = np.insert(thresholds, len(thresholds[0]), max_val, axis=1)
186-
else:
180+
if max_val <= DataType[input_data_type].max():
187181
max_val = max_val + 1
188182
# increase wdt
189183
if not wdt.signed():
190184
wdt = DataType.get_smallest_possible(max_val)
191185
else:
192186
wdt = DataType.get_smallest_possible(-max_val - 1)
193-
thresholds = np.insert(thresholds, len(thresholds[0]), max_val, axis=1)
194-
n_thres_steps += 1
195-
196-
# add dummy dimension as final dimension (that's what gets packed with next call)
197-
t_expand = np.expand_dims(thresholds, axis=-1)
198-
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 4)
199-
t_packed = pack_innermost_dim_as_hex_string(
200-
t_expand,
201-
wdt,
202-
bw_hexdigit,
203-
prefix="",
204-
)
205-
206-
pe = self.get_nodeattr("PE")
207-
num_channels = self.get_nodeattr("NumChannels") # number of channels
208187

209-
# If a single threshold value is found, broadcast the value
210-
if t_packed.shape[0] == 1:
211-
t_packed = np.broadcast_to(t_packed, (pe, expected_thresholds))
188+
# If a single threshold value is found, set num_channels to PE
189+
thresholds = model.get_initializer(self.onnx_node.input[1])
190+
if thresholds.shape[0] == 1:
212191
num_channels = pe
213192

214-
channel_fold = int(num_channels / pe)
215-
216-
for stage in range(o_bitwidth):
217-
sn = o_bitwidth - stage - 1
218-
for pe_value in range(pe):
219-
thresh_file = t_path + "/%s_threshs_%s_%s.dat" % (
220-
self.onnx_node.name,
221-
pe_value,
222-
stage,
223-
)
224-
threshs = np.zeros([channel_fold * (2**stage)], dtype="object")
225-
for ch in range(channel_fold):
226-
for i in range(2**stage):
227-
threshs[(ch << stage) + i] = t_packed[ch * pe + pe_value][
228-
(i << (o_bitwidth - stage)) + 2**sn - 1
229-
]
230-
with open(thresh_file, "w") as f:
231-
for val in threshs:
232-
f.write(val + "\n")
233193
code_gen_dict["$THRESHOLDS_PATH$"] = ['"./%s_"' % self.onnx_node.name]
234194

235195
# Identify the module name
@@ -433,6 +393,14 @@ def get_verilog_top_module_intf_names(self):
433393

434394
return intf_names
435395

396+
def generate_params(self, model, path):
397+
thresholds = model.get_initializer(self.onnx_node.input[1])
398+
rt_weights = self.get_nodeattr("runtime_writeable_weights")
399+
file_name = "{}/memblock.dat".format(path)
400+
if rt_weights:
401+
self.make_weight_file(thresholds, "decoupled_runtime", file_name)
402+
self.make_weight_file(thresholds, "internal_embedded", file_name)
403+
436404
def make_weight_file(self, weights, weight_file_mode, weight_file_name):
437405
"""Produce a file containing given weights (thresholds) in appropriate
438406
format for this layer. This file can be used for either synthesis or
@@ -444,14 +412,20 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
444412
* weight_file_name : filename for the weight file to be generated
445413
446414
"""
415+
path = os.path.dirname(weight_file_name)
416+
if not path:
417+
path = os.getcwd()
447418
thresholds = weights
448419
pe = self.get_nodeattr("PE")
449-
ch = self.get_nodeattr("NumChannels")
420+
num_channels = self.get_nodeattr("NumChannels") # number of channels
450421
output_data_type = self.get_nodeattr("outputDataType") # output precision
451422
o_bitwidth = DataType[output_data_type].bitwidth()
423+
input_data_type = self.get_nodeattr("inputDataType") # input/threshold precision
424+
452425
# The RTL expects 2^N-1 thresholds, but narrow range quantization will result in
453426
# one less threshold, prepending a dummy threshold (minimal possible value determined by
454-
# input data type) and decrease the bias by 1.
427+
# input data type)
428+
# and decrease the bias by 1 (needs to be done in code generation when bias is set).
455429
# Additionally, increase number of threshold steps to reflect new shape
456430
expected_thresholds = 2**o_bitwidth - 1
457431
n_thres_steps = self.get_nodeattr("numSteps")
@@ -463,7 +437,7 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
463437
# TODO: temporary fix for unsigned narrow quantization
464438
else:
465439
max_val = wdt.max()
466-
if max_val > self.get_input_datatype(0).max():
440+
if max_val > DataType[input_data_type].max():
467441
thresholds = np.insert(thresholds, len(thresholds[0]), max_val, axis=1)
468442
else:
469443
max_val = max_val + 1
@@ -475,35 +449,68 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name):
475449
thresholds = np.insert(thresholds, len(thresholds[0]), max_val, axis=1)
476450
n_thres_steps += 1
477451

478-
# If a single threshold value is found, broadcast the value
479-
if thresholds.shape[0] == 1:
480-
thresholds = np.broadcast_to(thresholds, (pe, expected_thresholds))
481-
ch = pe
482-
483-
width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth)
484-
thresh_padded = np.zeros((thresholds.shape[0], width_padded))
485-
thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds
486-
thresh_stream = []
487-
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32)
488-
padding = np.zeros(width_padded, dtype=np.int32)
489-
490-
chan_ind = 0
491-
cf = ch // pe
492-
for fold in range(cf):
493-
for c in range(2 ** (pe - 1).bit_length()):
494-
if (c == 0 or c % pe != 0) and c < pe:
495-
for t in thresh_padded[chan_ind]:
496-
t_packed = pack_innermost_dim_as_hex_string(
497-
[t], wdt, bw_hexdigit, prefix=""
498-
).item()
499-
thresh_stream.append(t_packed)
500-
chan_ind += 1
501-
else:
502-
for z in padding:
503-
t_packed = pack_innermost_dim_as_hex_string(
504-
[z], wdt, bw_hexdigit, prefix=""
505-
).item()
506-
thresh_stream.append(t_packed)
507-
with open(weight_file_name, "w") as f:
508-
for val in thresh_stream:
509-
f.write(val + "\n")
452+
if weight_file_mode == "decoupled_runtime":
453+
# If a single threshold value is found, broadcast the value
454+
if thresholds.shape[0] == 1:
455+
thresholds = np.broadcast_to(thresholds, (pe, expected_thresholds))
456+
num_channels = pe
457+
width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth)
458+
thresh_padded = np.zeros((thresholds.shape[0], width_padded))
459+
thresh_padded[: thresholds.shape[0], :n_thres_steps] = thresholds
460+
thresh_stream = []
461+
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 32)
462+
padding = np.zeros(width_padded, dtype=np.int32)
463+
464+
chan_ind = 0
465+
cf = num_channels // pe
466+
for fold in range(cf):
467+
for c in range(2 ** (pe - 1).bit_length()):
468+
if (c == 0 or c % pe != 0) and c < pe:
469+
for t in thresh_padded[chan_ind]:
470+
t_packed = pack_innermost_dim_as_hex_string(
471+
[t], wdt, bw_hexdigit, prefix=""
472+
).item()
473+
thresh_stream.append(t_packed)
474+
chan_ind += 1
475+
else:
476+
for z in padding:
477+
t_packed = pack_innermost_dim_as_hex_string(
478+
[z], wdt, bw_hexdigit, prefix=""
479+
).item()
480+
thresh_stream.append(t_packed)
481+
with open(weight_file_name, "w") as f:
482+
for val in thresh_stream:
483+
f.write(val + "\n")
484+
elif weight_file_mode == "internal_embedded":
485+
# add dummy dimension as final dimension (that's what gets packed with next call)
486+
t_expand = np.expand_dims(thresholds, axis=-1)
487+
bw_hexdigit = roundup_to_integer_multiple(wdt.bitwidth(), 4)
488+
t_packed = pack_innermost_dim_as_hex_string(
489+
t_expand,
490+
wdt,
491+
bw_hexdigit,
492+
prefix="",
493+
)
494+
# If a single threshold value is found, broadcast the value
495+
if t_packed.shape[0] == 1:
496+
t_packed = np.broadcast_to(t_packed, (pe, expected_thresholds))
497+
num_channels = pe
498+
channel_fold = int(num_channels / pe)
499+
500+
for stage in range(o_bitwidth):
501+
sn = o_bitwidth - stage - 1
502+
for pe_value in range(pe):
503+
thresh_file = path + "/%s_threshs_%s_%s.dat" % (
504+
self.onnx_node.name,
505+
pe_value,
506+
stage,
507+
)
508+
threshs = np.zeros([channel_fold * (2**stage)], dtype="object")
509+
for ch in range(channel_fold):
510+
for i in range(2**stage):
511+
threshs[(ch << stage) + i] = t_packed[ch * pe + pe_value][
512+
(i << (o_bitwidth - stage)) + 2**sn - 1
513+
]
514+
with open(thresh_file, "w") as f:
515+
for val in threshs:
516+
f.write(val + "\n")

0 commit comments

Comments
 (0)