@@ -28,7 +28,7 @@ def train_diffusion_model(args):
2828 optimizer = tf .keras .optimizers .AdamW (
2929 learning_rate = args .learning_rate , weight_decay = 1e-4 ,
3030 ),
31- loss = tf . keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else tf . keras .losses .MeanSquaredError ,
31+ loss = keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else keras .losses .MeanSquaredError ,
3232 )
3333 batch = next (iter (generate_train ))
3434 for k in batch [0 ]:
@@ -223,7 +223,7 @@ def train_diffusion_control_model(args, supervised=False):
223223 args .sigmoid_beta , args .diffusion_condition_strategy , args .inspect_model ,
224224 )
225225
226- loss = tf . keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else tf . keras .losses .MeanSquaredError
226+ loss = keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else keras .losses .MeanSquaredError
227227 model .compile (
228228 optimizer = tf .keras .optimizers .AdamW (
229229 learning_rate = args .learning_rate , weight_decay = 1e-4 ,
@@ -348,7 +348,7 @@ def test_diffusion_control_model(args, unconditioned=False, supervised=False):
348348 args .sigmoid_beta , args .diffusion_condition_strategy , args .inspect_model ,
349349 )
350350
351- loss = tf . keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else tf . keras .losses .MeanSquaredError
351+ loss = keras .losses .MeanAbsoluteError if args .diffusion_loss == 'mean_absolute_error' else keras .losses .MeanSquaredError
352352 model .compile (optimizer = tf .keras .optimizers .AdamW (learning_rate = args .learning_rate , weight_decay = 1e-4 ), loss = loss )
353353 checkpoint_path = f"{ args .output_folder } { args .id } /{ args .id } "
354354 if os .path .exists (checkpoint_path + '.index' ):
0 commit comments