Skip to content

Commit b61314d

Browse files
authored
Merge pull request #118 from iksnagreb/fix/gen_finn_dt_tensor
Make gen_finn_dt_tensor consider the numpy type for INT and FIXED types
2 parents 124fb35 + 99dd4dc commit b61314d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/qonnx/util/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,12 @@ def gen_finn_dt_tensor(finn_dt, tensor_shape):
251251
elif finn_dt == DataType["BINARY"]:
252252
tensor_values = np.random.randint(2, size=tensor_shape)
253253
elif "INT" in finn_dt.name or finn_dt == DataType["TERNARY"]:
254-
tensor_values = np.random.randint(finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape)
254+
tensor_values = np.random.randint(
255+
finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape, dtype=finn_dt.to_numpy_dt()
256+
)
255257
elif "FIXED" in finn_dt.name:
256258
int_dt = DataType["INT" + str(finn_dt.bitwidth())]
257-
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape)
259+
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape, dtype=int_dt.to_numpy_dt())
258260
tensor_values = tensor_values * finn_dt.scale_factor()
259261
elif finn_dt in [DataType["FLOAT32"], DataType["FLOAT16"]]:
260262
tensor_values = np.random.randn(*tensor_shape)

0 commit comments

Comments
 (0)