Skip to content

Commit 2676192

Browse files
Adapt masked language modeling.py script to backend-agnostic (#2054)
* adapting the script masked_langauge_modeling.py * refactoring the script * refactoring continues * improved_implementation * improved_implementation * addressing last comments * removing warnings
1 parent ba5b116 commit 2676192

File tree

3 files changed

+1248
-260
lines changed

3 files changed

+1248
-260
lines changed

examples/nlp/ipynb/masked_language_modeling.ipynb

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,22 @@
6262
},
6363
{
6464
"cell_type": "code",
65-
"execution_count": null,
65+
"execution_count": 0,
6666
"metadata": {
6767
"colab_type": "code"
6868
},
6969
"outputs": [],
7070
"source": [
7171
"import os\n",
7272
"\n",
73-
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
73+
"os.environ[\"KERAS_BACKEND\"] = \"torch\" # or jax, or tensorflow\n",
74+
"\n",
7475
"import keras_hub\n",
76+
"\n",
7577
"import keras\n",
76-
"import tensorflow as tf\n",
7778
"from keras import layers\n",
7879
"from keras.layers import TextVectorization\n",
80+
"\n",
7981
"from dataclasses import dataclass\n",
8082
"import pandas as pd\n",
8183
"import numpy as np\n",
@@ -95,7 +97,7 @@
9597
},
9698
{
9799
"cell_type": "code",
98-
"execution_count": null,
100+
"execution_count": 0,
99101
"metadata": {
100102
"colab_type": "code"
101103
},
@@ -130,7 +132,7 @@
130132
},
131133
{
132134
"cell_type": "code",
133-
"execution_count": null,
135+
"execution_count": 0,
134136
"metadata": {
135137
"colab_type": "code"
136138
},
@@ -142,7 +144,7 @@
142144
},
143145
{
144146
"cell_type": "code",
145-
"execution_count": null,
147+
"execution_count": 0,
146148
"metadata": {
147149
"colab_type": "code"
148150
},
@@ -159,7 +161,6 @@
159161
"\n",
160162
"\n",
161163
"def get_data_from_text_files(folder_name):\n",
162-
"\n",
163164
" pos_files = glob.glob(\"aclImdb/\" + folder_name + \"/pos/*.txt\")\n",
164165
" pos_texts = get_text_list_from_files(pos_files)\n",
165166
" neg_files = glob.glob(\"aclImdb/\" + folder_name + \"/neg/*.txt\")\n",
@@ -177,7 +178,7 @@
177178
"train_df = get_data_from_text_files(\"train\")\n",
178179
"test_df = get_data_from_text_files(\"test\")\n",
179180
"\n",
180-
"all_data = train_df.append(test_df)"
181+
"all_data = pd.concat([train_df, test_df], ignore_index=True)"
181182
]
182183
},
183184
{
@@ -203,12 +204,15 @@
203204
},
204205
{
205206
"cell_type": "code",
206-
"execution_count": null,
207+
"execution_count": 0,
207208
"metadata": {
208209
"colab_type": "code"
209210
},
210211
"outputs": [],
211212
"source": [
213+
"# For data pre-processing and tf.data.Dataset\n",
214+
"import tensorflow as tf\n",
215+
"\n",
212216
"\n",
213217
"def custom_standardization(input_data):\n",
214218
" lowercase = tf.strings.lower(input_data)\n",
@@ -276,9 +280,9 @@
276280
" # Set input to [MASK] which is the last token for the 90% of tokens\n",
277281
" # This means leaving 10% unchanged\n",
278282
" inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)\n",
279-
" encoded_texts_masked[\n",
280-
" inp_mask_2mask\n",
281-
" ] = mask_token_id # mask token is the last in the dict\n",
283+
" encoded_texts_masked[inp_mask_2mask] = (\n",
284+
" mask_token_id # mask token is the last in the dict\n",
285+
" )\n",
282286
"\n",
283287
" # Set 10% to a random token\n",
284288
" inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)\n",
@@ -312,10 +316,8 @@
312316
" config.BATCH_SIZE\n",
313317
")\n",
314318
"\n",
315-
"# Build dataset for end to end model input (will be used at the end)\n",
316-
"test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(\n",
317-
" (test_df.review.values, y_test)\n",
318-
").batch(config.BATCH_SIZE)\n",
319+
"# Dataset for end to end model input (will be used at the end)\n",
320+
"test_raw_classifier_ds = test_df\n",
319321
"\n",
320322
"# Prepare data for masked language model\n",
321323
"x_all_review = encode(all_data.review.values)\n",
@@ -345,7 +347,7 @@
345347
},
346348
{
347349
"cell_type": "code",
348-
"execution_count": null,
350+
"execution_count": 0,
349351
"metadata": {
350352
"colab_type": "code"
351353
},
@@ -389,26 +391,14 @@
389391
"\n",
390392
"\n",
391393
"class MaskedLanguageModel(keras.Model):\n",
392-
" def train_step(self, inputs):\n",
393-
" if len(inputs) == 3:\n",
394-
" features, labels, sample_weight = inputs\n",
395-
" else:\n",
396-
" features, labels = inputs\n",
397-
" sample_weight = None\n",
398394
"\n",
399-
" with tf.GradientTape() as tape:\n",
400-
" predictions = self(features, training=True)\n",
401-
" loss = loss_fn(labels, predictions, sample_weight=sample_weight)\n",
395+
" def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):\n",
402396
"\n",
403-
" # Compute gradients\n",
404-
" trainable_vars = self.trainable_variables\n",
405-
" gradients = tape.gradient(loss, trainable_vars)\n",
406-
"\n",
407-
" # Update weights\n",
408-
" self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
409-
"\n",
410-
" # Compute our own metrics\n",
397+
" loss = loss_fn(y, y_pred, sample_weight)\n",
411398
" loss_tracker.update_state(loss, sample_weight=sample_weight)\n",
399+
" return keras.ops.sum(loss)\n",
400+
"\n",
401+
" def compute_metrics(self, x, y, y_pred, sample_weight):\n",
412402
"\n",
413403
" # Return a dict mapping metric names to current value\n",
414404
" return {\"loss\": loss_tracker.result()}\n",
@@ -505,7 +495,7 @@
505495
},
506496
{
507497
"cell_type": "code",
508-
"execution_count": null,
498+
"execution_count": 0,
509499
"metadata": {
510500
"colab_type": "code"
511501
},
@@ -530,7 +520,7 @@
530520
},
531521
{
532522
"cell_type": "code",
533-
"execution_count": null,
523+
"execution_count": 0,
534524
"metadata": {
535525
"colab_type": "code"
536526
},
@@ -596,24 +586,41 @@
596586
"When you want to deploy a model, it's best if it already includes its preprocessing\n",
597587
"pipeline, so that you don't have to reimplement the preprocessing logic in your\n",
598588
"production environment. Let's create an end-to-end model that incorporates\n",
599-
"the `TextVectorization` layer, and let's evaluate. Our model will accept raw strings\n",
600-
"as input."
589+
"the `TextVectorization` layer inside evalaute method, and let's evaluate. We will pass raw strings as input."
601590
]
602591
},
603592
{
604593
"cell_type": "code",
605-
"execution_count": null,
594+
"execution_count": 0,
606595
"metadata": {
607596
"colab_type": "code"
608597
},
609598
"outputs": [],
610599
"source": [
600+
"\n",
601+
"# We create a custom Model to override the evaluate method so\n",
602+
"# that it first pre-process text data\n",
603+
"class ModelEndtoEnd(keras.Model):\n",
604+
"\n",
605+
" def evaluate(self, inputs):\n",
606+
" features = encode(inputs.review.values)\n",
607+
" labels = inputs.sentiment.values\n",
608+
" test_classifier_ds = (\n",
609+
" tf.data.Dataset.from_tensor_slices((features, labels))\n",
610+
" .shuffle(1000)\n",
611+
" .batch(config.BATCH_SIZE)\n",
612+
" )\n",
613+
" return super().evaluate(test_classifier_ds)\n",
614+
"\n",
615+
" # Build the model\n",
616+
" def build(self, input_shape):\n",
617+
" self.built = True\n",
618+
"\n",
611619
"\n",
612620
"def get_end_to_end(model):\n",
613-
" inputs_string = keras.Input(shape=(1,), dtype=\"string\")\n",
614-
" indices = vectorize_layer(inputs_string)\n",
615-
" outputs = model(indices)\n",
616-
" end_to_end_model = keras.Model(inputs_string, outputs, name=\"end_to_end_model\")\n",
621+
" inputs = classifer_model.inputs[0]\n",
622+
" outputs = classifer_model.outputs\n",
623+
" end_to_end_model = ModelEndtoEnd(inputs, outputs, name=\"end_to_end_model\")\n",
617624
" optimizer = keras.optimizers.Adam(learning_rate=config.LR)\n",
618625
" end_to_end_model.compile(\n",
619626
" optimizer=optimizer, loss=\"binary_crossentropy\", metrics=[\"accuracy\"]\n",
@@ -622,6 +629,7 @@
622629
"\n",
623630
"\n",
624631
"end_to_end_classification_model = get_end_to_end(classifer_model)\n",
632+
"# Pass raw text dataframe to the model\n",
625633
"end_to_end_classification_model.evaluate(test_raw_classifier_ds)"
626634
]
627635
}
@@ -630,7 +638,7 @@
630638
"accelerator": "GPU",
631639
"colab": {
632640
"collapsed_sections": [],
633-
"name": "mlm_and_finetune_with_bert",
641+
"name": "masked_language_modeling",
634642
"private_outputs": false,
635643
"provenance": [],
636644
"toc_visible": true
@@ -655,4 +663,4 @@
655663
},
656664
"nbformat": 4,
657665
"nbformat_minor": 0
658-
}
666+
}

examples/nlp/masked_language_modeling.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Last modified: 2024/03/15
66
Description: Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.
77
Accelerator: GPU
8-
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
8+
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
99
"""
1010

1111
"""
@@ -46,12 +46,14 @@
4646

4747
import os
4848

49-
os.environ["KERAS_BACKEND"] = "tensorflow"
49+
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
50+
5051
import keras_hub
52+
5153
import keras
52-
import tensorflow as tf
5354
from keras import layers
5455
from keras.layers import TextVectorization
56+
5557
from dataclasses import dataclass
5658
import pandas as pd
5759
import numpy as np
@@ -117,7 +119,7 @@ def get_data_from_text_files(folder_name):
117119
train_df = get_data_from_text_files("train")
118120
test_df = get_data_from_text_files("test")
119121

120-
all_data = train_df.append(test_df)
122+
all_data = pd.concat([train_df, test_df], ignore_index=True)
121123

122124
"""
123125
## Dataset preparation
@@ -135,6 +137,9 @@ def get_data_from_text_files(folder_name):
135137
It masks 15% of all input tokens in each sequence at random.
136138
"""
137139

140+
# For data pre-processing and tf.data.Dataset
141+
import tensorflow as tf
142+
138143

139144
def custom_standardization(input_data):
140145
lowercase = tf.strings.lower(input_data)
@@ -238,10 +243,8 @@ def get_masked_input_and_labels(encoded_texts):
238243
config.BATCH_SIZE
239244
)
240245

241-
# Build dataset for end to end model input (will be used at the end)
242-
test_raw_classifier_ds = tf.data.Dataset.from_tensor_slices(
243-
(test_df.review.values, y_test)
244-
).batch(config.BATCH_SIZE)
246+
# Dataset for end to end model input (will be used at the end)
247+
test_raw_classifier_ds = test_df
245248

246249
# Prepare data for masked language model
247250
x_all_review = encode(all_data.review.values)
@@ -301,26 +304,14 @@ def bert_module(query, key, value, i):
301304

302305

303306
class MaskedLanguageModel(keras.Model):
304-
def train_step(self, inputs):
305-
if len(inputs) == 3:
306-
features, labels, sample_weight = inputs
307-
else:
308-
features, labels = inputs
309-
sample_weight = None
310-
311-
with tf.GradientTape() as tape:
312-
predictions = self(features, training=True)
313-
loss = loss_fn(labels, predictions, sample_weight=sample_weight)
314307

315-
# Compute gradients
316-
trainable_vars = self.trainable_variables
317-
gradients = tape.gradient(loss, trainable_vars)
308+
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
318309

319-
# Update weights
320-
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
321-
322-
# Compute our own metrics
310+
loss = loss_fn(y, y_pred, sample_weight)
323311
loss_tracker.update_state(loss, sample_weight=sample_weight)
312+
return keras.ops.sum(loss)
313+
314+
def compute_metrics(self, x, y, y_pred, sample_weight):
324315

325316
# Return a dict mapping metric names to current value
326317
return {"loss": loss_tracker.result()}
@@ -475,16 +466,33 @@ def create_classifier_bert_model():
475466
When you want to deploy a model, it's best if it already includes its preprocessing
476467
pipeline, so that you don't have to reimplement the preprocessing logic in your
477468
production environment. Let's create an end-to-end model that incorporates
478-
the `TextVectorization` layer, and let's evaluate. Our model will accept raw strings
479-
as input.
469+
the `TextVectorization` layer inside evaluate method, and let's evaluate. We will pass raw strings as input.
480470
"""
481471

482472

473+
# We create a custom Model to override the evaluate method so
474+
# that it first pre-process text data
475+
class ModelEndtoEnd(keras.Model):
476+
477+
def evaluate(self, inputs):
478+
features = encode(inputs.review.values)
479+
labels = inputs.sentiment.values
480+
test_classifier_ds = (
481+
tf.data.Dataset.from_tensor_slices((features, labels))
482+
.shuffle(1000)
483+
.batch(config.BATCH_SIZE)
484+
)
485+
return super().evaluate(test_classifier_ds)
486+
487+
# Build the model
488+
def build(self, input_shape):
489+
self.built = True
490+
491+
483492
def get_end_to_end(model):
484-
inputs_string = keras.Input(shape=(1,), dtype="string")
485-
indices = vectorize_layer(inputs_string)
486-
outputs = model(indices)
487-
end_to_end_model = keras.Model(inputs_string, outputs, name="end_to_end_model")
493+
inputs = classifer_model.inputs[0]
494+
outputs = classifer_model.outputs
495+
end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
488496
optimizer = keras.optimizers.Adam(learning_rate=config.LR)
489497
end_to_end_model.compile(
490498
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
@@ -493,4 +501,5 @@ def get_end_to_end(model):
493501

494502

495503
end_to_end_classification_model = get_end_to_end(classifer_model)
504+
# Pass raw text dataframe to the model
496505
end_to_end_classification_model.evaluate(test_raw_classifier_ds)

0 commit comments

Comments
 (0)