Skip to content

Commit c457376

Browse files
committed
cleanup merge
1 parent 7700077 commit c457376

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ml4h/models/diffusion_blocks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def generate(self, num_images, diffusion_steps, reseed=None):
474474
generated_images = self.denormalize(generated_images)
475475
return generated_images
476476

477+
@tf.function
477478
def train_step(self, images_original):
478479
# normalize images to have standard deviation of 1, like the noises
479480
images = images_original[0][self.tensor_map.input_name()]
@@ -498,8 +499,8 @@ def train_step(self, images_original):
498499
noisy_images, noise_rates, signal_rates, training=True,
499500
)
500501

501-
noise_loss = tf.reduce_mean(self.loss(noises, pred_noises)) # used for training
502-
image_loss = tf.reduce_mean(self.loss(images, pred_images)) # only used as metric
502+
noise_loss = self.loss(noises, pred_noises) # used for training
503+
image_loss = self.loss(images, pred_images) # only used as metric
503504
if self.use_sigmoid_loss:
504505
signal_rates_squared = tf.square(signal_rates)
505506
noise_rates_squared = tf.square(noise_rates)

0 commit comments

Comments
 (0)