Skip to content

Commit 6b8dad5

Browse files
committed
add license
1 parent 822491b commit 6b8dad5

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

ml4h/models/train_diffusion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def train_diffusion_model(args):
8585
if os.path.exists(checkpoint_path):
8686
model = tf.keras.models.load_model(checkpoint_path)
8787
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
88+
model.compile(
89+
optimizer=tf.keras.optimizers.AdamW(
90+
learning_rate=args.learning_rate, weight_decay=1e-4,
91+
),
92+
loss=keras.losses.MeanAbsoluteError() if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError(),
93+
)
8894
else:
8995
logging.info(f'No checkpoint at: {checkpoint_path}')
9096
history = model.fit(

0 commit comments

Comments
 (0)