Skip to content

Commit 2346ca6

Browse files
committed
cleanup merge
1 parent c5b791d commit 2346ca6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ml4h/models/train_diffusion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ def train_diffusion_model(args):
3737
for k in batch[1]:
3838
logging.info(f"label {k} {batch[1][k].shape}")
3939
checkpoint_path = f"{args.output_folder}{args.id}/{args.id}.weights.h5"
40-
if os.path.exists(checkpoint_path):
41-
model.load_weights(checkpoint_path)
42-
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
43-
else:
44-
logging.info(f'No checkpoint at: {checkpoint_path}')
4540
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
4641
filepath=checkpoint_path,
4742
save_weights_only=True,
@@ -87,6 +82,11 @@ def train_diffusion_model(args):
8782
model.normalizer.adapt(images)
8883
# (4) call the model once
8984
_ = model((images, noise_rates))
85+
if os.path.exists(checkpoint_path):
86+
model.load_weights(checkpoint_path)
87+
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
88+
else:
89+
logging.info(f'No checkpoint at: {checkpoint_path}')
9090
history = model.fit(
9191
generate_train,
9292
steps_per_epoch=args.training_steps,

0 commit comments

Comments
 (0)