Skip to content

Commit 784ac6d

Browse files
committed
add license
1 parent f2a3d62 commit 784ac6d

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

ml4h/models/train_diffusion.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -261,30 +261,30 @@ def train_diffusion_control_model(args, supervised=False):
261261
model.normalizer.adapt(feature_batch)
262262
if args.inspect_model:
263263
model.network.summary(print_fn=logging.info, expand_nested=True)
264-
tf.keras.utils.plot_model(
265-
model.network,
266-
to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_unet.png",
267-
show_shapes=True,
268-
show_dtype=False,
269-
show_layer_names=True,
270-
rankdir="TB",
271-
expand_nested=True,
272-
dpi=args.dpi,
273-
layer_range=None,
274-
show_layer_activations=False,
275-
)
276-
tf.keras.utils.plot_model(
277-
model.control_embed_model,
278-
to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_control_embed.png",
279-
show_shapes=True,
280-
show_dtype=False,
281-
show_layer_names=True,
282-
rankdir="TB",
283-
expand_nested=True,
284-
dpi=args.dpi,
285-
layer_range=None,
286-
show_layer_activations=True,
287-
)
264+
# tf.keras.utils.plot_model(
265+
# model.network,
266+
# to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_unet.png",
267+
# show_shapes=True,
268+
# show_dtype=False,
269+
# show_layer_names=True,
270+
# rankdir="TB",
271+
# expand_nested=True,
272+
# dpi=args.dpi,
273+
# layer_range=None,
274+
# show_layer_activations=False,
275+
# )
276+
# tf.keras.utils.plot_model(
277+
# model.control_embed_model,
278+
# to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_control_embed.png",
279+
# show_shapes=True,
280+
# show_dtype=False,
281+
# show_layer_names=True,
282+
# rankdir="TB",
283+
# expand_nested=True,
284+
# dpi=args.dpi,
285+
# layer_range=None,
286+
# show_layer_activations=True,
287+
# )
288288
prefix_value = f'{args.output_folder}{args.id}/learning_generations/'
289289
# Create a partial function with reseed and prefix pre-filled
290290
if model.input_map.axes() == 2:

0 commit comments

Comments
 (0)