Skip to content

Commit fdabc58

Browse files
committed
Keras 3 migration basnet segmentation
1 parent 773263c commit fdabc58

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

examples/vision/basnet_segmentation.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Highly accurate boundaries segmentation using BASNet
33
Author: [Hamid Ali](https://github.com/hamidriasat)
44
Date created: 2023/05/30
5-
Last modified: 2023/07/13
5+
Last modified: 2024/10/02
66
Description: Boundaries aware segmentation model trained on the DUTS dataset.
77
Accelerator: GPU
88
"""
@@ -38,14 +38,15 @@
3838
"""
3939

4040
import os
41+
os.environ["KERAS_BACKEND"] = "tensorflow"
4142
import numpy as np
4243
from glob import glob
4344
import matplotlib.pyplot as plt
4445

4546
import keras_cv
4647
import tensorflow as tf
47-
from tensorflow import keras
48-
from tensorflow.keras import layers, backend
48+
import keras
49+
from keras import layers, ops
4950

5051
"""
5152
## Define Hyperparameters
@@ -58,10 +59,10 @@
5859
DATA_DIR = "./DUTS-TE/"
5960

6061
"""
61-
## Create TensorFlow Dataset
62+
## Create `PyDataset`s
6263
6364
We will use `load_paths()` to load and split 140 paths into train and validation set, and
64-
`load_dataset()` to convert paths into `tf.data.Dataset` object.
65+
convert paths into `PyDataset` object.
6566
"""
6667

6768

@@ -71,52 +72,53 @@ def load_paths(path, split_ratio):
7172
len_ = int(len(images) * split_ratio)
7273
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
7374

74-
75-
def read_image(path, size, mode):
76-
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
77-
x = keras.utils.img_to_array(x)
78-
x = (x / 255.0).astype(np.float32)
79-
return x
80-
81-
82-
def preprocess(x_batch, y_batch, img_size, out_classes):
83-
def f(_x, _y):
84-
_x, _y = _x.decode(), _y.decode()
85-
_x = read_image(_x, (img_size, img_size), mode="rgb") # image
86-
_y = read_image(_y, (img_size, img_size), mode="grayscale") # mask
87-
return _x, _y
88-
89-
images, masks = tf.numpy_function(f, [x_batch, y_batch], [tf.float32, tf.float32])
90-
images.set_shape([img_size, img_size, 3])
91-
masks.set_shape([img_size, img_size, out_classes])
92-
return images, masks
93-
94-
95-
def load_dataset(image_paths, mask_paths, img_size, out_classes, batch, shuffle=True):
96-
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
97-
if shuffle:
98-
dataset = dataset.cache().shuffle(buffer_size=1000)
99-
dataset = dataset.map(
100-
lambda x, y: preprocess(x, y, img_size, out_classes),
101-
num_parallel_calls=tf.data.AUTOTUNE,
102-
)
103-
dataset = dataset.batch(batch)
104-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
105-
return dataset
75+
class Dataset(keras.utils.PyDataset):
76+
def __init__(self, image_paths, mask_paths, img_size, out_classes, batch, shuffle=True, **kwargs):
77+
if shuffle:
78+
perm = np.random.permutation(len(image_paths))
79+
image_paths = [ image_paths[i] for i in perm ]
80+
mask_paths = [ mask_paths[i] for i in perm ]
81+
self.image_paths = image_paths
82+
self.mask_paths = mask_paths
83+
self.img_size = img_size
84+
self.out_classes = out_classes
85+
self.batch_size = batch
86+
super().__init__(*kwargs)
87+
88+
def __len__(self):
89+
return len(self.image_paths) // self.batch_size
90+
91+
def __getitem__(self, idx):
92+
batch_x, batch_y = [],[]
93+
for i in range(idx*self.batch_size, (idx+1)*self.batch_size):
94+
x,y = self.preprocess(self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes)
95+
batch_x.append(x)
96+
batch_y.append(y)
97+
batch_x = np.stack(batch_x, axis=0)
98+
batch_y = np.stack(batch_y, axis=0)
99+
return batch_x, batch_y
100+
101+
def read_image(self, path, size, mode):
102+
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
103+
x = keras.utils.img_to_array(x)
104+
x = (x / 255.0).astype(np.float32)
105+
return x
106+
107+
def preprocess(self, x_batch, y_batch, img_size, out_classes):
108+
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
109+
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
110+
return images, masks
106111

107112

108113
train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)
109114

110-
train_dataset = load_dataset(
115+
train_dataset = Dataset(
111116
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
112117
)
113-
val_dataset = load_dataset(
118+
val_dataset = Dataset(
114119
val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
115120
)
116121

117-
print(f"Train Dataset: {train_dataset}")
118-
print(f"Validation Dataset: {val_dataset}")
119-
120122
"""
121123
## Visualize Data
122124
"""
@@ -133,7 +135,7 @@ def display(display_list):
133135
plt.show()
134136

135137

136-
for image, mask in val_dataset.take(1):
138+
for (image, mask),_ in zip(val_dataset, range(1)):
137139
display([image[0], mask[0]])
138140

139141
"""
@@ -265,7 +267,7 @@ def basnet_predict(input_shape, out_classes):
265267
decoder_blocks = []
266268
for i in reversed(range(num_stages)):
267269
if i != (num_stages - 1): # Except first, scale other decoder stages.
268-
shape = keras.backend.int_shape(x)
270+
shape = x.shape
269271
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
270272

271273
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
@@ -318,7 +320,7 @@ def basnet_rrm(base_model, out_classes):
318320

319321
# -------------Decoder--------------
320322
for i in reversed(range(num_stages)):
321-
shape = keras.backend.int_shape(x)
323+
shape = x.shape
322324
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
323325
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
324326
x = convolution_block(x, filters=filters)
@@ -345,7 +347,7 @@ def basnet(input_shape, out_classes):
345347
# Refinement model.
346348
refine_model = basnet_rrm(predict_model, out_classes)
347349

348-
output = [refine_model.output] # Combine outputs.
350+
output = refine_model.outputs # Combine outputs.
349351
output.extend(predict_model.output)
350352

351353
output = [layers.Activation("sigmoid")(_) for _ in output] # Activations.
@@ -382,18 +384,18 @@ def calculate_iou(
382384
y_pred,
383385
):
384386
"""Calculate intersection over union (IoU) between images."""
385-
intersection = backend.sum(backend.abs(y_true * y_pred), axis=[1, 2, 3])
386-
union = backend.sum(y_true, [1, 2, 3]) + backend.sum(y_pred, [1, 2, 3])
387+
intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
388+
union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
387389
union = union - intersection
388-
return backend.mean(
390+
return ops.mean(
389391
(intersection + self.smooth) / (union + self.smooth), axis=0
390392
)
391393

392394
def call(self, y_true, y_pred):
393395
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
394396

395397
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
396-
ssim_loss = backend.mean(1 - ssim_value + self.smooth, axis=0)
398+
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
397399

398400
iou_value = self.iou_value(y_true, y_pred)
399401
iou_loss = 1 - iou_value
@@ -412,7 +414,7 @@ def call(self, y_true, y_pred):
412414
basnet_model.compile(
413415
loss=BasnetLoss(),
414416
optimizer=optimizer,
415-
metrics=[keras.metrics.MeanAbsoluteError(name="mae")],
417+
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
416418
)
417419

418420
"""
@@ -453,6 +455,6 @@ def normalize_output(prediction):
453455
### Make Predictions
454456
"""
455457

456-
for image, mask in val_dataset.take(1):
458+
for (image, mask),_ in zip(val_dataset,range(1)):
457459
pred_mask = basnet_model.predict(image)
458460
display([image[0], mask[0], normalize_output(pred_mask[0][0])])

0 commit comments

Comments
 (0)