Skip to content

Commit 68ef11e

Browse files
committed
cleanup merge
1 parent f0fd518 commit 68ef11e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ml4h/models/train_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def train_diffusion_model(args):
2828
optimizer=tf.keras.optimizers.AdamW(
2929
learning_rate=args.learning_rate, weight_decay=1e-4,
3030
),
31-
loss=keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError,
31+
loss=tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError,
3232
)
3333
batch = next(iter(generate_train))
3434
for k in batch[0]:
@@ -223,7 +223,7 @@ def train_diffusion_control_model(args, supervised=False):
223223
args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model,
224224
)
225225

226-
loss = keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError
226+
loss = tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError
227227
model.compile(
228228
optimizer=tf.keras.optimizers.AdamW(
229229
learning_rate=args.learning_rate, weight_decay=1e-4,
@@ -348,7 +348,7 @@ def test_diffusion_control_model(args, unconditioned=False, supervised=False):
348348
args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model,
349349
)
350350

351-
loss = keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError
351+
loss = tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError
352352
model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=args.learning_rate, weight_decay=1e-4), loss=loss)
353353
checkpoint_path = f"{args.output_folder}{args.id}/{args.id}"
354354
if os.path.exists(checkpoint_path+'.index'):

0 commit comments

Comments
 (0)