Skip to content

Commit 09642ec

Browse files
address comments
1 parent 9d1b056 commit 09642ec

File tree

1 file changed

+74
-103
lines changed

1 file changed

+74
-103
lines changed

guides/keras_hub/semantic_segmentation_deeplab_v3.py

Lines changed: 74 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
Title: Semantic Segmentation with KerasHub
3-
Author: [Sachin Prasad](https://github.com/sachinprasad)<br>
4-
Date created: 2024/10/11<br>
5-
Last modified: 2024/10/11<br>
6-
Description: Train and use DeepLabv3 and DeepLabv3+ segmentation model with KerasHub.
3+
Author: [Sachin Prasad](https://github.com/sachinprasad)
4+
Date created: 2024/10/11
5+
Last modified: 2024/10/11
6+
Description: DeepLabV3 training and inference with KerasHub
77
Accelerator: GPU
88
"""
99

@@ -12,26 +12,25 @@
1212
1313
## Background
1414
Semantic segmentation is a type of computer vision task that involves assigning a
15-
class label such as person, bike, or background to each individual pixel of an
16-
image, effectively dividing the image into regions that correspond to different
17-
fobject classes or categories.
15+
class label such as "person", "bike", or "background" to each individual pixel
16+
of an image, effectively dividing the image into regions that correspond to
17+
different object classes or categories.
1818
1919
![](https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png)
2020
2121
2222
23-
KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer etc models for semantic
23+
KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic
2424
segmentation.
2525
26-
This guide demonstrates how to finetune and use DeepLabv3+ model which is
27-
devoloped by Google for image semantic segmentaion with KerasHub. Its
28-
architecture that combines atrous convolutions, contextual information
29-
aggregation, and powerful backbones to achieve accurate and detailed semantic
30-
segmentation.
26+
This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed
27+
by Google for image semantic segmentation with KerasHub. Its architecture
28+
combines Atrous convolutions, contextual information aggregation, and powerful
29+
backbones to achieve accurate and detailed semantic segmentation.
3130
32-
DeepLabv3+, extends DeepLabv3 by adding a simple yet effective decoder module to
33-
refine the segmentation results especially along object boundaries both these
34-
models have achienved state-of-the-art results on a variety of image segmentation
31+
DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to
32+
refine the segmentation results, especially along object boundaries. Both models
33+
have achieved state-of-the-art results on a variety of image segmentation
3534
benchmarks.
3635
3736
### References
@@ -84,7 +83,7 @@ class label such as person, bike, or background to each individual pixel of an
8483
`keras_hub.models` API. This API includes fully pretrained semantic segmentation
8584
models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.
8685
87-
Let's get started by constructing a DeepLabv3 pretrained on the pascalvoc
86+
Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC
8887
dataset.
8988
Also, define the preprocessing function for the model to preprocess images and
9089
labels.
@@ -154,7 +153,6 @@ def plot_segmentation(original_image, predicted_mask):
154153
import multiprocessing
155154
import os.path
156155
import random
157-
import tarfile
158156
import xml
159157

160158
import tensorflow_datasets as tfds
@@ -221,11 +219,11 @@ def plot_segmentation(original_image, predicted_mask):
221219
[128, 192, 0],
222220
[0, 64, 128],
223221
]
224-
# Will be populated by _maybe_populate_voc_color_mapping() below.
222+
# Will be populated by maybe_populate_voc_color_mapping() below.
225223
VOC_PNG_COLOR_MAPPING = None
226224

227225

228-
def _maybe_populate_voc_color_mapping():
226+
def maybe_populate_voc_color_mapping():
229227
# Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.
230228
global VOC_PNG_COLOR_MAPPING
231229
if VOC_PNG_COLOR_MAPPING is None:
@@ -240,52 +238,14 @@ def _maybe_populate_voc_color_mapping():
240238
return VOC_PNG_COLOR_MAPPING
241239

242240

243-
def _download_data_file(
244-
data_url, extracted_dir, local_dir_path=None, override_extract=False
245-
):
246-
"""Fetch the original VOC or Semantic Boundaries Dataset from remote URL.
247-
248-
Args:
249-
data_url: string, the URL for the data to be downloaded, should be in a
250-
zipped tar package.
251-
local_dir_path: string, the local directory path to save the data.
252-
Returns:
253-
the path to the folder of extracted data.
254-
"""
255-
if not local_dir_path:
256-
# download to ~/.keras/datasets/fname
257-
cache_dir = os.path.join(os.path.expanduser("~"), ".keras/datasets")
258-
fname = os.path.join(os.path.basename(data_url))
259-
else:
260-
# Make sure the directory exists
261-
if not os.path.exists(local_dir_path):
262-
os.makedirs(local_dir_path, exist_ok=True)
263-
# download to local_dir_path/fname
264-
fname = os.path.join(os.path.basename(data_url))
265-
cache_dir = local_dir_path
266-
data_directory = os.path.join(os.path.dirname(fname), extracted_dir)
267-
if not override_extract and os.path.exists(data_directory):
268-
logging.info("data directory %s already exist", data_directory)
269-
return data_directory
270-
data_file_path = keras.utils.get_file(
271-
fname=fname, origin=data_url, cache_dir=cache_dir
272-
)
273-
# Extra the data into the same directory as the tar file.
274-
data_directory = os.path.dirname(data_file_path)
275-
logging.info("Extract data into %s", data_directory)
276-
with tarfile.open(data_file_path) as f:
277-
f.extractall(data_directory)
278-
return os.path.join(data_directory, extracted_dir)
279-
280-
281-
def _parse_annotation_data(annotation_file_path):
241+
def parse_annotation_data(annotation_file_path):
282242
"""Parse the annotation XML file for the image.
283243
284244
The annotation contains the metadata, as well as the object bounding box
285245
information.
286246
287247
"""
288-
with tf.io.gfile.GFile(annotation_file_path, "r") as f:
248+
with open(annotation_file_path, "r") as f:
289249
root = xml.etree.ElementTree.parse(f).getroot()
290250

291251
size = root.find("size")
@@ -318,15 +278,13 @@ def _parse_annotation_data(annotation_file_path):
318278
return {"width": width, "height": height, "objects": objects}
319279

320280

321-
def _get_image_ids(data_dir, split):
281+
def get_image_ids(data_dir, split):
322282
data_file_mapping = {
323283
"train": "train.txt",
324284
"eval": "val.txt",
325285
"trainval": "trainval.txt",
326-
# TODO(tanzhenyu): add diff dataset
327-
# "diff": "diff.txt",
328286
}
329-
with tf.io.gfile.GFile(
287+
with open(
330288
os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),
331289
"r",
332290
) as f:
@@ -335,9 +293,9 @@ def _get_image_ids(data_dir, split):
335293
return image_ids
336294

337295

338-
def _get_sbd_image_ids(data_dir, split):
296+
def get_sbd_image_ids(data_dir, split):
339297
data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}
340-
with tf.io.gfile.GFile(
298+
with open(
341299
os.path.join(data_dir, data_file_mapping[split]),
342300
"r",
343301
) as f:
@@ -346,7 +304,7 @@ def _get_sbd_image_ids(data_dir, split):
346304
return image_ids
347305

348306

349-
def _parse_single_image(image_file_path):
307+
def parse_single_image(image_file_path):
350308
data_dir, image_file_name = os.path.split(image_file_path)
351309
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
352310
image_id, _ = os.path.splitext(image_file_name)
@@ -357,7 +315,7 @@ def _parse_single_image(image_file_path):
357315
data_dir, "SegmentationObject", image_id + ".png"
358316
)
359317
annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")
360-
image_annotations = _parse_annotation_data(annotation_file_path)
318+
image_annotations = parse_annotation_data(annotation_file_path)
361319

362320
result = {
363321
"image/filename": image_id + ".jpg",
@@ -372,7 +330,7 @@ def _parse_single_image(image_file_path):
372330
return result
373331

374332

375-
def _parse_single_sbd_image(image_file_path):
333+
def parse_single_sbd_image(image_file_path):
376334
data_dir, image_file_name = os.path.split(image_file_path)
377335
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
378336
image_id, _ = os.path.splitext(image_file_name)
@@ -387,14 +345,14 @@ def _parse_single_sbd_image(image_file_path):
387345
return result
388346

389347

390-
def _build_metadata(data_dir, image_ids):
348+
def build_metadata(data_dir, image_ids):
391349
# Parallel process all the images.
392350
image_file_paths = [
393351
os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
394352
]
395353
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
396354
with multiprocessing.Pool(pool_size) as p:
397-
metadata = p.map(_parse_single_image, image_file_paths)
355+
metadata = p.map(parse_single_image, image_file_paths)
398356

399357
# Transpose the metadata which convert from list of dict to dict of list.
400358
keys = [
@@ -421,12 +379,12 @@ def _build_metadata(data_dir, image_ids):
421379
return result
422380

423381

424-
def _build_sbd_metadata(data_dir, image_ids):
382+
def build_sbd_metadata(data_dir, image_ids):
425383
# Parallel process all the images.
426384
image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]
427385
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
428386
with multiprocessing.Pool(pool_size) as p:
429-
metadata = p.map(_parse_single_sbd_image, image_file_paths)
387+
metadata = p.map(parse_single_sbd_image, image_file_paths)
430388

431389
keys = [
432390
"image/filename",
@@ -441,8 +399,7 @@ def _build_sbd_metadata(data_dir, image_ids):
441399
return result
442400

443401

444-
@tf.function(jit_compile=True)
445-
def _decode_png_mask(mask):
402+
def decode_png_mask(mask):
446403
"""Decode the raw PNG image and convert it to 2D tensor with probably
447404
class."""
448405
# Cast the mask to int32 since the original uint8 will overflow when
@@ -454,7 +411,7 @@ def _decode_png_mask(mask):
454411
return mask
455412

456413

457-
def _load_images(example):
414+
def load_images(example):
458415
image_file_path = example.pop("image/file_path")
459416
segmentation_class_file_path = example.pop("segmentation/class/file_path")
460417
segmentation_object_file_path = example.pop("segmentation/object/file_path")
@@ -463,11 +420,11 @@ def _load_images(example):
463420

464421
segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)
465422
segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)
466-
segmentation_class_mask = _decode_png_mask(segmentation_class_mask)
423+
segmentation_class_mask = decode_png_mask(segmentation_class_mask)
467424

468425
segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)
469426
segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)
470-
segmentation_object_mask = _decode_png_mask(segmentation_object_mask)
427+
segmentation_object_mask = decode_png_mask(segmentation_object_mask)
471428

472429
example.update(
473430
{
@@ -479,7 +436,7 @@ def _load_images(example):
479436
return example
480437

481438

482-
def _load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
439+
def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
483440
image = tf.io.read_file(image_file_path)
484441
image = tf.image.decode_jpeg(image)
485442

@@ -500,7 +457,7 @@ def _load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
500457
}
501458

502459

503-
def _build_dataset_from_metadata(metadata):
460+
def build_dataset_from_metadata(metadata):
504461
# The objects need some manual conversion to ragged tensor.
505462
metadata["labels"] = tf.ragged.constant(metadata["labels"])
506463
metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
@@ -516,11 +473,11 @@ def _build_dataset_from_metadata(metadata):
516473
)
517474

518475
dataset = tf.data.Dataset.from_tensor_slices(metadata)
519-
dataset = dataset.map(_load_images, num_parallel_calls=tf.data.AUTOTUNE)
476+
dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
520477
return dataset
521478

522479

523-
def _build_sbd_dataset_from_metadata(metadata):
480+
def build_sbd_dataset_from_metadata(metadata):
524481
img_filepath = metadata["image/file_path"]
525482
cls_filepath = metadata["segmentation/class/file_path"]
526483
obj_filepath = metadata["segmentation/object/file_path"]
@@ -531,7 +488,7 @@ def md_gen():
531488
random.shuffle(c)
532489
for fp in c:
533490
img_fp, cls_fp, obj_fp = fp
534-
yield _load_sbd_images(img_fp, cls_fp, obj_fp)
491+
yield load_sbd_images(img_fp, cls_fp, obj_fp)
535492

536493
dataset = tf.data.Dataset.from_generator(
537494
md_gen,
@@ -592,56 +549,70 @@ def load(
592549
data_dir = os.path.expanduser(data_dir)
593550

594551
if "sbd" in split:
595-
return _load_sbd(split, data_dir)
552+
return load_sbd(split, data_dir)
596553
else:
597-
return _load_voc(split, data_dir)
554+
return load_voc(split, data_dir)
598555

599556

600-
def _load_voc(
557+
def load_voc(
601558
split="train",
602559
data_dir=None,
603560
):
604561
extracted_dir = os.path.join("VOCdevkit", "VOC2012")
605-
data_dir = _download_data_file(
606-
VOC_URL, extracted_dir=extracted_dir, local_dir_path=data_dir
562+
get_data = keras.utils.get_file(
563+
fname=os.path.basename(VOC_URL),
564+
origin=VOC_URL,
565+
cache_dir=data_dir,
566+
extract=True,
607567
)
608-
image_ids = _get_image_ids(data_dir, split)
568+
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
569+
image_ids = get_image_ids(data_dir, split)
609570
# len(metadata) = #samples, metadata[i] is a dict.
610-
metadata = _build_metadata(data_dir, image_ids)
611-
_maybe_populate_voc_color_mapping()
612-
dataset = _build_dataset_from_metadata(metadata)
571+
metadata = build_metadata(data_dir, image_ids)
572+
maybe_populate_voc_color_mapping()
573+
dataset = build_dataset_from_metadata(metadata)
613574

614575
return dataset
615576

616577

617-
def _load_sbd(
578+
def load_sbd(
618579
split="sbd_train",
619580
data_dir=None,
620581
):
621582
extracted_dir = os.path.join("benchmark_RELEASE", "dataset")
622-
data_dir = _download_data_file(
623-
SBD_URL, extracted_dir=extracted_dir, local_dir_path=data_dir
583+
get_data = keras.utils.get_file(
584+
fname=os.path.basename(SBD_URL),
585+
origin=SBD_URL,
586+
cache_dir=data_dir,
587+
extract=True,
624588
)
625-
image_ids = _get_sbd_image_ids(data_dir, split)
589+
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
590+
image_ids = get_sbd_image_ids(data_dir, split)
626591
# len(metadata) = #samples, metadata[i] is a dict.
627-
metadata = _build_sbd_metadata(data_dir, image_ids)
628-
dataset = _build_sbd_dataset_from_metadata(metadata)
592+
metadata = build_sbd_metadata(data_dir, image_ids)
593+
594+
dataset = build_sbd_dataset_from_metadata(metadata)
629595
return dataset
630596

631597

632598
"""
633-
Load the dataset for training and evaluation.
599+
## Load the dataset
600+
601+
For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also
602+
choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',
603+
'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the
604+
SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.
634605
"""
635606
train_ds = load(split="sbd_train")
636607
eval_ds = load(split="sbd_eval")
637608

638609
"""
639610
## Preprocess the data
640611
641-
The `preprocess_inputs` utility function preprocesses the inputs to a dictionary
642-
of `images` and `segmentation_masks`. The images and segmentation masks are
643-
resized to 512x512. The resulting dataset is then batched into groups of 4 image
644-
and segmentation mask pairs.
612+
The preprocess_inputs utility function preprocesses inputs, converting them into
613+
a dictionary containing images and segmentation_masks. Both images and
614+
segmentation masks are resized to 512x512. The resulting dataset is then batched
615+
into groups of four image and segmentation mask pairs.
645616
"""
646617

647618

0 commit comments

Comments
 (0)