22Title: Automatic Speech Recognition using CTC
33Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
44Date created: 2021/09/26
5- Last modified: 2021/09/26
5+ Last modified: 2026/01/27
66Description: Training a CTC-based model for automatic speech recognition.
77Accelerator: GPU
8+ Converted to Keras 3 by: [Harshith K](https://github.com/kharshith-k/)
89"""
910
1011"""
4849- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
4950- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
5051- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
51-
5252"""
5353
5454"""
5858import pandas as pd
5959import numpy as np
6060import tensorflow as tf
61- from tensorflow import keras
62- from tensorflow .keras import layers
61+ import keras
62+ from keras import layers
63+ from keras import ops
6364import matplotlib .pyplot as plt
6465from IPython import display
6566from jiwer import wer
8485
8586data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
8687data_path = keras .utils .get_file ("LJSpeech-1.1" , data_url , untar = True )
87- wavs_path = data_path + "/wavs/"
88- metadata_path = data_path + "/metadata.csv"
88+ wavs_path = data_path + "/LJSpeech-1.1/ wavs/"
89+ metadata_path = data_path + "/LJSpeech-1.1" + "/ metadata.csv"
8990
9091
9192# Read metadata file and parse it
9596metadata_df = metadata_df .sample (frac = 1 ).reset_index (drop = True )
9697metadata_df .head (3 )
9798
98-
9999"""
100100We now split the data into training and validation set.
101101"""
107107print (f"Size of the training set: { len (df_train )} " )
108108print (f"Size of the training set: { len (df_val )} " )
109109
110-
111110"""
112111## Preprocessing
113112
@@ -150,19 +149,24 @@ def encode_single_sample(wav_file, label):
150149 file = tf .io .read_file (wavs_path + wav_file + ".wav" )
151150 # 2. Decode the wav file
152151 audio , _ = tf .audio .decode_wav (file )
153- audio = tf .squeeze (audio , axis = - 1 )
152+ audio = ops .squeeze (audio )
154153 # 3. Change type to float
155- audio = tf .cast (audio , tf . float32 )
154+ audio = ops .cast (audio , " float32" )
156155 # 4. Get the spectrogram
157- spectrogram = tf .signal .stft (
158- audio , frame_length = frame_length , frame_step = frame_step , fft_length = fft_length
156+ stft_output = ops .stft (
157+ audio ,
158+ sequence_length = frame_length ,
159+ sequence_stride = frame_step ,
160+ fft_length = fft_length ,
161+ center = False ,
159162 )
160- # 5. We only need the magnitude, which can be derived by applying tf.abs
161- spectrogram = tf .abs (spectrogram )
162- spectrogram = tf .math .pow (spectrogram , 0.5 )
163+ # 5. We only need the magnitude, which can be computed from real and imaginary parts
164+ # stft returns (real, imag) tuple - compute magnitude as sqrt(real^2 + imag^2)
165+ spectrogram = ops .sqrt (ops .square (stft_output [0 ]) + ops .square (stft_output [1 ]))
166+ spectrogram = ops .power (spectrogram , 0.5 )
163167 # 6. normalisation
164- means = tf . math . reduce_mean (spectrogram , 1 , keepdims = True )
165- stddevs = tf . math . reduce_std (spectrogram , 1 , keepdims = True )
168+ means = ops . mean (spectrogram , axis = 1 , keepdims = True )
169+ stddevs = ops . std (spectrogram , axis = 1 , keepdims = True )
166170 spectrogram = (spectrogram - means ) / (stddevs + 1e-10 )
167171 ###########################################
168172 ## Process the label
@@ -192,7 +196,7 @@ def encode_single_sample(wav_file, label):
192196)
193197train_dataset = (
194198 train_dataset .map (encode_single_sample , num_parallel_calls = tf .data .AUTOTUNE )
195- .padded_batch (batch_size )
199+ .padded_batch (batch_size , padded_shapes = ([ None , fft_length // 2 + 1 ], [ None ]) )
196200 .prefetch (buffer_size = tf .data .AUTOTUNE )
197201)
198202
@@ -202,11 +206,10 @@ def encode_single_sample(wav_file, label):
202206)
203207validation_dataset = (
204208 validation_dataset .map (encode_single_sample , num_parallel_calls = tf .data .AUTOTUNE )
205- .padded_batch (batch_size )
209+ .padded_batch (batch_size , padded_shapes = ([ None , fft_length // 2 + 1 ], [ None ]) )
206210 .prefetch (buffer_size = tf .data .AUTOTUNE )
207211)
208212
209-
210213"""
211214## Visualize the data
212215
@@ -245,14 +248,24 @@ def encode_single_sample(wav_file, label):
245248
246249def CTCLoss (y_true , y_pred ):
247250 # Compute the training-time loss value
248- batch_len = tf .cast (tf .shape (y_true )[0 ], dtype = "int64" )
249- input_length = tf .cast (tf .shape (y_pred )[1 ], dtype = "int64" )
250- label_length = tf .cast (tf .shape (y_true )[1 ], dtype = "int64" )
251-
252- input_length = input_length * tf .ones (shape = (batch_len , 1 ), dtype = "int64" )
253- label_length = label_length * tf .ones (shape = (batch_len , 1 ), dtype = "int64" )
254-
255- loss = keras .backend .ctc_batch_cost (y_true , y_pred , input_length , label_length )
251+ batch_len = ops .shape (y_true )[0 ]
252+ input_length = ops .shape (y_pred )[1 ]
253+ label_length = ops .shape (y_true )[1 ]
254+
255+ # Create length tensors - CTC needs to know the actual sequence lengths
256+ input_length = input_length * ops .ones (shape = (batch_len ,), dtype = "int32" )
257+ label_length = label_length * ops .ones (shape = (batch_len ,), dtype = "int32" )
258+
259+ # Use Keras ops CTC loss (backend-agnostic)
260+ # Note: mask_index should match the blank token index
261+ # With StringLookup(oov_token=""), index 0 is reserved, so we use 0 as mask
262+ loss = ops .nn .ctc_loss (
263+ target = ops .cast (y_true , "int32" ),
264+ output = y_pred ,
265+ target_length = label_length ,
266+ output_length = input_length ,
267+ mask_index = 0 ,
268+ )
256269 return loss
257270
258271
@@ -339,13 +352,31 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
339352# A utility function to decode the output of the network
340353def decode_batch_predictions (pred ):
341354 input_len = np .ones (pred .shape [0 ]) * pred .shape [1 ]
342- # Use greedy search. For complex tasks, you can use beam search
343- results = keras .backend .ctc_decode (pred , input_length = input_len , greedy = True )[0 ][0 ]
355+
356+ # Use Keras ops CTC decoder with greedy strategy (backend-agnostic)
357+ decoded = ops .nn .ctc_decode (
358+ inputs = pred ,
359+ sequence_lengths = ops .cast (input_len , "int32" ),
360+ strategy = "greedy" ,
361+ mask_index = 0 ,
362+ )
363+
364+ # ctc_decode returns a tuple of (decoded_sequences, log_probabilities)
365+ # For greedy strategy, decoded_sequences has shape: (1, batch_size, max_length)
366+ # So we need decoded[0][0] to get the batch with shape (batch_size, max_length)
367+ decoded_sequences = decoded [0 ][0 ]
368+
369+ # Convert to numpy once for the whole batch
370+ decoded_sequences = ops .convert_to_numpy (decoded_sequences )
371+
344372 # Iterate over the results and get back the text
345373 output_text = []
346- for result in results :
347- result = tf .strings .reduce_join (num_to_char (result )).numpy ().decode ("utf-8" )
348- output_text .append (result )
374+ for sequence in decoded_sequences :
375+ # Remove padding/mask values (0 is the mask index)
376+ sequence = sequence [sequence > 0 ]
377+ # Convert indices to characters
378+ text = tf .strings .reduce_join (num_to_char (sequence )).numpy ().decode ("utf-8" )
379+ output_text .append (text )
349380 return output_text
350381
351382
@@ -360,16 +391,23 @@ def __init__(self, dataset):
360391 def on_epoch_end (self , epoch : int , logs = None ):
361392 predictions = []
362393 targets = []
363- for batch in self .dataset :
394+ # Limit to 10 batches to avoid long evaluation times
395+ for i , batch in enumerate (self .dataset ):
396+ if i >= 10 :
397+ break
364398 X , y = batch
365- batch_predictions = model .predict (X )
399+ print (f"Batch { i } : X shape = { X .shape } , y shape = { y .shape } " )
400+ batch_predictions = model .predict (X , verbose = 0 )
401+ print (f"Batch { i } : predictions shape = { batch_predictions .shape } " )
366402 batch_predictions = decode_batch_predictions (batch_predictions )
403+ print (f"Batch { i } : decoded { len (batch_predictions )} predictions" )
367404 predictions .extend (batch_predictions )
368405 for label in y :
369406 label = (
370407 tf .strings .reduce_join (num_to_char (label )).numpy ().decode ("utf-8" )
371408 )
372409 targets .append (label )
410+ print (f"\n Total: { len (predictions )} predictions, { len (targets )} targets" )
373411 wer_score = wer (targets , predictions )
374412 print ("-" * 100 )
375413 print (f"Word Error Rate: { wer_score :.4f} " )
@@ -396,7 +434,6 @@ def on_epoch_end(self, epoch: int, logs=None):
396434 callbacks = [validation_callback ],
397435)
398436
399-
400437"""
401438## Inference
402439"""
@@ -421,12 +458,11 @@ def on_epoch_end(self, epoch: int, logs=None):
421458 print (f"Prediction: { predictions [i ]} " )
422459 print ("-" * 100 )
423460
424-
425461"""
426462## Conclusion
427463
428464In practice, you should train for around 50 epochs or more. Each epoch
429- takes approximately 5-6mn using a `GeForce RTX 2080 Ti ` GPU.
465+ takes approximately 8-10 minutes using a `Colab A100 ` GPU.
430466The model we trained at 50 epochs has a `Word Error Rate (WER) ≈ 16% to 17%`.
431467
432468Some of the transcriptions around epoch 50:
@@ -458,6 +494,7 @@ def on_epoch_end(self, epoch: int, logs=None):
458494Example available on HuggingFace.
459495| Trained Model | Demo |
460496| :--: | :--: |
461- | [](https://huggingface.co/keras-io/ctc_asr) | [](https://huggingface.co/spaces/keras-io/ctc_asr) |
462-
497+ | [](https://huggingface.co
499+ /keras-io/ctc_asr)
463500"""
0 commit comments