11"""
22Title: Semantic Segmentation with KerasHub
3- Author: [Sachin Prasad](https://github.com/sachinprasad)<br>
4- Date created: 2024/10/11<br>
5- Last modified: 2024/10/11<br>
6- Description: Train and use DeepLabv3 and DeepLabv3+ segmentation model with KerasHub.
3+ Author: [Sachin Prasad](https://github.com/sachinprasad)
4+ Date created: 2024/10/11
5+ Last modified: 2024/10/11
6+ Description: DeepLabV3 training and inference with KerasHub
77Accelerator: GPU
88"""
99
1212
1313## Background
1414Semantic segmentation is a type of computer vision task that involves assigning a
15- class label such as person, bike, or background to each individual pixel of an
16- image, effectively dividing the image into regions that correspond to different
17- fobject classes or categories.
15+ class label such as " person", " bike" , or " background" to each individual pixel
16+ of an image, effectively dividing the image into regions that correspond to
17+ different object classes or categories.
1818
1919
2020
2121
2222
23- KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer etc models for semantic
23+ KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic
2424segmentation.
2525
26- This guide demonstrates how to finetune and use DeepLabv3+ model which is
27- devoloped by Google for image semantic segmentaion with KerasHub. Its
28- architecture that combines atrous convolutions, contextual information
29- aggregation, and powerful backbones to achieve accurate and detailed semantic
30- segmentation.
26+ This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed
27+ by Google for image semantic segmentation with KerasHub. Its architecture
28+ combines Atrous convolutions, contextual information aggregation, and powerful
29+ backbones to achieve accurate and detailed semantic segmentation.
3130
32- DeepLabv3+, extends DeepLabv3 by adding a simple yet effective decoder module to
33- refine the segmentation results especially along object boundaries both these
34- models have achienved state-of-the-art results on a variety of image segmentation
31+ DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to
32+ refine the segmentation results, especially along object boundaries. Both models
33+ have achieved state-of-the-art results on a variety of image segmentation
3534benchmarks.
3635
3736### References
@@ -84,7 +83,7 @@ class label such as person, bike, or background to each individual pixel of an
8483`keras_hub.models` API. This API includes fully pretrained semantic segmentation
8584models, such as `keras_hub.models.DeepLabV3ImageSegmenter`.
8685
87- Let's get started by constructing a DeepLabv3 pretrained on the pascalvoc
86+ Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC
8887dataset.
8988Also, define the preprocessing function for the model to preprocess images and
9089labels.
@@ -154,7 +153,6 @@ def plot_segmentation(original_image, predicted_mask):
154153import multiprocessing
155154import os .path
156155import random
157- import tarfile
158156import xml
159157
160158import tensorflow_datasets as tfds
@@ -221,11 +219,11 @@ def plot_segmentation(original_image, predicted_mask):
221219 [128 , 192 , 0 ],
222220 [0 , 64 , 128 ],
223221]
224- # Will be populated by _maybe_populate_voc_color_mapping () below.
222+ # Will be populated by maybe_populate_voc_color_mapping () below.
225223VOC_PNG_COLOR_MAPPING = None
226224
227225
228- def _maybe_populate_voc_color_mapping ():
226+ def maybe_populate_voc_color_mapping ():
229227 # Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory.
230228 global VOC_PNG_COLOR_MAPPING
231229 if VOC_PNG_COLOR_MAPPING is None :
@@ -240,52 +238,14 @@ def _maybe_populate_voc_color_mapping():
240238 return VOC_PNG_COLOR_MAPPING
241239
242240
243- def _download_data_file (
244- data_url , extracted_dir , local_dir_path = None , override_extract = False
245- ):
246- """Fetch the original VOC or Semantic Boundaries Dataset from remote URL.
247-
248- Args:
249- data_url: string, the URL for the data to be downloaded, should be in a
250- zipped tar package.
251- local_dir_path: string, the local directory path to save the data.
252- Returns:
253- the path to the folder of extracted data.
254- """
255- if not local_dir_path :
256- # download to ~/.keras/datasets/fname
257- cache_dir = os .path .join (os .path .expanduser ("~" ), ".keras/datasets" )
258- fname = os .path .join (os .path .basename (data_url ))
259- else :
260- # Make sure the directory exists
261- if not os .path .exists (local_dir_path ):
262- os .makedirs (local_dir_path , exist_ok = True )
263- # download to local_dir_path/fname
264- fname = os .path .join (os .path .basename (data_url ))
265- cache_dir = local_dir_path
266- data_directory = os .path .join (os .path .dirname (fname ), extracted_dir )
267- if not override_extract and os .path .exists (data_directory ):
268- logging .info ("data directory %s already exist" , data_directory )
269- return data_directory
270- data_file_path = keras .utils .get_file (
271- fname = fname , origin = data_url , cache_dir = cache_dir
272- )
273- # Extra the data into the same directory as the tar file.
274- data_directory = os .path .dirname (data_file_path )
275- logging .info ("Extract data into %s" , data_directory )
276- with tarfile .open (data_file_path ) as f :
277- f .extractall (data_directory )
278- return os .path .join (data_directory , extracted_dir )
279-
280-
281- def _parse_annotation_data (annotation_file_path ):
241+ def parse_annotation_data (annotation_file_path ):
282242 """Parse the annotation XML file for the image.
283243
284244 The annotation contains the metadata, as well as the object bounding box
285245 information.
286246
287247 """
288- with tf . io . gfile . GFile (annotation_file_path , "r" ) as f :
248+ with open (annotation_file_path , "r" ) as f :
289249 root = xml .etree .ElementTree .parse (f ).getroot ()
290250
291251 size = root .find ("size" )
@@ -318,15 +278,13 @@ def _parse_annotation_data(annotation_file_path):
318278 return {"width" : width , "height" : height , "objects" : objects }
319279
320280
321- def _get_image_ids (data_dir , split ):
281+ def get_image_ids (data_dir , split ):
322282 data_file_mapping = {
323283 "train" : "train.txt" ,
324284 "eval" : "val.txt" ,
325285 "trainval" : "trainval.txt" ,
326- # TODO(tanzhenyu): add diff dataset
327- # "diff": "diff.txt",
328286 }
329- with tf . io . gfile . GFile (
287+ with open (
330288 os .path .join (data_dir , "ImageSets" , "Segmentation" , data_file_mapping [split ]),
331289 "r" ,
332290 ) as f :
@@ -335,9 +293,9 @@ def _get_image_ids(data_dir, split):
335293 return image_ids
336294
337295
338- def _get_sbd_image_ids (data_dir , split ):
296+ def get_sbd_image_ids (data_dir , split ):
339297 data_file_mapping = {"sbd_train" : "train.txt" , "sbd_eval" : "val.txt" }
340- with tf . io . gfile . GFile (
298+ with open (
341299 os .path .join (data_dir , data_file_mapping [split ]),
342300 "r" ,
343301 ) as f :
@@ -346,7 +304,7 @@ def _get_sbd_image_ids(data_dir, split):
346304 return image_ids
347305
348306
349- def _parse_single_image (image_file_path ):
307+ def parse_single_image (image_file_path ):
350308 data_dir , image_file_name = os .path .split (image_file_path )
351309 data_dir = os .path .normpath (os .path .join (data_dir , os .path .pardir ))
352310 image_id , _ = os .path .splitext (image_file_name )
@@ -357,7 +315,7 @@ def _parse_single_image(image_file_path):
357315 data_dir , "SegmentationObject" , image_id + ".png"
358316 )
359317 annotation_file_path = os .path .join (data_dir , "Annotations" , image_id + ".xml" )
360- image_annotations = _parse_annotation_data (annotation_file_path )
318+ image_annotations = parse_annotation_data (annotation_file_path )
361319
362320 result = {
363321 "image/filename" : image_id + ".jpg" ,
@@ -372,7 +330,7 @@ def _parse_single_image(image_file_path):
372330 return result
373331
374332
375- def _parse_single_sbd_image (image_file_path ):
333+ def parse_single_sbd_image (image_file_path ):
376334 data_dir , image_file_name = os .path .split (image_file_path )
377335 data_dir = os .path .normpath (os .path .join (data_dir , os .path .pardir ))
378336 image_id , _ = os .path .splitext (image_file_name )
@@ -387,14 +345,14 @@ def _parse_single_sbd_image(image_file_path):
387345 return result
388346
389347
390- def _build_metadata (data_dir , image_ids ):
348+ def build_metadata (data_dir , image_ids ):
391349 # Parallel process all the images.
392350 image_file_paths = [
393351 os .path .join (data_dir , "JPEGImages" , i + ".jpg" ) for i in image_ids
394352 ]
395353 pool_size = 10 if len (image_ids ) > 10 else len (image_ids )
396354 with multiprocessing .Pool (pool_size ) as p :
397- metadata = p .map (_parse_single_image , image_file_paths )
355+ metadata = p .map (parse_single_image , image_file_paths )
398356
399357 # Transpose the metadata which convert from list of dict to dict of list.
400358 keys = [
@@ -421,12 +379,12 @@ def _build_metadata(data_dir, image_ids):
421379 return result
422380
423381
424- def _build_sbd_metadata (data_dir , image_ids ):
382+ def build_sbd_metadata (data_dir , image_ids ):
425383 # Parallel process all the images.
426384 image_file_paths = [os .path .join (data_dir , "img" , i + ".jpg" ) for i in image_ids ]
427385 pool_size = 10 if len (image_ids ) > 10 else len (image_ids )
428386 with multiprocessing .Pool (pool_size ) as p :
429- metadata = p .map (_parse_single_sbd_image , image_file_paths )
387+ metadata = p .map (parse_single_sbd_image , image_file_paths )
430388
431389 keys = [
432390 "image/filename" ,
@@ -441,8 +399,7 @@ def _build_sbd_metadata(data_dir, image_ids):
441399 return result
442400
443401
444- @tf .function (jit_compile = True )
445- def _decode_png_mask (mask ):
402+ def decode_png_mask (mask ):
446403 """Decode the raw PNG image and convert it to 2D tensor with probably
447404 class."""
448405 # Cast the mask to int32 since the original uint8 will overflow when
@@ -454,7 +411,7 @@ def _decode_png_mask(mask):
454411 return mask
455412
456413
457- def _load_images (example ):
414+ def load_images (example ):
458415 image_file_path = example .pop ("image/file_path" )
459416 segmentation_class_file_path = example .pop ("segmentation/class/file_path" )
460417 segmentation_object_file_path = example .pop ("segmentation/object/file_path" )
@@ -463,11 +420,11 @@ def _load_images(example):
463420
464421 segmentation_class_mask = tf .io .read_file (segmentation_class_file_path )
465422 segmentation_class_mask = tf .image .decode_png (segmentation_class_mask )
466- segmentation_class_mask = _decode_png_mask (segmentation_class_mask )
423+ segmentation_class_mask = decode_png_mask (segmentation_class_mask )
467424
468425 segmentation_object_mask = tf .io .read_file (segmentation_object_file_path )
469426 segmentation_object_mask = tf .image .decode_png (segmentation_object_mask )
470- segmentation_object_mask = _decode_png_mask (segmentation_object_mask )
427+ segmentation_object_mask = decode_png_mask (segmentation_object_mask )
471428
472429 example .update (
473430 {
@@ -479,7 +436,7 @@ def _load_images(example):
479436 return example
480437
481438
482- def _load_sbd_images (image_file_path , seg_cls_file_path , seg_obj_file_path ):
439+ def load_sbd_images (image_file_path , seg_cls_file_path , seg_obj_file_path ):
483440 image = tf .io .read_file (image_file_path )
484441 image = tf .image .decode_jpeg (image )
485442
@@ -500,7 +457,7 @@ def _load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
500457 }
501458
502459
503- def _build_dataset_from_metadata (metadata ):
460+ def build_dataset_from_metadata (metadata ):
504461 # The objects need some manual conversion to ragged tensor.
505462 metadata ["labels" ] = tf .ragged .constant (metadata ["labels" ])
506463 metadata ["objects/label" ] = tf .ragged .constant (metadata ["objects/label" ])
@@ -516,11 +473,11 @@ def _build_dataset_from_metadata(metadata):
516473 )
517474
518475 dataset = tf .data .Dataset .from_tensor_slices (metadata )
519- dataset = dataset .map (_load_images , num_parallel_calls = tf .data .AUTOTUNE )
476+ dataset = dataset .map (load_images , num_parallel_calls = tf .data .AUTOTUNE )
520477 return dataset
521478
522479
523- def _build_sbd_dataset_from_metadata (metadata ):
480+ def build_sbd_dataset_from_metadata (metadata ):
524481 img_filepath = metadata ["image/file_path" ]
525482 cls_filepath = metadata ["segmentation/class/file_path" ]
526483 obj_filepath = metadata ["segmentation/object/file_path" ]
@@ -531,7 +488,7 @@ def md_gen():
531488 random .shuffle (c )
532489 for fp in c :
533490 img_fp , cls_fp , obj_fp = fp
534- yield _load_sbd_images (img_fp , cls_fp , obj_fp )
491+ yield load_sbd_images (img_fp , cls_fp , obj_fp )
535492
536493 dataset = tf .data .Dataset .from_generator (
537494 md_gen ,
@@ -592,56 +549,70 @@ def load(
592549 data_dir = os .path .expanduser (data_dir )
593550
594551 if "sbd" in split :
595- return _load_sbd (split , data_dir )
552+ return load_sbd (split , data_dir )
596553 else :
597- return _load_voc (split , data_dir )
554+ return load_voc (split , data_dir )
598555
599556
600- def _load_voc (
557+ def load_voc (
601558 split = "train" ,
602559 data_dir = None ,
603560):
604561 extracted_dir = os .path .join ("VOCdevkit" , "VOC2012" )
605- data_dir = _download_data_file (
606- VOC_URL , extracted_dir = extracted_dir , local_dir_path = data_dir
562+ get_data = keras .utils .get_file (
563+ fname = os .path .basename (VOC_URL ),
564+ origin = VOC_URL ,
565+ cache_dir = data_dir ,
566+ extract = True ,
607567 )
608- image_ids = _get_image_ids (data_dir , split )
568+ data_dir = os .path .join (os .path .dirname (get_data ), extracted_dir )
569+ image_ids = get_image_ids (data_dir , split )
609570 # len(metadata) = #samples, metadata[i] is a dict.
610- metadata = _build_metadata (data_dir , image_ids )
611- _maybe_populate_voc_color_mapping ()
612- dataset = _build_dataset_from_metadata (metadata )
571+ metadata = build_metadata (data_dir , image_ids )
572+ maybe_populate_voc_color_mapping ()
573+ dataset = build_dataset_from_metadata (metadata )
613574
614575 return dataset
615576
616577
617- def _load_sbd (
578+ def load_sbd (
618579 split = "sbd_train" ,
619580 data_dir = None ,
620581):
621582 extracted_dir = os .path .join ("benchmark_RELEASE" , "dataset" )
622- data_dir = _download_data_file (
623- SBD_URL , extracted_dir = extracted_dir , local_dir_path = data_dir
583+ get_data = keras .utils .get_file (
584+ fname = os .path .basename (SBD_URL ),
585+ origin = SBD_URL ,
586+ cache_dir = data_dir ,
587+ extract = True ,
624588 )
625- image_ids = _get_sbd_image_ids (data_dir , split )
589+ data_dir = os .path .join (os .path .dirname (get_data ), extracted_dir )
590+ image_ids = get_sbd_image_ids (data_dir , split )
626591 # len(metadata) = #samples, metadata[i] is a dict.
627- metadata = _build_sbd_metadata (data_dir , image_ids )
628- dataset = _build_sbd_dataset_from_metadata (metadata )
592+ metadata = build_sbd_metadata (data_dir , image_ids )
593+
594+ dataset = build_sbd_dataset_from_metadata (metadata )
629595 return dataset
630596
631597
632598"""
633- Load the dataset for training and evaluation.
599+ ## Load the dataset
600+
601+ For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also
602+ choose any of these datasets for the `load` function: 'train', 'eval', 'trainval',
603+ 'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the
604+ SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.
634605"""
635606train_ds = load (split = "sbd_train" )
636607eval_ds = load (split = "sbd_eval" )
637608
638609"""
639610## Preprocess the data
640611
641- The ` preprocess_inputs` utility function preprocesses the inputs to a dictionary
642- of ` images` and ` segmentation_masks`. The images and segmentation masks are
643- resized to 512x512. The resulting dataset is then batched into groups of 4 image
644- and segmentation mask pairs.
612+ The preprocess_inputs utility function preprocesses inputs, converting them into
613+ a dictionary containing images and segmentation_masks. Both images and
614+ segmentation masks are resized to 512x512. The resulting dataset is then batched
615+ into groups of four image and segmentation mask pairs.
645616"""
646617
647618
0 commit comments