Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 82 additions & 28 deletions examples/audio/ctc_asr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
"""
Title: Automatic Speech Recognition using CTC
Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
Date created: 2021/09/26
Last modified: 2021/09/26
Description: Training a CTC-based model for automatic speech recognition.
Accelerator: GPU
Title: FILLME
Author: FILLME
Date created: FILLME
Last modified: FILLME
Description: FILLME
"""

"""
# Automatic Speech Recognition using CTC

**Authors:** [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung
Huynh](https://www.linkedin.com/in/parkerhuynh/)<br>
**Date created:** 2021/09/26<br>
**Last modified:** 2026/01/22<br>
**Description:** Training a CTC-based model for automatic speech recognition.
"""

"""
Expand All @@ -25,6 +34,7 @@
how the input aligns with the output (how the characters in the transcript
align to the audio). The model we create is similar to
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).

We will use the LJSpeech dataset from the
[LibriVox](https://librivox.org/) project. It consists of short
Expand All @@ -47,8 +57,9 @@
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)

-
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
"""

"""
Expand All @@ -58,8 +69,9 @@
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers
from keras import ops
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer
Expand All @@ -85,7 +97,7 @@
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
wavs_path = data_path + "/wavs/"
metadata_path = data_path + "/metadata.csv"
metadata_path = data_path + "/LJSpeech-1.1" + "/metadata.csv"


# Read metadata file and parse it
Expand All @@ -95,7 +107,6 @@
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
metadata_df.head(3)


"""
We now split the data into training and validation set.
"""
Expand All @@ -107,7 +118,6 @@
print(f"Size of the training set: {len(df_train)}")
print(f"Size of the training set: {len(df_val)}")


"""
## Preprocessing

Expand Down Expand Up @@ -206,7 +216,6 @@ def encode_single_sample(wav_file, label):
.prefetch(buffer_size=tf.data.AUTOTUNE)
)


"""
## Visualize the data

Expand Down Expand Up @@ -245,20 +254,31 @@ def encode_single_sample(wav_file, label):

def CTCLoss(y_true, y_pred):
# Compute the training-time loss value
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
batch_len = ops.shape(y_true)[0]
input_length = ops.shape(y_pred)[1]
label_length = ops.shape(y_true)[1]

# Create length tensors - CTC needs to know the actual sequence lengths
input_length = input_length * ops.ones(shape=(batch_len,), dtype="int32")
label_length = label_length * ops.ones(shape=(batch_len,), dtype="int32")

# Use TensorFlow's CTC loss (no backend-agnostic equivalent in Keras 3)
# blank_index=-1 means the blank label is the last class (output_dim + 1)
loss = tf.nn.ctc_loss(
labels=ops.cast(y_true, "int32"),
logits=y_pred,
label_length=label_length,
logit_length=input_length,
logits_time_major=False,
blank_index=-1,
)
return loss


"""
We now define our model. We will define a model similar to
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
"""


Expand Down Expand Up @@ -339,11 +359,22 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]

# Use TensorFlow's ctc_greedy_decoder (no backend-agnostic equivalent in Keras 3)
# Note: TF's decoder expects time-major format [time, batch, dim]
results, _ = tf.nn.ctc_greedy_decoder(
inputs=ops.transpose(pred, axes=[1, 0, 2]),
sequence_length=ops.cast(input_len, "int32"),
)

# ctc_greedy_decoder returns a list of SparseTensor, take the first one
results = tf.sparse.to_dense(results[0], default_value=-1)

# Iterate over the results and get back the text
output_text = []
for result in results:
# Remove padding values (-1) - using TensorFlow for boolean indexing
result = tf.boolean_mask(result, result >= 0)
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
output_text.append(result)
return output_text
Expand Down Expand Up @@ -384,6 +415,28 @@ def on_epoch_end(self, epoch: int, logs=None):
Let's start the training process.
"""

# Fix the wavs_path which was missing the subdirectory
wavs_path = data_path + "/LJSpeech-1.1/wavs/"

# Re-create the datasets with the correct path
train_dataset = tf.data.Dataset.from_tensor_slices(
(list(df_train["file_name"]), list(df_train["normalized_transcription"]))
)
train_dataset = (
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.padded_batch(batch_size)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)

validation_dataset = tf.data.Dataset.from_tensor_slices(
(list(df_val["file_name"]), list(df_val["normalized_transcription"]))
)
validation_dataset = (
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
.padded_batch(batch_size)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)

# Define the number of epochs.
epochs = 1
# Callback function to check transcription on the val set.
Expand All @@ -396,7 +449,6 @@ def on_epoch_end(self, epoch: int, logs=None):
callbacks=[validation_callback],
)


"""
## Inference
"""
Expand All @@ -421,7 +473,6 @@ def on_epoch_end(self, epoch: int, logs=None):
print(f"Prediction: {predictions[i]}")
print("-" * 100)


"""
## Conclusion

Expand Down Expand Up @@ -458,6 +509,9 @@ def on_epoch_end(self, epoch: int, logs=None):
Example available on HuggingFace.
| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |

| [![Generic
badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co
/keras-io/ctc_asr) | [![Generic
badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.c
o/spaces/keras-io/ctc_asr) |
"""
80 changes: 62 additions & 18 deletions examples/audio/ipynb/ctc_asr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Authors:** [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)<br>\n",
"**Date created:** 2021/09/26<br>\n",
"**Last modified:** 2021/09/26<br>\n",
"**Last modified:** 2026/01/22<br>\n",
"**Description:** Training a CTC-based model for automatic speech recognition."
]
},
Expand Down Expand Up @@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -82,8 +82,9 @@
"import pandas as pd\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"import matplotlib.pyplot as plt\n",
"from IPython import display\n",
"from jiwer import wer"
Expand Down Expand Up @@ -114,7 +115,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -123,7 +124,7 @@
"data_url = \"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2\"\n",
"data_path = keras.utils.get_file(\"LJSpeech-1.1\", data_url, untar=True)\n",
"wavs_path = data_path + \"/wavs/\"\n",
"metadata_path = data_path + \"/metadata.csv\"\n",
"metadata_path = data_path + \"/LJSpeech-1.1\" + \"/metadata.csv\"\n",
"\n",
"\n",
"# Read metadata file and parse it\n",
Expand Down Expand Up @@ -362,14 +363,24 @@
"source": [
"def CTCLoss(y_true, y_pred):\n",
" # Compute the training-time loss value\n",
" batch_len = tf.cast(tf.shape(y_true)[0], dtype=\"int64\")\n",
" input_length = tf.cast(tf.shape(y_pred)[1], dtype=\"int64\")\n",
" label_length = tf.cast(tf.shape(y_true)[1], dtype=\"int64\")\n",
"\n",
" input_length = input_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
" label_length = label_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n",
"\n",
" loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)\n",
" batch_len = ops.shape(y_true)[0]\n",
" input_length = ops.shape(y_pred)[1]\n",
" label_length = ops.shape(y_true)[1]\n",
"\n",
" # Create length tensors - CTC needs to know the actual sequence lengths\n",
" input_length = input_length * ops.ones(shape=(batch_len,), dtype=\"int32\")\n",
" label_length = label_length * ops.ones(shape=(batch_len,), dtype=\"int32\")\n",
"\n",
" # Use TensorFlow's CTC loss (no backend-agnostic equivalent in Keras 3)\n",
" # blank_index=-1 means the blank label is the last class (output_dim + 1)\n",
" loss = tf.nn.ctc_loss(\n",
" labels=ops.cast(y_true, \"int32\"),\n",
" logits=y_pred,\n",
" label_length=label_length,\n",
" logit_length=input_length,\n",
" logits_time_major=False,\n",
" blank_index=-1\n",
" )\n",
" return loss"
]
},
Expand Down Expand Up @@ -481,11 +492,22 @@
"# A utility function to decode the output of the network\n",
"def decode_batch_predictions(pred):\n",
" input_len = np.ones(pred.shape[0]) * pred.shape[1]\n",
" # Use greedy search. For complex tasks, you can use beam search\n",
" results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]\n",
" \n",
" # Use TensorFlow's ctc_greedy_decoder (no backend-agnostic equivalent in Keras 3)\n",
" # Note: TF's decoder expects time-major format [time, batch, dim]\n",
" results, _ = tf.nn.ctc_greedy_decoder(\n",
" inputs=ops.transpose(pred, axes=[1, 0, 2]),\n",
" sequence_length=ops.cast(input_len, \"int32\")\n",
" )\n",
" \n",
" # ctc_greedy_decoder returns a list of SparseTensor, take the first one\n",
" results = tf.sparse.to_dense(results[0], default_value=-1)\n",
" \n",
" # Iterate over the results and get back the text\n",
" output_text = []\n",
" for result in results:\n",
" # Remove padding values (-1) - using TensorFlow for boolean indexing\n",
" result = tf.boolean_mask(result, result >= 0)\n",
" result = tf.strings.reduce_join(num_to_char(result)).numpy().decode(\"utf-8\")\n",
" output_text.append(result)\n",
" return output_text\n",
Expand Down Expand Up @@ -539,6 +561,28 @@
},
"outputs": [],
"source": [
"# Fix the wavs_path which was missing the subdirectory\n",
"wavs_path = data_path + \"/LJSpeech-1.1/wavs/\"\n",
"\n",
"# Re-create the datasets with the correct path\n",
"train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" (list(df_train[\"file_name\"]), list(df_train[\"normalized_transcription\"]))\n",
")\n",
"train_dataset = (\n",
" train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)\n",
" .padded_batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")\n",
"\n",
"validation_dataset = tf.data.Dataset.from_tensor_slices(\n",
" (list(df_val[\"file_name\"]), list(df_val[\"normalized_transcription\"]))\n",
")\n",
"validation_dataset = (\n",
" validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)\n",
" .padded_batch(batch_size)\n",
" .prefetch(buffer_size=tf.data.AUTOTUNE)\n",
")\n",
"\n",
"# Define the number of epochs.\n",
"epochs = 1\n",
"# Callback function to check transcription on the val set.\n",
Expand Down Expand Up @@ -645,7 +689,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "keras_io_311_env (3.11.3)",
"language": "python",
"name": "python3"
},
Expand All @@ -659,7 +703,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.11.3"
}
},
"nbformat": 4,
Expand Down