Skip to content

Commit ae8f3e2

Browse files
committed
Address SelectToken follow-ups after AddCLSToken merge
1 parent 09c3a5a commit ae8f3e2

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

tests/fpgadataflow/test_fpgadataflow_selecttoken.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,17 @@ def _prepare_selecttoken_stitched_ip_model(simd=1, token_index=0):
113113
return model
114114

115115

116+
def _make_input_dict(model, tokens):
117+
return {model.graph.input[0].name: tokens}
118+
119+
116120
@pytest.mark.fpgadataflow
117121
def test_convert_gather_to_selecttoken():
118122
model = _make_gather_model(token_index=2)
119123
tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4)
120124
expected = tokens[:, 2, :]
121125

122-
ret = execute_onnx(model, {"tokens": tokens})
126+
ret = execute_onnx(model, _make_input_dict(model, tokens))
123127
assert (ret["out"] == expected).all()
124128

125129
model = model.transform(InferSelectTokenLayer())
@@ -133,13 +137,13 @@ def test_convert_gather_to_selecttoken():
133137
assert inst.get_exp_cycles() == 16
134138
assert inst.get_nodeattr("TokenIndex") == 2
135139

136-
ret = execute_onnx(model, {"tokens": tokens})
140+
ret = execute_onnx(model, _make_input_dict(model, tokens))
137141
assert (ret["out"] == expected).all()
138142

139143
model = model.transform(SpecializeLayers(FPGA_PART))
144+
model = model.transform(GiveUniqueNodeNames())
140145
assert model.graph.node[0].op_type == "SelectToken_rtl"
141146
assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow.rtl"
142-
assert model.graph.node[0].name == "SelectToken_gather_token"
143147

144148

145149
@pytest.mark.fpgadataflow
@@ -149,7 +153,7 @@ def test_selecttoken_python_execution(token_index):
149153
tokens = np.arange(16, dtype=np.float32).reshape(1, 4, 4)
150154
expected = tokens[:, token_index, :]
151155

152-
ret = execute_onnx(model, {"tokens": tokens})
156+
ret = execute_onnx(model, _make_input_dict(model, tokens))
153157
assert (ret["out"] == expected).all()
154158

155159

@@ -161,14 +165,15 @@ def test_selecttoken_python_execution(token_index):
161165
def test_selecttoken_rtl_codegen(tmp_path, finn_dtype, fold_width):
162166
model = _make_selecttoken_model(token_index=3, simd=2, finn_dtype=finn_dtype)
163167
model = model.transform(SpecializeLayers(FPGA_PART))
168+
model = model.transform(GiveUniqueNodeNames())
164169

165170
node = model.graph.node[0]
166171
inst = getCustomOp(node)
167172
inst.set_nodeattr("code_gen_dir_ipgen", str(tmp_path))
168173
inst.code_generation_ipgen(model, FPGA_PART, CLK_NS)
169174

170175
topname = inst.get_nodeattr("gen_top_module")
171-
assert topname == "SelectToken_0"
176+
assert topname == node.name
172177
wrapper = tmp_path / (topname + ".v")
173178
core = tmp_path / "select_token.sv"
174179
assert wrapper.is_file()
@@ -223,7 +228,7 @@ def test_selecttoken_rtlsim(simd, token_index):
223228
model = model.transform(SetExecMode("rtlsim"))
224229
model = model.transform(PrepareRTLSim())
225230

226-
ret = execute_onnx(model, {"tokens": tokens})
231+
ret = execute_onnx(model, _make_input_dict(model, tokens))
227232
assert (ret["out"] == expected).all()
228233

229234
node = model.get_nodes_by_op_type("SelectToken_rtl")[0]
@@ -246,7 +251,7 @@ def test_selecttoken_stitched_ip_rtlsim(simd, token_index):
246251

247252
model.set_metadata_prop("exec_mode", "rtlsim")
248253

249-
ret = execute_onnx(model, {"tokens": tokens})
254+
ret = execute_onnx(model, _make_input_dict(model, tokens))
250255
assert (ret["out"] == expected).all()
251256

252257

0 commit comments

Comments
 (0)