Skip to content

Commit 9fdad44

Browse files
authored
Highly accurate boundaries segmentation using BASNet to keras 3.0 (Tensorflow backend only) (#1942)
* Keras 3 migration basnet segmentation * Fix format issues
1 parent 07b6a7e commit 9fdad44

File tree

1 file changed

+66
-53
lines changed

1 file changed

+66
-53
lines changed

examples/vision/basnet_segmentation.py

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Highly accurate boundaries segmentation using BASNet
33
Author: [Hamid Ali](https://github.com/hamidriasat)
44
Date created: 2023/05/30
5-
Last modified: 2023/07/13
5+
Last modified: 2024/10/02
66
Description: Boundaries aware segmentation model trained on the DUTS dataset.
77
Accelerator: GPU
88
"""
@@ -38,14 +38,16 @@
3838
"""
3939

4040
import os
41+
42+
os.environ["KERAS_BACKEND"] = "tensorflow"
4143
import numpy as np
4244
from glob import glob
4345
import matplotlib.pyplot as plt
4446

4547
import keras_cv
4648
import tensorflow as tf
47-
from tensorflow import keras
48-
from tensorflow.keras import layers, backend
49+
import keras
50+
from keras import layers, ops
4951

5052
"""
5153
## Define Hyperparameters
@@ -58,10 +60,10 @@
5860
DATA_DIR = "./DUTS-TE/"
5961

6062
"""
61-
## Create TensorFlow Dataset
63+
## Create `PyDataset`s
6264
6365
We will use `load_paths()` to load and split 140 paths into train and validation set, and
64-
`load_dataset()` to convert paths into `tf.data.Dataset` object.
66+
convert paths into `PyDataset` object.
6567
"""
6668

6769

@@ -72,51 +74,64 @@ def load_paths(path, split_ratio):
7274
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
7375

7476

75-
def read_image(path, size, mode):
76-
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
77-
x = keras.utils.img_to_array(x)
78-
x = (x / 255.0).astype(np.float32)
79-
return x
80-
81-
82-
def preprocess(x_batch, y_batch, img_size, out_classes):
83-
def f(_x, _y):
84-
_x, _y = _x.decode(), _y.decode()
85-
_x = read_image(_x, (img_size, img_size), mode="rgb") # image
86-
_y = read_image(_y, (img_size, img_size), mode="grayscale") # mask
87-
return _x, _y
88-
89-
images, masks = tf.numpy_function(f, [x_batch, y_batch], [tf.float32, tf.float32])
90-
images.set_shape([img_size, img_size, 3])
91-
masks.set_shape([img_size, img_size, out_classes])
92-
return images, masks
93-
94-
95-
def load_dataset(image_paths, mask_paths, img_size, out_classes, batch, shuffle=True):
96-
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
97-
if shuffle:
98-
dataset = dataset.cache().shuffle(buffer_size=1000)
99-
dataset = dataset.map(
100-
lambda x, y: preprocess(x, y, img_size, out_classes),
101-
num_parallel_calls=tf.data.AUTOTUNE,
102-
)
103-
dataset = dataset.batch(batch)
104-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
105-
return dataset
77+
class Dataset(keras.utils.PyDataset):
78+
def __init__(
79+
self,
80+
image_paths,
81+
mask_paths,
82+
img_size,
83+
out_classes,
84+
batch,
85+
shuffle=True,
86+
**kwargs,
87+
):
88+
if shuffle:
89+
perm = np.random.permutation(len(image_paths))
90+
image_paths = [image_paths[i] for i in perm]
91+
mask_paths = [mask_paths[i] for i in perm]
92+
self.image_paths = image_paths
93+
self.mask_paths = mask_paths
94+
self.img_size = img_size
95+
self.out_classes = out_classes
96+
self.batch_size = batch
97+
super().__init__(*kwargs)
98+
99+
def __len__(self):
100+
return len(self.image_paths) // self.batch_size
101+
102+
def __getitem__(self, idx):
103+
batch_x, batch_y = [], []
104+
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
105+
x, y = self.preprocess(
106+
self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes
107+
)
108+
batch_x.append(x)
109+
batch_y.append(y)
110+
batch_x = np.stack(batch_x, axis=0)
111+
batch_y = np.stack(batch_y, axis=0)
112+
return batch_x, batch_y
113+
114+
def read_image(self, path, size, mode):
115+
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
116+
x = keras.utils.img_to_array(x)
117+
x = (x / 255.0).astype(np.float32)
118+
return x
119+
120+
def preprocess(self, x_batch, y_batch, img_size, out_classes):
121+
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
122+
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
123+
return images, masks
106124

107125

108126
train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)
109127

110-
train_dataset = load_dataset(
128+
train_dataset = Dataset(
111129
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
112130
)
113-
val_dataset = load_dataset(
131+
val_dataset = Dataset(
114132
val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
115133
)
116134

117-
print(f"Train Dataset: {train_dataset}")
118-
print(f"Validation Dataset: {val_dataset}")
119-
120135
"""
121136
## Visualize Data
122137
"""
@@ -133,7 +148,7 @@ def display(display_list):
133148
plt.show()
134149

135150

136-
for image, mask in val_dataset.take(1):
151+
for (image, mask), _ in zip(val_dataset, range(1)):
137152
display([image[0], mask[0]])
138153

139154
"""
@@ -265,7 +280,7 @@ def basnet_predict(input_shape, out_classes):
265280
decoder_blocks = []
266281
for i in reversed(range(num_stages)):
267282
if i != (num_stages - 1): # Except first, scale other decoder stages.
268-
shape = keras.backend.int_shape(x)
283+
shape = x.shape
269284
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
270285

271286
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
@@ -318,7 +333,7 @@ def basnet_rrm(base_model, out_classes):
318333

319334
# -------------Decoder--------------
320335
for i in reversed(range(num_stages)):
321-
shape = keras.backend.int_shape(x)
336+
shape = x.shape
322337
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
323338
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
324339
x = convolution_block(x, filters=filters)
@@ -345,7 +360,7 @@ def basnet(input_shape, out_classes):
345360
# Refinement model.
346361
refine_model = basnet_rrm(predict_model, out_classes)
347362

348-
output = [refine_model.output] # Combine outputs.
363+
output = refine_model.outputs # Combine outputs.
349364
output.extend(predict_model.output)
350365

351366
output = [layers.Activation("sigmoid")(_) for _ in output] # Activations.
@@ -382,18 +397,16 @@ def calculate_iou(
382397
y_pred,
383398
):
384399
"""Calculate intersection over union (IoU) between images."""
385-
intersection = backend.sum(backend.abs(y_true * y_pred), axis=[1, 2, 3])
386-
union = backend.sum(y_true, [1, 2, 3]) + backend.sum(y_pred, [1, 2, 3])
400+
intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
401+
union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
387402
union = union - intersection
388-
return backend.mean(
389-
(intersection + self.smooth) / (union + self.smooth), axis=0
390-
)
403+
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)
391404

392405
def call(self, y_true, y_pred):
393406
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
394407

395408
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
396-
ssim_loss = backend.mean(1 - ssim_value + self.smooth, axis=0)
409+
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
397410

398411
iou_value = self.iou_value(y_true, y_pred)
399412
iou_loss = 1 - iou_value
@@ -412,7 +425,7 @@ def call(self, y_true, y_pred):
412425
basnet_model.compile(
413426
loss=BasnetLoss(),
414427
optimizer=optimizer,
415-
metrics=[keras.metrics.MeanAbsoluteError(name="mae")],
428+
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
416429
)
417430

418431
"""
@@ -453,6 +466,6 @@ def normalize_output(prediction):
453466
### Make Predictions
454467
"""
455468

456-
for image, mask in val_dataset.take(1):
469+
for (image, mask), _ in zip(val_dataset, range(1)):
457470
pred_mask = basnet_model.predict(image)
458471
display([image[0], mask[0], normalize_output(pred_mask[0][0])])

0 commit comments

Comments
 (0)