We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 822491b commit 6b8dad5Copy full SHA for 6b8dad5
ml4h/models/train_diffusion.py
@@ -85,6 +85,12 @@ def train_diffusion_model(args):
85
if os.path.exists(checkpoint_path):
86
model = tf.keras.models.load_model(checkpoint_path)
87
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
88
+ model.compile(
89
+ optimizer=tf.keras.optimizers.AdamW(
90
+ learning_rate=args.learning_rate, weight_decay=1e-4,
91
+ ),
92
+ loss=keras.losses.MeanAbsoluteError() if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError(),
93
+ )
94
else:
95
logging.info(f'No checkpoint at: {checkpoint_path}')
96
history = model.fit(
0 commit comments