Skip to content

Commit 3117146

Browse files
committed
Update basnet segmentation example
1 parent 9fdad44 commit 3117146

File tree

2 files changed

+135
-120
lines changed

2 files changed

+135
-120
lines changed

examples/vision/basnet_segmentation.py

Lines changed: 56 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@
3232
structures common to real-world images in both foreground and background.
3333
"""
3434

35-
"""shell
36-
wget http://saliencydetection.net/duts/download/DUTS-TE.zip
37-
unzip -q DUTS-TE.zip
38-
"""
39-
4035
import os
4136

37+
# Because of the use of tf.image.ssim in the loss,
38+
# this example requires TensorFlow. The rest of the code
39+
# is backend-agnostic.
4240
os.environ["KERAS_BACKEND"] = "tensorflow"
41+
4342
import numpy as np
4443
from glob import glob
4544
import matplotlib.pyplot as plt
@@ -49,6 +48,8 @@
4948
import keras
5049
from keras import layers, ops
5150

51+
keras.config.disable_traceback_filtering()
52+
5253
"""
5354
## Define Hyperparameters
5455
"""
@@ -57,15 +58,20 @@
5758
BATCH_SIZE = 4
5859
OUT_CLASSES = 1
5960
TRAIN_SPLIT_RATIO = 0.90
60-
DATA_DIR = "./DUTS-TE/"
6161

6262
"""
63-
## Create `PyDataset`s
63+
## Create `PyDataset`s
6464
6565
We will use `load_paths()` to load and split 140 paths into train and validation set, and
6666
convert paths into `PyDataset` object.
6767
"""
6868

69+
data_dir = keras.utils.get_file(
70+
origin="http://saliencydetection.net/duts/download/DUTS-TE.zip",
71+
extract=True,
72+
)
73+
data_dir = os.path.join(data_dir, "DUTS-TE")
74+
6975

7076
def load_paths(path, split_ratio):
7177
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
@@ -103,7 +109,9 @@ def __getitem__(self, idx):
103109
batch_x, batch_y = [], []
104110
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
105111
x, y = self.preprocess(
106-
self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes
112+
self.image_paths[i],
113+
self.mask_paths[i],
114+
self.img_size,
107115
)
108116
batch_x.append(x)
109117
batch_y.append(y)
@@ -117,13 +125,13 @@ def read_image(self, path, size, mode):
117125
x = (x / 255.0).astype(np.float32)
118126
return x
119127

120-
def preprocess(self, x_batch, y_batch, img_size, out_classes):
128+
def preprocess(self, x_batch, y_batch, img_size):
121129
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
122130
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
123131
return images, masks
124132

125133

126-
train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)
134+
train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)
127135

128136
train_dataset = Dataset(
129137
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
@@ -148,8 +156,9 @@ def display(display_list):
148156
plt.show()
149157

150158

151-
for (image, mask), _ in zip(val_dataset, range(1)):
159+
for image, mask in val_dataset:
152160
display([image[0], mask[0]])
161+
break
153162

154163
"""
155164
## Analyze Mask
@@ -343,52 +352,37 @@ def basnet_rrm(base_model, out_classes):
343352
# ------------- refined = coarse + residual
344353
x = layers.Add()([x_input, x]) # Add prediction + refinement output
345354

346-
return keras.models.Model(inputs=[base_model.input], outputs=[x])
355+
return keras.models.Model(inputs=base_model.input[0], outputs=x)
347356

348357

349358
"""
350359
## Combine Predict and Refinement Module
351360
"""
352361

353362

354-
def basnet(input_shape, out_classes):
355-
"""BASNet, it's a combination of two modules
356-
Prediction Module and Residual Refinement Module(RRM)."""
357-
358-
# Prediction model.
359-
predict_model = basnet_predict(input_shape, out_classes)
360-
# Refinement model.
361-
refine_model = basnet_rrm(predict_model, out_classes)
362-
363-
output = refine_model.outputs # Combine outputs.
364-
output.extend(predict_model.output)
365-
366-
output = [layers.Activation("sigmoid")(_) for _ in output] # Activations.
367-
368-
return keras.models.Model(inputs=[predict_model.input], outputs=output)
369-
370-
371-
"""
372-
## Hybrid Loss
363+
class BASNet(keras.Model):
364+
def __init__(self, input_shape, out_classes):
365+
"""BASNet, it's a combination of two modules
366+
Prediction Module and Residual Refinement Module(RRM)."""
373367

374-
Another important feature of BASNet is its hybrid loss function, which is a combination of
375-
binary cross entropy, structural similarity and intersection-over-union losses, which guide
376-
the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
377-
"""
368+
# Prediction model.
369+
predict_model = basnet_predict(input_shape, out_classes)
370+
# Refinement model.
371+
refine_model = basnet_rrm(predict_model, out_classes)
378372

373+
output = refine_model.outputs # Combine outputs.
374+
output.extend(predict_model.output)
379375

380-
class BasnetLoss(keras.losses.Loss):
381-
"""BASNet hybrid loss."""
376+
# Activations.
377+
output = [layers.Activation("sigmoid")(x) for x in output]
378+
super().__init__(inputs=predict_model.input[0], outputs=output)
382379

383-
def __init__(self, **kwargs):
384-
super().__init__(name="basnet_loss", **kwargs)
385380
self.smooth = 1.0e-9
386-
387381
# Binary Cross Entropy loss.
388382
self.cross_entropy_loss = keras.losses.BinaryCrossentropy()
389383
# Structural Similarity Index value.
390384
self.ssim_value = tf.image.ssim
391-
# Jaccard / IoU loss.
385+
# Jaccard / IoU loss.
392386
self.iou_value = self.calculate_iou
393387

394388
def calculate_iou(
@@ -402,28 +396,39 @@ def calculate_iou(
402396
union = union - intersection
403397
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)
404398

405-
def call(self, y_true, y_pred):
406-
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
399+
def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
400+
total = 0.0
401+
for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output
402+
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)
403+
404+
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
405+
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
406+
407+
iou_value = self.iou_value(y_true, y_pred)
408+
iou_loss = 1 - iou_value
407409

408-
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
409-
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
410+
# Add all three losses.
411+
total += cross_entropy_loss + ssim_loss + iou_loss
412+
return total
410413

411-
iou_value = self.iou_value(y_true, y_pred)
412-
iou_loss = 1 - iou_value
413414

414-
# Add all three losses.
415-
return cross_entropy_loss + ssim_loss + iou_loss
415+
"""
416+
## Hybrid Loss
417+
418+
Another important feature of BASNet is its hybrid loss function, which is a combination of
419+
binary cross entropy, structural similarity and intersection-over-union losses, which guide
420+
the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
421+
"""
416422

417423

418-
basnet_model = basnet(
424+
basnet_model = BASNet(
419425
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES
420426
) # Create model.
421427
basnet_model.summary() # Show model summary.
422428

423429
optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
424430
# Compile model.
425431
basnet_model.compile(
426-
loss=BasnetLoss(),
427432
optimizer=optimizer,
428433
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
429434
)

0 commit comments

Comments
 (0)