Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 78 additions & 40 deletions examples/audio/ctc_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Automatic Speech Recognition using CTC
Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
Date created: 2021/09/26
Last modified: 2021/09/26
Last modified: 2026/01/27
Description: Training a CTC-based model for automatic speech recognition.
Accelerator: GPU
Migrated By: [Harshith K](https://github.com/kharshith-k/)
"""

"""
Expand Down Expand Up @@ -47,8 +48,8 @@
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)

-
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
"""

"""
Expand All @@ -58,8 +59,9 @@
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers
from keras import ops
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer
Expand All @@ -84,8 +86,8 @@

data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
wavs_path = data_path + "/wavs/"
metadata_path = data_path + "/metadata.csv"
wavs_path = data_path + "/LJSpeech-1.1/wavs/"
metadata_path = data_path + "/LJSpeech-1.1" + "/metadata.csv"


# Read metadata file and parse it
Expand All @@ -95,7 +97,6 @@
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
metadata_df.head(3)


"""
We now split the data into training and validation set.
"""
Expand All @@ -107,7 +108,6 @@
print(f"Size of the training set: {len(df_train)}")
print(f"Size of the training set: {len(df_val)}")


"""
## Preprocessing

Expand Down Expand Up @@ -150,19 +150,24 @@ def encode_single_sample(wav_file, label):
file = tf.io.read_file(wavs_path + wav_file + ".wav")
# 2. Decode the wav file
audio, _ = tf.audio.decode_wav(file)
audio = tf.squeeze(audio, axis=-1)
audio = ops.squeeze(audio)
# 3. Change type to float
audio = tf.cast(audio, tf.float32)
audio = ops.cast(audio, "float32")
# 4. Get the spectrogram
spectrogram = tf.signal.stft(
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
stft_output = ops.stft(
audio,
sequence_length=frame_length,
sequence_stride=frame_step,
fft_length=fft_length,
center=False,
)
# 5. We only need the magnitude, which can be derived by applying tf.abs
spectrogram = tf.abs(spectrogram)
spectrogram = tf.math.pow(spectrogram, 0.5)
# 5. We only need the magnitude, which can be computed from real and imaginary parts
# stft returns (real, imag) tuple - compute magnitude as sqrt(real^2 + imag^2)
spectrogram = ops.sqrt(ops.square(stft_output[0]) + ops.square(stft_output[1]))
spectrogram = ops.power(spectrogram, 0.5)
# 6. normalisation
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
means = ops.mean(spectrogram, axis=1, keepdims=True)
stddevs = ops.std(spectrogram, axis=1, keepdims=True)
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
###########################################
## Process the label
Expand Down Expand Up @@ -192,7 +197,7 @@ def encode_single_sample(wav_file, label):
)
train_dataset = (
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.padded_batch(batch_size)
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
.prefetch(buffer_size=tf.data.AUTOTUNE)
)

Expand All @@ -202,11 +207,10 @@ def encode_single_sample(wav_file, label):
)
validation_dataset = (
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.padded_batch(batch_size)
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
.prefetch(buffer_size=tf.data.AUTOTUNE)
)


"""
## Visualize the data

Expand Down Expand Up @@ -245,14 +249,24 @@ def encode_single_sample(wav_file, label):

def CTCLoss(y_true, y_pred):
# Compute the training-time loss value
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
batch_len = ops.shape(y_true)[0]
input_length = ops.shape(y_pred)[1]
label_length = ops.shape(y_true)[1]

# Create length tensors - CTC needs to know the actual sequence lengths
input_length = input_length * ops.ones(shape=(batch_len,), dtype="int32")
label_length = label_length * ops.ones(shape=(batch_len,), dtype="int32")

# Use Keras ops CTC loss (backend-agnostic)
# Note: mask_index should match the blank token index
# With StringLookup(oov_token=""), index 0 is reserved, so we use 0 as mask
loss = ops.nn.ctc_loss(
target=ops.cast(y_true, "int32"),
output=y_pred,
target_length=label_length,
output_length=input_length,
mask_index=0,
)
return loss


Expand Down Expand Up @@ -339,13 +353,31 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]

# Use Keras ops CTC decoder with greedy strategy (backend-agnostic)
decoded = ops.nn.ctc_decode(
inputs=pred,
sequence_lengths=ops.cast(input_len, "int32"),
strategy="greedy",
mask_index=0,
)

# ctc_decode returns a tuple of (decoded_sequences, log_probabilities)
# For greedy strategy, decoded_sequences has shape: (1, batch_size, max_length)
# So we need decoded[0][0] to get the batch with shape (batch_size, max_length)
decoded_sequences = decoded[0][0]

# Convert to numpy once for the whole batch
decoded_sequences = ops.convert_to_numpy(decoded_sequences)

# Iterate over the results and get back the text
output_text = []
for result in results:
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
output_text.append(result)
for sequence in decoded_sequences:
# Remove padding/mask values (0 is the mask index)
sequence = sequence[sequence > 0]
# Convert indices to characters
text = tf.strings.reduce_join(num_to_char(sequence)).numpy().decode("utf-8")
output_text.append(text)
return output_text


Expand All @@ -360,16 +392,23 @@ def __init__(self, dataset):
def on_epoch_end(self, epoch: int, logs=None):
predictions = []
targets = []
for batch in self.dataset:
# Limit to 10 batches to avoid long evaluation times
for i, batch in enumerate(self.dataset):
if i >= 10:
break
X, y = batch
batch_predictions = model.predict(X)
print(f"Batch {i}: X shape = {X.shape}, y shape = {y.shape}")
batch_predictions = model.predict(X, verbose=0)
print(f"Batch {i}: predictions shape = {batch_predictions.shape}")
batch_predictions = decode_batch_predictions(batch_predictions)
print(f"Batch {i}: decoded {len(batch_predictions)} predictions")
predictions.extend(batch_predictions)
for label in y:
label = (
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
)
targets.append(label)
print(f"\nTotal: {len(predictions)} predictions, {len(targets)} targets")
wer_score = wer(targets, predictions)
print("-" * 100)
print(f"Word Error Rate: {wer_score:.4f}")
Expand All @@ -396,7 +435,6 @@ def on_epoch_end(self, epoch: int, logs=None):
callbacks=[validation_callback],
)


"""
## Inference
"""
Expand All @@ -421,7 +459,6 @@ def on_epoch_end(self, epoch: int, logs=None):
print(f"Prediction: {predictions[i]}")
print("-" * 100)


"""
## Conclusion

Expand Down Expand Up @@ -458,6 +495,7 @@ def on_epoch_end(self, epoch: int, logs=None):
Example available on HuggingFace.
| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |

| [![Generic
badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co
/keras-io/ctc_asr)
"""
Loading