22Title: Highly accurate boundaries segmentation using BASNet
33Author: [Hamid Ali](https://github.com/hamidriasat)
44Date created: 2023/05/30
5- Last modified: 2023/07/13
5+ Last modified: 2024/10/02
66Description: Boundaries aware segmentation model trained on the DUTS dataset.
77Accelerator: GPU
88"""
3838"""
3939
4040import os
41+
42+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
4143import numpy as np
4244from glob import glob
4345import matplotlib .pyplot as plt
4446
4547import keras_cv
4648import tensorflow as tf
47- from tensorflow import keras
48- from tensorflow . keras import layers , backend
49+ import keras
50+ from keras import layers , ops
4951
5052"""
5153## Define Hyperparameters
5860DATA_DIR = "./DUTS-TE/"
5961
6062"""
61- ## Create TensorFlow Dataset
63+ ## Create `PyDataset`s
6264
6365We will use `load_paths()` to load and split 140 paths into train and validation set, and
64- `load_dataset()` to convert paths into `tf.data.Dataset ` object.
66+ convert paths into `PyDataset ` object.
6567"""
6668
6769
@@ -72,51 +74,64 @@ def load_paths(path, split_ratio):
7274 return (images [:len_ ], masks [:len_ ]), (images [len_ :], masks [len_ :])
7375
7476
75- def read_image (path , size , mode ):
76- x = keras .utils .load_img (path , target_size = size , color_mode = mode )
77- x = keras .utils .img_to_array (x )
78- x = (x / 255.0 ).astype (np .float32 )
79- return x
80-
81-
82- def preprocess (x_batch , y_batch , img_size , out_classes ):
83- def f (_x , _y ):
84- _x , _y = _x .decode (), _y .decode ()
85- _x = read_image (_x , (img_size , img_size ), mode = "rgb" ) # image
86- _y = read_image (_y , (img_size , img_size ), mode = "grayscale" ) # mask
87- return _x , _y
88-
89- images , masks = tf .numpy_function (f , [x_batch , y_batch ], [tf .float32 , tf .float32 ])
90- images .set_shape ([img_size , img_size , 3 ])
91- masks .set_shape ([img_size , img_size , out_classes ])
92- return images , masks
93-
94-
95- def load_dataset (image_paths , mask_paths , img_size , out_classes , batch , shuffle = True ):
96- dataset = tf .data .Dataset .from_tensor_slices ((image_paths , mask_paths ))
97- if shuffle :
98- dataset = dataset .cache ().shuffle (buffer_size = 1000 )
99- dataset = dataset .map (
100- lambda x , y : preprocess (x , y , img_size , out_classes ),
101- num_parallel_calls = tf .data .AUTOTUNE ,
102- )
103- dataset = dataset .batch (batch )
104- dataset = dataset .prefetch (tf .data .AUTOTUNE )
105- return dataset
77+ class Dataset (keras .utils .PyDataset ):
78+ def __init__ (
79+ self ,
80+ image_paths ,
81+ mask_paths ,
82+ img_size ,
83+ out_classes ,
84+ batch ,
85+ shuffle = True ,
86+ ** kwargs ,
87+ ):
88+ if shuffle :
89+ perm = np .random .permutation (len (image_paths ))
90+ image_paths = [image_paths [i ] for i in perm ]
91+ mask_paths = [mask_paths [i ] for i in perm ]
92+ self .image_paths = image_paths
93+ self .mask_paths = mask_paths
94+ self .img_size = img_size
95+ self .out_classes = out_classes
96+ self .batch_size = batch
97+ super ().__init__ (* kwargs )
98+
99+ def __len__ (self ):
100+ return len (self .image_paths ) // self .batch_size
101+
102+ def __getitem__ (self , idx ):
103+ batch_x , batch_y = [], []
104+ for i in range (idx * self .batch_size , (idx + 1 ) * self .batch_size ):
105+ x , y = self .preprocess (
106+ self .image_paths [i ], self .mask_paths [i ], self .img_size , self .out_classes
107+ )
108+ batch_x .append (x )
109+ batch_y .append (y )
110+ batch_x = np .stack (batch_x , axis = 0 )
111+ batch_y = np .stack (batch_y , axis = 0 )
112+ return batch_x , batch_y
113+
114+ def read_image (self , path , size , mode ):
115+ x = keras .utils .load_img (path , target_size = size , color_mode = mode )
116+ x = keras .utils .img_to_array (x )
117+ x = (x / 255.0 ).astype (np .float32 )
118+ return x
119+
120+ def preprocess (self , x_batch , y_batch , img_size , out_classes ):
121+ images = self .read_image (x_batch , (img_size , img_size ), mode = "rgb" ) # image
122+ masks = self .read_image (y_batch , (img_size , img_size ), mode = "grayscale" ) # mask
123+ return images , masks
106124
107125
108126train_paths , val_paths = load_paths (DATA_DIR , TRAIN_SPLIT_RATIO )
109127
110- train_dataset = load_dataset (
128+ train_dataset = Dataset (
111129 train_paths [0 ], train_paths [1 ], IMAGE_SIZE , OUT_CLASSES , BATCH_SIZE , shuffle = True
112130)
113- val_dataset = load_dataset (
131+ val_dataset = Dataset (
114132 val_paths [0 ], val_paths [1 ], IMAGE_SIZE , OUT_CLASSES , BATCH_SIZE , shuffle = False
115133)
116134
117- print (f"Train Dataset: { train_dataset } " )
118- print (f"Validation Dataset: { val_dataset } " )
119-
120135"""
121136## Visualize Data
122137"""
@@ -133,7 +148,7 @@ def display(display_list):
133148 plt .show ()
134149
135150
136- for image , mask in val_dataset . take ( 1 ):
151+ for ( image , mask ), _ in zip ( val_dataset , range ( 1 ) ):
137152 display ([image [0 ], mask [0 ]])
138153
139154"""
@@ -265,7 +280,7 @@ def basnet_predict(input_shape, out_classes):
265280 decoder_blocks = []
266281 for i in reversed (range (num_stages )):
267282 if i != (num_stages - 1 ): # Except first, scale other decoder stages.
268- shape = keras . backend . int_shape ( x )
283+ shape = x . shape
269284 x = layers .Resizing (shape [1 ] * 2 , shape [2 ] * 2 )(x )
270285
271286 x = layers .concatenate ([encoder_blocks [i ], x ], axis = - 1 )
@@ -318,7 +333,7 @@ def basnet_rrm(base_model, out_classes):
318333
319334 # -------------Decoder--------------
320335 for i in reversed (range (num_stages )):
321- shape = keras . backend . int_shape ( x )
336+ shape = x . shape
322337 x = layers .Resizing (shape [1 ] * 2 , shape [2 ] * 2 )(x )
323338 x = layers .concatenate ([encoder_blocks [i ], x ], axis = - 1 )
324339 x = convolution_block (x , filters = filters )
@@ -345,7 +360,7 @@ def basnet(input_shape, out_classes):
345360 # Refinement model.
346361 refine_model = basnet_rrm (predict_model , out_classes )
347362
348- output = [ refine_model .output ] # Combine outputs.
363+ output = refine_model .outputs # Combine outputs.
349364 output .extend (predict_model .output )
350365
351366 output = [layers .Activation ("sigmoid" )(_ ) for _ in output ] # Activations.
@@ -382,18 +397,16 @@ def calculate_iou(
382397 y_pred ,
383398 ):
384399 """Calculate intersection over union (IoU) between images."""
385- intersection = backend .sum (backend .abs (y_true * y_pred ), axis = [1 , 2 , 3 ])
386- union = backend .sum (y_true , [1 , 2 , 3 ]) + backend .sum (y_pred , [1 , 2 , 3 ])
400+ intersection = ops .sum (ops .abs (y_true * y_pred ), axis = [1 , 2 , 3 ])
401+ union = ops .sum (y_true , [1 , 2 , 3 ]) + ops .sum (y_pred , [1 , 2 , 3 ])
387402 union = union - intersection
388- return backend .mean (
389- (intersection + self .smooth ) / (union + self .smooth ), axis = 0
390- )
403+ return ops .mean ((intersection + self .smooth ) / (union + self .smooth ), axis = 0 )
391404
392405 def call (self , y_true , y_pred ):
393406 cross_entropy_loss = self .cross_entropy_loss (y_true , y_pred )
394407
395408 ssim_value = self .ssim_value (y_true , y_pred , max_val = 1 )
396- ssim_loss = backend .mean (1 - ssim_value + self .smooth , axis = 0 )
409+ ssim_loss = ops .mean (1 - ssim_value + self .smooth , axis = 0 )
397410
398411 iou_value = self .iou_value (y_true , y_pred )
399412 iou_loss = 1 - iou_value
@@ -412,7 +425,7 @@ def call(self, y_true, y_pred):
412425basnet_model .compile (
413426 loss = BasnetLoss (),
414427 optimizer = optimizer ,
415- metrics = [keras .metrics .MeanAbsoluteError (name = "mae" )],
428+ metrics = [keras .metrics .MeanAbsoluteError (name = "mae" ) for _ in basnet_model . outputs ],
416429)
417430
418431"""
@@ -453,6 +466,6 @@ def normalize_output(prediction):
453466### Make Predictions
454467"""
455468
456- for image , mask in val_dataset . take ( 1 ):
469+ for ( image , mask ), _ in zip ( val_dataset , range ( 1 ) ):
457470 pred_mask = basnet_model .predict (image )
458471 display ([image [0 ], mask [0 ], normalize_output (pred_mask [0 ][0 ])])
0 commit comments