Skip to content

Commit b53446b

Browse files
committed
Fixes to neural_machine_translation_with_keras_hub example.
This example was no longer running beause of: - breaking change in `get_file`. - use of `to_tensor` on a non-ragged tensor, which was replaced with the cross-backend `convert_to_tensor(ragged=False)`. - decoding of string that was no longer needed. Fixes #2176
1 parent ce7228b commit b53446b

File tree

3 files changed

+72
-99
lines changed

3 files changed

+72
-99
lines changed

examples/nlp/ipynb/neural_machine_translation_with_keras_hub.ipynb

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@
6969
"outputs": [],
7070
"source": [
7171
"!pip install -q --upgrade rouge-score\n",
72-
"!pip install -q --upgrade keras-hub\n",
73-
"!pip install -q --upgrade keras # Upgrade to Keras 3."
72+
"!pip install -q --upgrade keras-hub"
7473
]
7574
},
7675
{
@@ -88,10 +87,7 @@
8887
"import keras\n",
8988
"from keras import ops\n",
9089
"\n",
91-
"import tensorflow.data as tf_data\n",
92-
"from tensorflow_text.tools.wordpiece_vocab import (\n",
93-
" bert_vocab_from_dataset as bert_vocab,\n",
94-
")"
90+
"import tensorflow.data as tf_data"
9591
]
9692
},
9793
{
@@ -147,7 +143,7 @@
147143
" origin=\"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\",\n",
148144
" extract=True,\n",
149145
")\n",
150-
"text_file = pathlib.Path(text_file).parent / \"spa-eng\" / \"spa.txt\""
146+
"text_file = pathlib.Path(text_file) / \"spa-eng\" / \"spa.txt\""
151147
]
152148
},
153149
{
@@ -435,8 +431,6 @@
435431
"source": [
436432
"\n",
437433
"def preprocess_batch(eng, spa):\n",
438-
" batch_size = ops.shape(spa)[0]\n",
439-
"\n",
440434
" eng = eng_tokenizer(eng)\n",
441435
" spa = spa_tokenizer(spa)\n",
442436
"\n",
@@ -659,12 +653,15 @@
659653
" batch_size = 1\n",
660654
"\n",
661655
" # Tokenize the encoder input.\n",
662-
" encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))\n",
663-
" if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:\n",
664-
" pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)\n",
665-
" encoder_input_tokens = ops.concatenate(\n",
666-
" [encoder_input_tokens.to_tensor(), pads], 1\n",
656+
" encoder_input_tokens = ops.convert_to_tensor(\n",
657+
" eng_tokenizer(input_sentences), sparse=False, ragged=False\n",
658+
" )\n",
659+
" if ops.shape(encoder_input_tokens)[1] < MAX_SEQUENCE_LENGTH:\n",
660+
" pads = ops.zeros(\n",
661+
" (1, MAX_SEQUENCE_LENGTH - ops.shape(encoder_input_tokens)[1]),\n",
662+
" dtype=encoder_input_tokens.dtype,\n",
667663
" )\n",
664+
" encoder_input_tokens = ops.concatenate([encoder_input_tokens, pads], 1)\n",
668665
"\n",
669666
" # Define a function that outputs the next token's probability given the\n",
670667
" # input sequence.\n",
@@ -693,8 +690,7 @@
693690
"test_eng_texts = [pair[0] for pair in test_pairs]\n",
694691
"for i in range(2):\n",
695692
" input_sentence = random.choice(test_eng_texts)\n",
696-
" translated = decode_sequences([input_sentence])\n",
697-
" translated = translated.numpy()[0].decode(\"utf-8\")\n",
693+
" translated = decode_sequences([input_sentence])[0]\n",
698694
" translated = (\n",
699695
" translated.replace(\"[PAD]\", \"\")\n",
700696
" .replace(\"[START]\", \"\")\n",
@@ -740,8 +736,7 @@
740736
" input_sentence = test_pair[0]\n",
741737
" reference_sentence = test_pair[1]\n",
742738
"\n",
743-
" translated_sentence = decode_sequences([input_sentence])\n",
744-
" translated_sentence = translated_sentence.numpy()[0].decode(\"utf-8\")\n",
739+
" translated_sentence = decode_sequences([input_sentence])[0]\n",
745740
" translated_sentence = (\n",
746741
" translated_sentence.replace(\"[PAD]\", \"\")\n",
747742
" .replace(\"[START]\", \"\")\n",

examples/nlp/md/neural_machine_translation_with_keras_hub.md

Lines changed: 48 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ Before we start implementing the pipeline, let's import all the libraries we nee
4848
```python
4949
!pip install -q --upgrade rouge-score
5050
!pip install -q --upgrade keras-hub
51-
!pip install -q --upgrade keras # Upgrade to Keras 3.
5251
```
5352

53+
5454
```python
5555
import keras_hub
5656
import pathlib
@@ -60,18 +60,8 @@ import keras
6060
from keras import ops
6161

6262
import tensorflow.data as tf_data
63-
from tensorflow_text.tools.wordpiece_vocab import (
64-
bert_vocab_from_dataset as bert_vocab,
65-
)
6663
```
67-
<div class="k-default-codeblock">
68-
```
69-
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
70-
tensorflow 2.15.1 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.
71-
7264

73-
```
74-
</div>
7565
Let's also define our parameters/hyperparameters.
7666

7767

@@ -100,16 +90,17 @@ text_file = keras.utils.get_file(
10090
origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
10191
extract=True,
10292
)
103-
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"
93+
text_file = pathlib.Path(text_file) / "spa-eng" / "spa.txt"
10494
```
10595

10696
<div class="k-default-codeblock">
10797
```
10898
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip
109-
2638744/2638744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
11099
100+
2638744/2638744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
111101
```
112102
</div>
103+
113104
---
114105
## Parsing the data
115106

@@ -139,14 +130,14 @@ for _ in range(5):
139130

140131
<div class="k-default-codeblock">
141132
```
142-
('tom heard that mary had bought a new computer.', 'tom oyó que mary se había comprado un computador nuevo.')
143-
('will you stay at home?', '¿te vas a quedar en casa?')
144-
('where is this train going?', '¿adónde va este tren?')
145-
('tom panicked.', 'tom entró en pánico.')
146-
("we'll help you rescue tom.", 'te ayudaremos a rescatar a tom.')
147-
133+
('i went to bed a little earlier than usual.', 'me fui a la cama un poco antes de lo habitual.')
134+
('she trusted you.', 'ella confiaba en ti.')
135+
('tom is more intelligent than i am.', 'tom es más inteligente que yo.')
136+
('he kept on smoking all the time.', 'seguía fumando todo el tiempo.')
137+
("it's two miles from here to the station.", 'son dos millas de aquí a la estación.')
148138
```
149139
</div>
140+
150141
Now, let's split the sentence pairs into a training set, a validation set,
151142
and a test set.
152143

@@ -172,9 +163,9 @@ print(f"{len(test_pairs)} test pairs")
172163
83276 training pairs
173164
17844 validation pairs
174165
17844 test pairs
175-
176166
```
177167
</div>
168+
178169
---
179170
## Tokenizing the data
180171

@@ -236,11 +227,11 @@ print("Spanish Tokens: ", spa_vocab[100:110])
236227

237228
<div class="k-default-codeblock">
238229
```
239-
English Tokens: ['at', 'know', 'him', 'there', 'go', 'they', 'her', 'has', 'time', 'will']
240-
Spanish Tokens: ['le', 'para', 'te', 'mary', 'las', 'más', 'al', 'yo', 'tu', 'estoy']
241-
230+
English Tokens: ['him', 'there', 'they', 'go', 'her', 'has', 're', 'will', 'time', 'll']
231+
Spanish Tokens: ['le', 'qué', 'ella', 'te', 'para', 'mary', 'las', 'más', 'al', 'yo']
242232
```
243233
</div>
234+
244235
Now, let's define the tokenizers. We will configure the tokenizers with the
245236
the vocabularies trained above.
246237

@@ -283,20 +274,16 @@ print(
283274

284275
<div class="k-default-codeblock">
285276
```
286-
English sentence: i am leaving the books here.
287-
Tokens: tf.Tensor([ 35 163 931 66 356 119 12], shape=(7,), dtype=int32)
288-
Recovered text after detokenizing: tf.Tensor(b'i am leaving the books here .', shape=(), dtype=string)
289-
```
290-
</div>
291-
292-
<div class="k-default-codeblock">
293-
```
294-
Spanish sentence: dejo los libros aquí.
295-
Tokens: tf.Tensor([2962 93 350 122 14], shape=(5,), dtype=int32)
296-
Recovered text after detokenizing: tf.Tensor(b'dejo los libros aqu\xc3\xad .', shape=(), dtype=string)
277+
English sentence: do you need a ride?
278+
Tokens: tf.Tensor([ 75 66 145 26 1075 25], shape=(6,), dtype=int32)
279+
Recovered text after detokenizing: do you need a ride ?
297280
281+
Spanish sentence: ¿necesitas que te lleven?
282+
Tokens: tf.Tensor([ 63 592 80 103 2994 128 29], shape=(7,), dtype=int32)
283+
Recovered text after detokenizing: ¿ necesitas que te lleven ?
298284
```
299285
</div>
286+
300287
---
301288
## Format datasets
302289

@@ -323,8 +310,6 @@ This can be easily done using `keras_hub.layers.StartEndPacker`.
323310
```python
324311

325312
def preprocess_batch(eng, spa):
326-
batch_size = ops.shape(spa)[0]
327-
328313
eng = eng_tokenizer(eng)
329314
spa = spa_tokenizer(spa)
330315

@@ -384,9 +369,9 @@ for inputs, targets in train_ds.take(1):
384369
inputs["encoder_inputs"].shape: (64, 40)
385370
inputs["decoder_inputs"].shape: (64, 40)
386371
targets.shape: (64, 40)
387-
388372
```
389373
</div>
374+
390375
---
391376
## Building the model
392377

@@ -508,7 +493,7 @@ transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
508493
│ transformer_encoder │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,315,072</span> │ token_and_positi… │
509494
│ (<span style="color: #0087ff; text-decoration-color: #0087ff">TransformerEncode…</span> │ │ │ │
510495
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
511-
functional_3 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">9,283,992</span> │ decoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │
496+
functional_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">9,283,992</span> │ decoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │
512497
│ (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">15000</span>) │ │ transformer_enco… │
513498
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
514499
</pre>
@@ -533,14 +518,15 @@ transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
533518

534519

535520

521+
536522
<div class="k-default-codeblock">
537523
```
538-
1302/1302 ━━━━━━━━━━━━━━━━━━━━ 1701s 1s/step - accuracy: 0.8168 - loss: 1.4819 - val_accuracy: 0.8650 - val_loss: 0.8129
539-
540-
<keras.src.callbacks.history.History at 0x7efdd7ee6a50>
524+
1302/1302 ━━━━━━━━━━━━━━━━━━━━ 688s 527ms/step - accuracy: 0.8385 - loss: 1.1014 - val_accuracy: 0.8661 - val_loss: 0.8040
541525
526+
<keras.src.callbacks.history.History at 0x3520df0d0>
542527
```
543528
</div>
529+
544530
---
545531
## Decoding test sentences (qualitative analysis)
546532

@@ -561,12 +547,15 @@ def decode_sequences(input_sentences):
561547
batch_size = 1
562548

563549
# Tokenize the encoder input.
564-
encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
565-
if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
566-
pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
567-
encoder_input_tokens = ops.concatenate(
568-
[encoder_input_tokens.to_tensor(), pads], 1
550+
encoder_input_tokens = ops.convert_to_tensor(
551+
eng_tokenizer(input_sentences), sparse=False, ragged=False
552+
)
553+
if ops.shape(encoder_input_tokens)[1] < MAX_SEQUENCE_LENGTH:
554+
pads = ops.zeros(
555+
(1, MAX_SEQUENCE_LENGTH - ops.shape(encoder_input_tokens)[1]),
556+
dtype=encoder_input_tokens.dtype,
569557
)
558+
encoder_input_tokens = ops.concatenate([encoder_input_tokens, pads], 1)
570559

571560
# Define a function that outputs the next token's probability given the
572561
# input sequence.
@@ -595,8 +584,7 @@ def decode_sequences(input_sentences):
595584
test_eng_texts = [pair[0] for pair in test_pairs]
596585
for i in range(2):
597586
input_sentence = random.choice(test_eng_texts)
598-
translated = decode_sequences([input_sentence])
599-
translated = translated.numpy()[0].decode("utf-8")
587+
translated = decode_sequences([input_sentence])[0]
600588
translated = (
601589
translated.replace("[PAD]", "")
602590
.replace("[START]", "")
@@ -612,23 +600,19 @@ for i in range(2):
612600
<div class="k-default-codeblock">
613601
```
614602
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
615-
I0000 00:00:1714519073.816969 34774 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
603+
I0000 00:00:1761330728.196220 3674624 service.cc:152] XLA service 0x600002a54000 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
604+
I0000 00:00:1761330728.196232 3674624 service.cc:160] StreamExecutor device (0): Host, Default Version
605+
I0000 00:00:1761330728.304584 3674624 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
616606
617607
** Example 0 **
618-
i got the ticket free of charge.
619-
me pregunto la comprome .
620-
```
621-
</div>
622-
623-
<div class="k-default-codeblock">
624-
```
608+
tom used to play the piano professionally.
609+
tom se a la fiesta de la vida .
610+
625611
** Example 1 **
626-
i think maybe that's all you have to do.
627-
creo que tom le dije que hacer eso .
612+
i had to leave boston.
613+
tuve que ir a boston .
628614
```
629615
</div>
630-
631-
632616

633617
---
634618
## Evaluating our model (quantitative analysis)
@@ -651,8 +635,7 @@ for test_pair in test_pairs[:30]:
651635
input_sentence = test_pair[0]
652636
reference_sentence = test_pair[1]
653637

654-
translated_sentence = decode_sequences([input_sentence])
655-
translated_sentence = translated_sentence.numpy()[0].decode("utf-8")
638+
translated_sentence = decode_sequences([input_sentence])[0]
656639
translated_sentence = (
657640
translated_sentence.replace("[PAD]", "")
658641
.replace("[START]", "")
@@ -669,11 +652,11 @@ print("ROUGE-2 Score: ", rouge_2.result())
669652

670653
<div class="k-default-codeblock">
671654
```
672-
ROUGE-1 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.30989552>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.37136248>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.33032653>}
673-
ROUGE-2 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.08999339>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.09524643>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.08855649>}
674-
655+
ROUGE-1 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.3267246186733246>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.3378041982650757>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.320748895406723>}
656+
ROUGE-2 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.0940079391002655>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.10507937520742416>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.09657182544469833>}
675657
```
676658
</div>
659+
677660
After 10 epochs, the scores are as follows:
678661

679662
| | **ROUGE-1** | **ROUGE-2** |

0 commit comments

Comments
 (0)