Skip to content

Commit b3f84bc

Browse files
committed
Fix format issues
1 parent fdabc58 commit b3f84bc

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

examples/vision/basnet_segmentation.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"""
3939

4040
import os
41+
4142
os.environ["KERAS_BACKEND"] = "tensorflow"
4243
import numpy as np
4344
from glob import glob
@@ -72,12 +73,22 @@ def load_paths(path, split_ratio):
7273
len_ = int(len(images) * split_ratio)
7374
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
7475

76+
7577
class Dataset(keras.utils.PyDataset):
76-
def __init__(self, image_paths, mask_paths, img_size, out_classes, batch, shuffle=True, **kwargs):
78+
def __init__(
79+
self,
80+
image_paths,
81+
mask_paths,
82+
img_size,
83+
out_classes,
84+
batch,
85+
shuffle=True,
86+
**kwargs,
87+
):
7788
if shuffle:
7889
perm = np.random.permutation(len(image_paths))
79-
image_paths = [ image_paths[i] for i in perm ]
80-
mask_paths = [ mask_paths[i] for i in perm ]
90+
image_paths = [image_paths[i] for i in perm]
91+
mask_paths = [mask_paths[i] for i in perm]
8192
self.image_paths = image_paths
8293
self.mask_paths = mask_paths
8394
self.img_size = img_size
@@ -89,9 +100,11 @@ def __len__(self):
89100
return len(self.image_paths) // self.batch_size
90101

91102
def __getitem__(self, idx):
92-
batch_x, batch_y = [],[]
93-
for i in range(idx*self.batch_size, (idx+1)*self.batch_size):
94-
x,y = self.preprocess(self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes)
103+
batch_x, batch_y = [], []
104+
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
105+
x, y = self.preprocess(
106+
self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes
107+
)
95108
batch_x.append(x)
96109
batch_y.append(y)
97110
batch_x = np.stack(batch_x, axis=0)
@@ -135,7 +148,7 @@ def display(display_list):
135148
plt.show()
136149

137150

138-
for (image, mask),_ in zip(val_dataset, range(1)):
151+
for (image, mask), _ in zip(val_dataset, range(1)):
139152
display([image[0], mask[0]])
140153

141154
"""
@@ -387,9 +400,7 @@ def calculate_iou(
387400
intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
388401
union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
389402
union = union - intersection
390-
return ops.mean(
391-
(intersection + self.smooth) / (union + self.smooth), axis=0
392-
)
403+
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)
393404

394405
def call(self, y_true, y_pred):
395406
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
@@ -455,6 +466,6 @@ def normalize_output(prediction):
455466
### Make Predictions
456467
"""
457468

458-
for (image, mask),_ in zip(val_dataset,range(1)):
469+
for (image, mask), _ in zip(val_dataset, range(1)):
459470
pred_mask = basnet_model.predict(image)
460471
display([image[0], mask[0], normalize_output(pred_mask[0][0])])

0 commit comments

Comments
 (0)