Skip to content

Commit a4d411d

Browse files
committed
fix: tflite conversion and inference
1 parent ec3ccc2 commit a4d411d

File tree

20 files changed

+317
-230
lines changed

20 files changed

+317
-230
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,3 @@ repos:
99
stages: [pre-commit]
1010
fail_fast: true
1111
verbose: true
12-
- id: pylint-check
13-
name: pylint-check
14-
entry: pylint --rcfile=.pylintrc -rn -sn
15-
language: system
16-
types: [python]
17-
stages: [pre-commit]
18-
fail_fast: true
19-
require_serial: true
20-
verbose: true

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ contextmanager-decorators=contextlib.contextmanager
216216
# List of members which are set dynamically and missed by pylint inference
217217
# system, and so shouldn't trigger E1101 when accessed. Python regular
218218
# expressions are accepted.
219-
generated-members=tensorflow.python
219+
generated-members=tensorflow.python,tensorflow.keras
220220

221221
# Tells whether missing members accessed in mixin class should be ignored. A
222222
# mixin class is detected if its name ends with "mixin" (case insensitive).

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ See [augmentations](./tensorflow_asr/augmentations/README.md)
159159

160160
## TFLite Convertion
161161

162-
After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **unicode code points**, then we can convert unicode points to string.
162+
After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **text and tokens**
163163

164164
See [tflite_convertion](./docs/tutorials/tflite.md)
165165

docs/tutorials/tflite.md

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,66 @@
1-
# TFLite Conversion Tutorial
1+
- [TFLite Tutorial](#tflite-tutorial)
2+
- [Conversion](#conversion)
3+
- [Inference](#inference)
4+
- [1. Input](#1-input)
5+
- [2. Output](#2-output)
6+
- [3. Example script](#3-example-script)
27

3-
## Run
8+
9+
# TFLite Tutorial
10+
11+
## Conversion
412

513
```bash
6-
python examples/train.py \
14+
python3 examples/train.py \
715
--config-path=/path/to/config.yml.j2 \
816
--h5=/path/to/weight.h5 \
17+
--bs=1 \ # Batch size
18+
--beam-width=0 \ # Beam width, set >0 to enable beam search
919
--output=/path/to/output.tflite
1020
## See others params
1121
python examples/tflite.py --help
12-
```
22+
```
23+
24+
## Inference
25+
26+
### 1. Input
27+
28+
Input of each tflite depends on the models' parameters and configs.
29+
30+
The `inputs`, `inputs_length` and `previous_tokens` are still the same as bellow for all models.
31+
32+
```python
33+
schemas.PredictInput(
34+
inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32),
35+
inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32),
36+
previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)),
37+
previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)),
38+
previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)),
39+
)
40+
```
41+
42+
For models that don't have encoder states or decoder states, the default values are `tf.zeros([], dtype=self.dtype)` tensors for `previous_encoder_states` and `previous_decoder_states`. This is just for tflite conversion because tflite does not allow `None` value in `input_signature`. However, the output `next_encoder_states` and `next_decoder_states` are still `None`, so we can simply ignore those outputs.
43+
44+
### 2. Output
45+
46+
```python
47+
schemas.PredictOutputWithTranscript(
48+
transcript=self.tokenizer.detokenize(outputs.tokens),
49+
tokens=outputs.tokens,
50+
next_tokens=outputs.next_tokens,
51+
next_encoder_states=outputs.next_encoder_states,
52+
next_decoder_states=outputs.next_decoder_states,
53+
)
54+
```
55+
56+
This is for supporting streaming inference.
57+
58+
Each output corresponds to the input = each chunk of audio signal.
59+
60+
Then we can overwrite `previous_tokens`, `previous_encoder_states` and `previous_decoder_states` with `next_tokens`, `next_encoder_states` and `next_decoder_states` for the next chunk of audio signal.
61+
62+
And continue until the end of the audio signal.
63+
64+
### 3. Example script
65+
66+
See [examples/inferences/tflite.py](../../examples/inferences/tflite.py) for more details.

examples/inferences/tflite.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,41 +13,57 @@
1313
# limitations under the License.
1414

1515
import tensorflow as tf
16+
import tensorflow_text as tft
17+
from tensorflow.lite.python import interpreter
1618

1719
from tensorflow_asr.utils import cli_util, data_util
1820

1921
logger = tf.get_logger()
2022

2123

2224
def main(
23-
file_path: str,
24-
tflite_path: str,
25-
previous_encoder_states_shape: list = None,
26-
previous_decoder_states_shape: list = None,
27-
blank_index: int = 0,
25+
audio_file_path: str,
26+
tflite: str,
27+
sample_rate: int = 16000,
28+
blank: int = 0,
2829
):
29-
tflitemodel = tf.lite.Interpreter(model_path=tflite_path)
30-
signal = data_util.read_raw_audio(file_path)
30+
wav = data_util.load_and_convert_to_wav(audio_file_path, sample_rate=sample_rate)
31+
signal = data_util.read_raw_audio(wav)
3132
signal = tf.reshape(signal, [1, -1])
3233
signal_length = tf.reshape(tf.shape(signal)[1], [1])
3334

35+
tflitemodel = interpreter.InterpreterWithCustomOps(model_path=tflite, custom_op_registerers=tft.tflite_registrar.SELECT_TFTEXT_OPS)
3436
input_details = tflitemodel.get_input_details()
3537
output_details = tflitemodel.get_output_details()
36-
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
38+
39+
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape, strict=True)
3740
tflitemodel.allocate_tensors()
3841
tflitemodel.set_tensor(input_details[0]["index"], signal)
3942
tflitemodel.set_tensor(input_details[1]["index"], signal_length)
40-
tflitemodel.set_tensor(input_details[2]["index"], tf.constant(blank_index, dtype=tf.int32))
41-
if previous_encoder_states_shape:
42-
tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(previous_encoder_states_shape, dtype=tf.float32))
43-
if previous_decoder_states_shape:
44-
tflitemodel.set_tensor(input_details[5]["index"], tf.zeros(previous_decoder_states_shape, dtype=tf.float32))
43+
tflitemodel.set_tensor(input_details[2]["index"], tf.ones(input_details[2]["shape"], dtype=input_details[2]["dtype"]) * blank)
44+
tflitemodel.set_tensor(input_details[3]["index"], tf.zeros(input_details[3]["shape"], dtype=input_details[3]["dtype"]))
45+
tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(input_details[4]["shape"], dtype=input_details[4]["dtype"]))
46+
4547
tflitemodel.invoke()
46-
hyp = tflitemodel.get_tensor(output_details[0]["index"])
4748

48-
transcript = "".join([chr(u) for u in hyp])
49+
transcript = tflitemodel.get_tensor(output_details[0]["index"])
50+
tokens = tflitemodel.get_tensor(output_details[1]["index"])
51+
next_tokens = tflitemodel.get_tensor(output_details[2]["index"])
52+
if len(output_details) > 4:
53+
next_encoder_states = tflitemodel.get_tensor(output_details[3]["index"])
54+
next_decoder_states = tflitemodel.get_tensor(output_details[4]["index"])
55+
elif len(output_details) > 3:
56+
next_encoder_states = None
57+
next_decoder_states = tflitemodel.get_tensor(output_details[3]["index"])
58+
else:
59+
next_encoder_states = None
60+
next_decoder_states = None
61+
4962
logger.info(f"Transcript: {transcript}")
50-
return transcript
63+
logger.info(f"Tokens: {tokens}")
64+
logger.info(f"Next tokens: {next_tokens}")
65+
logger.info(f"Next encoder states: {None if next_encoder_states is None else next_encoder_states.shape}")
66+
logger.info(f"Next decoder states: {None if next_decoder_states is None else next_decoder_states.shape}")
5167

5268

5369
if __name__ == "__main__":

examples/models/transducer/conformer/inference/run_tflite_model.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

examples/tflite.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,27 @@
2323

2424
def main(
2525
config_path: str,
26-
h5: str,
2726
output: str,
27+
h5: str = None,
2828
bs: int = 1,
29+
beam_width: int = 0,
2930
repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")),
3031
):
31-
assert h5 and output
32+
assert output
3233
tf.keras.backend.clear_session()
3334
env_util.setup_seed()
34-
tf.compat.v1.enable_control_flow_v2()
3535

3636
config = Config(config_path, training=False, repodir=repodir)
3737
tokenizer = tokenizers.get(config)
3838

3939
model: BaseModel = tf.keras.models.model_from_config(config.model_config)
4040
model.tokenizer = tokenizer
41-
model.make()
42-
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5))
41+
model.make(batch_size=bs)
42+
if h5 and tf.io.gfile.exists(h5):
43+
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5))
4344
model.summary()
4445

45-
app_util.convert_tflite(model=model, output=output, batch_size=bs)
46+
app_util.convert_tflite(model=model, output=output, batch_size=bs, beam_width=beam_width)
4647

4748

4849
if __name__ == "__main__":

requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ sounddevice~=0.4.6
88
jinja2~=3.1.3
99
fire~=0.5.0
1010
jiwer~=3.0.3
11-
chardet~=5.1.0
12-
charset-normalizer~=2.1.1
1311

1412
# extra=dev
1513
pytest~=7.4.1
1614
black~=24.3.0
17-
pylint~=3.1.0
15+
pylint~=3.2.1
1816
matplotlib~=3.7.2
1917
pydot~=1.4.2
2018
graphviz~=0.20.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def parse_requirements(lines: List[str]):
4242

4343
setup(
4444
name="TensorFlowASR",
45-
version="2.0.0",
45+
version="2.0.1",
4646
author="Huy Le Nguyen",
4747
author_email="[email protected]",
4848
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = os.environ.get("TF_FORCE_GPU_ALLOW_GROWTH", "true")
88

99
import tensorflow as tf
10+
import keras
1011
from tensorflow.python.util import deprecation # pylint: disable = no-name-in-module
1112

1213
# might cause performance penalty if ops fallback to cpu, see https://cloud.google.com/tpu/docs/tensorflow-ops
@@ -46,9 +47,9 @@ def match_dtype_and_rank(y_t, y_p, sw):
4647

4748

4849
# monkey patch
49-
tf.keras.layers.Layer.output_shape = output_shape
50-
tf.keras.layers.Layer.build = build
51-
tf.keras.layers.Layer.compute_output_shape = compute_output_shape
50+
keras.layers.Layer.output_shape = output_shape
51+
keras.layers.Layer.build = build
52+
keras.layers.Layer.compute_output_shape = compute_output_shape
5253
compile_utils.match_dtype_and_rank = match_dtype_and_rank
5354

5455
import tensorflow_asr.callbacks

0 commit comments

Comments
 (0)