Skip to content

Commit 7552568

Browse files
committed
Added losses, optimizers and callbacks
1 parent b4e0e33 commit 7552568

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

templates/Image classification_Tensorflow/code-template.py.jinja

+18-8
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,21 @@ num_epochs = {{ num_epochs}}
7878

7979
{# TODO Add Image_Size #}
8080
img_size = (224,224)
81+
img_shape = img_size + (3,)
8182

8283
# Set up logging.
8384

8485
{% if visualization_tool == "Tensorboard" or checkpoint %}
8586
experiment_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
87+
{% endif %}
88+
{% if visualization_tool == "Tensorboard" %}
8689
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=experiment_id, histogram_freq=1)
8790
{% endif %}
8891
{% if visualization_tool == "comet.ml" %}
8992
experiment = Experiment("{{ comet_api_key }}"{% if comet_project %}, project_name="{{ comet_project }}"{% endif %})
9093
{% endif %}
9194
{% if checkpoint %}
92-
checkpoint_dir = Path(f"checkpoints/{experiment_id}")
93-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
95+
checkpoint_dir = tf.keras.callbacks.ModelCheckpoint(filepath='checkpoints/{experiment_id}/model.{epoch:02d}-{val_loss:.2f}.h5')
9496
{% endif %}
9597
print_every = {{ print_every }} # batches
9698

@@ -136,19 +138,27 @@ val_loader = preprocess(val_data, "val")
136138
test_loader = preprocess(test_data, "test")
137139

138140
{{ header("Model") }}
139-
IMG_SHAPE = img_size + (3,)
140-
model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
141+
142+
model = tf.keras.applications.MobileNetV2(input_shape=img_shape,
141143
include_top=True,
142144
weights='imagenet', classes=1000)
143145

144146

145147
model.trainable = True
146148

147-
model.compile(optimizer=tf.keras.optimizers.Adam(lr={{lr}}),
148-
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))
149+
model.compile(optimizer = tf.keras.optimizers.{{ optimizer }}(lr={{lr}}),
150+
loss = "{{ loss }}",
151+
metrics = ["accuracy"])
149152

150153
model.fit(train_loader,
151154
batch_size={{batch_size}},
152155
epochs={{num_epochs}},
153-
validation_data=val_loader,
154-
callbacks=[tensorboard_callback])
156+
validation_data=val_loader
157+
{% if visualization_tool == "Tensorboard" and checkpoint%}
158+
,callbacks = [tensorboard_callback, checkpoint_dir]
159+
{% elif checkpoint %}
160+
,callbacks = [checkpoint_dir]
161+
{% elif visualization_tool == "Tensorboard" %}
162+
,callbacks = [tensorboard_callback]
163+
{% endif %}
164+
)

templates/Image classification_Tensorflow/sidebar.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ def show():
108108
st.write("Scale mean and std for pre-trained model")
109109

110110
st.write("## Training")
111-
inputs["gpu"] = st.checkbox("Use GPU if available", True)
111+
#inputs["gpu"] = st.checkbox("Use GPU if available", True)
112112
inputs["checkpoint"] = st.checkbox("Save model checkpoint each epoch")
113113
if inputs["checkpoint"]:
114114
st.markdown(
115115
"<sup>Checkpoints are saved to timestamped dir in `./checkpoints`. They may consume a lot of storage!</sup>",
116116
unsafe_allow_html=True,
117117
)
118118
inputs["loss"] = st.selectbox(
119-
"Loss function", ("CrossEntropyLoss", "BCEWithLogitsLoss")
119+
"Loss function", ("sparse_categorical_crossentropy", "binary_crossentropy")
120120
)
121121
inputs["optimizer"] = st.selectbox("Optimizer", list(OPTIMIZERS.keys()))
122122
default_lr = OPTIMIZERS[inputs["optimizer"]]

0 commit comments

Comments
 (0)