Skip to content

Commit 9f5c555

Browse files
adapting the script multi_label_classification.py to be Backend-Agnostic (#2052)
* adapting the script multi_label_classification.py to be Backend-Agnostic * addressing comments
1 parent cf31453 commit 9f5c555

File tree

3 files changed

+311
-180
lines changed

3 files changed

+311
-180
lines changed

examples/nlp/ipynb/multi_label_classification.ipynb

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)<br>\n",
1212
"**Date created:** 2020/09/25<br>\n",
13-
"**Last modified:** 2020/12/23<br>\n",
13+
"**Last modified:** 2025/02/27<br>\n",
1414
"**Description:** Implementing a large-scale multi-label text classification model."
1515
]
1616
},
@@ -49,19 +49,22 @@
4949
},
5050
{
5151
"cell_type": "code",
52-
"execution_count": null,
52+
"execution_count": 0,
5353
"metadata": {
5454
"colab_type": "code"
5555
},
5656
"outputs": [],
5757
"source": [
58-
"from tensorflow.keras import layers\n",
59-
"from tensorflow import keras\n",
60-
"import tensorflow as tf\n",
58+
"import os\n",
59+
"\n",
60+
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # or tensorflow, or torch\n",
61+
"\n",
62+
"import keras\n",
63+
"from keras import layers, ops\n",
6164
"\n",
6265
"from sklearn.model_selection import train_test_split\n",
63-
"from ast import literal_eval\n",
6466
"\n",
67+
"from ast import literal_eval\n",
6568
"import matplotlib.pyplot as plt\n",
6669
"import pandas as pd\n",
6770
"import numpy as np"
@@ -81,7 +84,7 @@
8184
},
8285
{
8386
"cell_type": "code",
84-
"execution_count": null,
87+
"execution_count": 0,
8588
"metadata": {
8689
"colab_type": "code"
8790
},
@@ -106,7 +109,7 @@
106109
},
107110
{
108111
"cell_type": "code",
109-
"execution_count": null,
112+
"execution_count": 0,
110113
"metadata": {
111114
"colab_type": "code"
112115
},
@@ -127,7 +130,7 @@
127130
},
128131
{
129132
"cell_type": "code",
130-
"execution_count": null,
133+
"execution_count": 0,
131134
"metadata": {
132135
"colab_type": "code"
133136
},
@@ -148,7 +151,7 @@
148151
},
149152
{
150153
"cell_type": "code",
151-
"execution_count": null,
154+
"execution_count": 0,
152155
"metadata": {
153156
"colab_type": "code"
154157
},
@@ -178,7 +181,7 @@
178181
},
179182
{
180183
"cell_type": "code",
181-
"execution_count": null,
184+
"execution_count": 0,
182185
"metadata": {
183186
"colab_type": "code"
184187
},
@@ -203,7 +206,7 @@
203206
},
204207
{
205208
"cell_type": "code",
206-
"execution_count": null,
209+
"execution_count": 0,
207210
"metadata": {
208211
"colab_type": "code"
209212
},
@@ -235,7 +238,7 @@
235238
},
236239
{
237240
"cell_type": "code",
238-
"execution_count": null,
241+
"execution_count": 0,
239242
"metadata": {
240243
"colab_type": "code"
241244
},
@@ -275,14 +278,17 @@
275278
},
276279
{
277280
"cell_type": "code",
278-
"execution_count": null,
281+
"execution_count": 0,
279282
"metadata": {
280283
"colab_type": "code"
281284
},
282285
"outputs": [],
283286
"source": [
287+
"# For RaggedTensor\n",
288+
"import tensorflow as tf\n",
289+
"\n",
284290
"terms = tf.ragged.constant(train_df[\"terms\"].values)\n",
285-
"lookup = tf.keras.layers.StringLookup(output_mode=\"multi_hot\")\n",
291+
"lookup = layers.StringLookup(output_mode=\"multi_hot\")\n",
286292
"lookup.adapt(terms)\n",
287293
"vocab = lookup.get_vocabulary()\n",
288294
"\n",
@@ -294,7 +300,8 @@
294300
"\n",
295301
"\n",
296302
"print(\"Vocabulary:\\n\")\n",
297-
"print(vocab)\n"
303+
"print(vocab)\n",
304+
""
298305
]
299306
},
300307
{
@@ -310,7 +317,7 @@
310317
},
311318
{
312319
"cell_type": "code",
313-
"execution_count": null,
320+
"execution_count": 0,
314321
"metadata": {
315322
"colab_type": "code"
316323
},
@@ -337,7 +344,7 @@
337344
},
338345
{
339346
"cell_type": "code",
340-
"execution_count": null,
347+
"execution_count": 0,
341348
"metadata": {
342349
"colab_type": "code"
343350
},
@@ -361,7 +368,7 @@
361368
},
362369
{
363370
"cell_type": "code",
364-
"execution_count": null,
371+
"execution_count": 0,
365372
"metadata": {
366373
"colab_type": "code"
367374
},
@@ -380,7 +387,8 @@
380387
" (dataframe[\"summaries\"].values, label_binarized)\n",
381388
" )\n",
382389
" dataset = dataset.shuffle(batch_size * 10) if is_train else dataset\n",
383-
" return dataset.batch(batch_size)\n"
390+
" return dataset.batch(batch_size)\n",
391+
""
384392
]
385393
},
386394
{
@@ -394,7 +402,7 @@
394402
},
395403
{
396404
"cell_type": "code",
397-
"execution_count": null,
405+
"execution_count": 0,
398406
"metadata": {
399407
"colab_type": "code"
400408
},
@@ -416,7 +424,7 @@
416424
},
417425
{
418426
"cell_type": "code",
419-
"execution_count": null,
427+
"execution_count": 0,
420428
"metadata": {
421429
"colab_type": "code"
422430
},
@@ -450,7 +458,7 @@
450458
},
451459
{
452460
"cell_type": "code",
453-
"execution_count": null,
461+
"execution_count": 0,
454462
"metadata": {
455463
"colab_type": "code"
456464
},
@@ -460,7 +468,8 @@
460468
"vocabulary = set()\n",
461469
"train_df[\"summaries\"].str.lower().str.split().apply(vocabulary.update)\n",
462470
"vocabulary_size = len(vocabulary)\n",
463-
"print(vocabulary_size)\n"
471+
"print(vocabulary_size)\n",
472+
""
464473
]
465474
},
466475
{
@@ -475,7 +484,7 @@
475484
},
476485
{
477486
"cell_type": "code",
478-
"execution_count": null,
487+
"execution_count": 0,
479488
"metadata": {
480489
"colab_type": "code"
481490
},
@@ -498,7 +507,8 @@
498507
").prefetch(auto)\n",
499508
"test_dataset = test_dataset.map(\n",
500509
" lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto\n",
501-
").prefetch(auto)\n"
510+
").prefetch(auto)\n",
511+
""
502512
]
503513
},
504514
{
@@ -535,7 +545,7 @@
535545
},
536546
{
537547
"cell_type": "code",
538-
"execution_count": null,
548+
"execution_count": 0,
539549
"metadata": {
540550
"colab_type": "code"
541551
},
@@ -550,7 +560,8 @@
550560
" layers.Dense(lookup.vocabulary_size(), activation=\"sigmoid\"),\n",
551561
" ] # More on why \"sigmoid\" has been used here in a moment.\n",
552562
" )\n",
553-
" return shallow_mlp_model\n"
563+
" return shallow_mlp_model\n",
564+
""
554565
]
555566
},
556567
{
@@ -582,7 +593,7 @@
582593
},
583594
{
584595
"cell_type": "code",
585-
"execution_count": null,
596+
"execution_count": 0,
586597
"metadata": {
587598
"colab_type": "code"
588599
},
@@ -635,7 +646,7 @@
635646
},
636647
{
637648
"cell_type": "code",
638-
"execution_count": null,
649+
"execution_count": 0,
639650
"metadata": {
640651
"colab_type": "code"
641652
},
@@ -676,20 +687,40 @@
676687
},
677688
{
678689
"cell_type": "code",
679-
"execution_count": null,
690+
"execution_count": 0,
680691
"metadata": {
681692
"colab_type": "code"
682693
},
683694
"outputs": [],
684695
"source": [
685-
"# Create a model for inference.\n",
686-
"model_for_inference = keras.Sequential([text_vectorizer, shallow_mlp_model])\n",
687696
"\n",
688-
"# Create a small dataset just for demoing inference.\n",
689-
"inference_dataset = make_dataset(test_df.sample(100), is_train=False)\n",
697+
"# We create a custom Model to override the predict method so\n",
698+
"# that it first vectorizes text data\n",
699+
"class ModelEndtoEnd(keras.Model):\n",
700+
"\n",
701+
" def predict(self, inputs):\n",
702+
" indices = text_vectorizer(inputs)\n",
703+
" return super().predict(indices)\n",
704+
"\n",
705+
"\n",
706+
"def get_inference_model(model):\n",
707+
" inputs = shallow_mlp_model.inputs\n",
708+
" outputs = shallow_mlp_model.outputs\n",
709+
" end_to_end_model = ModelEndtoEnd(inputs, outputs, name=\"end_to_end_model\")\n",
710+
" end_to_end_model.compile(\n",
711+
" optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"accuracy\"]\n",
712+
" )\n",
713+
" return end_to_end_model\n",
714+
"\n",
715+
"\n",
716+
"model_for_inference = get_inference_model(shallow_mlp_model)\n",
717+
"\n",
718+
"# Create a small dataset just for demonstrating inference.\n",
719+
"inference_dataset = make_dataset(test_df.sample(2), is_train=False)\n",
690720
"text_batch, label_batch = next(iter(inference_dataset))\n",
691721
"predicted_probabilities = model_for_inference.predict(text_batch)\n",
692722
"\n",
723+
"\n",
693724
"# Perform inference.\n",
694725
"for i, text in enumerate(text_batch[:5]):\n",
695726
" label = label_batch[i].numpy()[None, ...]\n",
@@ -731,16 +762,15 @@
731762
"tackle the multi-label binarization part and inverse-transforming the processed labels\n",
732763
"to the original form.\n",
733764
"\n",
734-
"Thanks [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending\n",
735-
"this code example by the binary accuracy."
765+
"Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric."
736766
]
737767
}
738768
],
739769
"metadata": {
740770
"accelerator": "GPU",
741771
"colab": {
742772
"collapsed_sections": [],
743-
"name": "..\\examples\\nlp\\multi_label_classification",
773+
"name": "multi_label_classification",
744774
"private_outputs": false,
745775
"provenance": [],
746776
"toc_visible": true
@@ -765,4 +795,4 @@
765795
},
766796
"nbformat": 4,
767797
"nbformat_minor": 0
768-
}
798+
}

0 commit comments

Comments
 (0)