diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png
new file mode 100644
index 0000000000..3e4559d82d
Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png differ
diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png
new file mode 100644
index 0000000000..d5dbce9a38
Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png differ
diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png
new file mode 100644
index 0000000000..b8698888be
Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png differ
diff --git a/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png
new file mode 100644
index 0000000000..7d60426bc3
Binary files /dev/null and b/guides/img/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_1.png differ
diff --git a/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb b/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb
new file mode 100644
index 0000000000..15ca8dc34c
--- /dev/null
+++ b/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb
@@ -0,0 +1,1120 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Semantic Segmentation with KerasHub\n",
+ "\n",
+ "**Authors:** [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
\n",
+ "**Date created:** 2024/10/11
\n",
+ "**Last modified:** 2024/10/11
\n",
+ "**Description:** DeepLabV3 training and inference with KerasHub."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "## Background\n",
+ "Semantic segmentation is a type of computer vision task that involves assigning a\n",
+ "class label such as \"person\", \"bike\", or \"background\" to each individual pixel\n",
+ "of an image, effectively dividing the image into regions that correspond to\n",
+ "different object classes or categories.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic\n",
+ "segmentation.\n",
+ "\n",
+ "This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed\n",
+ "by Google for image semantic segmentation with KerasHub. Its architecture\n",
+ "combines Atrous convolutions, contextual information aggregation, and powerful\n",
+ "backbones to achieve accurate and detailed semantic segmentation.\n",
+ "\n",
+ "DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to\n",
+ "refine the segmentation results, especially along object boundaries. Both models\n",
+ "have achieved state-of-the-art results on a variety of image segmentation\n",
+ "benchmarks.\n",
+ "\n",
+ "### References\n",
+ "[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)\n",
+ "[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup and Imports\n",
+ "\n",
+ "Let's install the dependencies and import the necessary modules.\n",
+ "\n",
+ "To run this tutorial, you will need to install the following packages:\n",
+ "\n",
+ "* `keras-hub`\n",
+ "* `keras`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q --upgrade keras-hub\n",
+ "!pip install -q --upgrade keras"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "After installing `keras` and `keras-hub`, set the backend for `keras`.\n",
+ "This guide can be run with any backend (Tensorflow, JAX, PyTorch)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
+ "import keras\n",
+ "from keras import ops\n",
+ "import keras_hub\n",
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Perform semantic segmentation with a pretrained DeepLabv3+ model\n",
+ "\n",
+ "The highest level API in the KerasHub semantic segmentation API is the\n",
+ "`keras_hub.models` API. This API includes fully pretrained semantic segmentation\n",
+ "models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.\n",
+ "\n",
+ "Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC\n",
+ "dataset.\n",
+ "Also, define the preprocessing function for the model to preprocess images and\n",
+ "labels.\n",
+ "**Note:** By default `from_preset()` method in KerasHub loads the pretrained\n",
+ "task weights with all the classes, 21 classes in this case."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(\n",
+ " \"deeplab_v3_plus_resnet50_pascalvoc\"\n",
+ ")\n",
+ "\n",
+ "image_converter = keras_hub.layers.DeepLabV3ImageConverter(\n",
+ " image_size=(512, 512),\n",
+ " interpolation=\"bilinear\",\n",
+ ")\n",
+ "preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let us visualize the results of this pretrained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "filepath = keras.utils.get_file(\n",
+ " origin=\"https://storage.googleapis.com/keras-cv/pictures/dog.jpeg\"\n",
+ ")\n",
+ "image = keras.utils.load_img(filepath)\n",
+ "image = keras.utils.img_to_array(image)\n",
+ "\n",
+ "image = preprocessor(image)\n",
+ "image = keras.ops.expand_dims(image, axis=0)\n",
+ "preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)\n",
+ "\n",
+ "\n",
+ "def plot_segmentation(original_image, predicted_mask):\n",
+ " original_image = np.squeeze(original_image, axis=0)\n",
+ " original_image = np.clip(original_image / 255.0, 0, 1)\n",
+ " predicted_mask = np.squeeze(predicted_mask, axis=0)\n",
+ " plt.figure(figsize=(5, 5))\n",
+ "\n",
+ " plt.subplot(1, 2, 1)\n",
+ " plt.imshow(original_image)\n",
+ " plt.axis(\"off\")\n",
+ "\n",
+ " plt.subplot(1, 2, 2)\n",
+ " plt.imshow(predicted_mask, cmap=\"gray\")\n",
+ " plt.axis(\"off\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "plot_segmentation(image, preds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Train a custom semantic segmentation model\n",
+ "In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3\n",
+ "semantic segmentation model. This includes data loading, augmentation, training,\n",
+ "metric evaluation, and inference!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Download the data\n",
+ "\n",
+ "We download Pascal VOC 2012 dataset with additional annotations provided here\n",
+ "[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)\n",
+ "and split them into train dataset `train_ds` and `eval_ds`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# @title helper functions\n",
+ "import logging\n",
+ "import multiprocessing\n",
+ "from builtins import open\n",
+ "import os.path\n",
+ "import random\n",
+ "import xml\n",
+ "\n",
+ "import tensorflow_datasets as tfds\n",
+ "\n",
+ "VOC_URL = \"https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar\"\n",
+ "\n",
+ "SBD_URL = \"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz\"\n",
+ "\n",
+ "# Note that this list doesn't contain the background class. In the\n",
+ "# classification use case, the label is 0 based (aeroplane -> 0), whereas in\n",
+ "# segmentation use case, the 0 is reserved for background, so aeroplane maps to\n",
+ "# 1.\n",
+ "CLASSES = [\n",
+ " \"aeroplane\",\n",
+ " \"bicycle\",\n",
+ " \"bird\",\n",
+ " \"boat\",\n",
+ " \"bottle\",\n",
+ " \"bus\",\n",
+ " \"car\",\n",
+ " \"cat\",\n",
+ " \"chair\",\n",
+ " \"cow\",\n",
+ " \"diningtable\",\n",
+ " \"dog\",\n",
+ " \"horse\",\n",
+ " \"motorbike\",\n",
+ " \"person\",\n",
+ " \"pottedplant\",\n",
+ " \"sheep\",\n",
+ " \"sofa\",\n",
+ " \"train\",\n",
+ " \"tvmonitor\",\n",
+ "]\n",
+ "# This is used to map between string class to index.\n",
+ "CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}\n",
+ "\n",
+ "# For the mask data in the PNG file, the encoded raw pixel value need to be\n",
+ "# converted to the proper class index. In the following map, [0, 0, 0] will be\n",
+ "# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also\n",
+ "# note that the mask class is 1 base since class 0 is reserved for the\n",
+ "# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.\n",
+ "VOC_PNG_COLOR_VALUE = [\n",
+ " [0, 0, 0],\n",
+ " [128, 0, 0],\n",
+ " [0, 128, 0],\n",
+ " [128, 128, 0],\n",
+ " [0, 0, 128],\n",
+ " [128, 0, 128],\n",
+ " [0, 128, 128],\n",
+ " [128, 128, 128],\n",
+ " [64, 0, 0],\n",
+ " [192, 0, 0],\n",
+ " [64, 128, 0],\n",
+ " [192, 128, 0],\n",
+ " [64, 0, 128],\n",
+ " [192, 0, 128],\n",
+ " [64, 128, 128],\n",
+ " [192, 128, 128],\n",
+ " [0, 64, 0],\n",
+ " [128, 64, 0],\n",
+ " [0, 192, 0],\n",
+ " [128, 192, 0],\n",
+ " [0, 64, 128],\n",
+ "]\n",
+ "# Will be populated by maybe_populate_voc_color_mapping() below.\n",
+ "VOC_PNG_COLOR_MAPPING = None\n",
+ "\n",
+ "\n",
+ "def maybe_populate_voc_color_mapping():\n",
+ " \"\"\"Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.\"\"\"\n",
+ " global VOC_PNG_COLOR_MAPPING\n",
+ " if VOC_PNG_COLOR_MAPPING is None:\n",
+ " VOC_PNG_COLOR_MAPPING = [0] * (256**3)\n",
+ " for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):\n",
+ " VOC_PNG_COLOR_MAPPING[\n",
+ " (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]\n",
+ " ] = i\n",
+ " # There is a special mapping with [224, 224, 192] -> 255\n",
+ " VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255\n",
+ " VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)\n",
+ " return VOC_PNG_COLOR_MAPPING\n",
+ "\n",
+ "\n",
+ "def parse_annotation_data(annotation_file_path):\n",
+ " \"\"\"Parse the annotation XML file for the image.\n",
+ "\n",
+ " The annotation contains the metadata, as well as the object bounding box\n",
+ " information.\n",
+ "\n",
+ " \"\"\"\n",
+ " with open(annotation_file_path, \"r\") as f:\n",
+ " root = xml.etree.ElementTree.parse(f).getroot()\n",
+ "\n",
+ " size = root.find(\"size\")\n",
+ " width = int(size.find(\"width\").text)\n",
+ " height = int(size.find(\"height\").text)\n",
+ "\n",
+ " objects = []\n",
+ " for obj in root.findall(\"object\"):\n",
+ " # Get object's label name.\n",
+ " label = CLASS_TO_INDEX[obj.find(\"name\").text.lower()]\n",
+ " # Get objects' pose name.\n",
+ " pose = obj.find(\"pose\").text.lower()\n",
+ " is_truncated = obj.find(\"truncated\").text == \"1\"\n",
+ " is_difficult = obj.find(\"difficult\").text == \"1\"\n",
+ " bndbox = obj.find(\"bndbox\")\n",
+ " xmax = int(bndbox.find(\"xmax\").text)\n",
+ " xmin = int(bndbox.find(\"xmin\").text)\n",
+ " ymax = int(bndbox.find(\"ymax\").text)\n",
+ " ymin = int(bndbox.find(\"ymin\").text)\n",
+ " objects.append(\n",
+ " {\n",
+ " \"label\": label,\n",
+ " \"pose\": pose,\n",
+ " \"bbox\": [ymin, xmin, ymax, xmax],\n",
+ " \"is_truncated\": is_truncated,\n",
+ " \"is_difficult\": is_difficult,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " return {\"width\": width, \"height\": height, \"objects\": objects}\n",
+ "\n",
+ "\n",
+ "def get_image_ids(data_dir, split):\n",
+ " \"\"\"To get image ids from the \"train\", \"eval\" or \"trainval\" files of VOC data.\"\"\"\n",
+ " data_file_mapping = {\n",
+ " \"train\": \"train.txt\",\n",
+ " \"eval\": \"val.txt\",\n",
+ " \"trainval\": \"trainval.txt\",\n",
+ " }\n",
+ " with open(\n",
+ " os.path.join(data_dir, \"ImageSets\", \"Segmentation\", data_file_mapping[split]),\n",
+ " \"r\",\n",
+ " ) as f:\n",
+ " image_ids = f.read().splitlines()\n",
+ " logging.info(f\"Received {len(image_ids)} images for {split} dataset.\")\n",
+ " return image_ids\n",
+ "\n",
+ "\n",
+ "def get_sbd_image_ids(data_dir, split):\n",
+ " \"\"\"To get image ids from the \"sbd_train\", \"sbd_eval\" from files of SBD data.\"\"\"\n",
+ " data_file_mapping = {\"sbd_train\": \"train.txt\", \"sbd_eval\": \"val.txt\"}\n",
+ " with open(\n",
+ " os.path.join(data_dir, data_file_mapping[split]),\n",
+ " \"r\",\n",
+ " ) as f:\n",
+ " image_ids = f.read().splitlines()\n",
+ " logging.info(f\"Received {len(image_ids)} images for {split} dataset.\")\n",
+ " return image_ids\n",
+ "\n",
+ "\n",
+ "def parse_single_image(image_file_path):\n",
+ " \"\"\"Creates metadata of VOC images and path.\"\"\"\n",
+ " data_dir, image_file_name = os.path.split(image_file_path)\n",
+ " data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))\n",
+ " image_id, _ = os.path.splitext(image_file_name)\n",
+ " class_segmentation_file_path = os.path.join(\n",
+ " data_dir, \"SegmentationClass\", image_id + \".png\"\n",
+ " )\n",
+ " object_segmentation_file_path = os.path.join(\n",
+ " data_dir, \"SegmentationObject\", image_id + \".png\"\n",
+ " )\n",
+ " annotation_file_path = os.path.join(data_dir, \"Annotations\", image_id + \".xml\")\n",
+ " image_annotations = parse_annotation_data(annotation_file_path)\n",
+ "\n",
+ " result = {\n",
+ " \"image/filename\": image_id + \".jpg\",\n",
+ " \"image/file_path\": image_file_path,\n",
+ " \"segmentation/class/file_path\": class_segmentation_file_path,\n",
+ " \"segmentation/object/file_path\": object_segmentation_file_path,\n",
+ " }\n",
+ " result.update(image_annotations)\n",
+ " # Labels field should be same as the 'object.label'\n",
+ " labels = list(set([o[\"label\"] for o in result[\"objects\"]]))\n",
+ " result[\"labels\"] = sorted(labels)\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def parse_single_sbd_image(image_file_path):\n",
+ " \"\"\"Creates metadata of SBD images and path.\"\"\"\n",
+ " data_dir, image_file_name = os.path.split(image_file_path)\n",
+ " data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))\n",
+ " image_id, _ = os.path.splitext(image_file_name)\n",
+ " class_segmentation_file_path = os.path.join(data_dir, \"cls\", image_id + \".mat\")\n",
+ " object_segmentation_file_path = os.path.join(data_dir, \"inst\", image_id + \".mat\")\n",
+ " result = {\n",
+ " \"image/filename\": image_id + \".jpg\",\n",
+ " \"image/file_path\": image_file_path,\n",
+ " \"segmentation/class/file_path\": class_segmentation_file_path,\n",
+ " \"segmentation/object/file_path\": object_segmentation_file_path,\n",
+ " }\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def build_metadata(data_dir, image_ids):\n",
+ " \"\"\"Transpose the metadata which convert from list of dict to dict of list.\"\"\"\n",
+ " # Parallel process all the images.\n",
+ " image_file_paths = [\n",
+ " os.path.join(data_dir, \"JPEGImages\", i + \".jpg\") for i in image_ids\n",
+ " ]\n",
+ " pool_size = 10 if len(image_ids) > 10 else len(image_ids)\n",
+ " with multiprocessing.Pool(pool_size) as p:\n",
+ " metadata = p.map(parse_single_image, image_file_paths)\n",
+ "\n",
+ " keys = [\n",
+ " \"image/filename\",\n",
+ " \"image/file_path\",\n",
+ " \"segmentation/class/file_path\",\n",
+ " \"segmentation/object/file_path\",\n",
+ " \"labels\",\n",
+ " \"width\",\n",
+ " \"height\",\n",
+ " ]\n",
+ " result = {}\n",
+ " for key in keys:\n",
+ " values = [value[key] for value in metadata]\n",
+ " result[key] = values\n",
+ "\n",
+ " # The ragged objects need some special handling\n",
+ " for key in [\"label\", \"pose\", \"bbox\", \"is_truncated\", \"is_difficult\"]:\n",
+ " values = []\n",
+ " objects = [value[\"objects\"] for value in metadata]\n",
+ " for object in objects:\n",
+ " values.append([o[key] for o in object])\n",
+ " result[\"objects/\" + key] = values\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def build_sbd_metadata(data_dir, image_ids):\n",
+ " \"\"\"Transpose the metadata which convert from list of dict to dict of list.\"\"\"\n",
+ " # Parallel process all the images.\n",
+ " image_file_paths = [os.path.join(data_dir, \"img\", i + \".jpg\") for i in image_ids]\n",
+ " pool_size = 10 if len(image_ids) > 10 else len(image_ids)\n",
+ " with multiprocessing.Pool(pool_size) as p:\n",
+ " metadata = p.map(parse_single_sbd_image, image_file_paths)\n",
+ "\n",
+ " keys = [\n",
+ " \"image/filename\",\n",
+ " \"image/file_path\",\n",
+ " \"segmentation/class/file_path\",\n",
+ " \"segmentation/object/file_path\",\n",
+ " ]\n",
+ " result = {}\n",
+ " for key in keys:\n",
+ " values = [value[key] for value in metadata]\n",
+ " result[key] = values\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def decode_png_mask(mask):\n",
+ " \"\"\"Decode the raw PNG image and convert it to 2D tensor with probably\n",
+ " class.\"\"\"\n",
+ " # Cast the mask to int32 since the original uint8 will overflow when\n",
+ " # multiplied with 256\n",
+ " mask = tf.cast(mask, tf.int32)\n",
+ " mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]\n",
+ " mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)\n",
+ " mask = tf.cast(mask, tf.uint8)\n",
+ " return mask\n",
+ "\n",
+ "\n",
+ "def load_images(example):\n",
+ " \"\"\"Loads VOC images for segmentation task from the provided paths\"\"\"\n",
+ " image_file_path = example.pop(\"image/file_path\")\n",
+ " segmentation_class_file_path = example.pop(\"segmentation/class/file_path\")\n",
+ " segmentation_object_file_path = example.pop(\"segmentation/object/file_path\")\n",
+ " image = tf.io.read_file(image_file_path)\n",
+ " image = tf.image.decode_jpeg(image)\n",
+ "\n",
+ " segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)\n",
+ " segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)\n",
+ " segmentation_class_mask = decode_png_mask(segmentation_class_mask)\n",
+ "\n",
+ " segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)\n",
+ " segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)\n",
+ " segmentation_object_mask = decode_png_mask(segmentation_object_mask)\n",
+ "\n",
+ " example.update(\n",
+ " {\n",
+ " \"image\": image,\n",
+ " \"class_segmentation\": segmentation_class_mask,\n",
+ " \"object_segmentation\": segmentation_object_mask,\n",
+ " }\n",
+ " )\n",
+ " return example\n",
+ "\n",
+ "\n",
+ "def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):\n",
+ " \"\"\"Loads SBD images for segmentation task from the provided paths\"\"\"\n",
+ " image = tf.io.read_file(image_file_path)\n",
+ " image = tf.image.decode_jpeg(image)\n",
+ "\n",
+ " segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)\n",
+ " segmentation_class_mask = segmentation_class_mask[\"GTcls\"][\"Segmentation\"][0][0]\n",
+ " segmentation_class_mask = segmentation_class_mask[..., np.newaxis]\n",
+ "\n",
+ " segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(\n",
+ " seg_obj_file_path\n",
+ " )\n",
+ " segmentation_object_mask = segmentation_object_mask[\"GTinst\"][\"Segmentation\"][0][0]\n",
+ " segmentation_object_mask = segmentation_object_mask[..., np.newaxis]\n",
+ "\n",
+ " return {\n",
+ " \"image\": image,\n",
+ " \"class_segmentation\": segmentation_class_mask,\n",
+ " \"object_segmentation\": segmentation_object_mask,\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def build_dataset_from_metadata(metadata):\n",
+ " \"\"\"Builds TensorFlow dataset from the image metadata of VOC dataset.\"\"\"\n",
+ " # The objects need some manual conversion to ragged tensor.\n",
+ " metadata[\"labels\"] = tf.ragged.constant(metadata[\"labels\"])\n",
+ " metadata[\"objects/label\"] = tf.ragged.constant(metadata[\"objects/label\"])\n",
+ " metadata[\"objects/pose\"] = tf.ragged.constant(metadata[\"objects/pose\"])\n",
+ " metadata[\"objects/is_truncated\"] = tf.ragged.constant(\n",
+ " metadata[\"objects/is_truncated\"]\n",
+ " )\n",
+ " metadata[\"objects/is_difficult\"] = tf.ragged.constant(\n",
+ " metadata[\"objects/is_difficult\"]\n",
+ " )\n",
+ " metadata[\"objects/bbox\"] = tf.ragged.constant(\n",
+ " metadata[\"objects/bbox\"], ragged_rank=1\n",
+ " )\n",
+ "\n",
+ " dataset = tf.data.Dataset.from_tensor_slices(metadata)\n",
+ " dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)\n",
+ " return dataset\n",
+ "\n",
+ "\n",
+ "def build_sbd_dataset_from_metadata(metadata):\n",
+ " \"\"\"Builds TensorFlow dataset from the image metadata of SBD dataset.\"\"\"\n",
+ " img_filepath = metadata[\"image/file_path\"]\n",
+ " cls_filepath = metadata[\"segmentation/class/file_path\"]\n",
+ " obj_filepath = metadata[\"segmentation/object/file_path\"]\n",
+ "\n",
+ " def md_gen():\n",
+ " c = list(zip(img_filepath, cls_filepath, obj_filepath))\n",
+ " # random shuffling for each generator boosts up the quality.\n",
+ " random.shuffle(c)\n",
+ " for fp in c:\n",
+ " img_fp, cls_fp, obj_fp = fp\n",
+ " yield load_sbd_images(img_fp, cls_fp, obj_fp)\n",
+ "\n",
+ " dataset = tf.data.Dataset.from_generator(\n",
+ " md_gen,\n",
+ " output_signature=(\n",
+ " {\n",
+ " \"image\": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),\n",
+ " \"class_segmentation\": tf.TensorSpec(\n",
+ " shape=(None, None, 1), dtype=tf.uint8\n",
+ " ),\n",
+ " \"object_segmentation\": tf.TensorSpec(\n",
+ " shape=(None, None, 1), dtype=tf.uint8\n",
+ " ),\n",
+ " }\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " return dataset\n",
+ "\n",
+ "\n",
+ "def load(\n",
+ " split=\"sbd_train\",\n",
+ " data_dir=None,\n",
+ "):\n",
+ " \"\"\"Load the Pacal VOC 2012 dataset.\n",
+ "\n",
+ " This function will download the data tar file from remote if needed, and\n",
+ " untar to the local `data_dir`, and build dataset from it.\n",
+ "\n",
+ " It supports both VOC2012 and Semantic Boundaries Dataset (SBD).\n",
+ "\n",
+ " The returned segmentation masks will be int ranging from [0, num_classes),\n",
+ " as well as 255 which is the boundary mask.\n",
+ "\n",
+ " Args:\n",
+ " split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or\n",
+ " 'sbd_eval'. 'sbd_train' represents the training dataset for SBD\n",
+ " dataset, while 'train' represents the training dataset for VOC2012\n",
+ " dataset. Defaults to `sbd_train`.\n",
+ " data_dir: string, local directory path for the loaded data. This will be\n",
+ " used to download the data file, and unzip. It will be used as a\n",
+ " cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`\n",
+ " will be used.\n",
+ " \"\"\"\n",
+ " supported_split_value = [\n",
+ " \"train\",\n",
+ " \"eval\",\n",
+ " \"trainval\",\n",
+ " \"sbd_train\",\n",
+ " \"sbd_eval\",\n",
+ " ]\n",
+ " if split not in supported_split_value:\n",
+ " raise ValueError(\n",
+ " f\"The support value for `split` are {supported_split_value}. \"\n",
+ " f\"Got: {split}\"\n",
+ " )\n",
+ "\n",
+ " if data_dir is not None:\n",
+ " data_dir = os.path.expanduser(data_dir)\n",
+ "\n",
+ " if \"sbd\" in split:\n",
+ " return load_sbd(split, data_dir)\n",
+ " else:\n",
+ " return load_voc(split, data_dir)\n",
+ "\n",
+ "\n",
+ "def load_voc(\n",
+ " split=\"train\",\n",
+ " data_dir=None,\n",
+ "):\n",
+ " \"\"\"This function will download VOC data from a URL. If the data is already\n",
+ " present in the cache directory, it will load the data from that directory\n",
+ " instead.\n",
+ " \"\"\"\n",
+ " extracted_dir = os.path.join(\"VOCdevkit\", \"VOC2012\")\n",
+ " get_data = keras.utils.get_file(\n",
+ " fname=os.path.basename(VOC_URL),\n",
+ " origin=VOC_URL,\n",
+ " cache_dir=data_dir,\n",
+ " extract=True,\n",
+ " )\n",
+ " data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)\n",
+ " image_ids = get_image_ids(data_dir, split)\n",
+ " # len(metadata) = #samples, metadata[i] is a dict.\n",
+ " metadata = build_metadata(data_dir, image_ids)\n",
+ " maybe_populate_voc_color_mapping()\n",
+ " dataset = build_dataset_from_metadata(metadata)\n",
+ "\n",
+ " return dataset\n",
+ "\n",
+ "\n",
+ "def load_sbd(\n",
+ " split=\"sbd_train\",\n",
+ " data_dir=None,\n",
+ "):\n",
+ " \"\"\"This function will download SBD data from a URL. If the data is already\n",
+ " present in the cache directory, it will load the data from that directory\n",
+ " instead.\n",
+ " \"\"\"\n",
+ " extracted_dir = os.path.join(\"benchmark_RELEASE\", \"dataset\")\n",
+ " get_data = keras.utils.get_file(\n",
+ " fname=os.path.basename(SBD_URL),\n",
+ " origin=SBD_URL,\n",
+ " cache_dir=data_dir,\n",
+ " extract=True,\n",
+ " )\n",
+ " data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)\n",
+ " image_ids = get_sbd_image_ids(data_dir, split)\n",
+ " # len(metadata) = #samples, metadata[i] is a dict.\n",
+ " metadata = build_sbd_metadata(data_dir, image_ids)\n",
+ "\n",
+ " dataset = build_sbd_dataset_from_metadata(metadata)\n",
+ " return dataset\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Load the dataset\n",
+ "\n",
+ "For training and evaluation, let's use \"sbd_train\" and \"sbd_eval.\" You can also\n",
+ "choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',\n",
+ "'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the\n",
+ "SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = load(split=\"sbd_train\", data_dir=\"segmentation\")\n",
+ "eval_ds = load(split=\"sbd_eval\", data_dir=\"segmentation\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Preprocess the data\n",
+ "\n",
+ "The preprocess_inputs utility function preprocesses inputs, converting them into\n",
+ "a dictionary containing images and segmentation_masks. Both images and\n",
+ "segmentation masks are resized to 512x512. The resulting dataset is then batched\n",
+ "into groups of four image and segmentation mask pairs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def preprocess_inputs(inputs):\n",
+ " def unpackage_inputs(inputs):\n",
+ " return {\n",
+ " \"images\": inputs[\"image\"],\n",
+ " \"segmentation_masks\": inputs[\"class_segmentation\"],\n",
+ " }\n",
+ "\n",
+ " outputs = inputs.map(unpackage_inputs)\n",
+ " outputs = outputs.map(keras.layers.Resizing(height=512, width=512))\n",
+ " outputs = outputs.batch(4, drop_remainder=True)\n",
+ " return outputs\n",
+ "\n",
+ "\n",
+ "train_ds = preprocess_inputs(train_ds)\n",
+ "batch = train_ds.take(1).get_single_element()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "A batch of this preprocessed input training data can be visualized using the\n",
+ "`plot_images_masks` function. This function takes a batch of images and\n",
+ "segmentation masks and prediction masks as input and displays them in a grid."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def plot_images_masks(images, masks, pred_masks=None):\n",
+ " images = (images - np.min(images)) / (np.max(images) - np.min(images))\n",
+ " masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks))\n",
+ " if pred_masks is not None:\n",
+ " pred_masks = (pred_masks - pred_masks.min()) / (\n",
+ " pred_masks.max() - pred_masks.min()\n",
+ " )\n",
+ " num_images = len(images)\n",
+ " plt.figure(figsize=(8, 4))\n",
+ " rows = 3 if pred_masks is not None else 2\n",
+ "\n",
+ " for i in range(num_images):\n",
+ " plt.subplot(rows, num_images, i + 1)\n",
+ " plt.imshow(images[i])\n",
+ " plt.axis(\"off\")\n",
+ "\n",
+ " plt.subplot(rows, num_images, num_images + i + 1)\n",
+ " plt.imshow(masks[i], cmap=\"gray\")\n",
+ " plt.axis(\"off\")\n",
+ "\n",
+ " if pred_masks is not None:\n",
+ " plt.subplot(rows, num_images, i + 1 + 2 * num_images)\n",
+ " plt.imshow(pred_masks[i, ..., 0], cmap=\"gray\")\n",
+ " plt.axis(\"off\")\n",
+ "\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "plot_images_masks(batch[\"images\"], batch[\"segmentation_masks\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The preprocessing is applied to the evaluation dataset `eval_ds`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "eval_ds = preprocess_inputs(eval_ds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Data Augmentation\n",
+ "\n",
+ "Keras provides a variety of image augmentation options. In this example, we will\n",
+ "use the `RandomFlip` augmentation to augment the training dataset. The\n",
+ "`RandomFlip` augmentation randomly flips the images in the training dataset\n",
+ "horizontally or vertically. This can help to improve the model's robustness to\n",
+ "changes in the orientation of the objects in the images."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = train_ds.map(keras.layers.RandomFlip())\n",
+ "batch = train_ds.take(1).get_single_element()\n",
+ "\n",
+ "plot_images_masks(batch[\"images\"], batch[\"segmentation_masks\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model Configuration\n",
+ "\n",
+ "Please feel free to modify the configurations for model training and note how the\n",
+ "training results changes. This is an great exercise to get a better\n",
+ "understanding of the training pipeline.\n",
+ "\n",
+ "The learning rate schedule is used by the optimizer to calculate the learning\n",
+ "rate for each epoch. The optimizer then uses the learning rate to update the\n",
+ "weights of the model.\n",
+ "In this case, the learning rate schedule uses a cosine decay function. A cosine\n",
+ "decay function starts high and then decreases over time, eventually reaching\n",
+ "zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The\n",
+ "dataset cardinality is important for learning rate decay because it determines\n",
+ "how many steps the model will train for. The initial learning rate is\n",
+ "proportional to 0.007 and the decay steps are 2124. This means that the learning\n",
+ "rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "BATCH_SIZE = 4\n",
+ "INITIAL_LR = 0.007 * BATCH_SIZE / 16\n",
+ "EPOCHS = 1\n",
+ "NUM_CLASSES = 21\n",
+ "learning_rate = keras.optimizers.schedules.CosineDecay(\n",
+ " INITIAL_LR,\n",
+ " decay_steps=EPOCHS * 2124,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for\n",
+ "the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with\n",
+ "additional decoder block.\n",
+ "For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing\n",
+ "`low_level_feature_key` as `P2` a pyramid level output to extract features from\n",
+ "`resnet_50_imagenet` which acts as a decoder block.\n",
+ "To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`\n",
+ "which defaults to `None`.\n",
+ "\n",
+ "Then we create DeepLabV3ImageSegmenter instance.\n",
+ "The `num_classes` parameter specifies the number of classes that the model will\n",
+ "be trained to segment. `preprocessor` argument to apply preprocessing to image\n",
+ "input and masks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "image_encoder = keras_hub.models.Backbone.from_preset(\"resnet_50_imagenet\")\n",
+ "\n",
+ "deeplab_backbone = keras_hub.models.DeepLabV3Backbone(\n",
+ " image_encoder=image_encoder,\n",
+ " low_level_feature_key=\"P2\",\n",
+ " spatial_pyramid_pooling_key=\"P5\",\n",
+ " dilation_rates=[6, 12, 18],\n",
+ " upsampling_size=8,\n",
+ ")\n",
+ "\n",
+ "model = keras_hub.models.DeepLabV3ImageSegmenter(\n",
+ " backbone=deeplab_backbone,\n",
+ " num_classes=21,\n",
+ " activation=\"softmax\",\n",
+ " preprocessor=preprocessor,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Compile the model\n",
+ "\n",
+ "The model.compile() function sets up the training process for the model. It defines the\n",
+ "- optimization algorithm - Stochastic Gradient Descent (SGD)\n",
+ "- the loss function - categorical cross-entropy\n",
+ "- the evaluation metrics - Mean IoU and categorical accuracy\n",
+ "\n",
+ "Semantic segmentation evaluation metrics:\n",
+ "\n",
+ "Mean Intersection over Union (MeanIoU):\n",
+ "MeanIoU measures how well a semantic segmentation model accurately identifies\n",
+ "and delineates different objects or regions in an image. It calculates the\n",
+ "overlap between predicted and actual object boundaries, providing a score\n",
+ "between 0 and 1, where 1 represents a perfect match.\n",
+ "\n",
+ "Categorical Accuracy:\n",
+ "Categorical Accuracy measures the proportion of correctly classified pixels in\n",
+ "an image. It gives a simple percentage indicating how accurately the model\n",
+ "predicts the categories of pixels in the entire image.\n",
+ "\n",
+ "In essence, MeanIoU emphasizes the accuracy of identifying specific object\n",
+ "boundaries, while Categorical Accuracy gives a broad overview of overall\n",
+ "pixel-level correctness."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model.compile(\n",
+ " optimizer=keras.optimizers.SGD(\n",
+ " learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0\n",
+ " ),\n",
+ " loss=keras.losses.CategoricalCrossentropy(from_logits=False),\n",
+ " metrics=[\n",
+ " keras.metrics.MeanIoU(\n",
+ " num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False\n",
+ " ),\n",
+ " keras.metrics.CategoricalAccuracy(),\n",
+ " ],\n",
+ ")\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The utility function `dict_to_tuple` effectively transforms the dictionaries of\n",
+ "training and validation datasets into tuples of images and one-hot encoded\n",
+ "segmentation masks, which is used during training and evaluation of the\n",
+ "DeepLabv3+ model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def dict_to_tuple(x):\n",
+ "\n",
+ " return x[\"images\"], tf.one_hot(\n",
+ " tf.cast(tf.squeeze(x[\"segmentation_masks\"], axis=-1), \"int32\"), 21\n",
+ " )\n",
+ "\n",
+ "\n",
+ "train_ds = train_ds.map(dict_to_tuple)\n",
+ "eval_ds = eval_ds.map(dict_to_tuple)\n",
+ "\n",
+ "model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Predictions with trained model\n",
+ "Now that the model training of DeepLabv3+ has completed, let's test it by making\n",
+ "predications\n",
+ "on a few sample images.\n",
+ "Note: For demonstration purpose the model has been trained on only 1 epoch, for\n",
+ "better accuracy and result train with more number of epochs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "test_ds = load(split=\"sbd_eval\")\n",
+ "test_ds = preprocess_inputs(test_ds)\n",
+ "\n",
+ "images, masks = next(iter(train_ds.take(1)))\n",
+ "images = ops.convert_to_tensor(images)\n",
+ "masks = ops.convert_to_tensor(masks)\n",
+ "preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1)\n",
+ "masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)\n",
+ "\n",
+ "plot_images_masks(images, masks, preds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Here are some additional tips for using the KerasHub DeepLabv3 model:\n",
+ "\n",
+ "- The model can be trained on a variety of datasets, including the COCO dataset, the\n",
+ "PASCAL VOC dataset, and the Cityscapes dataset.\n",
+ "- The model can be fine-tuned on a custom dataset to improve its performance on a\n",
+ "specific task.\n",
+ "- The model can be used to perform real-time inference on images.\n",
+ "- Also, check out KerasHub's other segmentation models."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "semantic_segmentation_deeplab_v3",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/guides/keras_hub/semantic_segmentation_deeplab_v3.py b/guides/keras_hub/semantic_segmentation_deeplab_v3.py
new file mode 100644
index 0000000000..d6bc3c1d61
--- /dev/null
+++ b/guides/keras_hub/semantic_segmentation_deeplab_v3.py
@@ -0,0 +1,856 @@
+"""
+Title: Semantic Segmentation with KerasHub
+Authors: [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
+Date created: 2024/10/11
+Last modified: 2024/10/11
+Description: DeepLabV3 training and inference with KerasHub.
+Accelerator: GPU
+"""
+
+"""
+
+
+## Background
+Semantic segmentation is a type of computer vision task that involves assigning a
+class label such as "person", "bike", or "background" to each individual pixel
+of an image, effectively dividing the image into regions that correspond to
+different object classes or categories.
+
+
+
+
+
+KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic
+segmentation.
+
+This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed
+by Google for image semantic segmentation with KerasHub. Its architecture
+combines Atrous convolutions, contextual information aggregation, and powerful
+backbones to achieve accurate and detailed semantic segmentation.
+
+DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to
+refine the segmentation results, especially along object boundaries. Both models
+have achieved state-of-the-art results on a variety of image segmentation
+benchmarks.
+
+### References
+[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
+[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
+"""
+
+"""
+## Setup and Imports
+
+Let's install the dependencies and import the necessary modules.
+
+To run this tutorial, you will need to install the following packages:
+
+* `keras-hub`
+* `keras`
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras
+"""
+
+"""
+After installing `keras` and `keras-hub`, set the backend for `keras`.
+This guide can be run with any backend (Tensorflow, JAX, PyTorch).
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+import keras
+from keras import ops
+import keras_hub
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+"""
+## Perform semantic segmentation with a pretrained DeepLabv3+ model
+
+The highest level API in the KerasHub semantic segmentation API is the
+`keras_hub.models` API. This API includes fully pretrained semantic segmentation
+models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.
+
+Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC
+dataset.
+Also, define the preprocessing function for the model to preprocess images and
+labels.
+**Note:** By default `from_preset()` method in KerasHub loads the pretrained
+task weights with all the classes, 21 classes in this case.
+"""
+
+model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
+ "deeplab_v3_plus_resnet50_pascalvoc"
+)
+
+image_converter = keras_hub.layers.DeepLabV3ImageConverter(
+ image_size=(512, 512),
+ interpolation="bilinear",
+)
+preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)
+
+"""
+Let us visualize the results of this pretrained model
+"""
+filepath = keras.utils.get_file(
+ origin="https://storage.googleapis.com/keras-cv/pictures/dog.jpeg"
+)
+image = keras.utils.load_img(filepath)
+image = keras.utils.img_to_array(image)
+
+image = preprocessor(image)
+image = keras.ops.expand_dims(image, axis=0)
+preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)
+
+
+def plot_segmentation(original_image, predicted_mask):
+ original_image = np.squeeze(original_image, axis=0)
+ original_image = np.clip(original_image / 255.0, 0, 1)
+ predicted_mask = np.squeeze(predicted_mask, axis=0)
+ plt.figure(figsize=(5, 5))
+
+ plt.subplot(1, 2, 1)
+ plt.imshow(original_image)
+ plt.axis("off")
+
+ plt.subplot(1, 2, 2)
+ plt.imshow(predicted_mask, cmap="gray")
+ plt.axis("off")
+
+ plt.tight_layout()
+ plt.show()
+
+
+plot_segmentation(image, preds)
+
+"""
+## Train a custom semantic segmentation model
+In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3
+semantic segmentation model. This includes data loading, augmentation, training,
+metric evaluation, and inference!
+"""
+
+"""
+## Download the data
+
+We download Pascal VOC 2012 dataset with additional annotations provided here
+[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)
+and split them into train dataset `train_ds` and `eval_ds`.
+"""
+
+# @title helper functions
+import logging
+import multiprocessing
+from builtins import open
+import os.path
+import random
+import xml
+
+import tensorflow_datasets as tfds
+
+VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
+
+SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+
+# Note that this list doesn't contain the background class. In the
+# classification use case, the label is 0 based (aeroplane -> 0), whereas in
+# segmentation use case, the 0 is reserved for background, so aeroplane maps to
+# 1.
+CLASSES = [
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor",
+]
+# This is used to map between string class to index.
+CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
+
+# For the mask data in the PNG file, the encoded raw pixel value need to be
+# converted to the proper class index. In the following map, [0, 0, 0] will be
+# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also
+# note that the mask class is 1 base since class 0 is reserved for the
+# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.
+VOC_PNG_COLOR_VALUE = [
+ [0, 0, 0],
+ [128, 0, 0],
+ [0, 128, 0],
+ [128, 128, 0],
+ [0, 0, 128],
+ [128, 0, 128],
+ [0, 128, 128],
+ [128, 128, 128],
+ [64, 0, 0],
+ [192, 0, 0],
+ [64, 128, 0],
+ [192, 128, 0],
+ [64, 0, 128],
+ [192, 0, 128],
+ [64, 128, 128],
+ [192, 128, 128],
+ [0, 64, 0],
+ [128, 64, 0],
+ [0, 192, 0],
+ [128, 192, 0],
+ [0, 64, 128],
+]
+# Will be populated by maybe_populate_voc_color_mapping() below.
+VOC_PNG_COLOR_MAPPING = None
+
+
+def maybe_populate_voc_color_mapping():
+ """Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""
+ global VOC_PNG_COLOR_MAPPING
+ if VOC_PNG_COLOR_MAPPING is None:
+ VOC_PNG_COLOR_MAPPING = [0] * (256**3)
+ for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):
+ VOC_PNG_COLOR_MAPPING[
+ (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]
+ ] = i
+ # There is a special mapping with [224, 224, 192] -> 255
+ VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255
+ VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)
+ return VOC_PNG_COLOR_MAPPING
+
+
+def parse_annotation_data(annotation_file_path):
+ """Parse the annotation XML file for the image.
+
+ The annotation contains the metadata, as well as the object bounding box
+ information.
+
+ """
+ with open(annotation_file_path, "r") as f:
+ root = xml.etree.ElementTree.parse(f).getroot()
+
+ size = root.find("size")
+ width = int(size.find("width").text)
+ height = int(size.find("height").text)
+
+ objects = []
+ for obj in root.findall("object"):
+ # Get object's label name.
+ label = CLASS_TO_INDEX[obj.find("name").text.lower()]
+ # Get objects' pose name.
+ pose = obj.find("pose").text.lower()
+ is_truncated = obj.find("truncated").text == "1"
+ is_difficult = obj.find("difficult").text == "1"
+ bndbox = obj.find("bndbox")
+ xmax = int(bndbox.find("xmax").text)
+ xmin = int(bndbox.find("xmin").text)
+ ymax = int(bndbox.find("ymax").text)
+ ymin = int(bndbox.find("ymin").text)
+ objects.append(
+ {
+ "label": label,
+ "pose": pose,
+ "bbox": [ymin, xmin, ymax, xmax],
+ "is_truncated": is_truncated,
+ "is_difficult": is_difficult,
+ }
+ )
+
+ return {"width": width, "height": height, "objects": objects}
+
+
+def get_image_ids(data_dir, split):
+ """To get image ids from the "train", "eval" or "trainval" files of VOC data."""
+ data_file_mapping = {
+ "train": "train.txt",
+ "eval": "val.txt",
+ "trainval": "trainval.txt",
+ }
+ with open(
+ os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),
+ "r",
+ ) as f:
+ image_ids = f.read().splitlines()
+ logging.info(f"Received {len(image_ids)} images for {split} dataset.")
+ return image_ids
+
+
+def get_sbd_image_ids(data_dir, split):
+ """To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""
+ data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}
+ with open(
+ os.path.join(data_dir, data_file_mapping[split]),
+ "r",
+ ) as f:
+ image_ids = f.read().splitlines()
+ logging.info(f"Received {len(image_ids)} images for {split} dataset.")
+ return image_ids
+
+
+def parse_single_image(image_file_path):
+ """Creates metadata of VOC images and path."""
+ data_dir, image_file_name = os.path.split(image_file_path)
+ data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
+ image_id, _ = os.path.splitext(image_file_name)
+ class_segmentation_file_path = os.path.join(
+ data_dir, "SegmentationClass", image_id + ".png"
+ )
+ object_segmentation_file_path = os.path.join(
+ data_dir, "SegmentationObject", image_id + ".png"
+ )
+ annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")
+ image_annotations = parse_annotation_data(annotation_file_path)
+
+ result = {
+ "image/filename": image_id + ".jpg",
+ "image/file_path": image_file_path,
+ "segmentation/class/file_path": class_segmentation_file_path,
+ "segmentation/object/file_path": object_segmentation_file_path,
+ }
+ result.update(image_annotations)
+ # Labels field should be same as the 'object.label'
+ labels = list(set([o["label"] for o in result["objects"]]))
+ result["labels"] = sorted(labels)
+ return result
+
+
+def parse_single_sbd_image(image_file_path):
+ """Creates metadata of SBD images and path."""
+ data_dir, image_file_name = os.path.split(image_file_path)
+ data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
+ image_id, _ = os.path.splitext(image_file_name)
+ class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat")
+ object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat")
+ result = {
+ "image/filename": image_id + ".jpg",
+ "image/file_path": image_file_path,
+ "segmentation/class/file_path": class_segmentation_file_path,
+ "segmentation/object/file_path": object_segmentation_file_path,
+ }
+ return result
+
+
+def build_metadata(data_dir, image_ids):
+ """Transpose the metadata which convert from list of dict to dict of list."""
+ # Parallel process all the images.
+ image_file_paths = [
+ os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
+ ]
+ pool_size = 10 if len(image_ids) > 10 else len(image_ids)
+ with multiprocessing.Pool(pool_size) as p:
+ metadata = p.map(parse_single_image, image_file_paths)
+
+ keys = [
+ "image/filename",
+ "image/file_path",
+ "segmentation/class/file_path",
+ "segmentation/object/file_path",
+ "labels",
+ "width",
+ "height",
+ ]
+ result = {}
+ for key in keys:
+ values = [value[key] for value in metadata]
+ result[key] = values
+
+ # The ragged objects need some special handling
+ for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]:
+ values = []
+ objects = [value["objects"] for value in metadata]
+ for object in objects:
+ values.append([o[key] for o in object])
+ result["objects/" + key] = values
+ return result
+
+
+def build_sbd_metadata(data_dir, image_ids):
+ """Transpose the metadata which convert from list of dict to dict of list."""
+ # Parallel process all the images.
+ image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]
+ pool_size = 10 if len(image_ids) > 10 else len(image_ids)
+ with multiprocessing.Pool(pool_size) as p:
+ metadata = p.map(parse_single_sbd_image, image_file_paths)
+
+ keys = [
+ "image/filename",
+ "image/file_path",
+ "segmentation/class/file_path",
+ "segmentation/object/file_path",
+ ]
+ result = {}
+ for key in keys:
+ values = [value[key] for value in metadata]
+ result[key] = values
+ return result
+
+
+def decode_png_mask(mask):
+ """Decode the raw PNG image and convert it to 2D tensor with probably
+ class."""
+ # Cast the mask to int32 since the original uint8 will overflow when
+ # multiplied with 256
+ mask = tf.cast(mask, tf.int32)
+ mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
+ mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)
+ mask = tf.cast(mask, tf.uint8)
+ return mask
+
+
+def load_images(example):
+ """Loads VOC images for segmentation task from the provided paths"""
+ image_file_path = example.pop("image/file_path")
+ segmentation_class_file_path = example.pop("segmentation/class/file_path")
+ segmentation_object_file_path = example.pop("segmentation/object/file_path")
+ image = tf.io.read_file(image_file_path)
+ image = tf.image.decode_jpeg(image)
+
+ segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)
+ segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)
+ segmentation_class_mask = decode_png_mask(segmentation_class_mask)
+
+ segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)
+ segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)
+ segmentation_object_mask = decode_png_mask(segmentation_object_mask)
+
+ example.update(
+ {
+ "image": image,
+ "class_segmentation": segmentation_class_mask,
+ "object_segmentation": segmentation_object_mask,
+ }
+ )
+ return example
+
+
+def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
+ """Loads SBD images for segmentation task from the provided paths"""
+ image = tf.io.read_file(image_file_path)
+ image = tf.image.decode_jpeg(image)
+
+ segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)
+ segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0]
+ segmentation_class_mask = segmentation_class_mask[..., np.newaxis]
+
+ segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(
+ seg_obj_file_path
+ )
+ segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0]
+ segmentation_object_mask = segmentation_object_mask[..., np.newaxis]
+
+ return {
+ "image": image,
+ "class_segmentation": segmentation_class_mask,
+ "object_segmentation": segmentation_object_mask,
+ }
+
+
+def build_dataset_from_metadata(metadata):
+ """Builds TensorFlow dataset from the image metadata of VOC dataset."""
+ # The objects need some manual conversion to ragged tensor.
+ metadata["labels"] = tf.ragged.constant(metadata["labels"])
+ metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
+ metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"])
+ metadata["objects/is_truncated"] = tf.ragged.constant(
+ metadata["objects/is_truncated"]
+ )
+ metadata["objects/is_difficult"] = tf.ragged.constant(
+ metadata["objects/is_difficult"]
+ )
+ metadata["objects/bbox"] = tf.ragged.constant(
+ metadata["objects/bbox"], ragged_rank=1
+ )
+
+ dataset = tf.data.Dataset.from_tensor_slices(metadata)
+ dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
+ return dataset
+
+
+def build_sbd_dataset_from_metadata(metadata):
+ """Builds TensorFlow dataset from the image metadata of SBD dataset."""
+ img_filepath = metadata["image/file_path"]
+ cls_filepath = metadata["segmentation/class/file_path"]
+ obj_filepath = metadata["segmentation/object/file_path"]
+
+ def md_gen():
+ c = list(zip(img_filepath, cls_filepath, obj_filepath))
+ # random shuffling for each generator boosts up the quality.
+ random.shuffle(c)
+ for fp in c:
+ img_fp, cls_fp, obj_fp = fp
+ yield load_sbd_images(img_fp, cls_fp, obj_fp)
+
+ dataset = tf.data.Dataset.from_generator(
+ md_gen,
+ output_signature=(
+ {
+ "image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
+ "class_segmentation": tf.TensorSpec(
+ shape=(None, None, 1), dtype=tf.uint8
+ ),
+ "object_segmentation": tf.TensorSpec(
+ shape=(None, None, 1), dtype=tf.uint8
+ ),
+ }
+ ),
+ )
+
+ return dataset
+
+
+def load(
+ split="sbd_train",
+ data_dir=None,
+):
+ """Load the Pacal VOC 2012 dataset.
+
+ This function will download the data tar file from remote if needed, and
+ untar to the local `data_dir`, and build dataset from it.
+
+ It supports both VOC2012 and Semantic Boundaries Dataset (SBD).
+
+ The returned segmentation masks will be int ranging from [0, num_classes),
+ as well as 255 which is the boundary mask.
+
+ Args:
+ split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or
+ 'sbd_eval'. 'sbd_train' represents the training dataset for SBD
+ dataset, while 'train' represents the training dataset for VOC2012
+ dataset. Defaults to `sbd_train`.
+ data_dir: string, local directory path for the loaded data. This will be
+ used to download the data file, and unzip. It will be used as a
+ cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`
+ will be used.
+ """
+ supported_split_value = [
+ "train",
+ "eval",
+ "trainval",
+ "sbd_train",
+ "sbd_eval",
+ ]
+ if split not in supported_split_value:
+ raise ValueError(
+ f"The support value for `split` are {supported_split_value}. "
+ f"Got: {split}"
+ )
+
+ if data_dir is not None:
+ data_dir = os.path.expanduser(data_dir)
+
+ if "sbd" in split:
+ return load_sbd(split, data_dir)
+ else:
+ return load_voc(split, data_dir)
+
+
+def load_voc(
+ split="train",
+ data_dir=None,
+):
+ """This function will download VOC data from a URL. If the data is already
+ present in the cache directory, it will load the data from that directory
+ instead.
+ """
+ extracted_dir = os.path.join("VOCdevkit", "VOC2012")
+ get_data = keras.utils.get_file(
+ fname=os.path.basename(VOC_URL),
+ origin=VOC_URL,
+ cache_dir=data_dir,
+ extract=True,
+ )
+ data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
+ image_ids = get_image_ids(data_dir, split)
+ # len(metadata) = #samples, metadata[i] is a dict.
+ metadata = build_metadata(data_dir, image_ids)
+ maybe_populate_voc_color_mapping()
+ dataset = build_dataset_from_metadata(metadata)
+
+ return dataset
+
+
+def load_sbd(
+ split="sbd_train",
+ data_dir=None,
+):
+ """This function will download SBD data from a URL. If the data is already
+ present in the cache directory, it will load the data from that directory
+ instead.
+ """
+ extracted_dir = os.path.join("benchmark_RELEASE", "dataset")
+ get_data = keras.utils.get_file(
+ fname=os.path.basename(SBD_URL),
+ origin=SBD_URL,
+ cache_dir=data_dir,
+ extract=True,
+ )
+ data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
+ image_ids = get_sbd_image_ids(data_dir, split)
+ # len(metadata) = #samples, metadata[i] is a dict.
+ metadata = build_sbd_metadata(data_dir, image_ids)
+
+ dataset = build_sbd_dataset_from_metadata(metadata)
+ return dataset
+
+
+"""
+## Load the dataset
+
+For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also
+choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',
+'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the
+SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.
+"""
+train_ds = load(split="sbd_train", data_dir="segmentation")
+eval_ds = load(split="sbd_eval", data_dir="segmentation")
+
+"""
+## Preprocess the data
+
+The preprocess_inputs utility function preprocesses inputs, converting them into
+a dictionary containing images and segmentation_masks. Both images and
+segmentation masks are resized to 512x512. The resulting dataset is then batched
+into groups of four image and segmentation mask pairs.
+"""
+
+
+def preprocess_inputs(inputs):
+ def unpackage_inputs(inputs):
+ return {
+ "images": inputs["image"],
+ "segmentation_masks": inputs["class_segmentation"],
+ }
+
+ outputs = inputs.map(unpackage_inputs)
+ outputs = outputs.map(keras.layers.Resizing(height=512, width=512))
+ outputs = outputs.batch(4, drop_remainder=True)
+ return outputs
+
+
+train_ds = preprocess_inputs(train_ds)
+batch = train_ds.take(1).get_single_element()
+
+"""
+A batch of this preprocessed input training data can be visualized using the
+`plot_images_masks` function. This function takes a batch of images and
+segmentation masks and prediction masks as input and displays them in a grid.
+"""
+
+
+def plot_images_masks(images, masks, pred_masks=None):
+ images = (images - np.min(images)) / (np.max(images) - np.min(images))
+ masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks))
+ if pred_masks is not None:
+ pred_masks = (pred_masks - pred_masks.min()) / (
+ pred_masks.max() - pred_masks.min()
+ )
+ num_images = len(images)
+ plt.figure(figsize=(8, 4))
+ rows = 3 if pred_masks is not None else 2
+
+ for i in range(num_images):
+ plt.subplot(rows, num_images, i + 1)
+ plt.imshow(images[i])
+ plt.axis("off")
+
+ plt.subplot(rows, num_images, num_images + i + 1)
+ plt.imshow(masks[i], cmap="gray")
+ plt.axis("off")
+
+ if pred_masks is not None:
+ plt.subplot(rows, num_images, i + 1 + 2 * num_images)
+ plt.imshow(pred_masks[i, ..., 0], cmap="gray")
+ plt.axis("off")
+
+ plt.show()
+
+
+plot_images_masks(batch["images"], batch["segmentation_masks"])
+
+"""
+The preprocessing is applied to the evaluation dataset `eval_ds`.
+"""
+eval_ds = preprocess_inputs(eval_ds)
+
+"""
+## Data Augmentation
+
+Keras provides a variety of image augmentation options. In this example, we will
+use the `RandomFlip` augmentation to augment the training dataset. The
+`RandomFlip` augmentation randomly flips the images in the training dataset
+horizontally or vertically. This can help to improve the model's robustness to
+changes in the orientation of the objects in the images.
+"""
+
+train_ds = train_ds.map(keras.layers.RandomFlip())
+batch = train_ds.take(1).get_single_element()
+
+plot_images_masks(batch["images"], batch["segmentation_masks"])
+
+"""
+## Model Configuration
+
+Please feel free to modify the configurations for model training and note how the
+training results changes. This is an great exercise to get a better
+understanding of the training pipeline.
+
+The learning rate schedule is used by the optimizer to calculate the learning
+rate for each epoch. The optimizer then uses the learning rate to update the
+weights of the model.
+In this case, the learning rate schedule uses a cosine decay function. A cosine
+decay function starts high and then decreases over time, eventually reaching
+zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The
+dataset cardinality is important for learning rate decay because it determines
+how many steps the model will train for. The initial learning rate is
+proportional to 0.007 and the decay steps are 2124. This means that the learning
+rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.
+
+"""
+
+BATCH_SIZE = 4
+INITIAL_LR = 0.007 * BATCH_SIZE / 16
+EPOCHS = 1
+NUM_CLASSES = 21
+learning_rate = keras.optimizers.schedules.CosineDecay(
+ INITIAL_LR,
+ decay_steps=EPOCHS * 2124,
+)
+
+"""
+Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for
+the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with
+additional decoder block.
+For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing
+`low_level_feature_key` as `P2` a pyramid level output to extract features from
+`resnet_50_imagenet` which acts as a decoder block.
+To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`
+which defaults to `None`.
+
+Then we create DeepLabV3ImageSegmenter instance.
+The `num_classes` parameter specifies the number of classes that the model will
+be trained to segment. `preprocessor` argument to apply preprocessing to image
+input and masks.
+"""
+
+image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
+
+deeplab_backbone = keras_hub.models.DeepLabV3Backbone(
+ image_encoder=image_encoder,
+ low_level_feature_key="P2",
+ spatial_pyramid_pooling_key="P5",
+ dilation_rates=[6, 12, 18],
+ upsampling_size=8,
+)
+
+model = keras_hub.models.DeepLabV3ImageSegmenter(
+ backbone=deeplab_backbone,
+ num_classes=21,
+ activation="softmax",
+ preprocessor=preprocessor,
+)
+
+"""
+## Compile the model
+
+The model.compile() function sets up the training process for the model. It defines the
+- optimization algorithm - Stochastic Gradient Descent (SGD)
+- the loss function - categorical cross-entropy
+- the evaluation metrics - Mean IoU and categorical accuracy
+
+Semantic segmentation evaluation metrics:
+
+Mean Intersection over Union (MeanIoU):
+MeanIoU measures how well a semantic segmentation model accurately identifies
+and delineates different objects or regions in an image. It calculates the
+overlap between predicted and actual object boundaries, providing a score
+between 0 and 1, where 1 represents a perfect match.
+
+Categorical Accuracy:
+Categorical Accuracy measures the proportion of correctly classified pixels in
+an image. It gives a simple percentage indicating how accurately the model
+predicts the categories of pixels in the entire image.
+
+In essence, MeanIoU emphasizes the accuracy of identifying specific object
+boundaries, while Categorical Accuracy gives a broad overview of overall
+pixel-level correctness.
+"""
+
+model.compile(
+ optimizer=keras.optimizers.SGD(
+ learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0
+ ),
+ loss=keras.losses.CategoricalCrossentropy(from_logits=False),
+ metrics=[
+ keras.metrics.MeanIoU(
+ num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False
+ ),
+ keras.metrics.CategoricalAccuracy(),
+ ],
+)
+
+model.summary()
+
+"""
+The utility function `dict_to_tuple` effectively transforms the dictionaries of
+training and validation datasets into tuples of images and one-hot encoded
+segmentation masks, which is used during training and evaluation of the
+DeepLabv3+ model.
+"""
+
+
+def dict_to_tuple(x):
+
+ return x["images"], tf.one_hot(
+ tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21
+ )
+
+
+train_ds = train_ds.map(dict_to_tuple)
+eval_ds = eval_ds.map(dict_to_tuple)
+
+model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)
+
+"""
+## Predictions with trained model
+Now that the model training of DeepLabv3+ has completed, let's test it by making
+predications
+on a few sample images.
+Note: For demonstration purpose the model has been trained on only 1 epoch, for
+better accuracy and result train with more number of epochs.
+"""
+
+test_ds = load(split="sbd_eval")
+test_ds = preprocess_inputs(test_ds)
+
+images, masks = next(iter(train_ds.take(1)))
+images = ops.convert_to_tensor(images)
+masks = ops.convert_to_tensor(masks)
+preds = ops.expand_dims(ops.argmax(model(images), axis=-1), axis=-1)
+masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)
+
+plot_images_masks(images, masks, preds)
+
+"""
+Here are some additional tips for using the KerasHub DeepLabv3 model:
+
+- The model can be trained on a variety of datasets, including the COCO dataset, the
+PASCAL VOC dataset, and the Cityscapes dataset.
+- The model can be fine-tuned on a custom dataset to improve its performance on a
+specific task.
+- The model can be used to perform real-time inference on images.
+- Also, check out KerasHub's other segmentation models.
+"""
diff --git a/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md b/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md
new file mode 100644
index 0000000000..19d5b2c3aa
--- /dev/null
+++ b/guides/md/keras_hub/semantic_segmentation_deeplab_v3.md
@@ -0,0 +1,973 @@
+# Semantic Segmentation with KerasHub
+
+**Authors:** [Sachin Prasad](https://github.com/sachinprasadhs), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli), [Ian Stenbit](https://github.com/ianstenbit)
+**Date created:** 2024/10/11
+**Last modified:** 2024/10/11
+**Description:** DeepLabV3 training and inference with KerasHub.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/keras_hub/semantic_segmentation_deeplab_v3.py)
+
+
+
+
+
+---
+## Background
+Semantic segmentation is a type of computer vision task that involves assigning a
+class label such as "person", "bike", or "background" to each individual pixel
+of an image, effectively dividing the image into regions that correspond to
+different object classes or categories.
+
+
+
+
+
+KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic
+segmentation.
+
+This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed
+by Google for image semantic segmentation with KerasHub. Its architecture
+combines Atrous convolutions, contextual information aggregation, and powerful
+backbones to achieve accurate and detailed semantic segmentation.
+
+DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to
+refine the segmentation results, especially along object boundaries. Both models
+have achieved state-of-the-art results on a variety of image segmentation
+benchmarks.
+
+### References
+[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
+[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
+
+---
+## Setup and Imports
+
+Let's install the dependencies and import the necessary modules.
+
+To run this tutorial, you will need to install the following packages:
+
+* `keras-hub`
+* `keras`
+
+
+```python
+!pip install -q --upgrade keras-hub
+!pip install -q --upgrade keras
+```
+
+After installing `keras` and `keras-hub`, set the backend for `keras`.
+This guide can be run with any backend (Tensorflow, JAX, PyTorch).
+
+
+```python
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+import keras
+from keras import ops
+import keras_hub
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+```
+
+---
+## Perform semantic segmentation with a pretrained DeepLabv3+ model
+
+The highest level API in the KerasHub semantic segmentation API is the
+`keras_hub.models` API. This API includes fully pretrained semantic segmentation
+models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.
+
+Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC
+dataset.
+Also, define the preprocessing function for the model to preprocess images and
+labels.
+**Note:** By default `from_preset()` method in KerasHub loads the pretrained
+task weights with all the classes, 21 classes in this case.
+
+
+```python
+model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
+ "deeplab_v3_plus_resnet50_pascalvoc"
+)
+
+image_converter = keras_hub.layers.DeepLabV3ImageConverter(
+ image_size=(512, 512),
+ interpolation="bilinear",
+)
+preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)
+```
+
+Let us visualize the results of this pretrained model
+
+
+```python
+filepath = keras.utils.get_file(
+ origin="https://storage.googleapis.com/keras-cv/pictures/dog.jpeg"
+)
+image = keras.utils.load_img(filepath)
+image = keras.utils.img_to_array(image)
+
+image = preprocessor(image)
+image = keras.ops.expand_dims(image, axis=0)
+preds = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)
+
+
+def plot_segmentation(original_image, predicted_mask):
+ original_image = np.squeeze(original_image, axis=0)
+ original_image = np.clip(original_image / 255.0, 0, 1)
+ predicted_mask = np.squeeze(predicted_mask, axis=0)
+ plt.figure(figsize=(5, 5))
+
+ plt.subplot(1, 2, 1)
+ plt.imshow(original_image)
+ plt.axis("off")
+
+ plt.subplot(1, 2, 2)
+ plt.imshow(predicted_mask, cmap="gray")
+ plt.axis("off")
+
+ plt.tight_layout()
+ plt.show()
+
+
+plot_segmentation(image, preds)
+```
+
+
+
+
+
+
+---
+## Train a custom semantic segmentation model
+In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3
+semantic segmentation model. This includes data loading, augmentation, training,
+metric evaluation, and inference!
+
+---
+## Download the data
+
+We download Pascal VOC 2012 dataset with additional annotations provided here
+[Semantic contours from inverse detectors](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)
+and split them into train dataset `train_ds` and `eval_ds`.
+
+
+```python
+# @title helper functions
+import logging
+import multiprocessing
+from builtins import open
+import os.path
+import random
+import xml
+
+import tensorflow_datasets as tfds
+
+VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
+
+SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+
+# Note that this list doesn't contain the background class. In the
+# classification use case, the label is 0 based (aeroplane -> 0), whereas in
+# segmentation use case, the 0 is reserved for background, so aeroplane maps to
+# 1.
+CLASSES = [
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor",
+]
+# This is used to map between string class to index.
+CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
+
+# For the mask data in the PNG file, the encoded raw pixel value need to be
+# converted to the proper class index. In the following map, [0, 0, 0] will be
+# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also
+# note that the mask class is 1 base since class 0 is reserved for the
+# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.
+VOC_PNG_COLOR_VALUE = [
+ [0, 0, 0],
+ [128, 0, 0],
+ [0, 128, 0],
+ [128, 128, 0],
+ [0, 0, 128],
+ [128, 0, 128],
+ [0, 128, 128],
+ [128, 128, 128],
+ [64, 0, 0],
+ [192, 0, 0],
+ [64, 128, 0],
+ [192, 128, 0],
+ [64, 0, 128],
+ [192, 0, 128],
+ [64, 128, 128],
+ [192, 128, 128],
+ [0, 64, 0],
+ [128, 64, 0],
+ [0, 192, 0],
+ [128, 192, 0],
+ [0, 64, 128],
+]
+# Will be populated by maybe_populate_voc_color_mapping() below.
+VOC_PNG_COLOR_MAPPING = None
+
+
+def maybe_populate_voc_color_mapping():
+ """Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""
+ global VOC_PNG_COLOR_MAPPING
+ if VOC_PNG_COLOR_MAPPING is None:
+ VOC_PNG_COLOR_MAPPING = [0] * (256**3)
+ for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):
+ VOC_PNG_COLOR_MAPPING[
+ (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]
+ ] = i
+ # There is a special mapping with [224, 224, 192] -> 255
+ VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255
+ VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)
+ return VOC_PNG_COLOR_MAPPING
+
+
+def parse_annotation_data(annotation_file_path):
+ """Parse the annotation XML file for the image.
+
+ The annotation contains the metadata, as well as the object bounding box
+ information.
+
+ """
+ with open(annotation_file_path, "r") as f:
+ root = xml.etree.ElementTree.parse(f).getroot()
+
+ size = root.find("size")
+ width = int(size.find("width").text)
+ height = int(size.find("height").text)
+
+ objects = []
+ for obj in root.findall("object"):
+ # Get object's label name.
+ label = CLASS_TO_INDEX[obj.find("name").text.lower()]
+ # Get objects' pose name.
+ pose = obj.find("pose").text.lower()
+ is_truncated = obj.find("truncated").text == "1"
+ is_difficult = obj.find("difficult").text == "1"
+ bndbox = obj.find("bndbox")
+ xmax = int(bndbox.find("xmax").text)
+ xmin = int(bndbox.find("xmin").text)
+ ymax = int(bndbox.find("ymax").text)
+ ymin = int(bndbox.find("ymin").text)
+ objects.append(
+ {
+ "label": label,
+ "pose": pose,
+ "bbox": [ymin, xmin, ymax, xmax],
+ "is_truncated": is_truncated,
+ "is_difficult": is_difficult,
+ }
+ )
+
+ return {"width": width, "height": height, "objects": objects}
+
+
+def get_image_ids(data_dir, split):
+ """To get image ids from the "train", "eval" or "trainval" files of VOC data."""
+ data_file_mapping = {
+ "train": "train.txt",
+ "eval": "val.txt",
+ "trainval": "trainval.txt",
+ }
+ with open(
+ os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),
+ "r",
+ ) as f:
+ image_ids = f.read().splitlines()
+ logging.info(f"Received {len(image_ids)} images for {split} dataset.")
+ return image_ids
+
+
+def get_sbd_image_ids(data_dir, split):
+ """To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""
+ data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}
+ with open(
+ os.path.join(data_dir, data_file_mapping[split]),
+ "r",
+ ) as f:
+ image_ids = f.read().splitlines()
+ logging.info(f"Received {len(image_ids)} images for {split} dataset.")
+ return image_ids
+
+
+def parse_single_image(image_file_path):
+ """Creates metadata of VOC images and path."""
+ data_dir, image_file_name = os.path.split(image_file_path)
+ data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
+ image_id, _ = os.path.splitext(image_file_name)
+ class_segmentation_file_path = os.path.join(
+ data_dir, "SegmentationClass", image_id + ".png"
+ )
+ object_segmentation_file_path = os.path.join(
+ data_dir, "SegmentationObject", image_id + ".png"
+ )
+ annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")
+ image_annotations = parse_annotation_data(annotation_file_path)
+
+ result = {
+ "image/filename": image_id + ".jpg",
+ "image/file_path": image_file_path,
+ "segmentation/class/file_path": class_segmentation_file_path,
+ "segmentation/object/file_path": object_segmentation_file_path,
+ }
+ result.update(image_annotations)
+ # Labels field should be same as the 'object.label'
+ labels = list(set([o["label"] for o in result["objects"]]))
+ result["labels"] = sorted(labels)
+ return result
+
+
+def parse_single_sbd_image(image_file_path):
+ """Creates metadata of SBD images and path."""
+ data_dir, image_file_name = os.path.split(image_file_path)
+ data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
+ image_id, _ = os.path.splitext(image_file_name)
+ class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat")
+ object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat")
+ result = {
+ "image/filename": image_id + ".jpg",
+ "image/file_path": image_file_path,
+ "segmentation/class/file_path": class_segmentation_file_path,
+ "segmentation/object/file_path": object_segmentation_file_path,
+ }
+ return result
+
+
+def build_metadata(data_dir, image_ids):
+ """Transpose the metadata which convert from list of dict to dict of list."""
+ # Parallel process all the images.
+ image_file_paths = [
+ os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
+ ]
+ pool_size = 10 if len(image_ids) > 10 else len(image_ids)
+ with multiprocessing.Pool(pool_size) as p:
+ metadata = p.map(parse_single_image, image_file_paths)
+
+ keys = [
+ "image/filename",
+ "image/file_path",
+ "segmentation/class/file_path",
+ "segmentation/object/file_path",
+ "labels",
+ "width",
+ "height",
+ ]
+ result = {}
+ for key in keys:
+ values = [value[key] for value in metadata]
+ result[key] = values
+
+ # The ragged objects need some special handling
+ for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]:
+ values = []
+ objects = [value["objects"] for value in metadata]
+ for object in objects:
+ values.append([o[key] for o in object])
+ result["objects/" + key] = values
+ return result
+
+
+def build_sbd_metadata(data_dir, image_ids):
+ """Transpose the metadata which convert from list of dict to dict of list."""
+ # Parallel process all the images.
+ image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]
+ pool_size = 10 if len(image_ids) > 10 else len(image_ids)
+ with multiprocessing.Pool(pool_size) as p:
+ metadata = p.map(parse_single_sbd_image, image_file_paths)
+
+ keys = [
+ "image/filename",
+ "image/file_path",
+ "segmentation/class/file_path",
+ "segmentation/object/file_path",
+ ]
+ result = {}
+ for key in keys:
+ values = [value[key] for value in metadata]
+ result[key] = values
+ return result
+
+
+def decode_png_mask(mask):
+ """Decode the raw PNG image and convert it to 2D tensor with probably
+ class."""
+ # Cast the mask to int32 since the original uint8 will overflow when
+ # multiplied with 256
+ mask = tf.cast(mask, tf.int32)
+ mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
+ mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)
+ mask = tf.cast(mask, tf.uint8)
+ return mask
+
+
+def load_images(example):
+ """Loads VOC images for segmentation task from the provided paths"""
+ image_file_path = example.pop("image/file_path")
+ segmentation_class_file_path = example.pop("segmentation/class/file_path")
+ segmentation_object_file_path = example.pop("segmentation/object/file_path")
+ image = tf.io.read_file(image_file_path)
+ image = tf.image.decode_jpeg(image)
+
+ segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)
+ segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)
+ segmentation_class_mask = decode_png_mask(segmentation_class_mask)
+
+ segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)
+ segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)
+ segmentation_object_mask = decode_png_mask(segmentation_object_mask)
+
+ example.update(
+ {
+ "image": image,
+ "class_segmentation": segmentation_class_mask,
+ "object_segmentation": segmentation_object_mask,
+ }
+ )
+ return example
+
+
+def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
+ """Loads SBD images for segmentation task from the provided paths"""
+ image = tf.io.read_file(image_file_path)
+ image = tf.image.decode_jpeg(image)
+
+ segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)
+ segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0]
+ segmentation_class_mask = segmentation_class_mask[..., np.newaxis]
+
+ segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(
+ seg_obj_file_path
+ )
+ segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0]
+ segmentation_object_mask = segmentation_object_mask[..., np.newaxis]
+
+ return {
+ "image": image,
+ "class_segmentation": segmentation_class_mask,
+ "object_segmentation": segmentation_object_mask,
+ }
+
+
+def build_dataset_from_metadata(metadata):
+ """Builds TensorFlow dataset from the image metadata of VOC dataset."""
+ # The objects need some manual conversion to ragged tensor.
+ metadata["labels"] = tf.ragged.constant(metadata["labels"])
+ metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
+ metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"])
+ metadata["objects/is_truncated"] = tf.ragged.constant(
+ metadata["objects/is_truncated"]
+ )
+ metadata["objects/is_difficult"] = tf.ragged.constant(
+ metadata["objects/is_difficult"]
+ )
+ metadata["objects/bbox"] = tf.ragged.constant(
+ metadata["objects/bbox"], ragged_rank=1
+ )
+
+ dataset = tf.data.Dataset.from_tensor_slices(metadata)
+ dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
+ return dataset
+
+
+def build_sbd_dataset_from_metadata(metadata):
+ """Builds TensorFlow dataset from the image metadata of SBD dataset."""
+ img_filepath = metadata["image/file_path"]
+ cls_filepath = metadata["segmentation/class/file_path"]
+ obj_filepath = metadata["segmentation/object/file_path"]
+
+ def md_gen():
+ c = list(zip(img_filepath, cls_filepath, obj_filepath))
+ # random shuffling for each generator boosts up the quality.
+ random.shuffle(c)
+ for fp in c:
+ img_fp, cls_fp, obj_fp = fp
+ yield load_sbd_images(img_fp, cls_fp, obj_fp)
+
+ dataset = tf.data.Dataset.from_generator(
+ md_gen,
+ output_signature=(
+ {
+ "image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
+ "class_segmentation": tf.TensorSpec(
+ shape=(None, None, 1), dtype=tf.uint8
+ ),
+ "object_segmentation": tf.TensorSpec(
+ shape=(None, None, 1), dtype=tf.uint8
+ ),
+ }
+ ),
+ )
+
+ return dataset
+
+
+def load(
+ split="sbd_train",
+ data_dir=None,
+):
+ """Load the Pacal VOC 2012 dataset.
+
+ This function will download the data tar file from remote if needed, and
+ untar to the local `data_dir`, and build dataset from it.
+
+ It supports both VOC2012 and Semantic Boundaries Dataset (SBD).
+
+ The returned segmentation masks will be int ranging from [0, num_classes),
+ as well as 255 which is the boundary mask.
+
+ Args:
+ split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or
+ 'sbd_eval'. 'sbd_train' represents the training dataset for SBD
+ dataset, while 'train' represents the training dataset for VOC2012
+ dataset. Defaults to `sbd_train`.
+ data_dir: string, local directory path for the loaded data. This will be
+ used to download the data file, and unzip. It will be used as a
+ cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`
+ will be used.
+ """
+ supported_split_value = [
+ "train",
+ "eval",
+ "trainval",
+ "sbd_train",
+ "sbd_eval",
+ ]
+ if split not in supported_split_value:
+ raise ValueError(
+ f"The support value for `split` are {supported_split_value}. "
+ f"Got: {split}"
+ )
+
+ if data_dir is not None:
+ data_dir = os.path.expanduser(data_dir)
+
+ if "sbd" in split:
+ return load_sbd(split, data_dir)
+ else:
+ return load_voc(split, data_dir)
+
+
+def load_voc(
+ split="train",
+ data_dir=None,
+):
+ """This function will download VOC data from a URL. If the data is already
+ present in the cache directory, it will load the data from that directory
+ instead.
+ """
+ extracted_dir = os.path.join("VOCdevkit", "VOC2012")
+ get_data = keras.utils.get_file(
+ fname=os.path.basename(VOC_URL),
+ origin=VOC_URL,
+ cache_dir=data_dir,
+ extract=True,
+ )
+ data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
+ image_ids = get_image_ids(data_dir, split)
+ # len(metadata) = #samples, metadata[i] is a dict.
+ metadata = build_metadata(data_dir, image_ids)
+ maybe_populate_voc_color_mapping()
+ dataset = build_dataset_from_metadata(metadata)
+
+ return dataset
+
+
+def load_sbd(
+ split="sbd_train",
+ data_dir=None,
+):
+ """This function will download SBD data from a URL. If the data is already
+ present in the cache directory, it will load the data from that directory
+ instead.
+ """
+ extracted_dir = os.path.join("benchmark_RELEASE", "dataset")
+ get_data = keras.utils.get_file(
+ fname=os.path.basename(SBD_URL),
+ origin=SBD_URL,
+ cache_dir=data_dir,
+ extract=True,
+ )
+ data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
+ image_ids = get_sbd_image_ids(data_dir, split)
+ # len(metadata) = #samples, metadata[i] is a dict.
+ metadata = build_sbd_metadata(data_dir, image_ids)
+
+ dataset = build_sbd_dataset_from_metadata(metadata)
+ return dataset
+
+```
+
+---
+## Load the dataset
+
+For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also
+choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',
+'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the
+SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.
+
+
+```python
+train_ds = load(split="sbd_train", data_dir="segmentation")
+eval_ds = load(split="sbd_eval", data_dir="segmentation")
+```
+
+---
+## Preprocess the data
+
+The preprocess_inputs utility function preprocesses inputs, converting them into
+a dictionary containing images and segmentation_masks. Both images and
+segmentation masks are resized to 512x512. The resulting dataset is then batched
+into groups of four image and segmentation mask pairs.
+
+
+```python
+
+def preprocess_inputs(inputs):
+ def unpackage_inputs(inputs):
+ return {
+ "images": inputs["image"],
+ "segmentation_masks": inputs["class_segmentation"],
+ }
+
+ outputs = inputs.map(unpackage_inputs)
+ outputs = outputs.map(keras.layers.Resizing(height=512, width=512))
+ outputs = outputs.batch(4, drop_remainder=True)
+ return outputs
+
+
+train_ds = preprocess_inputs(train_ds)
+batch = train_ds.take(1).get_single_element()
+```
+
+A batch of this preprocessed input training data can be visualized using the
+`plot_images_masks` function. This function takes a batch of images and
+segmentation masks and prediction masks as input and displays them in a grid.
+
+
+```python
+
+def plot_images_masks(images, masks, pred_masks=None):
+ images = (images - np.min(images)) / (np.max(images) - np.min(images))
+ masks = (masks - np.min(masks)) / (np.max(masks) - np.min(masks))
+ if pred_masks is not None:
+ pred_masks = (pred_masks - pred_masks.min()) / (
+ pred_masks.max() - pred_masks.min()
+ )
+ num_images = len(images)
+ plt.figure(figsize=(8, 4))
+ rows = 3 if pred_masks is not None else 2
+
+ for i in range(num_images):
+ plt.subplot(rows, num_images, i + 1)
+ plt.imshow(images[i])
+ plt.axis("off")
+
+ plt.subplot(rows, num_images, num_images + i + 1)
+ plt.imshow(masks[i], cmap="gray")
+ plt.axis("off")
+
+ if pred_masks is not None:
+ plt.subplot(rows, num_images, i + 1 + 2 * num_images)
+ plt.imshow(pred_masks[i, ..., 0], cmap="gray")
+ plt.axis("off")
+
+ plt.show()
+
+
+plot_images_masks(batch["images"], batch["segmentation_masks"])
+```
+
+
+
+
+
+
+
+The preprocessing is applied to the evaluation dataset `eval_ds`.
+
+
+```python
+eval_ds = preprocess_inputs(eval_ds)
+```
+
+---
+## Data Augmentation
+
+Keras provides a variety of image augmentation options. In this example, we will
+use the `RandomFlip` augmentation to augment the training dataset. The
+`RandomFlip` augmentation randomly flips the images in the training dataset
+horizontally or vertically. This can help to improve the model's robustness to
+changes in the orientation of the objects in the images.
+
+
+```python
+train_ds = train_ds.map(keras.layers.RandomFlip())
+batch = train_ds.take(1).get_single_element()
+
+plot_images_masks(batch["images"], batch["segmentation_masks"])
+```
+
+
+
+
+
+
+
+---
+## Model Configuration
+
+Please feel free to modify the configurations for model training and note how the
+training results changes. This is an great exercise to get a better
+understanding of the training pipeline.
+
+The learning rate schedule is used by the optimizer to calculate the learning
+rate for each epoch. The optimizer then uses the learning rate to update the
+weights of the model.
+In this case, the learning rate schedule uses a cosine decay function. A cosine
+decay function starts high and then decreases over time, eventually reaching
+zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The
+dataset cardinality is important for learning rate decay because it determines
+how many steps the model will train for. The initial learning rate is
+proportional to 0.007 and the decay steps are 2124. This means that the learning
+rate will start at `INITIAL_LR` and then decrease to zero over 2124 steps.
+
+
+
+```python
+BATCH_SIZE = 4
+INITIAL_LR = 0.007 * BATCH_SIZE / 16
+EPOCHS = 1
+NUM_CLASSES = 21
+learning_rate = keras.optimizers.schedules.CosineDecay(
+ INITIAL_LR,
+ decay_steps=EPOCHS * 2124,
+)
+```
+
+Let's take the `resnet_50_imagenet` pretrained weights as a image encoder for
+the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with
+additional decoder block.
+For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing
+`low_level_feature_key` as `P2` a pyramid level output to extract features from
+`resnet_50_imagenet` which acts as a decoder block.
+To use this model as DeepLabV3 architecture, ignore the `low_level_feature_key`
+which defaults to `None`.
+
+Then we create DeepLabV3ImageSegmenter instance.
+The `num_classes` parameter specifies the number of classes that the model will
+be trained to segment. `preprocessor` argument to apply preprocessing to image
+input and masks.
+
+
+```python
+image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
+
+deeplab_backbone = keras_hub.models.DeepLabV3Backbone(
+ image_encoder=image_encoder,
+ low_level_feature_key="P2",
+ spatial_pyramid_pooling_key="P5",
+ dilation_rates=[6, 12, 18],
+ upsampling_size=8,
+)
+
+model = keras_hub.models.DeepLabV3ImageSegmenter(
+ backbone=deeplab_backbone,
+ num_classes=21,
+ activation="softmax",
+ preprocessor=preprocessor,
+)
+```
+
+---
+## Compile the model
+
+The model.compile() function sets up the training process for the model. It defines the
+- optimization algorithm - Stochastic Gradient Descent (SGD)
+- the loss function - categorical cross-entropy
+- the evaluation metrics - Mean IoU and categorical accuracy
+
+Semantic segmentation evaluation metrics:
+
+Mean Intersection over Union (MeanIoU):
+MeanIoU measures how well a semantic segmentation model accurately identifies
+and delineates different objects or regions in an image. It calculates the
+overlap between predicted and actual object boundaries, providing a score
+between 0 and 1, where 1 represents a perfect match.
+
+Categorical Accuracy:
+Categorical Accuracy measures the proportion of correctly classified pixels in
+an image. It gives a simple percentage indicating how accurately the model
+predicts the categories of pixels in the entire image.
+
+In essence, MeanIoU emphasizes the accuracy of identifying specific object
+boundaries, while Categorical Accuracy gives a broad overview of overall
+pixel-level correctness.
+
+
+```python
+model.compile(
+ optimizer=keras.optimizers.SGD(
+ learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0
+ ),
+ loss=keras.losses.CategoricalCrossentropy(from_logits=False),
+ metrics=[
+ keras.metrics.MeanIoU(
+ num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False
+ ),
+ keras.metrics.CategoricalAccuracy(),
+ ],
+)
+
+model.summary()
+```
+
+
+
Preprocessor: "deep_lab_v3_image_segmenter_preprocessor"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Config ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ deep_lab_v3_image_converter (DeepLabV3ImageConverter) │ Image size: (512, 512) │ +└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘ ++ + + + +
Model: "deep_lab_v3_image_segmenter"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ +│ inputs (InputLayer) │ (None, None, None, 3) │ 0 │ +├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ +│ deep_lab_v3_backbone (DeepLabV3Backbone) │ (None, None, None, 256) │ 39,190,656 │ +├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ +│ segmentation_output (Conv2D) │ (None, None, None, 21) │ 5,376 │ +└───────────────────────────────────────────────┴────────────────────────────────────┴─────────────────────┘ ++ + + + +
Total params: 39,196,032 (149.52 MB) ++ + + + +
Trainable params: 39,139,232 (149.30 MB) ++ + + + +
Non-trainable params: 56,800 (221.88 KB) ++ + + +The utility function `dict_to_tuple` effectively transforms the dictionaries of +training and validation datasets into tuples of images and one-hot encoded +segmentation masks, which is used during training and evaluation of the +DeepLabv3+ model. + + +```python + +def dict_to_tuple(x): + + return x["images"], tf.one_hot( + tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21 + ) + + +train_ds = train_ds.map(dict_to_tuple) +eval_ds = eval_ds.map(dict_to_tuple) + +model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS) +``` + + +