Skip to content

Commit 2aca792

Browse files
authored
Update Handwritten Recognition example to keras version3 (#1916)
* Update Handwritten Recognition example to keras version3 * Update with .py file on keras3 changes * Updated with .py file on keras3 changes * Update .py file with reformatting changes * .py file with reformatting changes * Replace modified handwritten_recognition.ipynb and handwritten_recognition.py * handwriting_recognition.py reformatted * Reformatted .py file * Updated py and ipynb file changes * Update .md file
1 parent 695a0b5 commit 2aca792

File tree

5 files changed

+221
-163
lines changed

5 files changed

+221
-163
lines changed

examples/vision/handwriting_recognition.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Handwriting recognition
33
Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
44
Date created: 2021/08/16
5-
Last modified: 2023/07/06
5+
Last modified: 2024/09/01
66
Description: Training a handwriting recognition model with variable-length sequences.
77
Accelerator: GPU
88
"""
@@ -45,16 +45,16 @@
4545
## Imports
4646
"""
4747

48-
from tensorflow.keras.layers import StringLookup
49-
from tensorflow import keras
50-
48+
import keras
49+
from keras.layers import StringLookup
50+
from keras import ops
5151
import matplotlib.pyplot as plt
5252
import tensorflow as tf
5353
import numpy as np
5454
import os
5555

5656
np.random.seed(42)
57-
tf.random.set_seed(42)
57+
keras.utils.set_random_seed(42)
5858

5959
"""
6060
## Dataset splitting
@@ -213,8 +213,8 @@ def distortion_free_resize(image, img_size):
213213
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
214214

215215
# Check tha amount of padding needed to be done.
216-
pad_height = h - tf.shape(image)[0]
217-
pad_width = w - tf.shape(image)[1]
216+
pad_height = h - ops.shape(image)[0]
217+
pad_width = w - ops.shape(image)[1]
218218

219219
# Only necessary if you want to do same amount of padding on both sides.
220220
if pad_height % 2 != 0:
@@ -240,7 +240,7 @@ def distortion_free_resize(image, img_size):
240240
],
241241
)
242242

243-
image = tf.transpose(image, perm=[1, 0, 2])
243+
image = ops.transpose(image, (1, 0, 2))
244244
image = tf.image.flip_left_right(image)
245245
return image
246246

@@ -267,13 +267,13 @@ def preprocess_image(image_path, img_size=(image_width, image_height)):
267267
image = tf.io.read_file(image_path)
268268
image = tf.image.decode_png(image, 1)
269269
image = distortion_free_resize(image, img_size)
270-
image = tf.cast(image, tf.float32) / 255.0
270+
image = ops.cast(image, tf.float32) / 255.0
271271
return image
272272

273273

274274
def vectorize_label(label):
275275
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
276-
length = tf.shape(label)[0]
276+
length = ops.shape(label)[0]
277277
pad_amount = max_len - length
278278
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
279279
return label
@@ -312,7 +312,7 @@ def prepare_dataset(image_paths, labels):
312312
for i in range(16):
313313
img = images[i]
314314
img = tf.image.flip_left_right(img)
315-
img = tf.transpose(img, perm=[1, 0, 2])
315+
img = ops.transpose(img, (1, 0, 2))
316316
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
317317
img = img[:, :, 0]
318318

@@ -346,15 +346,15 @@ def prepare_dataset(image_paths, labels):
346346
class CTCLayer(keras.layers.Layer):
347347
def __init__(self, name=None):
348348
super().__init__(name=name)
349-
self.loss_fn = keras.backend.ctc_batch_cost
349+
self.loss_fn = tf.keras.backend.ctc_batch_cost
350350

351351
def call(self, y_true, y_pred):
352-
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
353-
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
354-
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
352+
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
353+
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
354+
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
355355

356-
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
357-
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
356+
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
357+
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
358358
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
359359
self.add_loss(loss)
360360

@@ -455,14 +455,14 @@ def build_model():
455455

456456
def calculate_edit_distance(labels, predictions):
457457
# Get a single batch and convert its labels to sparse tensors.
458-
saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
458+
saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
459459

460460
# Make predictions and convert them to sparse tensors.
461461
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
462-
predictions_decoded = keras.backend.ctc_decode(
463-
predictions, input_length=input_len, greedy=True
462+
predictions_decoded = keras.ops.nn.ctc_decode(
463+
predictions, sequence_lengths=input_len
464464
)[0][0][:, :max_len]
465-
sparse_predictions = tf.cast(
465+
sparse_predictions = ops.cast(
466466
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
467467
)
468468

@@ -501,7 +501,7 @@ def on_epoch_end(self, epoch, logs=None):
501501

502502
model = build_model()
503503
prediction_model = keras.models.Model(
504-
model.get_layer(name="image").input, model.get_layer(name="dense2").output
504+
model.get_layer(name="image").output, model.get_layer(name="dense2").output
505505
)
506506
edit_distance_callback = EditDistanceCallback(prediction_model)
507507

@@ -523,14 +523,19 @@ def on_epoch_end(self, epoch, logs=None):
523523
def decode_batch_predictions(pred):
524524
input_len = np.ones(pred.shape[0]) * pred.shape[1]
525525
# Use greedy search. For complex tasks, you can use beam search.
526-
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
526+
results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][
527527
:, :max_len
528528
]
529529
# Iterate over the results and get back the text.
530530
output_text = []
531531
for res in results:
532532
res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
533-
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
533+
res = (
534+
tf.strings.reduce_join(num_to_char(res))
535+
.numpy()
536+
.decode("utf-8")
537+
.replace("[UNK]", "")
538+
)
534539
output_text.append(res)
535540
return output_text
536541

@@ -546,7 +551,7 @@ def decode_batch_predictions(pred):
546551
for i in range(16):
547552
img = batch_images[i]
548553
img = tf.image.flip_left_right(img)
549-
img = tf.transpose(img, perm=[1, 0, 2])
554+
img = ops.transpose(img, (1, 0, 2))
550555
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
551556
img = img[:, :, 0]
552557

45.3 KB
Loading
132 KB
Loading

0 commit comments

Comments
 (0)