Skip to content

Commit 9fb8451

Browse files
committed
New Update to code
1 parent 7552568 commit 9fb8451

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

templates/Image classification_Tensorflow/code-template.py.jinja

+34-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import numpy as np
1717
import tensorflow as tf
1818
from tensorflow import keras
1919
from tensorflow.keras.preprocessing import image_dataset_from_directory
20+
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
2021
{% if data_format == "Image files" %}
2122
import urllib
2223
import zipfile
@@ -77,7 +78,7 @@ batch_size = {{ batch_size }}
7778
num_epochs = {{ num_epochs}}
7879

7980
{# TODO Add Image_Size #}
80-
img_size = (224,224)
81+
img_size = (160,160)
8182
img_shape = img_size + (3,)
8283

8384
# Set up logging.
@@ -103,9 +104,11 @@ def preprocess(data, name):
103104

104105
{% if data_format == "Image files" %}
105106
# Read image files to tensorflow dataset.
106-
dataset = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20, horizontal_flip=True)
107-
loader = dataset.flow_from_directory(data, target_size=img_size)
108-
{# TODO: Add more data_augmentation #}
107+
dataset = image_dataset_from_directory(data,
108+
shuffle=(name=="train"),
109+
image_size=img_size)
110+
111+
109112
{% elif data_format == "Numpy arrays" %}
110113
images, labels = data
111114

@@ -118,33 +121,41 @@ def preprocess(data, name):
118121
# If images are grayscale, convert to RGB by duplicating channels.
119122
if images.shape[1] == 1:
120123
images = np.stack((images[:, 0],) * 3, axis=1)
121-
)
124+
122125

123-
{# This code could be improved #}
124-
Pil_image = []
126+
images = np.rollaxis(images, 1, 4) # Reshape Image to channels_last
125127

126-
for i in range(len(data[0])):
127-
Pil_image.append(tf.keras.preprocessing.image.array_to_img(images[i],data_format ="channels_first"))
128-
Pil_image[i] = Pil_image[i].resize((224,224)
129-
Pil_image[i] = tf.keras.preprocessing.image.img_to_array(Pil_image[i])
128+
images = np.asarray([img_to_array(array_to_img(im, scale=False).resize((256,256))) for im in images])
130129

131-
loader = tf.convert_to_tensor(Pil_image)
130+
dataset = images, labels
132131
{% endif %}
133-
134-
return loader
132+
return dataset
135133

136134
train_loader = preprocess(train_data, "train")
137135
val_loader = preprocess(val_data, "val")
138136
test_loader = preprocess(test_data, "test")
139137

138+
{{ header("data augmentation") }}
139+
140+
data_augmentation = tf.keras.Sequential([
141+
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
142+
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
143+
])
144+
140145
{{ header("Model") }}
141146

142-
model = tf.keras.applications.MobileNetV2(input_shape=img_shape,
147+
# Create the base model
148+
base_model = tf.keras.applications.{{model_func}}(input_shape=img_shape,
143149
include_top=True,
144-
weights='imagenet', classes=1000)
150+
weights='imagenet')
151+
145152

153+
# using the Keras Functional API
146154

147-
model.trainable = True
155+
inputs = tf.keras.Input(shape=img_shape)
156+
x = data_augmentation(inputs)
157+
x = tf.keras.applications.{{model_func}}.preprocess_input(x)
158+
model = base_model(x)
148159

149160
model.compile(optimizer = tf.keras.optimizers.{{ optimizer }}(lr={{lr}}),
150161
loss = "{{ loss }}",
@@ -153,12 +164,14 @@ model.compile(optimizer = tf.keras.optimizers.{{ optimizer }}(lr={{lr}}),
153164
model.fit(train_loader,
154165
batch_size={{batch_size}},
155166
epochs={{num_epochs}},
156-
validation_data=val_loader
167+
validation_data=val_loader,
157168
{% if visualization_tool == "Tensorboard" and checkpoint%}
158-
,callbacks = [tensorboard_callback, checkpoint_dir]
169+
callbacks = [tensorboard_callback, checkpoint_dir],
159170
{% elif checkpoint %}
160-
,callbacks = [checkpoint_dir]
171+
callbacks = [checkpoint_dir],
161172
{% elif visualization_tool == "Tensorboard" %}
162-
,callbacks = [tensorboard_callback]
173+
callbacks = [tensorboard_callback],
163174
{% endif %}
164175
)
176+
177+
model.evaluate(val_loader)

templates/Image classification_Tensorflow/sidebar.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,19 @@
66
# option 1: model -> code
77
# option 2 – if model has multiple variants: model -> model variant -> code
88
MODELS = {
9-
"AlexNet": "alexnet", # single model variant
9+
#TODO: Add more models
10+
"Xception": "Xception", # single model variant
1011
"ResNet": { # multiple model variants
11-
"ResNet 18": "resnet18",
12-
"ResNet 34": "resnet34",
13-
"ResNet 50": "resnet50",
14-
"ResNet 101": "resnet101",
15-
"ResNet 152": "resnet152",
12+
"ResNet 50": "ResNet50",
13+
"ResNet 101": "ResNet101",
14+
"ResNet 152": "ResNet152",
15+
"ResNet 50v2": "ResNet50V2",
16+
"ResNet 101v2": "ResNet101V2",
17+
"ResNet 152v2": "ResNet152V2",
1618
},
17-
"DenseNet": "densenet",
1819
"VGG": {
19-
"VGG11": "vgg11",
20-
"VGG11 with batch normalization": "vgg11_bn",
21-
"VGG13": "vgg13",
22-
"VGG13 with batch normalization": "vgg13_bn",
23-
"VGG16": "vgg16",
24-
"VGG16 with batch normalization": "vgg16_bn",
20+
"VGG16": "VGG16",
2521
"VGG19": "vgg19",
26-
"VGG19 with batch normalization": "vgg19_bn",
2722
},
2823
}
2924

0 commit comments

Comments
 (0)