Skip to content

Commit fbcd896

Browse files
committed
[CustomOp] Update derive TAV method to allow for pre-hook for rtlsim
1 parent 09dbd19 commit fbcd896

5 files changed

Lines changed: 75 additions & 12 deletions

File tree

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ markers =
132132
bnn_kv260: mark tests that execute KV260 BNN tests
133133
bnn_pynq: mark tests that execute Pynq-Z1 BNN tests
134134
bnn_zcu104: mark tests that execute ZCU104 BNN tests
135+
node_tree_modeling: mark tests for analytical FIFO sizing tree models
135136
norecursedirs =
136137
dist
137138
build

src/finn/custom_op/fpgadataflow/hls/thresholding_hls.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,17 @@ def ipgen_extra_directives(self):
769769

770770
return ["config_compile -pipeline_style frp"]
771771

772-
def derive_characteristic_fxns(self, period):
772+
def derive_token_access_vectors(
773+
self,
774+
model,
775+
period,
776+
strategy,
777+
fpga_part,
778+
clk_period,
779+
op_type,
780+
override_dict=None,
781+
pre_hook=None,
782+
):
773783
n_inps = np.prod(self.get_folded_input_shape()[:-1])
774784
io_dict = {
775785
"inputs": {
@@ -782,7 +792,9 @@ def derive_characteristic_fxns(self, period):
782792
n_weight_inps = self.calc_tmem()
783793
num_w_reps = np.prod(self.get_nodeattr("numInputVectors"))
784794
io_dict["inputs"]["in1"] = [0 for i in range(num_w_reps * n_weight_inps)]
785-
super().derive_characteristic_fxns(period, override_rtlsim_dict=io_dict)
795+
super().derive_token_access_vectors(
796+
model, period, strategy, fpga_part, clk_period, op_type, io_dict, pre_hook=pre_hook
797+
)
786798

787799
def minimize_weight_bit_width(self, model):
788800
"""Minimize threshold datatype, with HLS-specific adjustments.

src/finn/custom_op/fpgadataflow/hwcustomop.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,15 @@ def get_tree_model(self):
319319
return None
320320

321321
def derive_token_access_vectors(
322-
self, model, period, strategy, fpga_part, clk_period, op_type, override_dict=None
322+
self,
323+
model,
324+
period,
325+
strategy,
326+
fpga_part,
327+
clk_period,
328+
op_type,
329+
override_dict=None,
330+
pre_hook=None,
323331
):
324332
if override_dict is None:
325333
n_inps = np.prod(self.get_folded_input_shape()[:-1])
@@ -342,7 +350,9 @@ def derive_token_access_vectors(
342350
# there is a 20 clock marging added for when get_exp_cycles()
343351
# is underestimating the real operator runtime.
344352
period = self.get_exp_cycles() + 20
345-
self.derive_token_access_vectors_using_rtlsim(model, period, fpga_part, clk_period, io_dict)
353+
self.derive_token_access_vectors_using_rtlsim(
354+
model, period, fpga_part, clk_period, io_dict, pre_hook=pre_hook
355+
)
346356

347357
def derive_token_access_vectors_using_tree_model(self, period, io_dict):
348358
# Analytical flow
@@ -469,9 +479,9 @@ def apply_micro_buffer_correction(start, txn_in, period):
469479

470480
def generate_hdl_memstream(self, fpgapart, pumped_memory=0):
471481
"""Helper function to generate verilog code for memstream component.
472-
Currently utilized by MVAU, VVAU and HLS Thresholding layer."""
482+
Currently utilized by MVAU, VVAU, HLS Thresholding and Elementwise layers."""
473483
ops = ["MVAU_hls", "MVAU_rtl", "VVAU_hls", "VVAU_rtl", "Thresholding_hls"]
474-
if self.onnx_node.op_type in ops:
484+
if self.onnx_node.op_type in ops or self.onnx_node.op_type.startswith("Elementwise"):
475485
template_path = (
476486
os.environ["FINN_ROOT"] + "/finn-rtllib/memstream/hdl/memstream_wrapper_template.v"
477487
)
@@ -601,10 +611,16 @@ def generate_hdl_dynload(self):
601611
f.write(template_wrapper)
602612

603613
def derive_token_access_vectors_using_rtlsim(
604-
self, model, period, fpga_part, clk_period, override_rtlsim_dict=None
614+
self, model, period, fpga_part, clk_period, override_rtlsim_dict=None, pre_hook=None
605615
):
606616
"""Return the token access vectors for this node using rtlsim.
607-
Used by analytical FIFO sizing approach."""
617+
Used by analytical FIFO sizing approach.
618+
619+
Args:
620+
pre_hook: Optional callable that takes sim as argument, called after
621+
reset_rtlsim but before running the simulation. Used by
622+
FINNLoop to initialize MLO state.
623+
"""
608624
# ensure rtlsim is ready
609625

610626
periods_to_simulate = 5
@@ -650,6 +666,8 @@ def derive_token_access_vectors_using_rtlsim(
650666
# signal name, note no underscore at the end (new finnxsi behavior)
651667
sname = "_V"
652668
self.reset_rtlsim(sim)
669+
if pre_hook is not None:
670+
pre_hook(sim)
653671

654672
# create stream tracers for all input and output streams
655673
for k in txns_in.keys():

src/finn/custom_op/fpgadataflow/rtl/elementwise_binary_rtl.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,17 @@ def instantiate_ip(self, cmd):
380380
"create_bd_cell -type hier -reference %s /%s/%s" % (top_module, node_name, node_name)
381381
)
382382

383-
def derive_characteristic_fxns(self, period, override_rtlsim_dict=None, pre_hook=None):
383+
def derive_token_access_vectors(
384+
self,
385+
model,
386+
period,
387+
strategy,
388+
fpga_part,
389+
clk_period,
390+
op_type,
391+
override_dict=None,
392+
pre_hook=None,
393+
):
384394
n_inps = np.prod(self.get_folded_input_shape(0)[:-1])
385395
io_dict = {
386396
"inputs": {
@@ -389,7 +399,9 @@ def derive_characteristic_fxns(self, period, override_rtlsim_dict=None, pre_hook
389399
},
390400
"outputs": {"out0": []},
391401
}
392-
super().derive_characteristic_fxns(period, override_rtlsim_dict=io_dict, pre_hook=pre_hook)
402+
super().derive_token_access_vectors(
403+
model, period, strategy, fpga_part, clk_period, op_type, io_dict, pre_hook=pre_hook
404+
)
393405

394406
def execute_node(self, context, graph):
395407
mode = self.get_nodeattr("exec_mode")

src/finn/custom_op/fpgadataflow/rtl/finn_loop.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,29 @@ def prepare_rtlsim(self, behav=False):
293293
sim_base, sim_rel = rtlsim_so
294294
self.set_nodeattr("rtlsim_so", sim_base + "/" + sim_rel)
295295

296-
def derive_characteristic_fxns(self, period):
296+
def derive_token_access_vectors(
297+
self,
298+
model,
299+
period,
300+
strategy,
301+
fpga_part,
302+
clk_period,
303+
op_type,
304+
override_dict=None,
305+
pre_hook=None,
306+
):
307+
# FINNLoop always uses rtlsim strategy with MLO prehook
297308
mlo_prehook = mlo_prehook_func_factory(self.onnx_node)
298-
super().derive_characteristic_fxns(period, pre_hook=mlo_prehook)
309+
super().derive_token_access_vectors(
310+
model,
311+
period,
312+
"rtlsim",
313+
fpga_part,
314+
clk_period,
315+
op_type,
316+
override_dict,
317+
pre_hook=mlo_prehook,
318+
)
299319

300320
def execute_node(self, context, graph):
301321
node = self.onnx_node

0 commit comments

Comments
 (0)