1010 " \n " ,
1111 " **Author:** [Hamid Ali](https://github.com/hamidriasat)<br>\n " ,
1212 " **Date created:** 2023/05/30<br>\n " ,
13- " **Last modified:** 2024/10/02 <br>\n " ,
13+ " **Last modified:** 2025/01/24 <br>\n " ,
1414 " **Description:** Boundaries aware segmentation model trained on the DUTS dataset."
1515 ]
1616 },
6868 " from glob import glob\n " ,
6969 " import matplotlib.pyplot as plt\n " ,
7070 " \n " ,
71- " import keras_cv \n " ,
71+ " import keras_hub \n " ,
7272 " import tensorflow as tf\n " ,
7373 " import keras\n " ,
74- " from keras import layers, ops"
74+ " from keras import layers, ops\n " ,
75+ " \n " ,
76+ " keras.config.disable_traceback_filtering()"
7577 ]
7678 },
7779 {
117119 },
118120 "outputs" : [],
119121 "source" : [
120- " DATA_DIR = keras.utils.get_file(\n " ,
122+ " data_dir = keras.utils.get_file(\n " ,
121123 " origin=\" http://saliencydetection.net/duts/download/DUTS-TE.zip\" ,\n " ,
122124 " extract=True,\n " ,
123125 " )\n " ,
126+ " data_dir = os.path.join(data_dir, \" DUTS-TE\" )\n " ,
124127 " \n " ,
125128 " \n " ,
126129 " def load_paths(path, split_ratio):\n " ,
159162 " batch_x, batch_y = [], []\n " ,
160163 " for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):\n " ,
161164 " x, y = self.preprocess(\n " ,
162- " self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes\n " ,
165+ " self.image_paths[i],\n " ,
166+ " self.mask_paths[i],\n " ,
167+ " self.img_size,\n " ,
163168 " )\n " ,
164169 " batch_x.append(x)\n " ,
165170 " batch_y.append(y)\n " ,
173178 " x = (x / 255.0).astype(np.float32)\n " ,
174179 " return x\n " ,
175180 " \n " ,
176- " def preprocess(self, x_batch, y_batch, img_size, out_classes ):\n " ,
181+ " def preprocess(self, x_batch, y_batch, img_size):\n " ,
177182 " images = self.read_image(x_batch, (img_size, img_size), mode=\" rgb\" ) # image\n " ,
178183 " masks = self.read_image(y_batch, (img_size, img_size), mode=\" grayscale\" ) # mask\n " ,
179184 " return images, masks\n " ,
180185 " \n " ,
181186 " \n " ,
182- " train_paths, val_paths = load_paths(DATA_DIR , TRAIN_SPLIT_RATIO)\n " ,
187+ " train_paths, val_paths = load_paths(data_dir , TRAIN_SPLIT_RATIO)\n " ,
183188 " \n " ,
184189 " train_dataset = Dataset(\n " ,
185190 " train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True\n " ,
318323 " return x\n " ,
319324 " \n " ,
320325 " \n " ,
321- " def get_resnet_block(_resnet, block_num):\n " ,
322- " \"\"\" Extract and return ResNet-34 block.\"\"\"\n " ,
323- " resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.\n " ,
326+ " def get_resnet_block(resnet, block_num):\n " ,
327+ " \"\"\" Extract and return a ResNet-34 block.\"\"\"\n " ,
328+ " extractor_levels = [\" P2\" , \" P3\" , \" P4\" , \" P5\" ]\n " ,
329+ " num_blocks = resnet.stackwise_num_blocks\n " ,
330+ " if block_num == 0:\n " ,
331+ " x = resnet.get_layer(\" pool1_pool\" ).output\n " ,
332+ " else:\n " ,
333+ " x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]\n " ,
334+ " y = resnet.get_layer(f\" stack{block_num}_block{num_blocks[block_num]-1}_add\" ).output\n " ,
324335 " return keras.models.Model(\n " ,
325- " inputs=_resnet.get_layer(f\" v2_stack_{block_num}_block1_1_conv\" ).input,\n " ,
326- " outputs=_resnet.get_layer(\n " ,
327- " f\" v2_stack_{block_num}_block{resnet_layers[block_num]}_add\"\n " ,
328- " ).output,\n " ,
329- " name=f\" resnet34_block{block_num + 1}\" ,\n " ,
330- " )\n " ,
331- " "
336+ " inputs=x,\n " ,
337+ " outputs=y,\n " ,
338+ " name=f\" resnet_block{block_num + 1}\" ,\n " ,
339+ " )\n "
332340 ]
333341 },
334342 {
366374 " # -------------Encoder--------------\n " ,
367375 " x = layers.Conv2D(filters, kernel_size=(3, 3), padding=\" same\" )(x_input)\n " ,
368376 " \n " ,
369- " resnet = keras_cv.models.ResNet34Backbone(\n " ,
370- " include_rescaling=False,\n " ,
377+ " resnet = keras_hub.models.ResNetBackbone(\n " ,
378+ " input_conv_filters=[64],\n " ,
379+ " input_conv_kernel_sizes=[7],\n " ,
380+ " stackwise_num_filters=[64, 128, 256, 512],\n " ,
381+ " stackwise_num_blocks=[3, 4, 6, 3],\n " ,
382+ " stackwise_num_strides=[1, 2, 2, 2],\n " ,
383+ " block_type=\" basic_block\" ,\n " ,
371384 " )\n " ,
372385 " \n " ,
373386 " encoder_blocks = []\n " ,
411424 " for decoder_block in decoder_blocks\n " ,
412425 " ]\n " ,
413426 " \n " ,
414- " return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)\n " ,
415- " "
427+ " return keras.models.Model(inputs=x_input, outputs=decoder_blocks)\n "
416428 ]
417429 },
418430 {
470482 " # ------------- refined = coarse + residual\n " ,
471483 " x = layers.Add()([x_input, x]) # Add prediction + refinement output\n " ,
472484 " \n " ,
473- " return keras.models.Model(inputs=[base_model.input], outputs=[x])\n " ,
474- " "
485+ " return keras.models.Model(inputs=[base_model.input], outputs=[x])\n "
475486 ]
476487 },
477488 {
492503 "outputs" : [],
493504 "source" : [
494505 " \n " ,
495- " def basnet(input_shape, out_classes):\n " ,
496- " \"\"\" BASNet, it's a combination of two modules\n " ,
497- " Prediction Module and Residual Refinement Module(RRM).\"\"\"\n " ,
506+ " class BASNet(keras.Model):\n " ,
507+ " def __init__(self, input_shape, out_classes):\n " ,
508+ " \"\"\" BASNet, it's a combination of two modules\n " ,
509+ " Prediction Module and Residual Refinement Module(RRM).\"\"\"\n " ,
510+ " \n " ,
511+ " # Prediction model.\n " ,
512+ " predict_model = basnet_predict(input_shape, out_classes)\n " ,
513+ " # Refinement model.\n " ,
514+ " refine_model = basnet_rrm(predict_model, out_classes)\n " ,
515+ " \n " ,
516+ " output = refine_model.outputs # Combine outputs.\n " ,
517+ " output.extend(predict_model.output)\n " ,
518+ " \n " ,
519+ " # Activations.\n " ,
520+ " output = [layers.Activation(\" sigmoid\" )(x) for x in output]\n " ,
521+ " super().__init__(inputs=predict_model.input, outputs=output)\n " ,
522+ " \n " ,
523+ " self.smooth = 1.0e-9\n " ,
524+ " # Binary Cross Entropy loss.\n " ,
525+ " self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n " ,
526+ " # Structural Similarity Index value.\n " ,
527+ " self.ssim_value = tf.image.ssim\n " ,
528+ " # Jaccard / IoU loss.\n " ,
529+ " self.iou_value = self.calculate_iou\n " ,
530+ " \n " ,
531+ " def calculate_iou(\n " ,
532+ " self,\n " ,
533+ " y_true,\n " ,
534+ " y_pred,\n " ,
535+ " ):\n " ,
536+ " \"\"\" Calculate intersection over union (IoU) between images.\"\"\"\n " ,
537+ " intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n " ,
538+ " union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n " ,
539+ " union = union - intersection\n " ,
540+ " return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n " ,
498541 " \n " ,
499- " # Prediction model. \n " ,
500- " predict_model = basnet_predict(input_shape, out_classes) \n " ,
501- " # Refinement model. \n " ,
502- " refine_model = basnet_rrm(predict_model, out_classes )\n " ,
542+ " def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False): \n " ,
543+ " total = 0.0 \n " ,
544+ " for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output \n " ,
545+ " cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i )\n " ,
503546 " \n " ,
504- " output = refine_model.outputs # Combine outputs. \n " ,
505- " output.extend(predict_model.output )\n " ,
547+ " ssim_value = self.ssim_value(y_true, y_pred, max_val=1) \n " ,
548+ " ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0 )\n " ,
506549 " \n " ,
507- " output = [layers.Activation(\" sigmoid\" )(_) for _ in output] # Activations.\n " ,
550+ " iou_value = self.iou_value(y_true, y_pred)\n " ,
551+ " iou_loss = 1 - iou_value\n " ,
508552 " \n " ,
509- " return keras.models.Model(inputs=[predict_model.input], outputs=output)\n " ,
510- " "
553+ " # Add all three losses.\n " ,
554+ " total += cross_entropy_loss + ssim_loss + iou_loss\n " ,
555+ " return total\n "
511556 ]
512557 },
513558 {
532577 "outputs" : [],
533578 "source" : [
534579 " \n " ,
535- " class BasnetLoss(keras.losses.Loss):\n " ,
536- " \"\"\" BASNet hybrid loss.\"\"\"\n " ,
537- " \n " ,
538- " def __init__(self, **kwargs):\n " ,
539- " super().__init__(name=\" basnet_loss\" , **kwargs)\n " ,
540- " self.smooth = 1.0e-9\n " ,
541- " \n " ,
542- " # Binary Cross Entropy loss.\n " ,
543- " self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n " ,
544- " # Structural Similarity Index value.\n " ,
545- " self.ssim_value = tf.image.ssim\n " ,
546- " # Jaccard / IoU loss.\n " ,
547- " self.iou_value = self.calculate_iou\n " ,
548- " \n " ,
549- " def calculate_iou(\n " ,
550- " self,\n " ,
551- " y_true,\n " ,
552- " y_pred,\n " ,
553- " ):\n " ,
554- " \"\"\" Calculate intersection over union (IoU) between images.\"\"\"\n " ,
555- " intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n " ,
556- " union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n " ,
557- " union = union - intersection\n " ,
558- " return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n " ,
559- " \n " ,
560- " def call(self, y_true, y_pred):\n " ,
561- " cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)\n " ,
562- " \n " ,
563- " ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n " ,
564- " ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n " ,
565- " \n " ,
566- " iou_value = self.iou_value(y_true, y_pred)\n " ,
567- " iou_loss = 1 - iou_value\n " ,
568- " \n " ,
569- " # Add all three losses.\n " ,
570- " return cross_entropy_loss + ssim_loss + iou_loss\n " ,
571- " \n " ,
572- " \n " ,
573- " basnet_model = basnet(\n " ,
580+ " basnet_model = BASNet(\n " ,
574581 " input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES\n " ,
575582 " ) # Create model.\n " ,
576583 " basnet_model.summary() # Show model summary.\n " ,
577584 " \n " ,
578585 " optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)\n " ,
579586 " # Compile model.\n " ,
580587 " basnet_model.compile(\n " ,
581- " loss=BasnetLoss(),\n " ,
582588 " optimizer=optimizer,\n " ,
583589 " metrics=[keras.metrics.MeanAbsoluteError(name=\" mae\" ) for _ in basnet_model.outputs],\n " ,
584590 " )"
631637 },
632638 "outputs" : [],
633639 "source" : [
634- " !!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg"
635- ]
636- },
637- {
638- "cell_type" : " code" ,
639- "execution_count" : 0 ,
640- "metadata" : {
641- "colab_type" : " code"
642- },
643- "outputs" : [],
644- "source" : [
640+ " import gdown\n " ,
641+ " \n " ,
642+ " gdown.download(id=\" 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg\" , output=\" basnet_weights.h5\" )\n " ,
643+ " \n " ,
645644 " \n " ,
646645 " def normalize_output(prediction):\n " ,
647646 " max_value = np.max(prediction)\n " ,
686685 "toc_visible" : true
687686 },
688687 "kernelspec" : {
689- "display_name" : " Python 3 " ,
688+ "display_name" : " evn1 " ,
690689 "language" : " python" ,
691690 "name" : " python3"
692691 },
700699 "name" : " python" ,
701700 "nbconvert_exporter" : " python" ,
702701 "pygments_lexer" : " ipython3" ,
703- "version" : " 3.7.0 "
702+ "version" : " 3.9.6 "
704703 }
705704 },
706705 "nbformat" : 4 ,
707706 "nbformat_minor" : 0
708- }
707+ }
0 commit comments