1010 " \n " ,
1111 " **Author:** [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)<br>\n " ,
1212 " **Date created:** 2020/09/25<br>\n " ,
13- " **Last modified:** 2020/12/23 <br>\n " ,
13+ " **Last modified:** 2025/02/27 <br>\n " ,
1414 " **Description:** Implementing a large-scale multi-label text classification model."
1515 ]
1616 },
4949 },
5050 {
5151 "cell_type" : " code" ,
52- "execution_count" : null ,
52+ "execution_count" : 0 ,
5353 "metadata" : {
5454 "colab_type" : " code"
5555 },
5656 "outputs" : [],
5757 "source" : [
58- " from tensorflow.keras import layers\n " ,
59- " from tensorflow import keras\n " ,
60- " import tensorflow as tf\n " ,
58+ " import os\n " ,
59+ " \n " ,
60+ " os.environ[\" KERAS_BACKEND\" ] = \" jax\" # or tensorflow, or torch\n " ,
61+ " \n " ,
62+ " import keras\n " ,
63+ " from keras import layers, ops\n " ,
6164 " \n " ,
6265 " from sklearn.model_selection import train_test_split\n " ,
63- " from ast import literal_eval\n " ,
6466 " \n " ,
67+ " from ast import literal_eval\n " ,
6568 " import matplotlib.pyplot as plt\n " ,
6669 " import pandas as pd\n " ,
6770 " import numpy as np"
8184 },
8285 {
8386 "cell_type" : " code" ,
84- "execution_count" : null ,
87+ "execution_count" : 0 ,
8588 "metadata" : {
8689 "colab_type" : " code"
8790 },
106109 },
107110 {
108111 "cell_type" : " code" ,
109- "execution_count" : null ,
112+ "execution_count" : 0 ,
110113 "metadata" : {
111114 "colab_type" : " code"
112115 },
127130 },
128131 {
129132 "cell_type" : " code" ,
130- "execution_count" : null ,
133+ "execution_count" : 0 ,
131134 "metadata" : {
132135 "colab_type" : " code"
133136 },
148151 },
149152 {
150153 "cell_type" : " code" ,
151- "execution_count" : null ,
154+ "execution_count" : 0 ,
152155 "metadata" : {
153156 "colab_type" : " code"
154157 },
178181 },
179182 {
180183 "cell_type" : " code" ,
181- "execution_count" : null ,
184+ "execution_count" : 0 ,
182185 "metadata" : {
183186 "colab_type" : " code"
184187 },
203206 },
204207 {
205208 "cell_type" : " code" ,
206- "execution_count" : null ,
209+ "execution_count" : 0 ,
207210 "metadata" : {
208211 "colab_type" : " code"
209212 },
235238 },
236239 {
237240 "cell_type" : " code" ,
238- "execution_count" : null ,
241+ "execution_count" : 0 ,
239242 "metadata" : {
240243 "colab_type" : " code"
241244 },
275278 },
276279 {
277280 "cell_type" : " code" ,
278- "execution_count" : null ,
281+ "execution_count" : 0 ,
279282 "metadata" : {
280283 "colab_type" : " code"
281284 },
282285 "outputs" : [],
283286 "source" : [
287+ " # For RaggedTensor\n " ,
288+ " import tensorflow as tf\n " ,
289+ " \n " ,
284290 " terms = tf.ragged.constant(train_df[\" terms\" ].values)\n " ,
285- " lookup = tf.keras. layers.StringLookup(output_mode=\" multi_hot\" )\n " ,
291+ " lookup = layers.StringLookup(output_mode=\" multi_hot\" )\n " ,
286292 " lookup.adapt(terms)\n " ,
287293 " vocab = lookup.get_vocabulary()\n " ,
288294 " \n " ,
294300 " \n " ,
295301 " \n " ,
296302 " print(\" Vocabulary:\\ n\" )\n " ,
297- " print(vocab)\n "
303+ " print(vocab)\n " ,
304+ " "
298305 ]
299306 },
300307 {
310317 },
311318 {
312319 "cell_type" : " code" ,
313- "execution_count" : null ,
320+ "execution_count" : 0 ,
314321 "metadata" : {
315322 "colab_type" : " code"
316323 },
337344 },
338345 {
339346 "cell_type" : " code" ,
340- "execution_count" : null ,
347+ "execution_count" : 0 ,
341348 "metadata" : {
342349 "colab_type" : " code"
343350 },
361368 },
362369 {
363370 "cell_type" : " code" ,
364- "execution_count" : null ,
371+ "execution_count" : 0 ,
365372 "metadata" : {
366373 "colab_type" : " code"
367374 },
380387 " (dataframe[\" summaries\" ].values, label_binarized)\n " ,
381388 " )\n " ,
382389 " dataset = dataset.shuffle(batch_size * 10) if is_train else dataset\n " ,
383- " return dataset.batch(batch_size)\n "
390+ " return dataset.batch(batch_size)\n " ,
391+ " "
384392 ]
385393 },
386394 {
394402 },
395403 {
396404 "cell_type" : " code" ,
397- "execution_count" : null ,
405+ "execution_count" : 0 ,
398406 "metadata" : {
399407 "colab_type" : " code"
400408 },
416424 },
417425 {
418426 "cell_type" : " code" ,
419- "execution_count" : null ,
427+ "execution_count" : 0 ,
420428 "metadata" : {
421429 "colab_type" : " code"
422430 },
450458 },
451459 {
452460 "cell_type" : " code" ,
453- "execution_count" : null ,
461+ "execution_count" : 0 ,
454462 "metadata" : {
455463 "colab_type" : " code"
456464 },
460468 " vocabulary = set()\n " ,
461469 " train_df[\" summaries\" ].str.lower().str.split().apply(vocabulary.update)\n " ,
462470 " vocabulary_size = len(vocabulary)\n " ,
463- " print(vocabulary_size)\n "
471+ " print(vocabulary_size)\n " ,
472+ " "
464473 ]
465474 },
466475 {
475484 },
476485 {
477486 "cell_type" : " code" ,
478- "execution_count" : null ,
487+ "execution_count" : 0 ,
479488 "metadata" : {
480489 "colab_type" : " code"
481490 },
498507 " ).prefetch(auto)\n " ,
499508 " test_dataset = test_dataset.map(\n " ,
500509 " lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto\n " ,
501- " ).prefetch(auto)\n "
510+ " ).prefetch(auto)\n " ,
511+ " "
502512 ]
503513 },
504514 {
535545 },
536546 {
537547 "cell_type" : " code" ,
538- "execution_count" : null ,
548+ "execution_count" : 0 ,
539549 "metadata" : {
540550 "colab_type" : " code"
541551 },
550560 " layers.Dense(lookup.vocabulary_size(), activation=\" sigmoid\" ),\n " ,
551561 " ] # More on why \" sigmoid\" has been used here in a moment.\n " ,
552562 " )\n " ,
553- " return shallow_mlp_model\n "
563+ " return shallow_mlp_model\n " ,
564+ " "
554565 ]
555566 },
556567 {
582593 },
583594 {
584595 "cell_type" : " code" ,
585- "execution_count" : null ,
596+ "execution_count" : 0 ,
586597 "metadata" : {
587598 "colab_type" : " code"
588599 },
635646 },
636647 {
637648 "cell_type" : " code" ,
638- "execution_count" : null ,
649+ "execution_count" : 0 ,
639650 "metadata" : {
640651 "colab_type" : " code"
641652 },
676687 },
677688 {
678689 "cell_type" : " code" ,
679- "execution_count" : null ,
690+ "execution_count" : 0 ,
680691 "metadata" : {
681692 "colab_type" : " code"
682693 },
683694 "outputs" : [],
684695 "source" : [
685- " # Create a model for inference.\n " ,
686- " model_for_inference = keras.Sequential([text_vectorizer, shallow_mlp_model])\n " ,
687696 " \n " ,
688- " # Create a small dataset just for demoing inference.\n " ,
689- " inference_dataset = make_dataset(test_df.sample(100), is_train=False)\n " ,
697+ " # We create a custom Model to override the predict method so\n " ,
698+ " # that it first vectorizes text data\n " ,
699+ " class ModelEndtoEnd(keras.Model):\n " ,
700+ " \n " ,
701+ " def predict(self, inputs):\n " ,
702+ " indices = text_vectorizer(inputs)\n " ,
703+ " return super().predict(indices)\n " ,
704+ " \n " ,
705+ " \n " ,
706+ " def get_inference_model(model):\n " ,
707+ " inputs = shallow_mlp_model.inputs\n " ,
708+ " outputs = shallow_mlp_model.outputs\n " ,
709+ " end_to_end_model = ModelEndtoEnd(inputs, outputs, name=\" end_to_end_model\" )\n " ,
710+ " end_to_end_model.compile(\n " ,
711+ " optimizer=\" adam\" , loss=\" binary_crossentropy\" , metrics=[\" accuracy\" ]\n " ,
712+ " )\n " ,
713+ " return end_to_end_model\n " ,
714+ " \n " ,
715+ " \n " ,
716+ " model_for_inference = get_inference_model(shallow_mlp_model)\n " ,
717+ " \n " ,
718+ " # Create a small dataset just for demonstrating inference.\n " ,
719+ " inference_dataset = make_dataset(test_df.sample(2), is_train=False)\n " ,
690720 " text_batch, label_batch = next(iter(inference_dataset))\n " ,
691721 " predicted_probabilities = model_for_inference.predict(text_batch)\n " ,
692722 " \n " ,
723+ " \n " ,
693724 " # Perform inference.\n " ,
694725 " for i, text in enumerate(text_batch[:5]):\n " ,
695726 " label = label_batch[i].numpy()[None, ...]\n " ,
731762 " tackle the multi-label binarization part and inverse-transforming the processed labels\n " ,
732763 " to the original form.\n " ,
733764 " \n " ,
734- " Thanks [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending\n " ,
735- " this code example by the binary accuracy."
765+ " Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric."
736766 ]
737767 }
738768 ],
739769 "metadata" : {
740770 "accelerator" : " GPU" ,
741771 "colab" : {
742772 "collapsed_sections" : [],
743- "name" : " .. \\ examples \\ nlp \\ multi_label_classification" ,
773+ "name" : " multi_label_classification" ,
744774 "private_outputs" : false ,
745775 "provenance" : [],
746776 "toc_visible" : true
765795 },
766796 "nbformat" : 4 ,
767797 "nbformat_minor" : 0
768- }
798+ }
0 commit comments