3232structures common to real-world images in both foreground and background.
3333"""
3434
35- """shell
36- wget http://saliencydetection.net/duts/download/DUTS-TE.zip
37- unzip -q DUTS-TE.zip
38- """
39-
4035import os
4136
37+ # Because of the use of tf.image.ssim in the loss,
38+ # this example requires TensorFlow. The rest of the code
39+ # is backend-agnostic.
4240os .environ ["KERAS_BACKEND" ] = "tensorflow"
41+
4342import numpy as np
4443from glob import glob
4544import matplotlib .pyplot as plt
4948import keras
5049from keras import layers , ops
5150
51+ keras .config .disable_traceback_filtering ()
52+
5253"""
5354## Define Hyperparameters
5455"""
5758BATCH_SIZE = 4
5859OUT_CLASSES = 1
5960TRAIN_SPLIT_RATIO = 0.90
60- DATA_DIR = "./DUTS-TE/"
6161
6262"""
63- ## Create `PyDataset`s
63+ ## Create `PyDataset`s
6464
6565We will use `load_paths()` to load and split 140 paths into train and validation set, and
6666convert paths into `PyDataset` object.
6767"""
6868
69+ data_dir = keras .utils .get_file (
70+ origin = "http://saliencydetection.net/duts/download/DUTS-TE.zip" ,
71+ extract = True ,
72+ )
73+ data_dir = os .path .join (data_dir , "DUTS-TE" )
74+
6975
7076def load_paths (path , split_ratio ):
7177 images = sorted (glob (os .path .join (path , "DUTS-TE-Image/*" )))[:140 ]
@@ -103,7 +109,9 @@ def __getitem__(self, idx):
103109 batch_x , batch_y = [], []
104110 for i in range (idx * self .batch_size , (idx + 1 ) * self .batch_size ):
105111 x , y = self .preprocess (
106- self .image_paths [i ], self .mask_paths [i ], self .img_size , self .out_classes
112+ self .image_paths [i ],
113+ self .mask_paths [i ],
114+ self .img_size ,
107115 )
108116 batch_x .append (x )
109117 batch_y .append (y )
@@ -117,13 +125,13 @@ def read_image(self, path, size, mode):
117125 x = (x / 255.0 ).astype (np .float32 )
118126 return x
119127
120- def preprocess (self , x_batch , y_batch , img_size , out_classes ):
128+ def preprocess (self , x_batch , y_batch , img_size ):
121129 images = self .read_image (x_batch , (img_size , img_size ), mode = "rgb" ) # image
122130 masks = self .read_image (y_batch , (img_size , img_size ), mode = "grayscale" ) # mask
123131 return images , masks
124132
125133
126- train_paths , val_paths = load_paths (DATA_DIR , TRAIN_SPLIT_RATIO )
134+ train_paths , val_paths = load_paths (data_dir , TRAIN_SPLIT_RATIO )
127135
128136train_dataset = Dataset (
129137 train_paths [0 ], train_paths [1 ], IMAGE_SIZE , OUT_CLASSES , BATCH_SIZE , shuffle = True
@@ -148,8 +156,9 @@ def display(display_list):
148156 plt .show ()
149157
150158
151- for ( image , mask ), _ in zip ( val_dataset , range ( 1 )) :
159+ for image , mask in val_dataset :
152160 display ([image [0 ], mask [0 ]])
161+ break
153162
154163"""
155164## Analyze Mask
@@ -343,52 +352,37 @@ def basnet_rrm(base_model, out_classes):
343352 # ------------- refined = coarse + residual
344353 x = layers .Add ()([x_input , x ]) # Add prediction + refinement output
345354
346- return keras .models .Model (inputs = [ base_model .input ], outputs = [ x ] )
355+ return keras .models .Model (inputs = base_model .input [ 0 ], outputs = x )
347356
348357
349358"""
350359## Combine Predict and Refinement Module
351360"""
352361
353362
354- def basnet (input_shape , out_classes ):
355- """BASNet, it's a combination of two modules
356- Prediction Module and Residual Refinement Module(RRM)."""
357-
358- # Prediction model.
359- predict_model = basnet_predict (input_shape , out_classes )
360- # Refinement model.
361- refine_model = basnet_rrm (predict_model , out_classes )
362-
363- output = refine_model .outputs # Combine outputs.
364- output .extend (predict_model .output )
365-
366- output = [layers .Activation ("sigmoid" )(_ ) for _ in output ] # Activations.
367-
368- return keras .models .Model (inputs = [predict_model .input ], outputs = output )
369-
370-
371- """
372- ## Hybrid Loss
363+ class BASNet (keras .Model ):
364+ def __init__ (self , input_shape , out_classes ):
365+ """BASNet, it's a combination of two modules
366+ Prediction Module and Residual Refinement Module(RRM)."""
373367
374- Another important feature of BASNet is its hybrid loss function, which is a combination of
375- binary cross entropy, structural similarity and intersection-over-union losses, which guide
376- the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations .
377- """
368+ # Prediction model.
369+ predict_model = basnet_predict ( input_shape , out_classes )
370+ # Refinement model .
371+ refine_model = basnet_rrm ( predict_model , out_classes )
378372
373+ output = refine_model .outputs # Combine outputs.
374+ output .extend (predict_model .output )
379375
380- class BasnetLoss (keras .losses .Loss ):
381- """BASNet hybrid loss."""
376+ # Activations.
377+ output = [layers .Activation ("sigmoid" )(x ) for x in output ]
378+ super ().__init__ (inputs = predict_model .input [0 ], outputs = output )
382379
383- def __init__ (self , ** kwargs ):
384- super ().__init__ (name = "basnet_loss" , ** kwargs )
385380 self .smooth = 1.0e-9
386-
387381 # Binary Cross Entropy loss.
388382 self .cross_entropy_loss = keras .losses .BinaryCrossentropy ()
389383 # Structural Similarity Index value.
390384 self .ssim_value = tf .image .ssim
391- # Jaccard / IoU loss.
385+ # Jaccard / IoU loss.
392386 self .iou_value = self .calculate_iou
393387
394388 def calculate_iou (
@@ -402,28 +396,39 @@ def calculate_iou(
402396 union = union - intersection
403397 return ops .mean ((intersection + self .smooth ) / (union + self .smooth ), axis = 0 )
404398
405- def call (self , y_true , y_pred ):
406- cross_entropy_loss = self .cross_entropy_loss (y_true , y_pred )
399+ def compute_loss (self , x , y_true , y_pred , sample_weight = None , training = False ):
400+ total = 0.0
401+ for y_pred_i in y_pred : # y_pred = refine_model.outputs + predict_model.output
402+ cross_entropy_loss = self .cross_entropy_loss (y_true , y_pred_i )
403+
404+ ssim_value = self .ssim_value (y_true , y_pred , max_val = 1 )
405+ ssim_loss = ops .mean (1 - ssim_value + self .smooth , axis = 0 )
406+
407+ iou_value = self .iou_value (y_true , y_pred )
408+ iou_loss = 1 - iou_value
407409
408- ssim_value = self .ssim_value (y_true , y_pred , max_val = 1 )
409- ssim_loss = ops .mean (1 - ssim_value + self .smooth , axis = 0 )
410+ # Add all three losses.
411+ total += cross_entropy_loss + ssim_loss + iou_loss
412+ return total
410413
411- iou_value = self .iou_value (y_true , y_pred )
412- iou_loss = 1 - iou_value
413414
414- # Add all three losses.
415- return cross_entropy_loss + ssim_loss + iou_loss
415+ """
416+ ## Hybrid Loss
417+
418+ Another important feature of BASNet is its hybrid loss function, which is a combination of
419+ binary cross entropy, structural similarity and intersection-over-union losses, which guide
420+ the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
421+ """
416422
417423
418- basnet_model = basnet (
424+ basnet_model = BASNet (
419425 input_shape = [IMAGE_SIZE , IMAGE_SIZE , 3 ], out_classes = OUT_CLASSES
420426) # Create model.
421427basnet_model .summary () # Show model summary.
422428
423429optimizer = keras .optimizers .Adam (learning_rate = 1e-4 , epsilon = 1e-8 )
424430# Compile model.
425431basnet_model .compile (
426- loss = BasnetLoss (),
427432 optimizer = optimizer ,
428433 metrics = [keras .metrics .MeanAbsoluteError (name = "mae" ) for _ in basnet_model .outputs ],
429434)
0 commit comments