3838"""
3939
4040import os
41+
4142os .environ ["KERAS_BACKEND" ] = "tensorflow"
4243import numpy as np
4344from glob import glob
@@ -72,12 +73,22 @@ def load_paths(path, split_ratio):
7273 len_ = int (len (images ) * split_ratio )
7374 return (images [:len_ ], masks [:len_ ]), (images [len_ :], masks [len_ :])
7475
76+
7577class Dataset (keras .utils .PyDataset ):
76- def __init__ (self , image_paths , mask_paths , img_size , out_classes , batch , shuffle = True , ** kwargs ):
78+ def __init__ (
79+ self ,
80+ image_paths ,
81+ mask_paths ,
82+ img_size ,
83+ out_classes ,
84+ batch ,
85+ shuffle = True ,
86+ ** kwargs ,
87+ ):
7788 if shuffle :
7889 perm = np .random .permutation (len (image_paths ))
79- image_paths = [ image_paths [i ] for i in perm ]
80- mask_paths = [ mask_paths [i ] for i in perm ]
90+ image_paths = [image_paths [i ] for i in perm ]
91+ mask_paths = [mask_paths [i ] for i in perm ]
8192 self .image_paths = image_paths
8293 self .mask_paths = mask_paths
8394 self .img_size = img_size
@@ -89,9 +100,11 @@ def __len__(self):
89100 return len (self .image_paths ) // self .batch_size
90101
91102 def __getitem__ (self , idx ):
92- batch_x , batch_y = [],[]
93- for i in range (idx * self .batch_size , (idx + 1 )* self .batch_size ):
94- x ,y = self .preprocess (self .image_paths [i ], self .mask_paths [i ], self .img_size , self .out_classes )
103+ batch_x , batch_y = [], []
104+ for i in range (idx * self .batch_size , (idx + 1 ) * self .batch_size ):
105+ x , y = self .preprocess (
106+ self .image_paths [i ], self .mask_paths [i ], self .img_size , self .out_classes
107+ )
95108 batch_x .append (x )
96109 batch_y .append (y )
97110 batch_x = np .stack (batch_x , axis = 0 )
@@ -135,7 +148,7 @@ def display(display_list):
135148 plt .show ()
136149
137150
138- for (image , mask ),_ in zip (val_dataset , range (1 )):
151+ for (image , mask ), _ in zip (val_dataset , range (1 )):
139152 display ([image [0 ], mask [0 ]])
140153
141154"""
@@ -387,9 +400,7 @@ def calculate_iou(
387400 intersection = ops .sum (ops .abs (y_true * y_pred ), axis = [1 , 2 , 3 ])
388401 union = ops .sum (y_true , [1 , 2 , 3 ]) + ops .sum (y_pred , [1 , 2 , 3 ])
389402 union = union - intersection
390- return ops .mean (
391- (intersection + self .smooth ) / (union + self .smooth ), axis = 0
392- )
403+ return ops .mean ((intersection + self .smooth ) / (union + self .smooth ), axis = 0 )
393404
394405 def call (self , y_true , y_pred ):
395406 cross_entropy_loss = self .cross_entropy_loss (y_true , y_pred )
@@ -455,6 +466,6 @@ def normalize_output(prediction):
455466### Make Predictions
456467"""
457468
458- for (image , mask ),_ in zip (val_dataset ,range (1 )):
469+ for (image , mask ), _ in zip (val_dataset , range (1 )):
459470 pred_mask = basnet_model .predict (image )
460471 display ([image [0 ], mask [0 ], normalize_output (pred_mask [0 ][0 ])])
0 commit comments