@@ -78,19 +78,21 @@ num_epochs = {{ num_epochs}}
78
78
79
79
{# TODO Add Image_Size #}
80
80
img_size = (224,224)
81
+ img_shape = img_size + (3,)
81
82
82
83
# Set up logging.
83
84
84
85
{% if visualization_tool == "Tensorboard" or checkpoint %}
85
86
experiment_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
87
+ {% endif %}
88
+ {% if visualization_tool == "Tensorboard" %}
86
89
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=experiment_id, histogram_freq=1)
87
90
{% endif %}
88
91
{% if visualization_tool == "comet.ml" %}
89
92
experiment = Experiment("{{ comet_api_key }}"{% if comet_project %} , project_name="{{ comet_project }}"{% endif %} )
90
93
{% endif %}
91
94
{% if checkpoint %}
92
- checkpoint_dir = Path(f"checkpoints/{experiment_id}")
93
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
95
+ checkpoint_dir = tf.keras.callbacks.ModelCheckpoint(filepath='checkpoints/{experiment_id}/model.{epoch:02d}-{val_loss:.2f}.h5')
94
96
{% endif %}
95
97
print_every = {{ print_every }} # batches
96
98
@@ -136,19 +138,27 @@ val_loader = preprocess(val_data, "val")
136
138
test_loader = preprocess(test_data, "test")
137
139
138
140
{{ header("Model") }}
139
- IMG_SHAPE = img_size + (3,)
140
- model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE ,
141
+
142
+ model = tf.keras.applications.MobileNetV2(input_shape=img_shape ,
141
143
include_top=True,
142
144
weights='imagenet', classes=1000)
143
145
144
146
145
147
model.trainable = True
146
148
147
- model.compile(optimizer=tf.keras.optimizers.Adam(lr={{lr}}),
148
- loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))
149
+ model.compile(optimizer = tf.keras.optimizers.{{ optimizer }}(lr={{lr}}),
150
+ loss = "{{ loss }}",
151
+ metrics = ["accuracy"])
149
152
150
153
model.fit(train_loader,
151
154
batch_size={{batch_size}},
152
155
epochs={{num_epochs}},
153
- validation_data=val_loader,
154
- callbacks=[tensorboard_callback])
156
+ validation_data=val_loader
157
+ {% if visualization_tool == "Tensorboard" and checkpoint %}
158
+ ,callbacks = [tensorboard_callback, checkpoint_dir]
159
+ {% elif checkpoint %}
160
+ ,callbacks = [checkpoint_dir]
161
+ {% elif visualization_tool == "Tensorboard" %}
162
+ ,callbacks = [tensorboard_callback]
163
+ {% endif %}
164
+ )
0 commit comments