Skip to content

Commit 2158c91

Browse files
authored
Fixes to the "English-to-Spanish Translation with a Sequence-to-Sequence Transformer" Code Example (#1997)
* bugfix: Encoder and decoder inputs were flipped. Given 30 epochs of training, the model never ended producing sensible output. These are examples: 1) Tom didn't like Mary. → [start] ha estoy qué 2) Tom called Mary and canceled their date. → [start] sola qué yo pasatiempo visto campo When fitting the model the following relevant warning was emitted: ``` UserWarning: The structure of `inputs` doesn't match the expected structure: ['encoder_inputs', 'decoder_inputs']. Received: the structure of inputs={'encoder_inputs': '*', 'decoder_inputs': '*'} ``` After the fix the model now outputs sentences that are close to proper Spanish: 1)That's what Tom told me. → [start] eso es lo que tom me dijo [end] 2) Does Tom like cheeseburgers? → [start] a tom le gustan las queso de queso [end] * Fix compute_mask in PostionalEmbedding The check essentially disables the mask calculation, as the layer is the first one to receive the input, and thus never has a previous. With this change mask is now passed on to the encoder. Looks like a regression error. The initial commit looks very similar to this. * Propagate both encoder/decoder-sequence masks to the decoder As per https://github.com/tensorflow/tensorflow/blob/6550e4bd80223cdb8be6c3afd1f81e86a4d433c3/tensorflow/python/keras/engine/base_layer.py#L965 the inputs should be a list, and not kwargs. When this is done, both the masks are received as a tuple in the mask argument. * Apply both the padding masks in the attention layers and during loss computation * Regenerate ipynb/md-files for NMT example
1 parent 5f16942 commit 2158c91

File tree

3 files changed

+146
-90
lines changed

3 files changed

+146
-90
lines changed

examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
1212
"**Date created:** 2021/05/26<br>\n",
13-
"**Last modified:** 2023/02/25<br>\n",
13+
"**Last modified:** 2024/11/18<br>\n",
1414
"**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task."
1515
]
1616
},
@@ -84,7 +84,7 @@
8484
"import keras\n",
8585
"from keras import layers\n",
8686
"from keras import ops\n",
87-
"from keras.layers import TextVectorization\n"
87+
"from keras.layers import TextVectorization"
8888
]
8989
},
9090
{
@@ -213,7 +213,7 @@
213213
"The English layer will use the default string standardization (strip punctuation characters)\n",
214214
"and splitting scheme (split on whitespace), while\n",
215215
"the Spanish layer will use a custom standardization, where we add the character\n",
216-
"`\"\u00bf\"` to the set of punctuation characters to be stripped.\n",
216+
"`\"¿\"` to the set of punctuation characters to be stripped.\n",
217217
"\n",
218218
"Note: in a production-grade machine translation model, I would not recommend\n",
219219
"stripping the punctuation characters in either language. Instead, I would recommend turning\n",
@@ -229,7 +229,7 @@
229229
},
230230
"outputs": [],
231231
"source": [
232-
"strip_chars = string.punctuation + \"\u00bf\"\n",
232+
"strip_chars = string.punctuation + \"¿\"\n",
233233
"strip_chars = strip_chars.replace(\"[\", \"\")\n",
234234
"strip_chars = strip_chars.replace(\"]\", \"\")\n",
235235
"\n",
@@ -441,10 +441,7 @@
441441
" return embedded_tokens + embedded_positions\n",
442442
"\n",
443443
" def compute_mask(self, inputs, mask=None):\n",
444-
" if mask is None:\n",
445-
" return None\n",
446-
" else:\n",
447-
" return ops.not_equal(inputs, 0)\n",
444+
" return ops.not_equal(inputs, 0)\n",
448445
"\n",
449446
" def get_config(self):\n",
450447
" config = super().get_config()\n",
@@ -481,24 +478,30 @@
481478
" self.layernorm_3 = layers.LayerNormalization()\n",
482479
" self.supports_masking = True\n",
483480
"\n",
484-
" def call(self, inputs, encoder_outputs, mask=None):\n",
481+
" def call(self, inputs, mask=None):\n",
482+
" inputs, encoder_outputs = inputs\n",
485483
" causal_mask = self.get_causal_attention_mask(inputs)\n",
486-
" if mask is not None:\n",
487-
" padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
488-
" padding_mask = ops.minimum(padding_mask, causal_mask)\n",
484+
"\n",
485+
" if mask is None:\n",
486+
" inputs_padding_mask, encoder_outputs_padding_mask = None, None\n",
489487
" else:\n",
490-
" padding_mask = None\n",
488+
" inputs_padding_mask, encoder_outputs_padding_mask = mask\n",
491489
"\n",
492490
" attention_output_1 = self.attention_1(\n",
493-
" query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
491+
" query=inputs,\n",
492+
" value=inputs,\n",
493+
" key=inputs,\n",
494+
" attention_mask=causal_mask,\n",
495+
" query_mask=inputs_padding_mask,\n",
494496
" )\n",
495497
" out_1 = self.layernorm_1(inputs + attention_output_1)\n",
496498
"\n",
497499
" attention_output_2 = self.attention_2(\n",
498500
" query=out_1,\n",
499501
" value=encoder_outputs,\n",
500502
" key=encoder_outputs,\n",
501-
" attention_mask=padding_mask,\n",
503+
" query_mask=inputs_padding_mask,\n",
504+
" key_mask=encoder_outputs_padding_mask,\n",
502505
" )\n",
503506
" out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
504507
"\n",
@@ -527,8 +530,7 @@
527530
" \"num_heads\": self.num_heads,\n",
528531
" }\n",
529532
" )\n",
530-
" return config\n",
531-
""
533+
" return config\n"
532534
]
533535
},
534536
{
@@ -560,14 +562,15 @@
560562
"decoder_inputs = keras.Input(shape=(None,), dtype=\"int64\", name=\"decoder_inputs\")\n",
561563
"encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name=\"decoder_state_inputs\")\n",
562564
"x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)\n",
563-
"x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)\n",
565+
"x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])\n",
564566
"x = layers.Dropout(0.5)(x)\n",
565567
"decoder_outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
566568
"decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)\n",
567569
"\n",
568-
"decoder_outputs = decoder([decoder_inputs, encoder_outputs])\n",
569570
"transformer = keras.Model(\n",
570-
" [encoder_inputs, decoder_inputs], decoder_outputs, name=\"transformer\"\n",
571+
" {\"encoder_inputs\": encoder_inputs, \"decoder_inputs\": decoder_inputs},\n",
572+
" decoder_outputs,\n",
573+
" name=\"transformer\",\n",
571574
")"
572575
]
573576
},
@@ -598,7 +601,9 @@
598601
"\n",
599602
"transformer.summary()\n",
600603
"transformer.compile(\n",
601-
" \"rmsprop\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n",
604+
" \"rmsprop\",\n",
605+
" loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),\n",
606+
" metrics=[\"accuracy\"],\n",
602607
")\n",
603608
"transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)"
604609
]
@@ -635,7 +640,12 @@
635640
" decoded_sentence = \"[start]\"\n",
636641
" for i in range(max_decoded_sentence_length):\n",
637642
" tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n",
638-
" predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])\n",
643+
" predictions = transformer(\n",
644+
" {\n",
645+
" \"encoder_inputs\": tokenized_input_sentence,\n",
646+
" \"decoder_inputs\": tokenized_target_sentence,\n",
647+
" }\n",
648+
" )\n",
639649
"\n",
640650
" # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here\n",
641651
" sampled_token_index = ops.convert_to_numpy(\n",
@@ -664,19 +674,19 @@
664674
"After 30 epochs, we get results such as:\n",
665675
"\n",
666676
"> She handed him the money.\n",
667-
"> [start] ella le pas\u00f3 el dinero [end]\n",
677+
"> [start] ella le pasó el dinero [end]\n",
668678
"\n",
669679
"> Tom has never heard Mary sing.\n",
670-
"> [start] tom nunca ha o\u00eddo cantar a mary [end]\n",
680+
"> [start] tom nunca ha oído cantar a mary [end]\n",
671681
"\n",
672682
"> Perhaps she will come tomorrow.\n",
673-
"> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n",
683+
"> [start] tal vez ella vendrá mañana [end]\n",
674684
"\n",
675685
"> I love to write.\n",
676686
"> [start] me encanta escribir [end]\n",
677687
"\n",
678688
"> His French is improving little by little.\n",
679-
"> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n",
689+
"> [start] su francés va a [UNK] sólo un poco [end]\n",
680690
"\n",
681691
"> My hotel told me to call you.\n",
682692
"> [start] mi hotel me dijo que te [UNK] [end]"
@@ -693,7 +703,7 @@
693703
"toc_visible": true
694704
},
695705
"kernelspec": {
696-
"display_name": "Python 3",
706+
"display_name": "venv",
697707
"language": "python",
698708
"name": "python3"
699709
},
@@ -707,9 +717,9 @@
707717
"name": "python",
708718
"nbconvert_exporter": "python",
709719
"pygments_lexer": "ipython3",
710-
"version": "3.7.0"
720+
"version": "3.10.12"
711721
}
712722
},
713723
"nbformat": 4,
714724
"nbformat_minor": 0
715-
}
725+
}

0 commit comments

Comments
 (0)