Skip to content

Commit cd041b5

Browse files
simonmaurerlgeiger
andauthored
QuantConv2D binarized activations with tf.int32 bitpacked output (#611)
* added function strip_lcedequantize_ops: - strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor - usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor - use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign()) - import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops` * reformatted using black code style * added pytest module for verifying lce_dequantize_ops * fixed larq import errors and renamed unit test function * fix PyFlakes error due to typo when defining toy_model * using Interpreter from larq_compute_engine.tflite.python.interpreter instead of tf.lite * reformatted strip_lcedequantize_test.py using black code style * added function strip_lcedequantize_ops: - strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor - usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor - use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign()) - import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops` * reformatted using black code style * added pytest module for verifying lce_dequantize_ops * fixed larq import errors and renamed unit test function * fix PyFlakes error due to typo when defining toy_model * using Interpreter from larq_compute_engine.tflite.python.interpreter instead of tf.lite * reformatted strip_lcedequantize_test.py using black code style * Remove dependency of compute engine interpreter * Add bazel target for dequantize test * Update strip_lcedequantize_test.py fixed test_strip_lcedequantize_ops function test as only models with tf.float32 output will result in tf.int32 tensor outputs when used with strip_lcedequantize_ops * Update strip_lcedequantize_test.py refactored if-else statement * Update strip_lcedequantize_test.py deactivate setting default int8 ranges for `tf.float32` models as the strip_lcedequantize_ops function will not remove `LceDequantize` ops * fix: accidentally added merge indicators * Update strip_lcedequantize_test.py Testing strip_lcedequantize_ops for tf.float32 output: - fix double allocation of Interpreter, using tf.lite.Interpreter instead - fix typo when converting model to TFLite model * Update strip_lcedequantize_test.py removed import of Larq interpreter due to Lint tests failing * Adapt unit test for output type checking - only validate output after LceDequantize ops have been stripped, input type tests already validated in end2end_test.py * Update strip_lcedequantize_test.py fix: setting inference_input_type statically to tf.float32 as we're only validating the output * set tf.float32 as parametrized input type * Updated strip_lcedequantize_ops() to support more models: - updated signature defs for TF2.5 compatibility - support int8-quantized models when stripping LceDequantize op for int8 output - support int8-quantized models when using dequantized tf.float32 output, strips Dequantize operator first then LceDequantize * Unit tests for tf.int8 input/output models * Correction in toy_model_int8_sign - fake quantize before QuantConv2D * Extended Unit tests for test_strip_lcedequantize_ops() to parametrize experimental_enable_bitpacked_activations * Clean up using black code style Co-authored-by: Lukas Geiger <[email protected]>
1 parent 5b8284c commit cd041b5

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

.github/workflows/unittests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ jobs:
9090
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all
9191
- name: Run End2End tests
9292
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all
93+
- name: Run Strip dequantize op tests
94+
run: bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all
9395

9496
ConverterPython:
9597
runs-on: ubuntu-latest

larq_compute_engine/mlir/python/util.py

+143
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,146 @@ def modify_integer_quantized_model_io_type(
225225

226226
# Convert the model to a bytearray
227227
return _convert_model_from_object_to_bytearray(model)
228+
229+
230+
def strip_lcedequantize_ops(model):
231+
"""Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
232+
# Convert the model to an object
233+
model = _convert_model_from_bytearray_to_object(model)
234+
235+
if len(model.subgraphs) > 1:
236+
raise ValueError(
237+
"Model must only have one subgraph. Instead, it has "
238+
"{} subgraphs.".format(len(model.subgraphs))
239+
)
240+
241+
# Ensure model has at least one LceDequantize and/or Dequantize operator
242+
lce_dequant_opcode_idx, dequant_opcode_idx = None, None
243+
for idx, opcode in enumerate(model.operatorCodes):
244+
if opcode.customCode == b"LceDequantize":
245+
lce_dequant_opcode_idx = idx
246+
elif opcode.builtinCode == tflite_schema.BuiltinOperator.DEQUANTIZE:
247+
dequant_opcode_idx = idx
248+
if lce_dequant_opcode_idx is not None and dequant_opcode_idx is not None:
249+
break
250+
if lce_dequant_opcode_idx is None and dequant_opcode_idx is None:
251+
raise ValueError(
252+
"Model does not contain any LceDequantize or Dequantize operators."
253+
)
254+
255+
# Ensure model outputs are dequantized and remove Dequantize ops first if any
256+
if dequant_opcode_idx is not None:
257+
subgraph = model.subgraphs[0]
258+
tensors = subgraph.tensors
259+
operators = subgraph.operators
260+
remove_tensors_idxs = set()
261+
262+
output_dequant_ops = []
263+
for op in operators:
264+
# Find output Dequantize operator
265+
if (
266+
op.opcodeIndex == dequant_opcode_idx
267+
and op.outputs[0] in subgraph.outputs
268+
):
269+
pos, float_tensor, int_tensor = (
270+
"output",
271+
tensors[op.outputs[0]],
272+
tensors[op.inputs[0]],
273+
)
274+
output_dequant_ops.append(op)
275+
# Otherwise, ignore
276+
else:
277+
continue
278+
# If found, validate the input/output tensor type
279+
if float_tensor.type != tflite_schema.TensorType.FLOAT32:
280+
raise ValueError(
281+
"Model {} type must be tf.float32. Expected type for tensor with "
282+
"name '{}' is tf.float32, instead type is tf.{}".format(
283+
pos,
284+
float_tensor.name,
285+
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name,
286+
)
287+
)
288+
if int_tensor.type != tflite_schema.TensorType.INT8:
289+
raise ValueError(
290+
"Model is not integer quantized. Expected type for tensor with "
291+
"name '{}' is tf.int8, instead type is tf.{}".format(
292+
int_tensor.name,
293+
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name,
294+
)
295+
)
296+
297+
# Remove the Dequantize operators
298+
for op in output_dequant_ops:
299+
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
300+
if model.signatureDefs:
301+
signature_def = model.signatureDefs[0]
302+
for i in range(len(signature_def.outputs)):
303+
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
304+
signature_def.outputs[i].tensorIndex = op.inputs[0]
305+
remove_tensors_idxs.add(op.outputs[0])
306+
operators.remove(op)
307+
308+
# Remove tensors marked for deletion.
309+
_remove_tensors_from_model(model, remove_tensors_idxs)
310+
311+
subgraph = model.subgraphs[0]
312+
tensors = subgraph.tensors
313+
operators = subgraph.operators
314+
remove_tensors_idxs = set()
315+
316+
# Ensure model outputs are Lce dequantized and remove LceDequantize ops
317+
lce_output_dequant_ops = []
318+
for op in operators:
319+
# Find output LceDequantize operator
320+
if (
321+
op.opcodeIndex == lce_dequant_opcode_idx
322+
and op.outputs[0] in subgraph.outputs
323+
):
324+
pos, output_tensor, input_tensor = (
325+
"output",
326+
tensors[op.outputs[0]],
327+
tensors[op.inputs[0]],
328+
)
329+
lce_output_dequant_ops.append(op)
330+
# Otherwise, ignore
331+
else:
332+
continue
333+
# If found, validate the input/output tensor type
334+
if (
335+
output_tensor.type != tflite_schema.TensorType.FLOAT32
336+
and output_tensor.type != tflite_schema.TensorType.INT8
337+
):
338+
raise ValueError(
339+
"Model {} type must be tf.float32/tf.int8. Expected type for tensor with "
340+
"name '{}' is tf.float32/tf.int8, instead type is tf.{}".format(
341+
pos,
342+
output_tensor.name,
343+
_convert_tflite_enum_type_to_tf_type(output_tensor.type).name,
344+
)
345+
)
346+
if input_tensor.type != tflite_schema.TensorType.INT32:
347+
raise ValueError(
348+
"Expected type for tensor with "
349+
"name '{}' is tf.int32, instead type is tf.{}".format(
350+
input_tensor.name,
351+
_convert_tflite_enum_type_to_tf_type(input_tensor.type).name,
352+
)
353+
)
354+
355+
# Remove the LceDequantize operators
356+
for op in lce_output_dequant_ops:
357+
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
358+
if model.signatureDefs:
359+
signature_def = model.signatureDefs[0]
360+
for i in range(len(signature_def.outputs)):
361+
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
362+
signature_def.outputs[i].tensorIndex = op.inputs[0]
363+
remove_tensors_idxs.add(op.outputs[0])
364+
operators.remove(op)
365+
366+
# Remove tensors marked for deletion.
367+
_remove_tensors_from_model(model, remove_tensors_idxs)
368+
369+
# Convert the model to a bytearray
370+
return _convert_model_from_object_to_bytearray(model)

larq_compute_engine/tests/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ py_test(
2020
],
2121
)
2222

23+
py_test(
24+
name = "strip_lcedequantize_test",
25+
srcs = ["strip_lcedequantize_test.py"],
26+
deps = [
27+
"//larq_compute_engine/mlir:converter",
28+
],
29+
)
30+
2331
py_test(
2432
name = "convert_model",
2533
srcs = ["convert_model.py"],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import sys
2+
3+
import larq as lq
4+
import pytest
5+
import tensorflow as tf
6+
7+
from larq_compute_engine.mlir.python.converter import convert_keras_model
8+
from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops
9+
10+
11+
def toy_model_sign(**kwargs):
12+
img = tf.keras.layers.Input(shape=(224, 224, 3))
13+
x = lq.layers.QuantConv2D(
14+
256,
15+
kernel_size=3,
16+
strides=1,
17+
padding="same",
18+
pad_values=1,
19+
input_quantizer="ste_sign",
20+
kernel_quantizer="ste_sign",
21+
kernel_constraint="weight_clip",
22+
)(img)
23+
x = lq.quantizers.SteSign()(x)
24+
return tf.keras.Model(inputs=img, outputs=x)
25+
26+
27+
def quant(x):
28+
return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)
29+
30+
31+
def toy_model_int8_sign(**kwargs):
32+
img = tf.keras.layers.Input(shape=(224, 224, 3))
33+
x = quant(img)
34+
x = lq.layers.QuantConv2D(
35+
256,
36+
kernel_size=3,
37+
strides=1,
38+
padding="same",
39+
pad_values=1,
40+
input_quantizer="ste_sign",
41+
kernel_quantizer="ste_sign",
42+
kernel_constraint="weight_clip",
43+
)(x)
44+
x = lq.quantizers.SteSign()(x)
45+
x = quant(x)
46+
return tf.keras.Model(inputs=img, outputs=x)
47+
48+
49+
@pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
50+
@pytest.mark.parametrize("inference_input_type", [tf.float32, tf.int8])
51+
@pytest.mark.parametrize("inference_output_type", [tf.float32, tf.int8])
52+
@pytest.mark.parametrize("experimental_enable_bitpacked_activations", [True, False])
53+
def test_strip_lcedequantize_ops(
54+
model_cls,
55+
inference_input_type,
56+
inference_output_type,
57+
experimental_enable_bitpacked_activations,
58+
):
59+
model_lce = convert_keras_model(
60+
model_cls(),
61+
inference_input_type=inference_input_type,
62+
inference_output_type=inference_output_type,
63+
experimental_enable_bitpacked_activations=experimental_enable_bitpacked_activations,
64+
)
65+
model_lce = strip_lcedequantize_ops(model_lce)
66+
interpreter = tf.lite.Interpreter(model_content=model_lce)
67+
output_details = interpreter.get_output_details()
68+
assert len(output_details) == 1
69+
assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype
70+
71+
72+
if __name__ == "__main__":
73+
sys.exit(pytest.main([__file__, "-s"]))

0 commit comments

Comments
 (0)