diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png new file mode 100644 index 0000000000..3e4559d82d Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png differ diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png new file mode 100644 index 0000000000..d5dbce9a38 Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png differ diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png new file mode 100644 index 0000000000..b8698888be Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png differ diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png new file mode 100644 index 0000000000..7d60426bc3 Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png differ diff --git a/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb b/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb new file mode 100644 index 0000000000..15ca8dc34c --- /dev/null +++ b/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb @@ -0,0 +1,1120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Semantic Segmentation with KerasHub\n", + "\n", + "**Authors:** [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
\n", + "**Date created:** 2024/10/11
\n", + "**Last modified:** 2024/10/11
\n", + "**Description:** DeepLabV3 training and inference with KerasHub." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png)\n", + "\n", + "## Background\n", + "Semantic segmentation is a type of computer vision task that involves assigning a\n", + "class label such as \"person\", \"bike\", or \"background\" to each individual pixel\n", + "of an image, effectively dividing the image into regions that correspond to\n", + "different object classes or categories.\n", + "\n", + "![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png)\n", + "\n", + "\n", + "\n", + "KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic\n", + "segmentation.\n", + "\n", + "This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed\n", + "by Google for image semantic segmentation with KerasHub. Its architecture\n", + "combines Atrous convolutions, contextual information aggregation, and powerful\n", + "backbones to achieve accurate and detailed semantic segmentation.\n", + "\n", + "DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to\n", + "refine the segmentation results, especially along object boundaries. Both models\n", + "have achieved state-of-the-art results on a variety of image segmentation\n", + "benchmarks.\n", + "\n", + "### References\n", + "[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)\n", + "[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup and Imports\n", + "\n", + "Let's install the dependencies and import the necessary modules.\n", + "\n", + "To run this tutorial, you will need to install the following packages:\n", + "\n", + "* `keras-hub`\n", + "* `keras`" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!pip install -q --upgrade keras-hub\n", + "!pip install -q --upgrade keras" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "After installing `keras` and `keras-hub`, set the backend for `keras`.\n", + "This guide can be run with any backend (Tensorflow, JAX, PyTorch)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "import keras\n", + "from keras import ops\n", + "import keras_hub\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Perform semantic segmentation with a pretrained DeepLabv3+ model\n", + "\n", + "The highest level API in the KerasHub semantic segmentation API is the\n", + "`keras_hub.models` API. This API includes fully pretrained semantic segmentation\n", + "models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.\n", + "\n", + "Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC\n", + "dataset.\n", + "Also, define the preprocessing function for the model to preprocess images and\n", + "labels.\n", + "**Note:** By default `from_preset()` method in KerasHub loads the pretrained\n", + "task weights with all the classes, 21 classes in this case." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(\n", + " \"deeplab_v3_plus_resnet50_pascalvoc\"\n", + ")\n", + "\n", + "image_converter = keras_hub.layers.DeepLabV3ImageConverter(\n", + " image_size=(512, 512),\n", + " interpolation=\"bilinear\",\n", + ")\n", + "preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Let us visualize the results of this pretrained model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "filepath = keras.utils.get_file(\n", + " origin=\"https://storage.googleapis.com/keras-cv/pictures/dog.jpeg\"\n", + ")\n", + "image = keras.utils.load_img(filepath)\n", + "image = keras.utils.img_to_array(image)\n", + "\n", + "image = preprocessor(image)\n", + "image = keras.ops.expand_dims(image, axis=0)\n", + "preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)\n", + "\n", + "\n", + "def plot_segmentation(original_image, predicted_mask):\n", + " original_image = np.squeeze(original_image, axis=0)\n", + " original_image = np.clip(original_image / 255.0, 0, 1)\n", + " predicted_mask = np.squeeze(predicted_mask, axis=0)\n", + " plt.figure(figsize=(5, 5))\n", + "\n", + " plt.subplot(1, 2, 1)\n", + " plt.imshow(original_image)\n", + " plt.axis(\"off\")\n", + "\n", + " plt.subplot(1, 2, 2)\n", + " plt.imshow(predicted_mask, cmap=\"gray\")\n", + " plt.axis(\"off\")\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "plot_segmentation(image, preds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Train a custom semantic segmentation model\n", + "In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3\n", + "semantic segmentation model. This includes data loading, augmentation, training,\n", + "metric evaluation, and inference!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Download the data\n", + "\n", + "We download Pascal VOC 2012 dataset with additional annotations provided here\n", + "[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)\n", + "and split them into train dataset `train_ds` and `eval_ds`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# @title helper functions\n", + "import logging\n", + "import multiprocessing\n", + "from builtins import open\n", + "import os.path\n", + "import random\n", + "import xml\n", + "\n", + "import tensorflow_datasets as tfds\n", + "\n", + "VOC_URL = \"https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar\"\n", + "\n", + "SBD_URL = \"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz\"\n", + "\n", + "# Note that this list doesn't contain the background class. In the\n", + "# classification use case, the label is 0 based (aeroplane -> 0), whereas in\n", + "# segmentation use case, the 0 is reserved for background, so aeroplane maps to\n", + "# 1.\n", + "CLASSES = [\n", + " \"aeroplane\",\n", + " \"bicycle\",\n", + " \"bird\",\n", + " \"boat\",\n", + " \"bottle\",\n", + " \"bus\",\n", + " \"car\",\n", + " \"cat\",\n", + " \"chair\",\n", + " \"cow\",\n", + " \"diningtable\",\n", + " \"dog\",\n", + " \"horse\",\n", + " \"motorbike\",\n", + " \"person\",\n", + " \"pottedplant\",\n", + " \"sheep\",\n", + " \"sofa\",\n", + " \"train\",\n", + " \"tvmonitor\",\n", + "]\n", + "# This is used to map between string class to index.\n", + "CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}\n", + "\n", + "# For the mask data in the PNG file, the encoded raw pixel value need to be\n", + "# converted to the proper class index. In the following map, [0, 0, 0] will be\n", + "# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also\n", + "# note that the mask class is 1 base since class 0 is reserved for the\n", + "# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.\n", + "VOC_PNG_COLOR_VALUE = [\n", + " [0, 0, 0],\n", + " [128, 0, 0],\n", + " [0, 128, 0],\n", + " [128, 128, 0],\n", + " [0, 0, 128],\n", + " [128, 0, 128],\n", + " [0, 128, 128],\n", + " [128, 128, 128],\n", + " [64, 0, 0],\n", + " [192, 0, 0],\n", + " [64, 128, 0],\n", + " [192, 128, 0],\n", + " [64, 0, 128],\n", + " [192, 0, 128],\n", + " [64, 128, 128],\n", + " [192, 128, 128],\n", + " [0, 64, 0],\n", + " [128, 64, 0],\n", + " [0, 192, 0],\n", + " [128, 192, 0],\n", + " [0, 64, 128],\n", + "]\n", + "# Will be populated by maybe_populate_voc_color_mapping() below.\n", + "VOC_PNG_COLOR_MAPPING = None\n", + "\n", + "\n", + "def maybe_populate_voc_color_mapping():\n", + " \"\"\"Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.\"\"\"\n", + " global VOC_PNG_COLOR_MAPPING\n", + " if VOC_PNG_COLOR_MAPPING is None:\n", + " VOC_PNG_COLOR_MAPPING = [0] * (256**3)\n", + " for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):\n", + " VOC_PNG_COLOR_MAPPING[\n", + " (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]\n", + " ] = i\n", + " # There is a special mapping with [224, 224, 192] -> 255\n", + " VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255\n", + " VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)\n", + " return VOC_PNG_COLOR_MAPPING\n", + "\n", + "\n", + "def parse_annotation_data(annotation_file_path):\n", + " \"\"\"Parse the annotation XML file for the image.\n", + "\n", + " The annotation contains the metadata, as well as the object bounding box\n", + " information.\n", + "\n", + " \"\"\"\n", + " with open(annotation_file_path, \"r\") as f:\n", + " root = xml.etree.ElementTree.parse(f).getroot()\n", + "\n", + " size = root.find(\"size\")\n", + " width = int(size.find(\"width\").text)\n", + " height = int(size.find(\"height\").text)\n", + "\n", + " objects = []\n", + " for obj in root.findall(\"object\"):\n", + " # Get object's label name.\n", + " label = CLASS_TO_INDEX[obj.find(\"name\").text.lower()]\n", + " # Get objects' pose name.\n", + " pose = obj.find(\"pose\").text.lower()\n", + " is_truncated = obj.find(\"truncated\").text == \"1\"\n", + " is_difficult = obj.find(\"difficult\").text == \"1\"\n", + " bndbox = obj.find(\"bndbox\")\n", + " xmax = int(bndbox.find(\"xmax\").text)\n", + " xmin = int(bndbox.find(\"xmin\").text)\n", + " ymax = int(bndbox.find(\"ymax\").text)\n", + " ymin = int(bndbox.find(\"ymin\").text)\n", + " objects.append(\n", + " {\n", + " \"label\": label,\n", + " \"pose\": pose,\n", + " \"bbox\": [ymin, xmin, ymax, xmax],\n", + " \"is_truncated\": is_truncated,\n", + " \"is_difficult\": is_difficult,\n", + " }\n", + " )\n", + "\n", + " return {\"width\": width, \"height\": height, \"objects\": objects}\n", + "\n", + "\n", + "def get_image_ids(data_dir, split):\n", + " \"\"\"To get image ids from the \"train\", \"eval\" or \"trainval\" files of VOC data.\"\"\"\n", + " data_file_mapping = {\n", + " \"train\": \"train.txt\",\n", + " \"eval\": \"val.txt\",\n", + " \"trainval\": \"trainval.txt\",\n", + " }\n", + " with open(\n", + " os.path.join(data_dir, \"ImageSets\", \"Segmentation\", data_file_mapping[split]),\n", + " \"r\",\n", + " ) as f:\n", + " image_ids = f.read().splitlines()\n", + " logging.info(f\"Received {len(image_ids)} images for {split} dataset.\")\n", + " return image_ids\n", + "\n", + "\n", + "def get_sbd_image_ids(data_dir, split):\n", + " \"\"\"To get image ids from the \"sbd_train\", \"sbd_eval\" from files of SBD data.\"\"\"\n", + " data_file_mapping = {\"sbd_train\": \"train.txt\", \"sbd_eval\": \"val.txt\"}\n", + " with open(\n", + " os.path.join(data_dir, data_file_mapping[split]),\n", + " \"r\",\n", + " ) as f:\n", + " image_ids = f.read().splitlines()\n", + " logging.info(f\"Received {len(image_ids)} images for {split} dataset.\")\n", + " return image_ids\n", + "\n", + "\n", + "def parse_single_image(image_file_path):\n", + " \"\"\"Creates metadata of VOC images and path.\"\"\"\n", + " data_dir, image_file_name = os.path.split(image_file_path)\n", + " data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))\n", + " image_id, _ = os.path.splitext(image_file_name)\n", + " class_segmentation_file_path = os.path.join(\n", + " data_dir, \"SegmentationClass\", image_id + \".png\"\n", + " )\n", + " object_segmentation_file_path = os.path.join(\n", + " data_dir, \"SegmentationObject\", image_id + \".png\"\n", + " )\n", + " annotation_file_path = os.path.join(data_dir, \"Annotations\", image_id + \".xml\")\n", + " image_annotations = parse_annotation_data(annotation_file_path)\n", + "\n", + " result = {\n", + " \"image/filename\": image_id + \".jpg\",\n", + " \"image/file_path\": image_file_path,\n", + " \"segmentation/class/file_path\": class_segmentation_file_path,\n", + " \"segmentation/object/file_path\": object_segmentation_file_path,\n", + " }\n", + " result.update(image_annotations)\n", + " # Labels field should be same as the 'object.label'\n", + " labels = list(set([o[\"label\"] for o in result[\"objects\"]]))\n", + " result[\"labels\"] = sorted(labels)\n", + " return result\n", + "\n", + "\n", + "def parse_single_sbd_image(image_file_path):\n", + " \"\"\"Creates metadata of SBD images and path.\"\"\"\n", + " data_dir, image_file_name = os.path.split(image_file_path)\n", + " data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))\n", + " image_id, _ = os.path.splitext(image_file_name)\n", + " class_segmentation_file_path = os.path.join(data_dir, \"cls\", image_id + \".mat\")\n", + " object_segmentation_file_path = os.path.join(data_dir, \"inst\", image_id + \".mat\")\n", + " result = {\n", + " \"image/filename\": image_id + \".jpg\",\n", + " \"image/file_path\": image_file_path,\n", + " \"segmentation/class/file_path\": class_segmentation_file_path,\n", + " \"segmentation/object/file_path\": object_segmentation_file_path,\n", + " }\n", + " return result\n", + "\n", + "\n", + "def build_metadata(data_dir, image_ids):\n", + " \"\"\"Transpose the metadata which convert from list of dict to dict of list.\"\"\"\n", + " # Parallel process all the images.\n", + " image_file_paths = [\n", + " os.path.join(data_dir, \"JPEGImages\", i + \".jpg\") for i in image_ids\n", + " ]\n", + " pool_size = 10 if len(image_ids) > 10 else len(image_ids)\n", + " with multiprocessing.Pool(pool_size) as p:\n", + " metadata = p.map(parse_single_image, image_file_paths)\n", + "\n", + " keys = [\n", + " \"image/filename\",\n", + " \"image/file_path\",\n", + " \"segmentation/class/file_path\",\n", + " \"segmentation/object/file_path\",\n", + " \"labels\",\n", + " \"width\",\n", + " \"height\",\n", + " ]\n", + " result = {}\n", + " for key in keys:\n", + " values = [value[key] for value in metadata]\n", + " result[key] = values\n", + "\n", + " # The ragged objects need some special handling\n", + " for key in [\"label\", \"pose\", \"bbox\", \"is_truncated\", \"is_difficult\"]:\n", + " values = []\n", + " objects = [value[\"objects\"] for value in metadata]\n", + " for object in objects:\n", + " values.append([o[key] for o in object])\n", + " result[\"objects/\" + key] = values\n", + " return result\n", + "\n", + "\n", + "def build_sbd_metadata(data_dir, image_ids):\n", + " \"\"\"Transpose the metadata which convert from list of dict to dict of list.\"\"\"\n", + " # Parallel process all the images.\n", + " image_file_paths = [os.path.join(data_dir, \"img\", i + \".jpg\") for i in image_ids]\n", + " pool_size = 10 if len(image_ids) > 10 else len(image_ids)\n", + " with multiprocessing.Pool(pool_size) as p:\n", + " metadata = p.map(parse_single_sbd_image, image_file_paths)\n", + "\n", + " keys = [\n", + " \"image/filename\",\n", + " \"image/file_path\",\n", + " \"segmentation/class/file_path\",\n", + " \"segmentation/object/file_path\",\n", + " ]\n", + " result = {}\n", + " for key in keys:\n", + " values = [value[key] for value in metadata]\n", + " result[key] = values\n", + " return result\n", + "\n", + "\n", + "def decode_png_mask(mask):\n", + " \"\"\"Decode the raw PNG image and convert it to 2D tensor with probably\n", + " class.\"\"\"\n", + " # Cast the mask to int32 since the original uint8 will overflow when\n", + " # multiplied with 256\n", + " mask = tf.cast(mask, tf.int32)\n", + " mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]\n", + " mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)\n", + " mask = tf.cast(mask, tf.uint8)\n", + " return mask\n", + "\n", + "\n", + "def load_images(example):\n", + " \"\"\"Loads VOC images for segmentation task from the provided paths\"\"\"\n", + " image_file_path = example.pop(\"image/file_path\")\n", + " segmentation_class_file_path = example.pop(\"segmentation/class/file_path\")\n", + " segmentation_object_file_path = example.pop(\"segmentation/object/file_path\")\n", + " image = tf.io.read_file(image_file_path)\n", + " image = tf.image.decode_jpeg(image)\n", + "\n", + " segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)\n", + " segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)\n", + " segmentation_class_mask = decode_png_mask(segmentation_class_mask)\n", + "\n", + " segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)\n", + " segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)\n", + " segmentation_object_mask = decode_png_mask(segmentation_object_mask)\n", + "\n", + " example.update(\n", + " {\n", + " \"image\": image,\n", + " \"class_segmentation\": segmentation_class_mask,\n", + " \"object_segmentation\": segmentation_object_mask,\n", + " }\n", + " )\n", + " return example\n", + "\n", + "\n", + "def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):\n", + " \"\"\"Loads SBD images for segmentation task from the provided paths\"\"\"\n", + " image = tf.io.read_file(image_file_path)\n", + " image = tf.image.decode_jpeg(image)\n", + "\n", + " segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)\n", + " segmentation_class_mask = segmentation_class_mask[\"GTcls\"][\"Segmentation\"][0][0]\n", + " segmentation_class_mask = segmentation_class_mask[..., np.newaxis]\n", + "\n", + " segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(\n", + " seg_obj_file_path\n", + " )\n", + " segmentation_object_mask = segmentation_object_mask[\"GTinst\"][\"Segmentation\"][0][0]\n", + " segmentation_object_mask = segmentation_object_mask[..., np.newaxis]\n", + "\n", + " return {\n", + " \"image\": image,\n", + " \"class_segmentation\": segmentation_class_mask,\n", + " \"object_segmentation\": segmentation_object_mask,\n", + " }\n", + "\n", + "\n", + "def build_dataset_from_metadata(metadata):\n", + " \"\"\"Builds TensorFlow dataset from the image metadata of VOC dataset.\"\"\"\n", + " # The objects need some manual conversion to ragged tensor.\n", + " metadata[\"labels\"] = tf.ragged.constant(metadata[\"labels\"])\n", + " metadata[\"objects/label\"] = tf.ragged.constant(metadata[\"objects/label\"])\n", + " metadata[\"objects/pose\"] = tf.ragged.constant(metadata[\"objects/pose\"])\n", + " metadata[\"objects/is_truncated\"] = tf.ragged.constant(\n", + " metadata[\"objects/is_truncated\"]\n", + " )\n", + " metadata[\"objects/is_difficult\"] = tf.ragged.constant(\n", + " metadata[\"objects/is_difficult\"]\n", + " )\n", + " metadata[\"objects/bbox\"] = tf.ragged.constant(\n", + " metadata[\"objects/bbox\"], ragged_rank=1\n", + " )\n", + "\n", + " dataset = tf.data.Dataset.from_tensor_slices(metadata)\n", + " dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)\n", + " return dataset\n", + "\n", + "\n", + "def build_sbd_dataset_from_metadata(metadata):\n", + " \"\"\"Builds TensorFlow dataset from the image metadata of SBD dataset.\"\"\"\n", + " img_filepath = metadata[\"image/file_path\"]\n", + " cls_filepath = metadata[\"segmentation/class/file_path\"]\n", + " obj_filepath = metadata[\"segmentation/object/file_path\"]\n", + "\n", + " def md_gen():\n", + " c = list(zip(img_filepath, cls_filepath, obj_filepath))\n", + " # random shuffling for each generator boosts up the quality.\n", + " random.shuffle(c)\n", + " for fp in c:\n", + " img_fp, cls_fp, obj_fp = fp\n", + " yield load_sbd_images(img_fp, cls_fp, obj_fp)\n", + "\n", + " dataset = tf.data.Dataset.from_generator(\n", + " md_gen,\n", + " output_signature=(\n", + " {\n", + " \"image\": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),\n", + " \"class_segmentation\": tf.TensorSpec(\n", + " shape=(None, None, 1), dtype=tf.uint8\n", + " ),\n", + " \"object_segmentation\": tf.TensorSpec(\n", + " shape=(None, None, 1), dtype=tf.uint8\n", + " ),\n", + " }\n", + " ),\n", + " )\n", + "\n", + " return dataset\n", + "\n", + "\n", + "def load(\n", + " split=\"sbd_train\",\n", + " data_dir=None,\n", + "):\n", + " \"\"\"Load the Pacal VOC 2012 dataset.\n", + "\n", + " This function will download the data tar file from remote if needed, and\n", + " untar to the local `data_dir`, and build dataset from it.\n", + "\n", + " It supports both VOC2012 and Semantic Boundaries Dataset (SBD).\n", + "\n", + " The returned segmentation masks will be int ranging from [0, num_classes),\n", + " as well as 255 which is the boundary mask.\n", + "\n", + " Args:\n", + " split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or\n", + " 'sbd_eval'. 'sbd_train' represents the training dataset for SBD\n", + " dataset, while 'train' represents the training dataset for VOC2012\n", + " dataset. Defaults to `sbd_train`.\n", + " data_dir: string, local directory path for the loaded data. This will be\n", + " used to download the data file, and unzip. It will be used as a\n", + " cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`\n", + " will be used.\n", + " \"\"\"\n", + " supported_split_value = [\n", + " \"train\",\n", + " \"eval\",\n", + " \"trainval\",\n", + " \"sbd_train\",\n", + " \"sbd_eval\",\n", + " ]\n", + " if split not in supported_split_value:\n", + " raise ValueError(\n", + " f\"The support value for `split` are {supported_split_value}. \"\n", + " f\"Got: {split}\"\n", + " )\n", + "\n", + " if data_dir is not None:\n", + " data_dir = os.path.expanduser(data_dir)\n", + "\n", + " if \"sbd\" in split:\n", + " return load_sbd(split, data_dir)\n", + " else:\n", + " return load_voc(split, data_dir)\n", + "\n", + "\n", + "def load_voc(\n", + " split=\"train\",\n", + " data_dir=None,\n", + "):\n", + " \"\"\"This function will download VOC data from a URL. If the data is already\n", + " present in the cache directory, it will load the data from that directory\n", + " instead.\n", + " \"\"\"\n", + " extracted_dir = os.path.join(\"VOCdevkit\", \"VOC2012\")\n", + " get_data = keras.utils.get_file(\n", + " fname=os.path.basename(VOC_URL),\n", + " origin=VOC_URL,\n", + " cache_dir=data_dir,\n", + " extract=True,\n", + " )\n", + " data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)\n", + " image_ids = get_image_ids(data_dir, split)\n", + " # len(metadata) = #samples, metadata[i] is a dict.\n", + " metadata = build_metadata(data_dir, image_ids)\n", + " maybe_populate_voc_color_mapping()\n", + " dataset = build_dataset_from_metadata(metadata)\n", + "\n", + " return dataset\n", + "\n", + "\n", + "def load_sbd(\n", + " split=\"sbd_train\",\n", + " data_dir=None,\n", + "):\n", + " \"\"\"This function will download SBD data from a URL. If the data is already\n", + " present in the cache directory, it will load the data from that directory\n", + " instead.\n", + " \"\"\"\n", + " extracted_dir = os.path.join(\"benchmark_RELEASE\", \"dataset\")\n", + " get_data = keras.utils.get_file(\n", + " fname=os.path.basename(SBD_URL),\n", + " origin=SBD_URL,\n", + " cache_dir=data_dir,\n", + " extract=True,\n", + " )\n", + " data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)\n", + " image_ids = get_sbd_image_ids(data_dir, split)\n", + " # len(metadata) = #samples, metadata[i] is a dict.\n", + " metadata = build_sbd_metadata(data_dir, image_ids)\n", + "\n", + " dataset = build_sbd_dataset_from_metadata(metadata)\n", + " return dataset\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Load the dataset\n", + "\n", + "For training and evaluation, let's use \"sbd_train\" and \"sbd_eval.\" You can also\n", + "choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',\n", + "'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the\n", + "SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds = load(split=\"sbd_train\", data_dir=\"segmentation\")\n", + "eval_ds = load(split=\"sbd_eval\", data_dir=\"segmentation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Preprocess the data\n", + "\n", + "The preprocess_inputs utility function preprocesses inputs, converting them into\n", + "a dictionary containing images and segmentation_masks. Both images and\n", + "segmentation masks are resized to 512x512. The resulting dataset is then batched\n", + "into groups of four image and segmentation mask pairs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def preprocess_inputs(inputs):\n", + " def unpackage_inputs(inputs):\n", + " return {\n", + " \"images\": inputs[\"image\"],\n", + " \"segmentation_masks\": inputs[\"class_segmentation\"],\n", + " }\n", + "\n", + " outputs = inputs.map(unpackage_inputs)\n", + " outputs = outputs.map(keras.layers.Resizing(height=512, width=512))\n", + " outputs = outputs.batch(4, drop_remainder=True)\n", + " return outputs\n", + "\n", + "\n", + "train_ds = preprocess_inputs(train_ds)\n", + "batch = train_ds.take(1).get_single_element()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "A batch of this preprocessed input training data can be visualized using the\n", + "`plot_images_masks` function. This function takes a batch of images and\n", + "segmentation masks and prediction masks as input and displays them in a grid." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def plot_images_masks(images, masks, pred_masks=None):\n", + " images = (images - np.min(images)) / (np.max(images) - np.min(images))\n", + " masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks))\n", + " if pred_masks is not None:\n", + " pred_masks = (pred_masks - pred_masks.min()) / (\n", + " pred_masks.max() - pred_masks.min()\n", + " )\n", + " num_images = len(images)\n", + " plt.figure(figsize=(8, 4))\n", + " rows = 3 if pred_masks is not None else 2\n", + "\n", + " for i in range(num_images):\n", + " plt.subplot(rows, num_images, i + 1)\n", + " plt.imshow(images[i])\n", + " plt.axis(\"off\")\n", + "\n", + " plt.subplot(rows, num_images, num_images + i + 1)\n", + " plt.imshow(masks[i], cmap=\"gray\")\n", + " plt.axis(\"off\")\n", + "\n", + " if pred_masks is not None:\n", + " plt.subplot(rows, num_images, i + 1 + 2 * num_images)\n", + " plt.imshow(pred_masks[i, ..., 0], cmap=\"gray\")\n", + " plt.axis(\"off\")\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "plot_images_masks(batch[\"images\"], batch[\"segmentation_masks\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The preprocessing is applied to the evaluation dataset `eval_ds`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "eval_ds = preprocess_inputs(eval_ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Data Augmentation\n", + "\n", + "Keras provides a variety of image augmentation options. In this example, we will\n", + "use the `RandomFlip` augmentation to augment the training dataset. The\n", + "`RandomFlip` augmentation randomly flips the images in the training dataset\n", + "horizontally or vertically. This can help to improve the model's robustness to\n", + "changes in the orientation of the objects in the images." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds = train_ds.map(keras.layers.RandomFlip())\n", + "batch = train_ds.take(1).get_single_element()\n", + "\n", + "plot_images_masks(batch[\"images\"], batch[\"segmentation_masks\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model Configuration\n", + "\n", + "Please feel free to modify the configurations for model training and note how the\n", + "training results changes. This is an great exercise to get a better\n", + "understanding of the training pipeline.\n", + "\n", + "The learning rate schedule is used by the optimizer to calculate the learning\n", + "rate for each epoch. The optimizer then uses the learning rate to update the\n", + "weights of the model.\n", + "In this case, the learning rate schedule uses a cosine decay function. A cosine\n", + "decay function starts high and then decreases over time, eventually reaching\n", + "zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The\n", + "dataset cardinality is important for learning rate decay because it determines\n", + "how many steps the model will train for. The initial learning rate is\n", + "proportional to 0.007 and the decay steps are 2124. This means that the learning\n", + "rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.\n", + "![png](/img/guides/semantic_segmentation_deeplab_v3_plus/learning_rate_schedule.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 4\n", + "INITIAL_LR = 0.007 * BATCH_SIZE / 16\n", + "EPOCHS = 1\n", + "NUM_CLASSES = 21\n", + "learning_rate = keras.optimizers.schedules.CosineDecay(\n", + " INITIAL_LR,\n", + " decay_steps=EPOCHS * 2124,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for\n", + "the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with\n", + "additional decoder block.\n", + "For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing\n", + "`low_level_feature_key` as `P2` a pyramid level output to extract features from\n", + "`resnet_50_imagenet` which acts as a decoder block.\n", + "To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`\n", + "which defaults to `None`.\n", + "\n", + "Then we create DeepLabV3ImageSegmenter instance.\n", + "The `num_classes` parameter specifies the number of classes that the model will\n", + "be trained to segment. `preprocessor` argument to apply preprocessing to image\n", + "input and masks." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_encoder = keras_hub.models.Backbone.from_preset(\"resnet_50_imagenet\")\n", + "\n", + "deeplab_backbone = keras_hub.models.DeepLabV3Backbone(\n", + " image_encoder=image_encoder,\n", + " low_level_feature_key=\"P2\",\n", + " spatial_pyramid_pooling_key=\"P5\",\n", + " dilation_rates=[6, 12, 18],\n", + " upsampling_size=8,\n", + ")\n", + "\n", + "model = keras_hub.models.DeepLabV3ImageSegmenter(\n", + " backbone=deeplab_backbone,\n", + " num_classes=21,\n", + " activation=\"softmax\",\n", + " preprocessor=preprocessor,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Compile the model\n", + "\n", + "The model.compile() function sets up the training process for the model. It defines the\n", + "- optimization algorithm - Stochastic Gradient Descent (SGD)\n", + "- the loss function - categorical cross-entropy\n", + "- the evaluation metrics - Mean IoU and categorical accuracy\n", + "\n", + "Semantic segmentation evaluation metrics:\n", + "\n", + "Mean Intersection over Union (MeanIoU):\n", + "MeanIoU measures how well a semantic segmentation model accurately identifies\n", + "and delineates different objects or regions in an image. It calculates the\n", + "overlap between predicted and actual object boundaries, providing a score\n", + "between 0 and 1, where 1 represents a perfect match.\n", + "\n", + "Categorical Accuracy:\n", + "Categorical Accuracy measures the proportion of correctly classified pixels in\n", + "an image. It gives a simple percentage indicating how accurately the model\n", + "predicts the categories of pixels in the entire image.\n", + "\n", + "In essence, MeanIoU emphasizes the accuracy of identifying specific object\n", + "boundaries, while Categorical Accuracy gives a broad overview of overall\n", + "pixel-level correctness." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model.compile(\n", + " optimizer=keras.optimizers.SGD(\n", + " learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0\n", + " ),\n", + " loss=keras.losses.CategoricalCrossentropy(from_logits=False),\n", + " metrics=[\n", + " keras.metrics.MeanIoU(\n", + " num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False\n", + " ),\n", + " keras.metrics.CategoricalAccuracy(),\n", + " ],\n", + ")\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The utility function `dict_to_tuple` effectively transforms the dictionaries of\n", + "training and validation datasets into tuples of images and one-hot encoded\n", + "segmentation masks, which is used during training and evaluation of the\n", + "DeepLabv3+ model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def dict_to_tuple(x):\n", + "\n", + " return x[\"images\"], tf.one_hot(\n", + " tf.cast(tf.squeeze(x[\"segmentation_masks\"], axis=-1), \"int32\"), 21\n", + " )\n", + "\n", + "\n", + "train_ds = train_ds.map(dict_to_tuple)\n", + "eval_ds = eval_ds.map(dict_to_tuple)\n", + "\n", + "model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Predictions with trained model\n", + "Now that the model training of DeepLabv3+ has completed, let's test it by making\n", + "predications\n", + "on a few sample images.\n", + "Note: For demonstration purpose the model has been trained on only 1 epoch, for\n", + "better accuracy and result train with more number of epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "test_ds = load(split=\"sbd_eval\")\n", + "test_ds = preprocess_inputs(test_ds)\n", + "\n", + "images, masks = next(iter(train_ds.take(1)))\n", + "images = ops.convert_to_tensor(images)\n", + "masks = ops.convert_to_tensor(masks)\n", + "preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1)\n", + "masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)\n", + "\n", + "plot_images_masks(images, masks, preds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Here are some additional tips for using the KerasHub DeepLabv3 model:\n", + "\n", + "- The model can be trained on a variety of datasets, including the COCO dataset, the\n", + "PASCAL VOC dataset, and the Cityscapes dataset.\n", + "- The model can be fine-tuned on a custom dataset to improve its performance on a\n", + "specific task.\n", + "- The model can be used to perform real-time inference on images.\n", + "- Also, check out KerasHub's other segmentation models." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "semantic_segmentation_deeplab_v3", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/guides/keras_hub/semantic_segmentation_deeplab_v3.py b/guides/keras_hub/semantic_segmentation_deeplab_v3.py new file mode 100644 index 0000000000..d6bc3c1d61 --- /dev/null +++ b/guides/keras_hub/semantic_segmentation_deeplab_v3.py @@ -0,0 +1,856 @@ +""" +Title: Semantic Segmentation with KerasHub +Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit) +Date created: 2024/10/11 +Last modified: 2024/10/11 +Description: DeepLabV3 training and inference with KerasHub. +Accelerator: GPU +""" + +""" +![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png) + +## Background +Semantic segmentation is a type of computer vision task that involves assigning a +class label such as "person", "bike", or "background" to each individual pixel +of an image, effectively dividing the image into regions that correspond to +different object classes or categories. + +![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png) + + + +KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic +segmentation. + +This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed +by Google for image semantic segmentation with KerasHub. Its architecture +combines Atrous convolutions, contextual information aggregation, and powerful +backbones to achieve accurate and detailed semantic segmentation. + +DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to +refine the segmentation results, especially along object boundaries. Both models +have achieved state-of-the-art results on a variety of image segmentation +benchmarks. + +### References +[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611) +[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) +""" + +""" +## Setup and Imports + +Let's install the dependencies and import the necessary modules. + +To run this tutorial, you will need to install the following packages: + +* `keras-hub` +* `keras` +""" + +"""shell +pip install -q --upgrade keras-hub +pip install -q --upgrade keras +""" + +""" +After installing `keras` and `keras-hub`, set the backend for `keras`. +This guide can be run with any backend (Tensorflow, JAX, PyTorch). +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" +import keras +from keras import ops +import keras_hub +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt + +""" +## Perform semantic segmentation with a pretrained DeepLabv3+ model + +The highest level API in the KerasHub semantic segmentation API is the +`keras_hub.models` API. This API includes fully pretrained semantic segmentation +models, such as `keras_hub.models.DeepLabV3ImageSegmenter`. + +Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC +dataset. +Also, define the preprocessing function for the model to preprocess images and +labels. +**Note:** By default `from_preset()` method in KerasHub loads the pretrained +task weights with all the classes, 21 classes in this case. +""" + +model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplab_v3_plus_resnet50_pascalvoc" +) + +image_converter = keras_hub.layers.DeepLabV3ImageConverter( + image_size=(512, 512), + interpolation="bilinear", +) +preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter) + +""" +Let us visualize the results of this pretrained model +""" +filepath = keras.utils.get_file( + origin="https://storage.googleapis.com/keras-cv/pictures/dog.jpeg" +) +image = keras.utils.load_img(filepath) +image = keras.utils.img_to_array(image) + +image = preprocessor(image) +image = keras.ops.expand_dims(image, axis=0) +preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1) + + +def plot_segmentation(original_image, predicted_mask): + original_image = np.squeeze(original_image, axis=0) + original_image = np.clip(original_image / 255.0, 0, 1) + predicted_mask = np.squeeze(predicted_mask, axis=0) + plt.figure(figsize=(5, 5)) + + plt.subplot(1, 2, 1) + plt.imshow(original_image) + plt.axis("off") + + plt.subplot(1, 2, 2) + plt.imshow(predicted_mask, cmap="gray") + plt.axis("off") + + plt.tight_layout() + plt.show() + + +plot_segmentation(image, preds) + +""" +## Train a custom semantic segmentation model +In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3 +semantic segmentation model. This includes data loading, augmentation, training, +metric evaluation, and inference! +""" + +""" +## Download the data + +We download Pascal VOC 2012 dataset with additional annotations provided here +[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) +and split them into train dataset `train_ds` and `eval_ds`. +""" + +# @title helper functions +import logging +import multiprocessing +from builtins import open +import os.path +import random +import xml + +import tensorflow_datasets as tfds + +VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" + +SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz" + +# Note that this list doesn't contain the background class. In the +# classification use case, the label is 0 based (aeroplane -> 0), whereas in +# segmentation use case, the 0 is reserved for background, so aeroplane maps to +# 1. +CLASSES = [ + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] +# This is used to map between string class to index. +CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)} + +# For the mask data in the PNG file, the encoded raw pixel value need to be +# converted to the proper class index. In the following map, [0, 0, 0] will be +# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also +# note that the mask class is 1 base since class 0 is reserved for the +# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`. +VOC_PNG_COLOR_VALUE = [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], +] +# Will be populated by maybe_populate_voc_color_mapping() below. +VOC_PNG_COLOR_MAPPING = None + + +def maybe_populate_voc_color_mapping(): + """Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.""" + global VOC_PNG_COLOR_MAPPING + if VOC_PNG_COLOR_MAPPING is None: + VOC_PNG_COLOR_MAPPING = [0] * (256**3) + for i, colormap in enumerate(VOC_PNG_COLOR_VALUE): + VOC_PNG_COLOR_MAPPING[ + (colormap[0] * 256 + colormap[1]) * 256 + colormap[2] + ] = i + # There is a special mapping with [224, 224, 192] -> 255 + VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255 + VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING) + return VOC_PNG_COLOR_MAPPING + + +def parse_annotation_data(annotation_file_path): + """Parse the annotation XML file for the image. + + The annotation contains the metadata, as well as the object bounding box + information. + + """ + with open(annotation_file_path, "r") as f: + root = xml.etree.ElementTree.parse(f).getroot() + + size = root.find("size") + width = int(size.find("width").text) + height = int(size.find("height").text) + + objects = [] + for obj in root.findall("object"): + # Get object's label name. + label = CLASS_TO_INDEX[obj.find("name").text.lower()] + # Get objects' pose name. + pose = obj.find("pose").text.lower() + is_truncated = obj.find("truncated").text == "1" + is_difficult = obj.find("difficult").text == "1" + bndbox = obj.find("bndbox") + xmax = int(bndbox.find("xmax").text) + xmin = int(bndbox.find("xmin").text) + ymax = int(bndbox.find("ymax").text) + ymin = int(bndbox.find("ymin").text) + objects.append( + { + "label": label, + "pose": pose, + "bbox": [ymin, xmin, ymax, xmax], + "is_truncated": is_truncated, + "is_difficult": is_difficult, + } + ) + + return {"width": width, "height": height, "objects": objects} + + +def get_image_ids(data_dir, split): + """To get image ids from the "train", "eval" or "trainval" files of VOC data.""" + data_file_mapping = { + "train": "train.txt", + "eval": "val.txt", + "trainval": "trainval.txt", + } + with open( + os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]), + "r", + ) as f: + image_ids = f.read().splitlines() + logging.info(f"Received {len(image_ids)} images for {split} dataset.") + return image_ids + + +def get_sbd_image_ids(data_dir, split): + """To get image ids from the "sbd_train", "sbd_eval" from files of SBD data.""" + data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"} + with open( + os.path.join(data_dir, data_file_mapping[split]), + "r", + ) as f: + image_ids = f.read().splitlines() + logging.info(f"Received {len(image_ids)} images for {split} dataset.") + return image_ids + + +def parse_single_image(image_file_path): + """Creates metadata of VOC images and path.""" + data_dir, image_file_name = os.path.split(image_file_path) + data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir)) + image_id, _ = os.path.splitext(image_file_name) + class_segmentation_file_path = os.path.join( + data_dir, "SegmentationClass", image_id + ".png" + ) + object_segmentation_file_path = os.path.join( + data_dir, "SegmentationObject", image_id + ".png" + ) + annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml") + image_annotations = parse_annotation_data(annotation_file_path) + + result = { + "image/filename": image_id + ".jpg", + "image/file_path": image_file_path, + "segmentation/class/file_path": class_segmentation_file_path, + "segmentation/object/file_path": object_segmentation_file_path, + } + result.update(image_annotations) + # Labels field should be same as the 'object.label' + labels = list(set([o["label"] for o in result["objects"]])) + result["labels"] = sorted(labels) + return result + + +def parse_single_sbd_image(image_file_path): + """Creates metadata of SBD images and path.""" + data_dir, image_file_name = os.path.split(image_file_path) + data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir)) + image_id, _ = os.path.splitext(image_file_name) + class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat") + object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat") + result = { + "image/filename": image_id + ".jpg", + "image/file_path": image_file_path, + "segmentation/class/file_path": class_segmentation_file_path, + "segmentation/object/file_path": object_segmentation_file_path, + } + return result + + +def build_metadata(data_dir, image_ids): + """Transpose the metadata which convert from list of dict to dict of list.""" + # Parallel process all the images. + image_file_paths = [ + os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids + ] + pool_size = 10 if len(image_ids) > 10 else len(image_ids) + with multiprocessing.Pool(pool_size) as p: + metadata = p.map(parse_single_image, image_file_paths) + + keys = [ + "image/filename", + "image/file_path", + "segmentation/class/file_path", + "segmentation/object/file_path", + "labels", + "width", + "height", + ] + result = {} + for key in keys: + values = [value[key] for value in metadata] + result[key] = values + + # The ragged objects need some special handling + for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]: + values = [] + objects = [value["objects"] for value in metadata] + for object in objects: + values.append([o[key] for o in object]) + result["objects/" + key] = values + return result + + +def build_sbd_metadata(data_dir, image_ids): + """Transpose the metadata which convert from list of dict to dict of list.""" + # Parallel process all the images. + image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids] + pool_size = 10 if len(image_ids) > 10 else len(image_ids) + with multiprocessing.Pool(pool_size) as p: + metadata = p.map(parse_single_sbd_image, image_file_paths) + + keys = [ + "image/filename", + "image/file_path", + "segmentation/class/file_path", + "segmentation/object/file_path", + ] + result = {} + for key in keys: + values = [value[key] for value in metadata] + result[key] = values + return result + + +def decode_png_mask(mask): + """Decode the raw PNG image and convert it to 2D tensor with probably + class.""" + # Cast the mask to int32 since the original uint8 will overflow when + # multiplied with 256 + mask = tf.cast(mask, tf.int32) + mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2] + mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1) + mask = tf.cast(mask, tf.uint8) + return mask + + +def load_images(example): + """Loads VOC images for segmentation task from the provided paths""" + image_file_path = example.pop("image/file_path") + segmentation_class_file_path = example.pop("segmentation/class/file_path") + segmentation_object_file_path = example.pop("segmentation/object/file_path") + image = tf.io.read_file(image_file_path) + image = tf.image.decode_jpeg(image) + + segmentation_class_mask = tf.io.read_file(segmentation_class_file_path) + segmentation_class_mask = tf.image.decode_png(segmentation_class_mask) + segmentation_class_mask = decode_png_mask(segmentation_class_mask) + + segmentation_object_mask = tf.io.read_file(segmentation_object_file_path) + segmentation_object_mask = tf.image.decode_png(segmentation_object_mask) + segmentation_object_mask = decode_png_mask(segmentation_object_mask) + + example.update( + { + "image": image, + "class_segmentation": segmentation_class_mask, + "object_segmentation": segmentation_object_mask, + } + ) + return example + + +def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path): + """Loads SBD images for segmentation task from the provided paths""" + image = tf.io.read_file(image_file_path) + image = tf.image.decode_jpeg(image) + + segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path) + segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0] + segmentation_class_mask = segmentation_class_mask[..., np.newaxis] + + segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat( + seg_obj_file_path + ) + segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0] + segmentation_object_mask = segmentation_object_mask[..., np.newaxis] + + return { + "image": image, + "class_segmentation": segmentation_class_mask, + "object_segmentation": segmentation_object_mask, + } + + +def build_dataset_from_metadata(metadata): + """Builds TensorFlow dataset from the image metadata of VOC dataset.""" + # The objects need some manual conversion to ragged tensor. + metadata["labels"] = tf.ragged.constant(metadata["labels"]) + metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"]) + metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"]) + metadata["objects/is_truncated"] = tf.ragged.constant( + metadata["objects/is_truncated"] + ) + metadata["objects/is_difficult"] = tf.ragged.constant( + metadata["objects/is_difficult"] + ) + metadata["objects/bbox"] = tf.ragged.constant( + metadata["objects/bbox"], ragged_rank=1 + ) + + dataset = tf.data.Dataset.from_tensor_slices(metadata) + dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE) + return dataset + + +def build_sbd_dataset_from_metadata(metadata): + """Builds TensorFlow dataset from the image metadata of SBD dataset.""" + img_filepath = metadata["image/file_path"] + cls_filepath = metadata["segmentation/class/file_path"] + obj_filepath = metadata["segmentation/object/file_path"] + + def md_gen(): + c = list(zip(img_filepath, cls_filepath, obj_filepath)) + # random shuffling for each generator boosts up the quality. + random.shuffle(c) + for fp in c: + img_fp, cls_fp, obj_fp = fp + yield load_sbd_images(img_fp, cls_fp, obj_fp) + + dataset = tf.data.Dataset.from_generator( + md_gen, + output_signature=( + { + "image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8), + "class_segmentation": tf.TensorSpec( + shape=(None, None, 1), dtype=tf.uint8 + ), + "object_segmentation": tf.TensorSpec( + shape=(None, None, 1), dtype=tf.uint8 + ), + } + ), + ) + + return dataset + + +def load( + split="sbd_train", + data_dir=None, +): + """Load the Pacal VOC 2012 dataset. + + This function will download the data tar file from remote if needed, and + untar to the local `data_dir`, and build dataset from it. + + It supports both VOC2012 and Semantic Boundaries Dataset (SBD). + + The returned segmentation masks will be int ranging from [0, num_classes), + as well as 255 which is the boundary mask. + + Args: + split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or + 'sbd_eval'. 'sbd_train' represents the training dataset for SBD + dataset, while 'train' represents the training dataset for VOC2012 + dataset. Defaults to `sbd_train`. + data_dir: string, local directory path for the loaded data. This will be + used to download the data file, and unzip. It will be used as a + cache directory. Defaults to None, and `~/.keras/pascal_voc_2012` + will be used. + """ + supported_split_value = [ + "train", + "eval", + "trainval", + "sbd_train", + "sbd_eval", + ] + if split not in supported_split_value: + raise ValueError( + f"The support value for `split` are {supported_split_value}. " + f"Got: {split}" + ) + + if data_dir is not None: + data_dir = os.path.expanduser(data_dir) + + if "sbd" in split: + return load_sbd(split, data_dir) + else: + return load_voc(split, data_dir) + + +def load_voc( + split="train", + data_dir=None, +): + """This function will download VOC data from a URL. If the data is already + present in the cache directory, it will load the data from that directory + instead. + """ + extracted_dir = os.path.join("VOCdevkit", "VOC2012") + get_data = keras.utils.get_file( + fname=os.path.basename(VOC_URL), + origin=VOC_URL, + cache_dir=data_dir, + extract=True, + ) + data_dir = os.path.join(os.path.dirname(get_data), extracted_dir) + image_ids = get_image_ids(data_dir, split) + # len(metadata) = #samples, metadata[i] is a dict. + metadata = build_metadata(data_dir, image_ids) + maybe_populate_voc_color_mapping() + dataset = build_dataset_from_metadata(metadata) + + return dataset + + +def load_sbd( + split="sbd_train", + data_dir=None, +): + """This function will download SBD data from a URL. If the data is already + present in the cache directory, it will load the data from that directory + instead. + """ + extracted_dir = os.path.join("benchmark_RELEASE", "dataset") + get_data = keras.utils.get_file( + fname=os.path.basename(SBD_URL), + origin=SBD_URL, + cache_dir=data_dir, + extract=True, + ) + data_dir = os.path.join(os.path.dirname(get_data), extracted_dir) + image_ids = get_sbd_image_ids(data_dir, split) + # len(metadata) = #samples, metadata[i] is a dict. + metadata = build_sbd_metadata(data_dir, image_ids) + + dataset = build_sbd_dataset_from_metadata(metadata) + return dataset + + +""" +## Load the dataset + +For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also +choose any of these datasets for the `load` function: 'train', 'eval', 'trainval', +'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the +SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset. +""" +train_ds = load(split="sbd_train", data_dir="segmentation") +eval_ds = load(split="sbd_eval", data_dir="segmentation") + +""" +## Preprocess the data + +The preprocess_inputs utility function preprocesses inputs, converting them into +a dictionary containing images and segmentation_masks. Both images and +segmentation masks are resized to 512x512. The resulting dataset is then batched +into groups of four image and segmentation mask pairs. +""" + + +def preprocess_inputs(inputs): + def unpackage_inputs(inputs): + return { + "images": inputs["image"], + "segmentation_masks": inputs["class_segmentation"], + } + + outputs = inputs.map(unpackage_inputs) + outputs = outputs.map(keras.layers.Resizing(height=512, width=512)) + outputs = outputs.batch(4, drop_remainder=True) + return outputs + + +train_ds = preprocess_inputs(train_ds) +batch = train_ds.take(1).get_single_element() + +""" +A batch of this preprocessed input training data can be visualized using the +`plot_images_masks` function. This function takes a batch of images and +segmentation masks and prediction masks as input and displays them in a grid. +""" + + +def plot_images_masks(images, masks, pred_masks=None): + images = (images - np.min(images)) / (np.max(images) - np.min(images)) + masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks)) + if pred_masks is not None: + pred_masks = (pred_masks - pred_masks.min()) / ( + pred_masks.max() - pred_masks.min() + ) + num_images = len(images) + plt.figure(figsize=(8, 4)) + rows = 3 if pred_masks is not None else 2 + + for i in range(num_images): + plt.subplot(rows, num_images, i + 1) + plt.imshow(images[i]) + plt.axis("off") + + plt.subplot(rows, num_images, num_images + i + 1) + plt.imshow(masks[i], cmap="gray") + plt.axis("off") + + if pred_masks is not None: + plt.subplot(rows, num_images, i + 1 + 2 * num_images) + plt.imshow(pred_masks[i, ..., 0], cmap="gray") + plt.axis("off") + + plt.show() + + +plot_images_masks(batch["images"], batch["segmentation_masks"]) + +""" +The preprocessing is applied to the evaluation dataset `eval_ds`. +""" +eval_ds = preprocess_inputs(eval_ds) + +""" +## Data Augmentation + +Keras provides a variety of image augmentation options. In this example, we will +use the `RandomFlip` augmentation to augment the training dataset. The +`RandomFlip` augmentation randomly flips the images in the training dataset +horizontally or vertically. This can help to improve the model's robustness to +changes in the orientation of the objects in the images. +""" + +train_ds = train_ds.map(keras.layers.RandomFlip()) +batch = train_ds.take(1).get_single_element() + +plot_images_masks(batch["images"], batch["segmentation_masks"]) + +""" +## Model Configuration + +Please feel free to modify the configurations for model training and note how the +training results changes. This is an great exercise to get a better +understanding of the training pipeline. + +The learning rate schedule is used by the optimizer to calculate the learning +rate for each epoch. The optimizer then uses the learning rate to update the +weights of the model. +In this case, the learning rate schedule uses a cosine decay function. A cosine +decay function starts high and then decreases over time, eventually reaching +zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The +dataset cardinality is important for learning rate decay because it determines +how many steps the model will train for. The initial learning rate is +proportional to 0.007 and the decay steps are 2124. This means that the learning +rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps. +![png](/img/guides/semantic_segmentation_deeplab_v3_plus/learning_rate_schedule.png) +""" + +BATCH_SIZE = 4 +INITIAL_LR = 0.007 * BATCH_SIZE / 16 +EPOCHS = 1 +NUM_CLASSES = 21 +learning_rate = keras.optimizers.schedules.CosineDecay( + INITIAL_LR, + decay_steps=EPOCHS * 2124, +) + +""" +Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for +the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with +additional decoder block. +For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing +`low_level_feature_key` as `P2` a pyramid level output to extract features from +`resnet_50_imagenet` which acts as a decoder block. +To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key` +which defaults to `None`. + +Then we create DeepLabV3ImageSegmenter instance. +The `num_classes` parameter specifies the number of classes that the model will +be trained to segment. `preprocessor` argument to apply preprocessing to image +input and masks. +""" + +image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet") + +deeplab_backbone = keras_hub.models.DeepLabV3Backbone( + image_encoder=image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + dilation_rates=[6, 12, 18], + upsampling_size=8, +) + +model = keras_hub.models.DeepLabV3ImageSegmenter( + backbone=deeplab_backbone, + num_classes=21, + activation="softmax", + preprocessor=preprocessor, +) + +""" +## Compile the model + +The model.compile() function sets up the training process for the model. It defines the +- optimization algorithm - Stochastic Gradient Descent (SGD) +- the loss function - categorical cross-entropy +- the evaluation metrics - Mean IoU and categorical accuracy + +Semantic segmentation evaluation metrics: + +Mean Intersection over Union (MeanIoU): +MeanIoU measures how well a semantic segmentation model accurately identifies +and delineates different objects or regions in an image. It calculates the +overlap between predicted and actual object boundaries, providing a score +between 0 and 1, where 1 represents a perfect match. + +Categorical Accuracy: +Categorical Accuracy measures the proportion of correctly classified pixels in +an image. It gives a simple percentage indicating how accurately the model +predicts the categories of pixels in the entire image. + +In essence, MeanIoU emphasizes the accuracy of identifying specific object +boundaries, while Categorical Accuracy gives a broad overview of overall +pixel-level correctness. +""" + +model.compile( + optimizer=keras.optimizers.SGD( + learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0 + ), + loss=keras.losses.CategoricalCrossentropy(from_logits=False), + metrics=[ + keras.metrics.MeanIoU( + num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False + ), + keras.metrics.CategoricalAccuracy(), + ], +) + +model.summary() + +""" +The utility function `dict_to_tuple` effectively transforms the dictionaries of +training and validation datasets into tuples of images and one-hot encoded +segmentation masks, which is used during training and evaluation of the +DeepLabv3+ model. +""" + + +def dict_to_tuple(x): + + return x["images"], tf.one_hot( + tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21 + ) + + +train_ds = train_ds.map(dict_to_tuple) +eval_ds = eval_ds.map(dict_to_tuple) + +model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS) + +""" +## Predictions with trained model +Now that the model training of DeepLabv3+ has completed, let's test it by making +predications +on a few sample images. +Note: For demonstration purpose the model has been trained on only 1 epoch, for +better accuracy and result train with more number of epochs. +""" + +test_ds = load(split="sbd_eval") +test_ds = preprocess_inputs(test_ds) + +images, masks = next(iter(train_ds.take(1))) +images = ops.convert_to_tensor(images) +masks = ops.convert_to_tensor(masks) +preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1) +masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1) + +plot_images_masks(images, masks, preds) + +""" +Here are some additional tips for using the KerasHub DeepLabv3 model: + +- The model can be trained on a variety of datasets, including the COCO dataset, the +PASCAL VOC dataset, and the Cityscapes dataset. +- The model can be fine-tuned on a custom dataset to improve its performance on a +specific task. +- The model can be used to perform real-time inference on images. +- Also, check out KerasHub's other segmentation models. +""" diff --git a/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md b/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md new file mode 100644 index 0000000000..19d5b2c3aa --- /dev/null +++ b/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md @@ -0,0 +1,973 @@ +# Semantic Segmentation with KerasHub + +**Authors:** [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
+**Date created:** 2024/10/11
+**Last modified:** 2024/10/11
+**Description:** DeepLabV3 training and inference with KerasHub. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/keras_hub/semantic_segmentation_deeplab_v3.py) + + + +![](https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png) + +--- +## Background +Semantic segmentation is a type of computer vision task that involves assigning a +class label such as "person", "bike", or "background" to each individual pixel +of an image, effectively dividing the image into regions that correspond to +different object classes or categories. + +![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png) + + + +KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic +segmentation. + +This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed +by Google for image semantic segmentation with KerasHub. Its architecture +combines Atrous convolutions, contextual information aggregation, and powerful +backbones to achieve accurate and detailed semantic segmentation. + +DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to +refine the segmentation results, especially along object boundaries. Both models +have achieved state-of-the-art results on a variety of image segmentation +benchmarks. + +### References +[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611) +[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) + +--- +## Setup and Imports + +Let's install the dependencies and import the necessary modules. + +To run this tutorial, you will need to install the following packages: + +* `keras-hub` +* `keras` + + +```python +!pip install -q --upgrade keras-hub +!pip install -q --upgrade keras +``` + +After installing `keras` and `keras-hub`, set the backend for `keras`. +This guide can be run with any backend (Tensorflow, JAX, PyTorch). + + +```python +import os + +os.environ["KERAS_BACKEND"] = "jax" +import keras +from keras import ops +import keras_hub +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +``` + +--- +## Perform semantic segmentation with a pretrained DeepLabv3+ model + +The highest level API in the KerasHub semantic segmentation API is the +`keras_hub.models` API. This API includes fully pretrained semantic segmentation +models, such as `keras_hub.models.DeepLabV3ImageSegmenter`. + +Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC +dataset. +Also, define the preprocessing function for the model to preprocess images and +labels. +**Note:** By default `from_preset()` method in KerasHub loads the pretrained +task weights with all the classes, 21 classes in this case. + + +```python +model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset( + "deeplab_v3_plus_resnet50_pascalvoc" +) + +image_converter = keras_hub.layers.DeepLabV3ImageConverter( + image_size=(512, 512), + interpolation="bilinear", +) +preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter) +``` + +Let us visualize the results of this pretrained model + + +```python +filepath = keras.utils.get_file( + origin="https://storage.googleapis.com/keras-cv/pictures/dog.jpeg" +) +image = keras.utils.load_img(filepath) +image = keras.utils.img_to_array(image) + +image = preprocessor(image) +image = keras.ops.expand_dims(image, axis=0) +preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1) + + +def plot_segmentation(original_image, predicted_mask): + original_image = np.squeeze(original_image, axis=0) + original_image = np.clip(original_image / 255.0, 0, 1) + predicted_mask = np.squeeze(predicted_mask, axis=0) + plt.figure(figsize=(5, 5)) + + plt.subplot(1, 2, 1) + plt.imshow(original_image) + plt.axis("off") + + plt.subplot(1, 2, 2) + plt.imshow(predicted_mask, cmap="gray") + plt.axis("off") + + plt.tight_layout() + plt.show() + + +plot_segmentation(image, preds) +``` + + +![png](/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png) + + + +--- +## Train a custom semantic segmentation model +In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3 +semantic segmentation model. This includes data loading, augmentation, training, +metric evaluation, and inference! + +--- +## Download the data + +We download Pascal VOC 2012 dataset with additional annotations provided here +[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) +and split them into train dataset `train_ds` and `eval_ds`. + + +```python +# @title helper functions +import logging +import multiprocessing +from builtins import open +import os.path +import random +import xml + +import tensorflow_datasets as tfds + +VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" + +SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz" + +# Note that this list doesn't contain the background class. In the +# classification use case, the label is 0 based (aeroplane -> 0), whereas in +# segmentation use case, the 0 is reserved for background, so aeroplane maps to +# 1. +CLASSES = [ + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] +# This is used to map between string class to index. +CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)} + +# For the mask data in the PNG file, the encoded raw pixel value need to be +# converted to the proper class index. In the following map, [0, 0, 0] will be +# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also +# note that the mask class is 1 base since class 0 is reserved for the +# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`. +VOC_PNG_COLOR_VALUE = [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], +] +# Will be populated by maybe_populate_voc_color_mapping() below. +VOC_PNG_COLOR_MAPPING = None + + +def maybe_populate_voc_color_mapping(): + """Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.""" + global VOC_PNG_COLOR_MAPPING + if VOC_PNG_COLOR_MAPPING is None: + VOC_PNG_COLOR_MAPPING = [0] * (256**3) + for i, colormap in enumerate(VOC_PNG_COLOR_VALUE): + VOC_PNG_COLOR_MAPPING[ + (colormap[0] * 256 + colormap[1]) * 256 + colormap[2] + ] = i + # There is a special mapping with [224, 224, 192] -> 255 + VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255 + VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING) + return VOC_PNG_COLOR_MAPPING + + +def parse_annotation_data(annotation_file_path): + """Parse the annotation XML file for the image. + + The annotation contains the metadata, as well as the object bounding box + information. + + """ + with open(annotation_file_path, "r") as f: + root = xml.etree.ElementTree.parse(f).getroot() + + size = root.find("size") + width = int(size.find("width").text) + height = int(size.find("height").text) + + objects = [] + for obj in root.findall("object"): + # Get object's label name. + label = CLASS_TO_INDEX[obj.find("name").text.lower()] + # Get objects' pose name. + pose = obj.find("pose").text.lower() + is_truncated = obj.find("truncated").text == "1" + is_difficult = obj.find("difficult").text == "1" + bndbox = obj.find("bndbox") + xmax = int(bndbox.find("xmax").text) + xmin = int(bndbox.find("xmin").text) + ymax = int(bndbox.find("ymax").text) + ymin = int(bndbox.find("ymin").text) + objects.append( + { + "label": label, + "pose": pose, + "bbox": [ymin, xmin, ymax, xmax], + "is_truncated": is_truncated, + "is_difficult": is_difficult, + } + ) + + return {"width": width, "height": height, "objects": objects} + + +def get_image_ids(data_dir, split): + """To get image ids from the "train", "eval" or "trainval" files of VOC data.""" + data_file_mapping = { + "train": "train.txt", + "eval": "val.txt", + "trainval": "trainval.txt", + } + with open( + os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]), + "r", + ) as f: + image_ids = f.read().splitlines() + logging.info(f"Received {len(image_ids)} images for {split} dataset.") + return image_ids + + +def get_sbd_image_ids(data_dir, split): + """To get image ids from the "sbd_train", "sbd_eval" from files of SBD data.""" + data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"} + with open( + os.path.join(data_dir, data_file_mapping[split]), + "r", + ) as f: + image_ids = f.read().splitlines() + logging.info(f"Received {len(image_ids)} images for {split} dataset.") + return image_ids + + +def parse_single_image(image_file_path): + """Creates metadata of VOC images and path.""" + data_dir, image_file_name = os.path.split(image_file_path) + data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir)) + image_id, _ = os.path.splitext(image_file_name) + class_segmentation_file_path = os.path.join( + data_dir, "SegmentationClass", image_id + ".png" + ) + object_segmentation_file_path = os.path.join( + data_dir, "SegmentationObject", image_id + ".png" + ) + annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml") + image_annotations = parse_annotation_data(annotation_file_path) + + result = { + "image/filename": image_id + ".jpg", + "image/file_path": image_file_path, + "segmentation/class/file_path": class_segmentation_file_path, + "segmentation/object/file_path": object_segmentation_file_path, + } + result.update(image_annotations) + # Labels field should be same as the 'object.label' + labels = list(set([o["label"] for o in result["objects"]])) + result["labels"] = sorted(labels) + return result + + +def parse_single_sbd_image(image_file_path): + """Creates metadata of SBD images and path.""" + data_dir, image_file_name = os.path.split(image_file_path) + data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir)) + image_id, _ = os.path.splitext(image_file_name) + class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat") + object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat") + result = { + "image/filename": image_id + ".jpg", + "image/file_path": image_file_path, + "segmentation/class/file_path": class_segmentation_file_path, + "segmentation/object/file_path": object_segmentation_file_path, + } + return result + + +def build_metadata(data_dir, image_ids): + """Transpose the metadata which convert from list of dict to dict of list.""" + # Parallel process all the images. + image_file_paths = [ + os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids + ] + pool_size = 10 if len(image_ids) > 10 else len(image_ids) + with multiprocessing.Pool(pool_size) as p: + metadata = p.map(parse_single_image, image_file_paths) + + keys = [ + "image/filename", + "image/file_path", + "segmentation/class/file_path", + "segmentation/object/file_path", + "labels", + "width", + "height", + ] + result = {} + for key in keys: + values = [value[key] for value in metadata] + result[key] = values + + # The ragged objects need some special handling + for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]: + values = [] + objects = [value["objects"] for value in metadata] + for object in objects: + values.append([o[key] for o in object]) + result["objects/" + key] = values + return result + + +def build_sbd_metadata(data_dir, image_ids): + """Transpose the metadata which convert from list of dict to dict of list.""" + # Parallel process all the images. + image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids] + pool_size = 10 if len(image_ids) > 10 else len(image_ids) + with multiprocessing.Pool(pool_size) as p: + metadata = p.map(parse_single_sbd_image, image_file_paths) + + keys = [ + "image/filename", + "image/file_path", + "segmentation/class/file_path", + "segmentation/object/file_path", + ] + result = {} + for key in keys: + values = [value[key] for value in metadata] + result[key] = values + return result + + +def decode_png_mask(mask): + """Decode the raw PNG image and convert it to 2D tensor with probably + class.""" + # Cast the mask to int32 since the original uint8 will overflow when + # multiplied with 256 + mask = tf.cast(mask, tf.int32) + mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2] + mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1) + mask = tf.cast(mask, tf.uint8) + return mask + + +def load_images(example): + """Loads VOC images for segmentation task from the provided paths""" + image_file_path = example.pop("image/file_path") + segmentation_class_file_path = example.pop("segmentation/class/file_path") + segmentation_object_file_path = example.pop("segmentation/object/file_path") + image = tf.io.read_file(image_file_path) + image = tf.image.decode_jpeg(image) + + segmentation_class_mask = tf.io.read_file(segmentation_class_file_path) + segmentation_class_mask = tf.image.decode_png(segmentation_class_mask) + segmentation_class_mask = decode_png_mask(segmentation_class_mask) + + segmentation_object_mask = tf.io.read_file(segmentation_object_file_path) + segmentation_object_mask = tf.image.decode_png(segmentation_object_mask) + segmentation_object_mask = decode_png_mask(segmentation_object_mask) + + example.update( + { + "image": image, + "class_segmentation": segmentation_class_mask, + "object_segmentation": segmentation_object_mask, + } + ) + return example + + +def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path): + """Loads SBD images for segmentation task from the provided paths""" + image = tf.io.read_file(image_file_path) + image = tf.image.decode_jpeg(image) + + segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path) + segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0] + segmentation_class_mask = segmentation_class_mask[..., np.newaxis] + + segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat( + seg_obj_file_path + ) + segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0] + segmentation_object_mask = segmentation_object_mask[..., np.newaxis] + + return { + "image": image, + "class_segmentation": segmentation_class_mask, + "object_segmentation": segmentation_object_mask, + } + + +def build_dataset_from_metadata(metadata): + """Builds TensorFlow dataset from the image metadata of VOC dataset.""" + # The objects need some manual conversion to ragged tensor. + metadata["labels"] = tf.ragged.constant(metadata["labels"]) + metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"]) + metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"]) + metadata["objects/is_truncated"] = tf.ragged.constant( + metadata["objects/is_truncated"] + ) + metadata["objects/is_difficult"] = tf.ragged.constant( + metadata["objects/is_difficult"] + ) + metadata["objects/bbox"] = tf.ragged.constant( + metadata["objects/bbox"], ragged_rank=1 + ) + + dataset = tf.data.Dataset.from_tensor_slices(metadata) + dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE) + return dataset + + +def build_sbd_dataset_from_metadata(metadata): + """Builds TensorFlow dataset from the image metadata of SBD dataset.""" + img_filepath = metadata["image/file_path"] + cls_filepath = metadata["segmentation/class/file_path"] + obj_filepath = metadata["segmentation/object/file_path"] + + def md_gen(): + c = list(zip(img_filepath, cls_filepath, obj_filepath)) + # random shuffling for each generator boosts up the quality. + random.shuffle(c) + for fp in c: + img_fp, cls_fp, obj_fp = fp + yield load_sbd_images(img_fp, cls_fp, obj_fp) + + dataset = tf.data.Dataset.from_generator( + md_gen, + output_signature=( + { + "image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8), + "class_segmentation": tf.TensorSpec( + shape=(None, None, 1), dtype=tf.uint8 + ), + "object_segmentation": tf.TensorSpec( + shape=(None, None, 1), dtype=tf.uint8 + ), + } + ), + ) + + return dataset + + +def load( + split="sbd_train", + data_dir=None, +): + """Load the Pacal VOC 2012 dataset. + + This function will download the data tar file from remote if needed, and + untar to the local `data_dir`, and build dataset from it. + + It supports both VOC2012 and Semantic Boundaries Dataset (SBD). + + The returned segmentation masks will be int ranging from [0, num_classes), + as well as 255 which is the boundary mask. + + Args: + split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or + 'sbd_eval'. 'sbd_train' represents the training dataset for SBD + dataset, while 'train' represents the training dataset for VOC2012 + dataset. Defaults to `sbd_train`. + data_dir: string, local directory path for the loaded data. This will be + used to download the data file, and unzip. It will be used as a + cache directory. Defaults to None, and `~/.keras/pascal_voc_2012` + will be used. + """ + supported_split_value = [ + "train", + "eval", + "trainval", + "sbd_train", + "sbd_eval", + ] + if split not in supported_split_value: + raise ValueError( + f"The support value for `split` are {supported_split_value}. " + f"Got: {split}" + ) + + if data_dir is not None: + data_dir = os.path.expanduser(data_dir) + + if "sbd" in split: + return load_sbd(split, data_dir) + else: + return load_voc(split, data_dir) + + +def load_voc( + split="train", + data_dir=None, +): + """This function will download VOC data from a URL. If the data is already + present in the cache directory, it will load the data from that directory + instead. + """ + extracted_dir = os.path.join("VOCdevkit", "VOC2012") + get_data = keras.utils.get_file( + fname=os.path.basename(VOC_URL), + origin=VOC_URL, + cache_dir=data_dir, + extract=True, + ) + data_dir = os.path.join(os.path.dirname(get_data), extracted_dir) + image_ids = get_image_ids(data_dir, split) + # len(metadata) = #samples, metadata[i] is a dict. + metadata = build_metadata(data_dir, image_ids) + maybe_populate_voc_color_mapping() + dataset = build_dataset_from_metadata(metadata) + + return dataset + + +def load_sbd( + split="sbd_train", + data_dir=None, +): + """This function will download SBD data from a URL. If the data is already + present in the cache directory, it will load the data from that directory + instead. + """ + extracted_dir = os.path.join("benchmark_RELEASE", "dataset") + get_data = keras.utils.get_file( + fname=os.path.basename(SBD_URL), + origin=SBD_URL, + cache_dir=data_dir, + extract=True, + ) + data_dir = os.path.join(os.path.dirname(get_data), extracted_dir) + image_ids = get_sbd_image_ids(data_dir, split) + # len(metadata) = #samples, metadata[i] is a dict. + metadata = build_sbd_metadata(data_dir, image_ids) + + dataset = build_sbd_dataset_from_metadata(metadata) + return dataset + +``` + +--- +## Load the dataset + +For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also +choose any of these datasets for the `load` function: 'train', 'eval', 'trainval', +'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the +SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset. + + +```python +train_ds = load(split="sbd_train", data_dir="segmentation") +eval_ds = load(split="sbd_eval", data_dir="segmentation") +``` + +--- +## Preprocess the data + +The preprocess_inputs utility function preprocesses inputs, converting them into +a dictionary containing images and segmentation_masks. Both images and +segmentation masks are resized to 512x512. The resulting dataset is then batched +into groups of four image and segmentation mask pairs. + + +```python + +def preprocess_inputs(inputs): + def unpackage_inputs(inputs): + return { + "images": inputs["image"], + "segmentation_masks": inputs["class_segmentation"], + } + + outputs = inputs.map(unpackage_inputs) + outputs = outputs.map(keras.layers.Resizing(height=512, width=512)) + outputs = outputs.batch(4, drop_remainder=True) + return outputs + + +train_ds = preprocess_inputs(train_ds) +batch = train_ds.take(1).get_single_element() +``` + +A batch of this preprocessed input training data can be visualized using the +`plot_images_masks` function. This function takes a batch of images and +segmentation masks and prediction masks as input and displays them in a grid. + + +```python + +def plot_images_masks(images, masks, pred_masks=None): + images = (images - np.min(images)) / (np.max(images) - np.min(images)) + masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks)) + if pred_masks is not None: + pred_masks = (pred_masks - pred_masks.min()) / ( + pred_masks.max() - pred_masks.min() + ) + num_images = len(images) + plt.figure(figsize=(8, 4)) + rows = 3 if pred_masks is not None else 2 + + for i in range(num_images): + plt.subplot(rows, num_images, i + 1) + plt.imshow(images[i]) + plt.axis("off") + + plt.subplot(rows, num_images, num_images + i + 1) + plt.imshow(masks[i], cmap="gray") + plt.axis("off") + + if pred_masks is not None: + plt.subplot(rows, num_images, i + 1 + 2 * num_images) + plt.imshow(pred_masks[i, ..., 0], cmap="gray") + plt.axis("off") + + plt.show() + + +plot_images_masks(batch["images"], batch["segmentation_masks"]) +``` + + + +![png](/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png) + + + +The preprocessing is applied to the evaluation dataset `eval_ds`. + + +```python +eval_ds = preprocess_inputs(eval_ds) +``` + +--- +## Data Augmentation + +Keras provides a variety of image augmentation options. In this example, we will +use the `RandomFlip` augmentation to augment the training dataset. The +`RandomFlip` augmentation randomly flips the images in the training dataset +horizontally or vertically. This can help to improve the model's robustness to +changes in the orientation of the objects in the images. + + +```python +train_ds = train_ds.map(keras.layers.RandomFlip()) +batch = train_ds.take(1).get_single_element() + +plot_images_masks(batch["images"], batch["segmentation_masks"]) +``` + + + +![png](/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png) + + + +--- +## Model Configuration + +Please feel free to modify the configurations for model training and note how the +training results changes. This is an great exercise to get a better +understanding of the training pipeline. + +The learning rate schedule is used by the optimizer to calculate the learning +rate for each epoch. The optimizer then uses the learning rate to update the +weights of the model. +In this case, the learning rate schedule uses a cosine decay function. A cosine +decay function starts high and then decreases over time, eventually reaching +zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The +dataset cardinality is important for learning rate decay because it determines +how many steps the model will train for. The initial learning rate is +proportional to 0.007 and the decay steps are 2124. This means that the learning +rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps. +![png](/guides/img/semantic_segmentation_deeplab_v3_plus/learning_rate_schedule.png) + + +```python +BATCH_SIZE = 4 +INITIAL_LR = 0.007 * BATCH_SIZE / 16 +EPOCHS = 1 +NUM_CLASSES = 21 +learning_rate = keras.optimizers.schedules.CosineDecay( + INITIAL_LR, + decay_steps=EPOCHS * 2124, +) +``` + +Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for +the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with +additional decoder block. +For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing +`low_level_feature_key` as `P2` a pyramid level output to extract features from +`resnet_50_imagenet` which acts as a decoder block. +To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key` +which defaults to `None`. + +Then we create DeepLabV3ImageSegmenter instance. +The `num_classes` parameter specifies the number of classes that the model will +be trained to segment. `preprocessor` argument to apply preprocessing to image +input and masks. + + +```python +image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet") + +deeplab_backbone = keras_hub.models.DeepLabV3Backbone( + image_encoder=image_encoder, + low_level_feature_key="P2", + spatial_pyramid_pooling_key="P5", + dilation_rates=[6, 12, 18], + upsampling_size=8, +) + +model = keras_hub.models.DeepLabV3ImageSegmenter( + backbone=deeplab_backbone, + num_classes=21, + activation="softmax", + preprocessor=preprocessor, +) +``` + +--- +## Compile the model + +The model.compile() function sets up the training process for the model. It defines the +- optimization algorithm - Stochastic Gradient Descent (SGD) +- the loss function - categorical cross-entropy +- the evaluation metrics - Mean IoU and categorical accuracy + +Semantic segmentation evaluation metrics: + +Mean Intersection over Union (MeanIoU): +MeanIoU measures how well a semantic segmentation model accurately identifies +and delineates different objects or regions in an image. It calculates the +overlap between predicted and actual object boundaries, providing a score +between 0 and 1, where 1 represents a perfect match. + +Categorical Accuracy: +Categorical Accuracy measures the proportion of correctly classified pixels in +an image. It gives a simple percentage indicating how accurately the model +predicts the categories of pixels in the entire image. + +In essence, MeanIoU emphasizes the accuracy of identifying specific object +boundaries, while Categorical Accuracy gives a broad overview of overall +pixel-level correctness. + + +```python +model.compile( + optimizer=keras.optimizers.SGD( + learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0 + ), + loss=keras.losses.CategoricalCrossentropy(from_logits=False), + metrics=[ + keras.metrics.MeanIoU( + num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False + ), + keras.metrics.CategoricalAccuracy(), + ], +) + +model.summary() +``` + + +
Preprocessor: "deep_lab_v3_image_segmenter_preprocessor"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)                                                                                     Config ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ deep_lab_v3_image_converter (DeepLabV3ImageConverter)         │                   Image size: (512, 512) │
+└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘
+
+ + + + +
Model: "deep_lab_v3_image_segmenter"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)                                   Output Shape                                    Param # ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
+│ inputs (InputLayer)                           │ (None, None, None, 3)              │                   0 │
+├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤
+│ deep_lab_v3_backbone (DeepLabV3Backbone)      │ (None, None, None, 256)            │          39,190,656 │
+├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤
+│ segmentation_output (Conv2D)                  │ (None, None, None, 21)             │               5,376 │
+└───────────────────────────────────────────────┴────────────────────────────────────┴─────────────────────┘
+
+ + + + +
 Total params: 39,196,032 (149.52 MB)
+
+ + + + +
 Trainable params: 39,139,232 (149.30 MB)
+
+ + + + +
 Non-trainable params: 56,800 (221.88 KB)
+
+ + + +The utility function `dict_to_tuple` effectively transforms the dictionaries of +training and validation datasets into tuples of images and one-hot encoded +segmentation masks, which is used during training and evaluation of the +DeepLabv3+ model. + + +```python + +def dict_to_tuple(x): + + return x["images"], tf.one_hot( + tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21 + ) + + +train_ds = train_ds.map(dict_to_tuple) +eval_ds = eval_ds.map(dict_to_tuple) + +model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS) +``` + + +
+``` + 1/Unknown 40s 40s/step - categorical_accuracy: 0.0494 - loss: 3.4081 - mean_io_u: 0.0112 +``` +
+ 2124/2124 ━━━━━━━━━━━━━━━━━━━━ 279s 113ms/step - categorical_accuracy: 0.7188 - loss: 1.1003 - mean_io_u: 0.0934 - val_categorical_accuracy: 0.8222 - val_loss: 0.5761 - val_mean_io_u: 0.3481 + + + + + +
+``` + + +``` +
+--- +## Predictions with trained model +Now that the model training of DeepLabv3+ has completed, let's test it by making +predications +on a few sample images. +Note: For demonstration purpose the model has been trained on only 1 epoch, for +better accuracy and result train with more number of epochs. + + +```python +test_ds = load(split="sbd_eval") +test_ds = preprocess_inputs(test_ds) + +images, masks = next(iter(train_ds.take(1))) +images = ops.convert_to_tensor(images) +masks = ops.convert_to_tensor(masks) +preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1) +masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1) + +plot_images_masks(images, masks, preds) +``` + +![png](/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png) + + + +Here are some additional tips for using the KerasHub DeepLabv3 model: + +- The model can be trained on a variety of datasets, including the COCO dataset, the +PASCAL VOC dataset, and the Cityscapes dataset. +- The model can be fine-tuned on a custom dataset to improve its performance on a +specific task. +- The model can be used to perform real-time inference on images. +- Also, check out KerasHub's other segmentation models.