Skip to content

Commit 7700077

Browse files
committed
cleanup merge
1 parent 68ef11e commit 7700077

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ml4h/models/diffusion_blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,8 @@ def train_step(self, images_original):
498498
noisy_images, noise_rates, signal_rates, training=True,
499499
)
500500

501-
noise_loss = self.loss(noises, pred_noises) # used for training
502-
image_loss = self.loss(images, pred_images) # only used as metric
501+
noise_loss = tf.reduce_mean(self.loss(noises, pred_noises)) # used for training
502+
image_loss = tf.reduce_mean(self.loss(images, pred_images)) # only used as metric
503503
if self.use_sigmoid_loss:
504504
signal_rates_squared = tf.square(signal_rates)
505505
noise_rates_squared = tf.square(noise_rates)

ml4h/models/train_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def train_diffusion_model(args):
2828
optimizer=tf.keras.optimizers.AdamW(
2929
learning_rate=args.learning_rate, weight_decay=1e-4,
3030
),
31-
loss=tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError,
31+
loss=keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError,
3232
)
3333
batch = next(iter(generate_train))
3434
for k in batch[0]:
@@ -223,7 +223,7 @@ def train_diffusion_control_model(args, supervised=False):
223223
args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model,
224224
)
225225

226-
loss = tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError
226+
loss = keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError
227227
model.compile(
228228
optimizer=tf.keras.optimizers.AdamW(
229229
learning_rate=args.learning_rate, weight_decay=1e-4,
@@ -348,7 +348,7 @@ def test_diffusion_control_model(args, unconditioned=False, supervised=False):
348348
args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model,
349349
)
350350

351-
loss = tf.keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else tf.keras.losses.MeanSquaredError
351+
loss = keras.losses.MeanAbsoluteError if args.diffusion_loss == 'mean_absolute_error' else keras.losses.MeanSquaredError
352352
model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=args.learning_rate, weight_decay=1e-4), loss=loss)
353353
checkpoint_path = f"{args.output_folder}{args.id}/{args.id}"
354354
if os.path.exists(checkpoint_path+'.index'):

0 commit comments

Comments
 (0)