Skip to content

Commit 4d5893f

Browse files
authored
Migrated ctc_asr from keras2 to keras3 (#2262)
* Migrated ctc_asr from keras2 to keras3 * Moddified py and ipynb * Moddified py and ipynb * Addressed all the suggestions as per PR comments and migrated all the supported APIs * Addressed all the suggestions as per PR comments and migrated all the supported APIs * Addressed Comments in PR * Generated ipynb and md files
1 parent de04af1 commit 4d5893f

File tree

4 files changed

+410
-203
lines changed

4 files changed

+410
-203
lines changed

examples/audio/ctc_asr.py

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
Title: Automatic Speech Recognition using CTC
33
Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
44
Date created: 2021/09/26
5-
Last modified: 2021/09/26
5+
Last modified: 2026/01/27
66
Description: Training a CTC-based model for automatic speech recognition.
77
Accelerator: GPU
8+
Converted to Keras 3 by: [Harshith K](https://github.com/kharshith-k/)
89
"""
910

1011
"""
@@ -48,7 +49,6 @@
4849
- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
4950
- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
5051
- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
51-
5252
"""
5353

5454
"""
@@ -58,8 +58,9 @@
5858
import pandas as pd
5959
import numpy as np
6060
import tensorflow as tf
61-
from tensorflow import keras
62-
from tensorflow.keras import layers
61+
import keras
62+
from keras import layers
63+
from keras import ops
6364
import matplotlib.pyplot as plt
6465
from IPython import display
6566
from jiwer import wer
@@ -84,8 +85,8 @@
8485

8586
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
8687
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
87-
wavs_path = data_path + "/wavs/"
88-
metadata_path = data_path + "/metadata.csv"
88+
wavs_path = data_path + "/LJSpeech-1.1/wavs/"
89+
metadata_path = data_path + "/LJSpeech-1.1" + "/metadata.csv"
8990

9091

9192
# Read metadata file and parse it
@@ -95,7 +96,6 @@
9596
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
9697
metadata_df.head(3)
9798

98-
9999
"""
100100
We now split the data into training and validation set.
101101
"""
@@ -107,7 +107,6 @@
107107
print(f"Size of the training set: {len(df_train)}")
108108
print(f"Size of the training set: {len(df_val)}")
109109

110-
111110
"""
112111
## Preprocessing
113112
@@ -150,19 +149,24 @@ def encode_single_sample(wav_file, label):
150149
file = tf.io.read_file(wavs_path + wav_file + ".wav")
151150
# 2. Decode the wav file
152151
audio, _ = tf.audio.decode_wav(file)
153-
audio = tf.squeeze(audio, axis=-1)
152+
audio = ops.squeeze(audio)
154153
# 3. Change type to float
155-
audio = tf.cast(audio, tf.float32)
154+
audio = ops.cast(audio, "float32")
156155
# 4. Get the spectrogram
157-
spectrogram = tf.signal.stft(
158-
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
156+
stft_output = ops.stft(
157+
audio,
158+
sequence_length=frame_length,
159+
sequence_stride=frame_step,
160+
fft_length=fft_length,
161+
center=False,
159162
)
160-
# 5. We only need the magnitude, which can be derived by applying tf.abs
161-
spectrogram = tf.abs(spectrogram)
162-
spectrogram = tf.math.pow(spectrogram, 0.5)
163+
# 5. We only need the magnitude, which can be computed from real and imaginary parts
164+
# stft returns (real, imag) tuple - compute magnitude as sqrt(real^2 + imag^2)
165+
spectrogram = ops.sqrt(ops.square(stft_output[0]) + ops.square(stft_output[1]))
166+
spectrogram = ops.power(spectrogram, 0.5)
163167
# 6. normalisation
164-
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
165-
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
168+
means = ops.mean(spectrogram, axis=1, keepdims=True)
169+
stddevs = ops.std(spectrogram, axis=1, keepdims=True)
166170
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
167171
###########################################
168172
## Process the label
@@ -192,7 +196,7 @@ def encode_single_sample(wav_file, label):
192196
)
193197
train_dataset = (
194198
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
195-
.padded_batch(batch_size)
199+
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
196200
.prefetch(buffer_size=tf.data.AUTOTUNE)
197201
)
198202

@@ -202,11 +206,10 @@ def encode_single_sample(wav_file, label):
202206
)
203207
validation_dataset = (
204208
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
205-
.padded_batch(batch_size)
209+
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
206210
.prefetch(buffer_size=tf.data.AUTOTUNE)
207211
)
208212

209-
210213
"""
211214
## Visualize the data
212215
@@ -245,14 +248,24 @@ def encode_single_sample(wav_file, label):
245248

246249
def CTCLoss(y_true, y_pred):
247250
# Compute the training-time loss value
248-
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
249-
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
250-
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
251-
252-
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
253-
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
254-
255-
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
251+
batch_len = ops.shape(y_true)[0]
252+
input_length = ops.shape(y_pred)[1]
253+
label_length = ops.shape(y_true)[1]
254+
255+
# Create length tensors - CTC needs to know the actual sequence lengths
256+
input_length = input_length * ops.ones(shape=(batch_len,), dtype="int32")
257+
label_length = label_length * ops.ones(shape=(batch_len,), dtype="int32")
258+
259+
# Use Keras ops CTC loss (backend-agnostic)
260+
# Note: mask_index should match the blank token index
261+
# With StringLookup(oov_token=""), index 0 is reserved, so we use 0 as mask
262+
loss = ops.nn.ctc_loss(
263+
target=ops.cast(y_true, "int32"),
264+
output=y_pred,
265+
target_length=label_length,
266+
output_length=input_length,
267+
mask_index=0,
268+
)
256269
return loss
257270

258271

@@ -339,13 +352,31 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
339352
# A utility function to decode the output of the network
340353
def decode_batch_predictions(pred):
341354
input_len = np.ones(pred.shape[0]) * pred.shape[1]
342-
# Use greedy search. For complex tasks, you can use beam search
343-
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
355+
356+
# Use Keras ops CTC decoder with greedy strategy (backend-agnostic)
357+
decoded = ops.nn.ctc_decode(
358+
inputs=pred,
359+
sequence_lengths=ops.cast(input_len, "int32"),
360+
strategy="greedy",
361+
mask_index=0,
362+
)
363+
364+
# ctc_decode returns a tuple of (decoded_sequences, log_probabilities)
365+
# For greedy strategy, decoded_sequences has shape: (1, batch_size, max_length)
366+
# So we need decoded[0][0] to get the batch with shape (batch_size, max_length)
367+
decoded_sequences = decoded[0][0]
368+
369+
# Convert to numpy once for the whole batch
370+
decoded_sequences = ops.convert_to_numpy(decoded_sequences)
371+
344372
# Iterate over the results and get back the text
345373
output_text = []
346-
for result in results:
347-
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
348-
output_text.append(result)
374+
for sequence in decoded_sequences:
375+
# Remove padding/mask values (0 is the mask index)
376+
sequence = sequence[sequence > 0]
377+
# Convert indices to characters
378+
text = tf.strings.reduce_join(num_to_char(sequence)).numpy().decode("utf-8")
379+
output_text.append(text)
349380
return output_text
350381

351382

@@ -360,16 +391,23 @@ def __init__(self, dataset):
360391
def on_epoch_end(self, epoch: int, logs=None):
361392
predictions = []
362393
targets = []
363-
for batch in self.dataset:
394+
# Limit to 10 batches to avoid long evaluation times
395+
for i, batch in enumerate(self.dataset):
396+
if i >= 10:
397+
break
364398
X, y = batch
365-
batch_predictions = model.predict(X)
399+
print(f"Batch {i}: X shape = {X.shape}, y shape = {y.shape}")
400+
batch_predictions = model.predict(X, verbose=0)
401+
print(f"Batch {i}: predictions shape = {batch_predictions.shape}")
366402
batch_predictions = decode_batch_predictions(batch_predictions)
403+
print(f"Batch {i}: decoded {len(batch_predictions)} predictions")
367404
predictions.extend(batch_predictions)
368405
for label in y:
369406
label = (
370407
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
371408
)
372409
targets.append(label)
410+
print(f"\nTotal: {len(predictions)} predictions, {len(targets)} targets")
373411
wer_score = wer(targets, predictions)
374412
print("-" * 100)
375413
print(f"Word Error Rate: {wer_score:.4f}")
@@ -396,7 +434,6 @@ def on_epoch_end(self, epoch: int, logs=None):
396434
callbacks=[validation_callback],
397435
)
398436

399-
400437
"""
401438
## Inference
402439
"""
@@ -421,12 +458,11 @@ def on_epoch_end(self, epoch: int, logs=None):
421458
print(f"Prediction: {predictions[i]}")
422459
print("-" * 100)
423460

424-
425461
"""
426462
## Conclusion
427463
428464
In practice, you should train for around 50 epochs or more. Each epoch
429-
takes approximately 5-6mn using a `GeForce RTX 2080 Ti` GPU.
465+
takes approximately 8-10 minutes using a `Colab A100` GPU.
430466
The model we trained at 50 epochs has a `Word Error Rate (WER) ≈ 16% to 17%`.
431467
432468
Some of the transcriptions around epoch 50:
@@ -458,6 +494,7 @@ def on_epoch_end(self, epoch: int, logs=None):
458494
Example available on HuggingFace.
459495
| Trained Model | Demo |
460496
| :--: | :--: |
461-
| [![Generic badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |
462-
497+
| [![Generic
498+
badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co
499+
/keras-io/ctc_asr)
463500
"""
82.7 KB
Loading

0 commit comments

Comments
 (0)