22Title: Highly accurate boundaries segmentation using BASNet
33Author: [Hamid Ali](https://github.com/hamidriasat)
44Date created: 2023/05/30
5- Last modified: 2023/07/13
5+ Last modified: 2024/10/02
66Description: Boundaries aware segmentation model trained on the DUTS dataset.
77Accelerator: GPU
88"""
3838"""
3939
4040import os
41+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
4142import numpy as np
4243from glob import glob
4344import matplotlib .pyplot as plt
4445
4546import keras_cv
4647import tensorflow as tf
47- from tensorflow import keras
48- from tensorflow . keras import layers , backend
48+ import keras
49+ from keras import layers , ops
4950
5051"""
5152## Define Hyperparameters
5859DATA_DIR = "./DUTS-TE/"
5960
6061"""
61- ## Create TensorFlow Dataset
62+ ## Create `PyDataset`s
6263
6364We will use `load_paths()` to load and split 140 paths into train and validation set, and
64- `load_dataset()` to convert paths into `tf.data.Dataset ` object.
65+ convert paths into `PyDataset ` object.
6566"""
6667
6768
@@ -71,52 +72,53 @@ def load_paths(path, split_ratio):
7172 len_ = int (len (images ) * split_ratio )
7273 return (images [:len_ ], masks [:len_ ]), (images [len_ :], masks [len_ :])
7374
74-
75- def read_image (path , size , mode ):
76- x = keras .utils .load_img (path , target_size = size , color_mode = mode )
77- x = keras .utils .img_to_array (x )
78- x = (x / 255.0 ).astype (np .float32 )
79- return x
80-
81-
82- def preprocess (x_batch , y_batch , img_size , out_classes ):
83- def f (_x , _y ):
84- _x , _y = _x .decode (), _y .decode ()
85- _x = read_image (_x , (img_size , img_size ), mode = "rgb" ) # image
86- _y = read_image (_y , (img_size , img_size ), mode = "grayscale" ) # mask
87- return _x , _y
88-
89- images , masks = tf .numpy_function (f , [x_batch , y_batch ], [tf .float32 , tf .float32 ])
90- images .set_shape ([img_size , img_size , 3 ])
91- masks .set_shape ([img_size , img_size , out_classes ])
92- return images , masks
93-
94-
95- def load_dataset (image_paths , mask_paths , img_size , out_classes , batch , shuffle = True ):
96- dataset = tf .data .Dataset .from_tensor_slices ((image_paths , mask_paths ))
97- if shuffle :
98- dataset = dataset .cache ().shuffle (buffer_size = 1000 )
99- dataset = dataset .map (
100- lambda x , y : preprocess (x , y , img_size , out_classes ),
101- num_parallel_calls = tf .data .AUTOTUNE ,
102- )
103- dataset = dataset .batch (batch )
104- dataset = dataset .prefetch (tf .data .AUTOTUNE )
105- return dataset
75+ class Dataset (keras .utils .PyDataset ):
76+ def __init__ (self , image_paths , mask_paths , img_size , out_classes , batch , shuffle = True , ** kwargs ):
77+ if shuffle :
78+ perm = np .random .permutation (len (image_paths ))
79+ image_paths = [ image_paths [i ] for i in perm ]
80+ mask_paths = [ mask_paths [i ] for i in perm ]
81+ self .image_paths = image_paths
82+ self .mask_paths = mask_paths
83+ self .img_size = img_size
84+ self .out_classes = out_classes
85+ self .batch_size = batch
86+ super ().__init__ (* kwargs )
87+
88+ def __len__ (self ):
89+ return len (self .image_paths ) // self .batch_size
90+
91+ def __getitem__ (self , idx ):
92+ batch_x , batch_y = [],[]
93+ for i in range (idx * self .batch_size , (idx + 1 )* self .batch_size ):
94+ x ,y = self .preprocess (self .image_paths [i ], self .mask_paths [i ], self .img_size , self .out_classes )
95+ batch_x .append (x )
96+ batch_y .append (y )
97+ batch_x = np .stack (batch_x , axis = 0 )
98+ batch_y = np .stack (batch_y , axis = 0 )
99+ return batch_x , batch_y
100+
101+ def read_image (self , path , size , mode ):
102+ x = keras .utils .load_img (path , target_size = size , color_mode = mode )
103+ x = keras .utils .img_to_array (x )
104+ x = (x / 255.0 ).astype (np .float32 )
105+ return x
106+
107+ def preprocess (self , x_batch , y_batch , img_size , out_classes ):
108+ images = self .read_image (x_batch , (img_size , img_size ), mode = "rgb" ) # image
109+ masks = self .read_image (y_batch , (img_size , img_size ), mode = "grayscale" ) # mask
110+ return images , masks
106111
107112
108113train_paths , val_paths = load_paths (DATA_DIR , TRAIN_SPLIT_RATIO )
109114
110- train_dataset = load_dataset (
115+ train_dataset = Dataset (
111116 train_paths [0 ], train_paths [1 ], IMAGE_SIZE , OUT_CLASSES , BATCH_SIZE , shuffle = True
112117)
113- val_dataset = load_dataset (
118+ val_dataset = Dataset (
114119 val_paths [0 ], val_paths [1 ], IMAGE_SIZE , OUT_CLASSES , BATCH_SIZE , shuffle = False
115120)
116121
117- print (f"Train Dataset: { train_dataset } " )
118- print (f"Validation Dataset: { val_dataset } " )
119-
120122"""
121123## Visualize Data
122124"""
@@ -133,7 +135,7 @@ def display(display_list):
133135 plt .show ()
134136
135137
136- for image , mask in val_dataset . take ( 1 ):
138+ for ( image , mask ), _ in zip ( val_dataset , range ( 1 ) ):
137139 display ([image [0 ], mask [0 ]])
138140
139141"""
@@ -265,7 +267,7 @@ def basnet_predict(input_shape, out_classes):
265267 decoder_blocks = []
266268 for i in reversed (range (num_stages )):
267269 if i != (num_stages - 1 ): # Except first, scale other decoder stages.
268- shape = keras . backend . int_shape ( x )
270+ shape = x . shape
269271 x = layers .Resizing (shape [1 ] * 2 , shape [2 ] * 2 )(x )
270272
271273 x = layers .concatenate ([encoder_blocks [i ], x ], axis = - 1 )
@@ -318,7 +320,7 @@ def basnet_rrm(base_model, out_classes):
318320
319321 # -------------Decoder--------------
320322 for i in reversed (range (num_stages )):
321- shape = keras . backend . int_shape ( x )
323+ shape = x . shape
322324 x = layers .Resizing (shape [1 ] * 2 , shape [2 ] * 2 )(x )
323325 x = layers .concatenate ([encoder_blocks [i ], x ], axis = - 1 )
324326 x = convolution_block (x , filters = filters )
@@ -345,7 +347,7 @@ def basnet(input_shape, out_classes):
345347 # Refinement model.
346348 refine_model = basnet_rrm (predict_model , out_classes )
347349
348- output = [ refine_model .output ] # Combine outputs.
350+ output = refine_model .outputs # Combine outputs.
349351 output .extend (predict_model .output )
350352
351353 output = [layers .Activation ("sigmoid" )(_ ) for _ in output ] # Activations.
@@ -382,18 +384,18 @@ def calculate_iou(
382384 y_pred ,
383385 ):
384386 """Calculate intersection over union (IoU) between images."""
385- intersection = backend .sum (backend .abs (y_true * y_pred ), axis = [1 , 2 , 3 ])
386- union = backend .sum (y_true , [1 , 2 , 3 ]) + backend .sum (y_pred , [1 , 2 , 3 ])
387+ intersection = ops .sum (ops .abs (y_true * y_pred ), axis = [1 , 2 , 3 ])
388+ union = ops .sum (y_true , [1 , 2 , 3 ]) + ops .sum (y_pred , [1 , 2 , 3 ])
387389 union = union - intersection
388- return backend .mean (
390+ return ops .mean (
389391 (intersection + self .smooth ) / (union + self .smooth ), axis = 0
390392 )
391393
392394 def call (self , y_true , y_pred ):
393395 cross_entropy_loss = self .cross_entropy_loss (y_true , y_pred )
394396
395397 ssim_value = self .ssim_value (y_true , y_pred , max_val = 1 )
396- ssim_loss = backend .mean (1 - ssim_value + self .smooth , axis = 0 )
398+ ssim_loss = ops .mean (1 - ssim_value + self .smooth , axis = 0 )
397399
398400 iou_value = self .iou_value (y_true , y_pred )
399401 iou_loss = 1 - iou_value
@@ -412,7 +414,7 @@ def call(self, y_true, y_pred):
412414basnet_model .compile (
413415 loss = BasnetLoss (),
414416 optimizer = optimizer ,
415- metrics = [keras .metrics .MeanAbsoluteError (name = "mae" )],
417+ metrics = [keras .metrics .MeanAbsoluteError (name = "mae" ) for _ in basnet_model . outputs ],
416418)
417419
418420"""
@@ -453,6 +455,6 @@ def normalize_output(prediction):
453455### Make Predictions
454456"""
455457
456- for image , mask in val_dataset . take ( 1 ):
458+ for ( image , mask ), _ in zip ( val_dataset , range ( 1 ) ):
457459 pred_mask = basnet_model .predict (image )
458460 display ([image [0 ], mask [0 ], normalize_output (pred_mask [0 ][0 ])])
0 commit comments