We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9434b22 commit 071eb26Copy full SHA for 071eb26
ml4h/models/train_diffusion.py
@@ -76,6 +76,9 @@ def train_diffusion_model(args):
76
plot_partial = partial(model.plot_images, reseed=args.random_seed, prefix=prefix_value)
77
callbacks.append(keras.callbacks.LambdaCallback(on_epoch_end=plot_partial))
78
79
+ sample_input = next(iter(generate_train))[0][model.tensor_map.input_name()]
80
+ model.normalizer.adapt(sample_input)
81
+ model(sample_input)
82
history = model.fit(
83
generate_train,
84
steps_per_epoch=args.training_steps,
0 commit comments