@@ -261,30 +261,30 @@ def train_diffusion_control_model(args, supervised=False):
261261 model .normalizer .adapt (feature_batch )
262262 if args .inspect_model :
263263 model .network .summary (print_fn = logging .info , expand_nested = True )
264- tf .keras .utils .plot_model (
265- model .network ,
266- to_file = f"{ args .output_folder } /{ args .id } /architecture_{ args .id } _unet.png" ,
267- show_shapes = True ,
268- show_dtype = False ,
269- show_layer_names = True ,
270- rankdir = "TB" ,
271- expand_nested = True ,
272- dpi = args .dpi ,
273- layer_range = None ,
274- show_layer_activations = False ,
275- )
276- tf .keras .utils .plot_model (
277- model .control_embed_model ,
278- to_file = f"{ args .output_folder } /{ args .id } /architecture_{ args .id } _control_embed.png" ,
279- show_shapes = True ,
280- show_dtype = False ,
281- show_layer_names = True ,
282- rankdir = "TB" ,
283- expand_nested = True ,
284- dpi = args .dpi ,
285- layer_range = None ,
286- show_layer_activations = True ,
287- )
264+ # tf.keras.utils.plot_model(
265+ # model.network,
266+ # to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_unet.png",
267+ # show_shapes=True,
268+ # show_dtype=False,
269+ # show_layer_names=True,
270+ # rankdir="TB",
271+ # expand_nested=True,
272+ # dpi=args.dpi,
273+ # layer_range=None,
274+ # show_layer_activations=False,
275+ # )
276+ # tf.keras.utils.plot_model(
277+ # model.control_embed_model,
278+ # to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_control_embed.png",
279+ # show_shapes=True,
280+ # show_dtype=False,
281+ # show_layer_names=True,
282+ # rankdir="TB",
283+ # expand_nested=True,
284+ # dpi=args.dpi,
285+ # layer_range=None,
286+ # show_layer_activations=True,
287+ # )
288288 prefix_value = f'{ args .output_folder } { args .id } /learning_generations/'
289289 # Create a partial function with reseed and prefix pre-filled
290290 if model .input_map .axes () == 2 :
0 commit comments