22Title: Handwriting recognition
33Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
44Date created: 2021/08/16
5- Last modified: 2023/07/06
5+ Last modified: 2024/09/01
66Description: Training a handwriting recognition model with variable-length sequences.
77Accelerator: GPU
88"""
4545## Imports
4646"""
4747
48- from tensorflow . keras . layers import StringLookup
49- from tensorflow import keras
50-
48+ import keras
49+ from keras . layers import StringLookup
50+ from keras import ops
5151import matplotlib .pyplot as plt
5252import tensorflow as tf
5353import numpy as np
5454import os
5555
5656np .random .seed (42 )
57- tf . random . set_seed (42 )
57+ keras . utils . set_random_seed (42 )
5858
5959"""
6060## Dataset splitting
@@ -213,8 +213,8 @@ def distortion_free_resize(image, img_size):
213213 image = tf .image .resize (image , size = (h , w ), preserve_aspect_ratio = True )
214214
215215 # Check tha amount of padding needed to be done.
216- pad_height = h - tf .shape (image )[0 ]
217- pad_width = w - tf .shape (image )[1 ]
216+ pad_height = h - ops .shape (image )[0 ]
217+ pad_width = w - ops .shape (image )[1 ]
218218
219219 # Only necessary if you want to do same amount of padding on both sides.
220220 if pad_height % 2 != 0 :
@@ -240,7 +240,7 @@ def distortion_free_resize(image, img_size):
240240 ],
241241 )
242242
243- image = tf .transpose (image , perm = [ 1 , 0 , 2 ] )
243+ image = ops .transpose (image , ( 1 , 0 , 2 ) )
244244 image = tf .image .flip_left_right (image )
245245 return image
246246
@@ -267,13 +267,13 @@ def preprocess_image(image_path, img_size=(image_width, image_height)):
267267 image = tf .io .read_file (image_path )
268268 image = tf .image .decode_png (image , 1 )
269269 image = distortion_free_resize (image , img_size )
270- image = tf .cast (image , tf .float32 ) / 255.0
270+ image = ops .cast (image , tf .float32 ) / 255.0
271271 return image
272272
273273
274274def vectorize_label (label ):
275275 label = char_to_num (tf .strings .unicode_split (label , input_encoding = "UTF-8" ))
276- length = tf .shape (label )[0 ]
276+ length = ops .shape (label )[0 ]
277277 pad_amount = max_len - length
278278 label = tf .pad (label , paddings = [[0 , pad_amount ]], constant_values = padding_token )
279279 return label
@@ -312,7 +312,7 @@ def prepare_dataset(image_paths, labels):
312312 for i in range (16 ):
313313 img = images [i ]
314314 img = tf .image .flip_left_right (img )
315- img = tf .transpose (img , perm = [ 1 , 0 , 2 ] )
315+ img = ops .transpose (img , ( 1 , 0 , 2 ) )
316316 img = (img * 255.0 ).numpy ().clip (0 , 255 ).astype (np .uint8 )
317317 img = img [:, :, 0 ]
318318
@@ -346,15 +346,15 @@ def prepare_dataset(image_paths, labels):
346346class CTCLayer (keras .layers .Layer ):
347347 def __init__ (self , name = None ):
348348 super ().__init__ (name = name )
349- self .loss_fn = keras .backend .ctc_batch_cost
349+ self .loss_fn = tf . keras .backend .ctc_batch_cost
350350
351351 def call (self , y_true , y_pred ):
352- batch_len = tf .cast (tf .shape (y_true )[0 ], dtype = "int64" )
353- input_length = tf .cast (tf .shape (y_pred )[1 ], dtype = "int64" )
354- label_length = tf .cast (tf .shape (y_true )[1 ], dtype = "int64" )
352+ batch_len = ops .cast (ops .shape (y_true )[0 ], dtype = "int64" )
353+ input_length = ops .cast (ops .shape (y_pred )[1 ], dtype = "int64" )
354+ label_length = ops .cast (ops .shape (y_true )[1 ], dtype = "int64" )
355355
356- input_length = input_length * tf .ones (shape = (batch_len , 1 ), dtype = "int64" )
357- label_length = label_length * tf .ones (shape = (batch_len , 1 ), dtype = "int64" )
356+ input_length = input_length * ops .ones (shape = (batch_len , 1 ), dtype = "int64" )
357+ label_length = label_length * ops .ones (shape = (batch_len , 1 ), dtype = "int64" )
358358 loss = self .loss_fn (y_true , y_pred , input_length , label_length )
359359 self .add_loss (loss )
360360
@@ -455,14 +455,14 @@ def build_model():
455455
456456def calculate_edit_distance (labels , predictions ):
457457 # Get a single batch and convert its labels to sparse tensors.
458- saprse_labels = tf .cast (tf .sparse .from_dense (labels ), dtype = tf .int64 )
458+ saprse_labels = ops .cast (tf .sparse .from_dense (labels ), dtype = tf .int64 )
459459
460460 # Make predictions and convert them to sparse tensors.
461461 input_len = np .ones (predictions .shape [0 ]) * predictions .shape [1 ]
462- predictions_decoded = keras .backend .ctc_decode (
463- predictions , input_length = input_len , greedy = True
462+ predictions_decoded = keras .ops . nn .ctc_decode (
463+ predictions , sequence_lengths = input_len
464464 )[0 ][0 ][:, :max_len ]
465- sparse_predictions = tf .cast (
465+ sparse_predictions = ops .cast (
466466 tf .sparse .from_dense (predictions_decoded ), dtype = tf .int64
467467 )
468468
@@ -501,7 +501,7 @@ def on_epoch_end(self, epoch, logs=None):
501501
502502model = build_model ()
503503prediction_model = keras .models .Model (
504- model .get_layer (name = "image" ).input , model .get_layer (name = "dense2" ).output
504+ model .get_layer (name = "image" ).output , model .get_layer (name = "dense2" ).output
505505)
506506edit_distance_callback = EditDistanceCallback (prediction_model )
507507
@@ -523,14 +523,19 @@ def on_epoch_end(self, epoch, logs=None):
523523def decode_batch_predictions (pred ):
524524 input_len = np .ones (pred .shape [0 ]) * pred .shape [1 ]
525525 # Use greedy search. For complex tasks, you can use beam search.
526- results = keras .backend . ctc_decode (pred , input_length = input_len , greedy = True )[0 ][0 ][
526+ results = keras .ops . nn . ctc_decode (pred , sequence_lengths = input_len )[0 ][0 ][
527527 :, :max_len
528528 ]
529529 # Iterate over the results and get back the text.
530530 output_text = []
531531 for res in results :
532532 res = tf .gather (res , tf .where (tf .math .not_equal (res , - 1 )))
533- res = tf .strings .reduce_join (num_to_char (res )).numpy ().decode ("utf-8" )
533+ res = (
534+ tf .strings .reduce_join (num_to_char (res ))
535+ .numpy ()
536+ .decode ("utf-8" )
537+ .replace ("[UNK]" , "" )
538+ )
534539 output_text .append (res )
535540 return output_text
536541
@@ -546,7 +551,7 @@ def decode_batch_predictions(pred):
546551 for i in range (16 ):
547552 img = batch_images [i ]
548553 img = tf .image .flip_left_right (img )
549- img = tf .transpose (img , perm = [ 1 , 0 , 2 ] )
554+ img = ops .transpose (img , ( 1 , 0 , 2 ) )
550555 img = (img * 255.0 ).numpy ().clip (0 , 255 ).astype (np .uint8 )
551556 img = img [:, :, 0 ]
552557
0 commit comments