@@ -319,7 +319,15 @@ def get_tree_model(self):
319319 return None
320320
321321 def derive_token_access_vectors (
322- self , model , period , strategy , fpga_part , clk_period , op_type , override_dict = None
322+ self ,
323+ model ,
324+ period ,
325+ strategy ,
326+ fpga_part ,
327+ clk_period ,
328+ op_type ,
329+ override_dict = None ,
330+ pre_hook = None ,
323331 ):
324332 if override_dict is None :
325333 n_inps = np .prod (self .get_folded_input_shape ()[:- 1 ])
@@ -342,7 +350,9 @@ def derive_token_access_vectors(
342350 # there is a 20 clock marging added for when get_exp_cycles()
343351 # is underestimating the real operator runtime.
344352 period = self .get_exp_cycles () + 20
345- self .derive_token_access_vectors_using_rtlsim (model , period , fpga_part , clk_period , io_dict )
353+ self .derive_token_access_vectors_using_rtlsim (
354+ model , period , fpga_part , clk_period , io_dict , pre_hook = pre_hook
355+ )
346356
347357 def derive_token_access_vectors_using_tree_model (self , period , io_dict ):
348358 # Analytical flow
@@ -469,9 +479,9 @@ def apply_micro_buffer_correction(start, txn_in, period):
469479
470480 def generate_hdl_memstream (self , fpgapart , pumped_memory = 0 ):
471481 """Helper function to generate verilog code for memstream component.
472- Currently utilized by MVAU, VVAU and HLS Thresholding layer ."""
482+ Currently utilized by MVAU, VVAU, HLS Thresholding and Elementwise layers ."""
473483 ops = ["MVAU_hls" , "MVAU_rtl" , "VVAU_hls" , "VVAU_rtl" , "Thresholding_hls" ]
474- if self .onnx_node .op_type in ops :
484+ if self .onnx_node .op_type in ops or self . onnx_node . op_type . startswith ( "Elementwise" ) :
475485 template_path = (
476486 os .environ ["FINN_ROOT" ] + "/finn-rtllib/memstream/hdl/memstream_wrapper_template.v"
477487 )
@@ -601,10 +611,16 @@ def generate_hdl_dynload(self):
601611 f .write (template_wrapper )
602612
603613 def derive_token_access_vectors_using_rtlsim (
604- self , model , period , fpga_part , clk_period , override_rtlsim_dict = None
614+ self , model , period , fpga_part , clk_period , override_rtlsim_dict = None , pre_hook = None
605615 ):
606616 """Return the token access vectors for this node using rtlsim.
607- Used by analytical FIFO sizing approach."""
617+ Used by analytical FIFO sizing approach.
618+
619+ Args:
620+ pre_hook: Optional callable that takes sim as argument, called after
621+ reset_rtlsim but before running the simulation. Used by
622+ FINNLoop to initialize MLO state.
623+ """
608624 # ensure rtlsim is ready
609625
610626 periods_to_simulate = 5
@@ -650,6 +666,8 @@ def derive_token_access_vectors_using_rtlsim(
650666 # signal name, note no underscore at the end (new finnxsi behavior)
651667 sname = "_V"
652668 self .reset_rtlsim (sim )
669+ if pre_hook is not None :
670+ pre_hook (sim )
653671
654672 # create stream tracers for all input and output streams
655673 for k in txns_in .keys ():
0 commit comments