@@ -113,13 +113,17 @@ def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0):
113113 return model
114114
115115
116+ def _make_input_dict (model , tokens ):
117+ return {model .graph .input [0 ].name : tokens }
118+
119+
116120@pytest .mark .fpgadataflow
117121def test_convert_gather_to_selecttoken ():
118122 model = _make_gather_model (token_index = 2 )
119123 tokens = np .arange (16 , dtype = np .float32 ).reshape (1 , 4 , 4 )
120124 expected = tokens [:, 2 , :]
121125
122- ret = execute_onnx (model , { "tokens" : tokens } )
126+ ret = execute_onnx (model , _make_input_dict ( model , tokens ) )
123127 assert (ret ["out" ] == expected ).all ()
124128
125129 model = model .transform (InferSelectTokenLayer ())
@@ -133,13 +137,13 @@ def test_convert_gather_to_selecttoken():
133137 assert inst .get_exp_cycles () == 16
134138 assert inst .get_nodeattr ("TokenIndex" ) == 2
135139
136- ret = execute_onnx (model , { "tokens" : tokens } )
140+ ret = execute_onnx (model , _make_input_dict ( model , tokens ) )
137141 assert (ret ["out" ] == expected ).all ()
138142
139143 model = model .transform (SpecializeLayers (FPGA_PART ))
144+ model = model .transform (GiveUniqueNodeNames ())
140145 assert model .graph .node [0 ].op_type == "SelectToken_rtl"
141146 assert model .graph .node [0 ].domain == "finn.custom_op.fpgadataflow.rtl"
142- assert model .graph .node [0 ].name == "SelectToken_gather_token"
143147
144148
145149@pytest .mark .fpgadataflow
@@ -149,7 +153,7 @@ def test_selecttoken_python_execution(token_index):
149153 tokens = np .arange (16 , dtype = np .float32 ).reshape (1 , 4 , 4 )
150154 expected = tokens [:, token_index , :]
151155
152- ret = execute_onnx (model , { "tokens" : tokens } )
156+ ret = execute_onnx (model , _make_input_dict ( model , tokens ) )
153157 assert (ret ["out" ] == expected ).all ()
154158
155159
@@ -161,14 +165,15 @@ def test_selecttoken_python_execution(token_index):
161165def test_selecttoken_rtl_codegen (tmp_path , finn_dtype , fold_width ):
162166 model = _make_selecttoken_model (token_index = 3 , simd = 2 , finn_dtype = finn_dtype )
163167 model = model .transform (SpecializeLayers (FPGA_PART ))
168+ model = model .transform (GiveUniqueNodeNames ())
164169
165170 node = model .graph .node [0 ]
166171 inst = getCustomOp (node )
167172 inst .set_nodeattr ("code_gen_dir_ipgen" , str (tmp_path ))
168173 inst .code_generation_ipgen (model , FPGA_PART , CLK_NS )
169174
170175 topname = inst .get_nodeattr ("gen_top_module" )
171- assert topname == "SelectToken_0"
176+ assert topname == node . name
172177 wrapper = tmp_path / (topname + ".v" )
173178 core = tmp_path / "select_token.sv"
174179 assert wrapper .is_file ()
@@ -223,7 +228,7 @@ def test_selecttoken_rtlsim(simd, token_index):
223228 model = model .transform (SetExecMode ("rtlsim" ))
224229 model = model .transform (PrepareRTLSim ())
225230
226- ret = execute_onnx (model , { "tokens" : tokens } )
231+ ret = execute_onnx (model , _make_input_dict ( model , tokens ) )
227232 assert (ret ["out" ] == expected ).all ()
228233
229234 node = model .get_nodes_by_op_type ("SelectToken_rtl" )[0 ]
@@ -246,7 +251,7 @@ def test_selecttoken_stitched_ip_rtlsim(simd, token_index):
246251
247252 model .set_metadata_prop ("exec_mode" , "rtlsim" )
248253
249- ret = execute_onnx (model , { "tokens" : tokens } )
254+ ret = execute_onnx (model , _make_input_dict ( model , tokens ) )
250255 assert (ret ["out" ] == expected ).all ()
251256
252257
0 commit comments