Skip to content

Commit 64282e5

Browse files
authored
Merge pull request #31 from iksnagreb/fix/lookup
[Lookup] Relax input datatype constraints
2 parents bfc66a0 + 3de81d0 commit 64282e5

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/finn/custom_op/fpgadataflow/hls/lookup_hls.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import numpy as np
3030
import os
31+
import warnings
3132
from math import ceil, log2
3233
from qonnx.core.datatype import DataType
3334

@@ -273,7 +274,18 @@ def execute_node(self, context, graph):
273274
)
274275

275276
inp = context[node.input[0]]
276-
assert inp.dtype == np.int64, "Inputs must be contained in int64 ndarray"
277+
278+
# Make sure the input has the right container datatype
279+
if inp.dtype is not np.float32:
280+
# Issue a warning to make the user aware of this type-cast
281+
warnings.warn(
282+
f"{node.name}: Changing input container datatype from "
283+
f"{inp.dtype} to {np.float32}"
284+
)
285+
# Convert the input to floating point representation as the
286+
# container datatype
287+
inp = inp.astype(np.float32)
288+
277289
assert inp.shape == exp_ishape, """Input shape doesn't match expected shape."""
278290
export_idt = self.get_input_datatype()
279291
odt = self.get_output_datatype()

src/finn/custom_op/fpgadataflow/rtl/streamingfifo_rtl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,18 @@ def execute_node(self, context, graph):
133133
elif mode == "rtlsim":
134134
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
135135
# create a npy file for the input of the node
136-
assert (
137-
str(inp.dtype) == "float32"
138-
), """Input datatype is
139-
not float32 as expected."""
136+
137+
# Make sure the input has the right container datatype
138+
if inp.dtype is not np.float32:
139+
# Issue a warning to make the user aware of this type-cast
140+
warnings.warn(
141+
f"{node.name}: Changing input container datatype from "
142+
f"{inp.dtype} to {np.float32}"
143+
)
144+
# Convert the input to floating point representation as the
145+
# container datatype
146+
inp = inp.astype(np.float32)
147+
140148
expected_inp_shape = self.get_folded_input_shape()
141149
reshaped_input = inp.reshape(expected_inp_shape)
142150
if DataType[self.get_nodeattr("dataType")] == DataType["BIPOLAR"]:

0 commit comments

Comments
 (0)