Skip to content

Commit 6681f9e

Browse files
authored
Basnet keras-3 migration compatibility with keras-hub (#2038)
1 parent ba30e94 commit 6681f9e

File tree

5 files changed

+1428
-809
lines changed

5 files changed

+1428
-809
lines changed

examples/vision/basnet_segmentation.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from glob import glob
4444
import matplotlib.pyplot as plt
4545

46-
import keras_cv
46+
import keras_hub
4747
import tensorflow as tf
4848
import keras
4949
from keras import layers, ops
@@ -228,15 +228,19 @@ def segmentation_head(x_input, out_classes, final_size):
228228
return x
229229

230230

231-
def get_resnet_block(_resnet, block_num):
232-
"""Extract and return ResNet-34 block."""
233-
resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.
231+
def get_resnet_block(resnet, block_num):
232+
"""Extract and return a ResNet-34 block."""
233+
extractor_levels = ["P2", "P3", "P4", "P5"]
234+
num_blocks = resnet.stackwise_num_blocks
235+
if block_num == 0:
236+
x = resnet.get_layer("pool1_pool").output
237+
else:
238+
x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]
239+
y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output
234240
return keras.models.Model(
235-
inputs=_resnet.get_layer(f"v2_stack_{block_num}_block1_1_conv").input,
236-
outputs=_resnet.get_layer(
237-
f"v2_stack_{block_num}_block{resnet_layers[block_num]}_add"
238-
).output,
239-
name=f"resnet34_block{block_num + 1}",
241+
inputs=x,
242+
outputs=y,
243+
name=f"resnet_block{block_num + 1}",
240244
)
241245

242246

@@ -262,8 +266,13 @@ def basnet_predict(input_shape, out_classes):
262266
# -------------Encoder--------------
263267
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
264268

265-
resnet = keras_cv.models.ResNet34Backbone(
266-
include_rescaling=False,
269+
resnet = keras_hub.models.ResNetBackbone(
270+
input_conv_filters=[64],
271+
input_conv_kernel_sizes=[7],
272+
stackwise_num_filters=[64, 128, 256, 512],
273+
stackwise_num_blocks=[3, 4, 6, 3],
274+
stackwise_num_strides=[1, 2, 2, 2],
275+
block_type="basic_block",
267276
)
268277

269278
encoder_blocks = []
@@ -307,7 +316,7 @@ def basnet_predict(input_shape, out_classes):
307316
for decoder_block in decoder_blocks
308317
]
309318

310-
return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
319+
return keras.models.Model(inputs=x_input, outputs=decoder_blocks)
311320

312321

313322
"""
@@ -352,7 +361,7 @@ def basnet_rrm(base_model, out_classes):
352361
# ------------- refined = coarse + residual
353362
x = layers.Add()([x_input, x]) # Add prediction + refinement output
354363

355-
return keras.models.Model(inputs=base_model.input[0], outputs=x)
364+
return keras.models.Model(inputs=[base_model.input], outputs=[x])
356365

357366

358367
"""
@@ -375,7 +384,7 @@ def __init__(self, input_shape, out_classes):
375384

376385
# Activations.
377386
output = [layers.Activation("sigmoid")(x) for x in output]
378-
super().__init__(inputs=predict_model.input[0], outputs=output)
387+
super().__init__(inputs=predict_model.input, outputs=output)
379388

380389
self.smooth = 1.0e-9
381390
# Binary Cross Entropy loss.
@@ -453,9 +462,9 @@ def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
453462
trainings parameters please check given link.
454463
"""
455464

456-
"""shell
457-
!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg
458-
"""
465+
import gdown
466+
467+
gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")
459468

460469

461470
def normalize_output(prediction):
55.5 KB
Loading
104 KB
Loading

examples/vision/ipynb/basnet_segmentation.ipynb

Lines changed: 88 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Hamid Ali](https://github.com/hamidriasat)<br>\n",
1212
"**Date created:** 2023/05/30<br>\n",
13-
"**Last modified:** 2024/10/02<br>\n",
13+
"**Last modified:** 2025/01/24<br>\n",
1414
"**Description:** Boundaries aware segmentation model trained on the DUTS dataset."
1515
]
1616
},
@@ -68,10 +68,12 @@
6868
"from glob import glob\n",
6969
"import matplotlib.pyplot as plt\n",
7070
"\n",
71-
"import keras_cv\n",
71+
"import keras_hub\n",
7272
"import tensorflow as tf\n",
7373
"import keras\n",
74-
"from keras import layers, ops"
74+
"from keras import layers, ops\n",
75+
"\n",
76+
"keras.config.disable_traceback_filtering()"
7577
]
7678
},
7779
{
@@ -117,10 +119,11 @@
117119
},
118120
"outputs": [],
119121
"source": [
120-
"DATA_DIR = keras.utils.get_file(\n",
122+
"data_dir = keras.utils.get_file(\n",
121123
" origin=\"http://saliencydetection.net/duts/download/DUTS-TE.zip\",\n",
122124
" extract=True,\n",
123125
")\n",
126+
"data_dir = os.path.join(data_dir, \"DUTS-TE\")\n",
124127
"\n",
125128
"\n",
126129
"def load_paths(path, split_ratio):\n",
@@ -159,7 +162,9 @@
159162
" batch_x, batch_y = [], []\n",
160163
" for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):\n",
161164
" x, y = self.preprocess(\n",
162-
" self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes\n",
165+
" self.image_paths[i],\n",
166+
" self.mask_paths[i],\n",
167+
" self.img_size,\n",
163168
" )\n",
164169
" batch_x.append(x)\n",
165170
" batch_y.append(y)\n",
@@ -173,13 +178,13 @@
173178
" x = (x / 255.0).astype(np.float32)\n",
174179
" return x\n",
175180
"\n",
176-
" def preprocess(self, x_batch, y_batch, img_size, out_classes):\n",
181+
" def preprocess(self, x_batch, y_batch, img_size):\n",
177182
" images = self.read_image(x_batch, (img_size, img_size), mode=\"rgb\") # image\n",
178183
" masks = self.read_image(y_batch, (img_size, img_size), mode=\"grayscale\") # mask\n",
179184
" return images, masks\n",
180185
"\n",
181186
"\n",
182-
"train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)\n",
187+
"train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)\n",
183188
"\n",
184189
"train_dataset = Dataset(\n",
185190
" train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True\n",
@@ -318,17 +323,20 @@
318323
" return x\n",
319324
"\n",
320325
"\n",
321-
"def get_resnet_block(_resnet, block_num):\n",
322-
" \"\"\"Extract and return ResNet-34 block.\"\"\"\n",
323-
" resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.\n",
326+
"def get_resnet_block(resnet, block_num):\n",
327+
" \"\"\"Extract and return a ResNet-34 block.\"\"\"\n",
328+
" extractor_levels = [\"P2\", \"P3\", \"P4\", \"P5\"]\n",
329+
" num_blocks = resnet.stackwise_num_blocks\n",
330+
" if block_num == 0:\n",
331+
" x = resnet.get_layer(\"pool1_pool\").output\n",
332+
" else:\n",
333+
" x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]\n",
334+
" y = resnet.get_layer(f\"stack{block_num}_block{num_blocks[block_num]-1}_add\").output\n",
324335
" return keras.models.Model(\n",
325-
" inputs=_resnet.get_layer(f\"v2_stack_{block_num}_block1_1_conv\").input,\n",
326-
" outputs=_resnet.get_layer(\n",
327-
" f\"v2_stack_{block_num}_block{resnet_layers[block_num]}_add\"\n",
328-
" ).output,\n",
329-
" name=f\"resnet34_block{block_num + 1}\",\n",
330-
" )\n",
331-
""
336+
" inputs=x,\n",
337+
" outputs=y,\n",
338+
" name=f\"resnet_block{block_num + 1}\",\n",
339+
" )\n"
332340
]
333341
},
334342
{
@@ -366,8 +374,13 @@
366374
" # -------------Encoder--------------\n",
367375
" x = layers.Conv2D(filters, kernel_size=(3, 3), padding=\"same\")(x_input)\n",
368376
"\n",
369-
" resnet = keras_cv.models.ResNet34Backbone(\n",
370-
" include_rescaling=False,\n",
377+
" resnet = keras_hub.models.ResNetBackbone(\n",
378+
" input_conv_filters=[64],\n",
379+
" input_conv_kernel_sizes=[7],\n",
380+
" stackwise_num_filters=[64, 128, 256, 512],\n",
381+
" stackwise_num_blocks=[3, 4, 6, 3],\n",
382+
" stackwise_num_strides=[1, 2, 2, 2],\n",
383+
" block_type=\"basic_block\",\n",
371384
" )\n",
372385
"\n",
373386
" encoder_blocks = []\n",
@@ -411,8 +424,7 @@
411424
" for decoder_block in decoder_blocks\n",
412425
" ]\n",
413426
"\n",
414-
" return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)\n",
415-
""
427+
" return keras.models.Model(inputs=x_input, outputs=decoder_blocks)\n"
416428
]
417429
},
418430
{
@@ -470,8 +482,7 @@
470482
" # ------------- refined = coarse + residual\n",
471483
" x = layers.Add()([x_input, x]) # Add prediction + refinement output\n",
472484
"\n",
473-
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n",
474-
""
485+
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n"
475486
]
476487
},
477488
{
@@ -492,22 +503,56 @@
492503
"outputs": [],
493504
"source": [
494505
"\n",
495-
"def basnet(input_shape, out_classes):\n",
496-
" \"\"\"BASNet, it's a combination of two modules\n",
497-
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
506+
"class BASNet(keras.Model):\n",
507+
" def __init__(self, input_shape, out_classes):\n",
508+
" \"\"\"BASNet, it's a combination of two modules\n",
509+
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
510+
"\n",
511+
" # Prediction model.\n",
512+
" predict_model = basnet_predict(input_shape, out_classes)\n",
513+
" # Refinement model.\n",
514+
" refine_model = basnet_rrm(predict_model, out_classes)\n",
515+
"\n",
516+
" output = refine_model.outputs # Combine outputs.\n",
517+
" output.extend(predict_model.output)\n",
518+
"\n",
519+
" # Activations.\n",
520+
" output = [layers.Activation(\"sigmoid\")(x) for x in output]\n",
521+
" super().__init__(inputs=predict_model.input, outputs=output)\n",
522+
"\n",
523+
" self.smooth = 1.0e-9\n",
524+
" # Binary Cross Entropy loss.\n",
525+
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
526+
" # Structural Similarity Index value.\n",
527+
" self.ssim_value = tf.image.ssim\n",
528+
" # Jaccard / IoU loss.\n",
529+
" self.iou_value = self.calculate_iou\n",
530+
"\n",
531+
" def calculate_iou(\n",
532+
" self,\n",
533+
" y_true,\n",
534+
" y_pred,\n",
535+
" ):\n",
536+
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
537+
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
538+
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
539+
" union = union - intersection\n",
540+
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
498541
"\n",
499-
" # Prediction model.\n",
500-
" predict_model = basnet_predict(input_shape, out_classes)\n",
501-
" # Refinement model.\n",
502-
" refine_model = basnet_rrm(predict_model, out_classes)\n",
542+
" def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):\n",
543+
" total = 0.0\n",
544+
" for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output\n",
545+
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)\n",
503546
"\n",
504-
" output = refine_model.outputs # Combine outputs.\n",
505-
" output.extend(predict_model.output)\n",
547+
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
548+
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
506549
"\n",
507-
" output = [layers.Activation(\"sigmoid\")(_) for _ in output] # Activations.\n",
550+
" iou_value = self.iou_value(y_true, y_pred)\n",
551+
" iou_loss = 1 - iou_value\n",
508552
"\n",
509-
" return keras.models.Model(inputs=[predict_model.input], outputs=output)\n",
510-
""
553+
" # Add all three losses.\n",
554+
" total += cross_entropy_loss + ssim_loss + iou_loss\n",
555+
" return total\n"
511556
]
512557
},
513558
{
@@ -532,53 +577,14 @@
532577
"outputs": [],
533578
"source": [
534579
"\n",
535-
"class BasnetLoss(keras.losses.Loss):\n",
536-
" \"\"\"BASNet hybrid loss.\"\"\"\n",
537-
"\n",
538-
" def __init__(self, **kwargs):\n",
539-
" super().__init__(name=\"basnet_loss\", **kwargs)\n",
540-
" self.smooth = 1.0e-9\n",
541-
"\n",
542-
" # Binary Cross Entropy loss.\n",
543-
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
544-
" # Structural Similarity Index value.\n",
545-
" self.ssim_value = tf.image.ssim\n",
546-
" # Jaccard / IoU loss.\n",
547-
" self.iou_value = self.calculate_iou\n",
548-
"\n",
549-
" def calculate_iou(\n",
550-
" self,\n",
551-
" y_true,\n",
552-
" y_pred,\n",
553-
" ):\n",
554-
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
555-
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
556-
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
557-
" union = union - intersection\n",
558-
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
559-
"\n",
560-
" def call(self, y_true, y_pred):\n",
561-
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)\n",
562-
"\n",
563-
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
564-
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
565-
"\n",
566-
" iou_value = self.iou_value(y_true, y_pred)\n",
567-
" iou_loss = 1 - iou_value\n",
568-
"\n",
569-
" # Add all three losses.\n",
570-
" return cross_entropy_loss + ssim_loss + iou_loss\n",
571-
"\n",
572-
"\n",
573-
"basnet_model = basnet(\n",
580+
"basnet_model = BASNet(\n",
574581
" input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES\n",
575582
") # Create model.\n",
576583
"basnet_model.summary() # Show model summary.\n",
577584
"\n",
578585
"optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)\n",
579586
"# Compile model.\n",
580587
"basnet_model.compile(\n",
581-
" loss=BasnetLoss(),\n",
582588
" optimizer=optimizer,\n",
583589
" metrics=[keras.metrics.MeanAbsoluteError(name=\"mae\") for _ in basnet_model.outputs],\n",
584590
")"
@@ -631,17 +637,10 @@
631637
},
632638
"outputs": [],
633639
"source": [
634-
"!!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg"
635-
]
636-
},
637-
{
638-
"cell_type": "code",
639-
"execution_count": 0,
640-
"metadata": {
641-
"colab_type": "code"
642-
},
643-
"outputs": [],
644-
"source": [
640+
"import gdown\n",
641+
"\n",
642+
"gdown.download(id=\"1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg\", output=\"basnet_weights.h5\")\n",
643+
"\n",
645644
"\n",
646645
"def normalize_output(prediction):\n",
647646
" max_value = np.max(prediction)\n",
@@ -686,7 +685,7 @@
686685
"toc_visible": true
687686
},
688687
"kernelspec": {
689-
"display_name": "Python 3",
688+
"display_name": "evn1",
690689
"language": "python",
691690
"name": "python3"
692691
},
@@ -700,9 +699,9 @@
700699
"name": "python",
701700
"nbconvert_exporter": "python",
702701
"pygments_lexer": "ipython3",
703-
"version": "3.7.0"
702+
"version": "3.9.6"
704703
}
705704
},
706705
"nbformat": 4,
707706
"nbformat_minor": 0
708-
}
707+
}

0 commit comments

Comments
 (0)