Skip to content

Commit 071eb26

Browse files
committed
cleanup merge
1 parent 9434b22 commit 071eb26

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

ml4h/models/train_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def train_diffusion_model(args):
7676
plot_partial = partial(model.plot_images, reseed=args.random_seed, prefix=prefix_value)
7777
callbacks.append(keras.callbacks.LambdaCallback(on_epoch_end=plot_partial))
7878

79+
sample_input = next(iter(generate_train))[0][model.tensor_map.input_name()]
80+
model.normalizer.adapt(sample_input)
81+
model(sample_input)
7982
history = model.fit(
8083
generate_train,
8184
steps_per_epoch=args.training_steps,

0 commit comments

Comments
 (0)