diff --git a/examples/generative/cyclegan.py b/examples/generative/cyclegan.py index ca208efa88..bef76f5d9b 100644 --- a/examples/generative/cyclegan.py +++ b/examples/generative/cyclegan.py @@ -2,7 +2,7 @@ Title: CycleGAN Author: [A_K_Nain](https://twitter.com/A_K_Nain) Date created: 2020/08/12 -Last modified: 2020/08/12 +Last modified: 2024/09/30 Description: Implementation of CycleGAN. Accelerator: GPU """ @@ -17,7 +17,7 @@ CycleGAN tries to learn this mapping without requiring paired input-output images, using cycle-consistent adversarial networks. -- [Paper](https://arxiv.org/pdf/1703.10593.pdf) +- [Paper](https://arxiv.org/abs/1703.10593) - [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) """ @@ -25,18 +25,18 @@ ## Setup """ - +import os import numpy as np import matplotlib.pyplot as plt import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_addons as tfa +import keras +from keras import layers, ops import tensorflow_datasets as tfds tfds.disable_progress_bar() autotune = tf.data.AUTOTUNE +os.environ["KERAS_BACKEND"] = "tensorflow" """ ## Prepare the dataset @@ -47,7 +47,7 @@ """ # Load the horse-zebra dataset using tensorflow-datasets. -dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True) +dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True) train_horses, train_zebras = dataset["trainA"], dataset["trainB"] test_horses, test_zebras = dataset["testA"], dataset["testB"] @@ -65,7 +65,7 @@ def normalize_img(img): - img = tf.cast(img, dtype=tf.float32) + img = ops.cast(img, dtype=tf.float32) # Map values in the range [-1, 1] return (img / 127.5) - 1.0 @@ -74,7 +74,7 @@ def preprocess_train_image(img, label): # Random flip img = tf.image.random_flip_left_right(img) # Resize to the original size first - img = tf.image.resize(img, [*orig_img_size]) + img = ops.image.resize(img, [*orig_img_size]) # Random crop to 256X256 img = tf.image.random_crop(img, size=[*input_img_size]) # Normalize the pixel values in the range [-1, 1] @@ -84,7 +84,7 @@ def preprocess_train_image(img, label): def preprocess_test_image(img, label): # Only resizing and normalization for the test images. - img = tf.image.resize(img, [input_img_size[0], input_img_size[1]]) + img = ops.image.resize(img, [input_img_size[0], input_img_size[1]]) img = normalize_img(img) return img @@ -165,7 +165,7 @@ def call(self, input_tensor, mask=None): [padding_width, padding_width], [0, 0], ] - return tf.pad(input_tensor, padding_tensor, mode="REFLECT") + return ops.pad(input_tensor, padding_tensor, mode="REFLECT") def residual_block( @@ -190,7 +190,9 @@ def residual_block( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = activation(x) x = ReflectionPadding2D()(x) @@ -202,7 +204,9 @@ def residual_block( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = layers.add([input_tensor, x]) return x @@ -226,7 +230,9 @@ def downsample( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) if activation: x = activation(x) return x @@ -251,7 +257,9 @@ def upsample( kernel_initializer=kernel_initializer, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) if activation: x = activation(x) return x @@ -298,7 +306,9 @@ def get_resnet_generator( x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)( x ) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = layers.Activation("relu")(x) # Downsampling @@ -581,14 +591,14 @@ def on_epoch_end(self, epoch, logs=None): def generator_loss_fn(fake): - fake_loss = adv_loss_fn(tf.ones_like(fake), fake) + fake_loss = adv_loss_fn(ops.ones_like(fake), fake) return fake_loss # Define the loss function for the discriminators def discriminator_loss_fn(real, fake): - real_loss = adv_loss_fn(tf.ones_like(real), real) - fake_loss = adv_loss_fn(tf.zeros_like(fake), fake) + real_loss = adv_loss_fn(ops.ones_like(real), real) + fake_loss = adv_loss_fn(ops.zeros_like(fake), fake) return (real_loss + fake_loss) * 0.5 @@ -599,16 +609,16 @@ def discriminator_loss_fn(real, fake): # Compile the model cycle_gan_model.compile( - gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), + gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), gen_loss_fn=generator_loss_fn, disc_loss_fn=discriminator_loss_fn, ) # Callbacks plotter = GANMonitor() -checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}" +checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5" model_checkpoint_callback = keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True ) @@ -617,31 +627,20 @@ def discriminator_loss_fn(real, fake): # 7 minutes on a single P100 backed machine. cycle_gan_model.fit( tf.data.Dataset.zip((train_horses, train_zebras)), - epochs=1, + epochs=90, callbacks=[plotter, model_checkpoint_callback], ) """ Test the performance of the model. - -You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN) -and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN). """ -# This model was trained for 90 epochs. We will be loading those weights -# here. Once the weights are loaded, we will take a few samples from the test -# data and check the model's performance. - -"""shell -curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip -unzip -qq saved_checkpoints.zip -""" +# Once the weights are loaded, we will take a few samples from the test data and check the model's performance. # Load the checkpoints -weight_file = "./saved_checkpoints/cyclegan_checkpoints.090" -cycle_gan_model.load_weights(weight_file).expect_partial() +cycle_gan_model.load_weights(checkpoint_filepath) print("Weights loaded successfully") _, ax = plt.subplots(4, 2, figsize=(10, 15)) diff --git a/examples/generative/img/cyclegan/cyclegan_21_1069.png b/examples/generative/img/cyclegan/cyclegan_21_1069.png new file mode 100644 index 0000000000..0f077a8568 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_1069.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_10699.png b/examples/generative/img/cyclegan/cyclegan_21_10699.png new file mode 100644 index 0000000000..88f3bf2574 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_10699.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_11769.png b/examples/generative/img/cyclegan/cyclegan_21_11769.png new file mode 100644 index 0000000000..9a92e1502f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_11769.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_12839.png b/examples/generative/img/cyclegan/cyclegan_21_12839.png new file mode 100644 index 0000000000..a3568b2fc2 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_12839.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_13909.png b/examples/generative/img/cyclegan/cyclegan_21_13909.png new file mode 100644 index 0000000000..f336b92899 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_13909.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_14979.png b/examples/generative/img/cyclegan/cyclegan_21_14979.png new file mode 100644 index 0000000000..d3fbad6fd7 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_14979.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_16049.png b/examples/generative/img/cyclegan/cyclegan_21_16049.png new file mode 100644 index 0000000000..2371b01b6f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_16049.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_17119.png b/examples/generative/img/cyclegan/cyclegan_21_17119.png new file mode 100644 index 0000000000..5a5804c8be Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_17119.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_18189.png b/examples/generative/img/cyclegan/cyclegan_21_18189.png new file mode 100644 index 0000000000..6a511b56d6 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_18189.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_19259.png b/examples/generative/img/cyclegan/cyclegan_21_19259.png new file mode 100644 index 0000000000..595bddc345 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_19259.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_20329.png b/examples/generative/img/cyclegan/cyclegan_21_20329.png new file mode 100644 index 0000000000..2f5879f971 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_20329.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_2139.png b/examples/generative/img/cyclegan/cyclegan_21_2139.png new file mode 100644 index 0000000000..e214d8b0c3 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_2139.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_21399.png b/examples/generative/img/cyclegan/cyclegan_21_21399.png new file mode 100644 index 0000000000..923ed3098c Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_21399.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_22469.png b/examples/generative/img/cyclegan/cyclegan_21_22469.png new file mode 100644 index 0000000000..597a6dc34d Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_22469.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_23539.png b/examples/generative/img/cyclegan/cyclegan_21_23539.png new file mode 100644 index 0000000000..107b7373f0 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_23539.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_24609.png b/examples/generative/img/cyclegan/cyclegan_21_24609.png new file mode 100644 index 0000000000..22ce7f84b0 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_24609.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_25679.png b/examples/generative/img/cyclegan/cyclegan_21_25679.png new file mode 100644 index 0000000000..544843541f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_25679.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_26749.png b/examples/generative/img/cyclegan/cyclegan_21_26749.png new file mode 100644 index 0000000000..5b1fcbb77a Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_26749.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_27819.png b/examples/generative/img/cyclegan/cyclegan_21_27819.png new file mode 100644 index 0000000000..6be4bd72d9 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_27819.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_28889.png b/examples/generative/img/cyclegan/cyclegan_21_28889.png new file mode 100644 index 0000000000..fc94fb0243 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_28889.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_29959.png b/examples/generative/img/cyclegan/cyclegan_21_29959.png new file mode 100644 index 0000000000..5c460e436e Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_29959.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_31029.png b/examples/generative/img/cyclegan/cyclegan_21_31029.png new file mode 100644 index 0000000000..2c8d40037f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_31029.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_3209.png b/examples/generative/img/cyclegan/cyclegan_21_3209.png new file mode 100644 index 0000000000..2bd7d0bf1d Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_3209.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_32099.png b/examples/generative/img/cyclegan/cyclegan_21_32099.png new file mode 100644 index 0000000000..98d12e6469 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_32099.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_33169.png b/examples/generative/img/cyclegan/cyclegan_21_33169.png new file mode 100644 index 0000000000..9837dcf2e4 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_33169.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_34239.png b/examples/generative/img/cyclegan/cyclegan_21_34239.png new file mode 100644 index 0000000000..6c79d1ba3f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_34239.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_35309.png b/examples/generative/img/cyclegan/cyclegan_21_35309.png new file mode 100644 index 0000000000..0c45d26f0a Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_35309.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_36379.png b/examples/generative/img/cyclegan/cyclegan_21_36379.png new file mode 100644 index 0000000000..d114388196 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_36379.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_37449.png b/examples/generative/img/cyclegan/cyclegan_21_37449.png new file mode 100644 index 0000000000..c766263b48 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_37449.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_38519.png b/examples/generative/img/cyclegan/cyclegan_21_38519.png new file mode 100644 index 0000000000..88b535d5c4 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_38519.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_39589.png b/examples/generative/img/cyclegan/cyclegan_21_39589.png new file mode 100644 index 0000000000..7667223689 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_39589.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_40659.png b/examples/generative/img/cyclegan/cyclegan_21_40659.png new file mode 100644 index 0000000000..2ca1af93ec Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_40659.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_41729.png b/examples/generative/img/cyclegan/cyclegan_21_41729.png new file mode 100644 index 0000000000..70a388f8de Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_41729.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_4279.png b/examples/generative/img/cyclegan/cyclegan_21_4279.png new file mode 100644 index 0000000000..387aaba129 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_4279.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_42799.png b/examples/generative/img/cyclegan/cyclegan_21_42799.png new file mode 100644 index 0000000000..c2d8a5431b Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_42799.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_43869.png b/examples/generative/img/cyclegan/cyclegan_21_43869.png new file mode 100644 index 0000000000..85d642c066 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_43869.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_44939.png b/examples/generative/img/cyclegan/cyclegan_21_44939.png new file mode 100644 index 0000000000..803e6d6d5f Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_44939.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_46009.png b/examples/generative/img/cyclegan/cyclegan_21_46009.png new file mode 100644 index 0000000000..749c639f0e Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_46009.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_47079.png b/examples/generative/img/cyclegan/cyclegan_21_47079.png new file mode 100644 index 0000000000..4c0a595269 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_47079.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_48149.png b/examples/generative/img/cyclegan/cyclegan_21_48149.png new file mode 100644 index 0000000000..d703272a28 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_48149.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_49219.png b/examples/generative/img/cyclegan/cyclegan_21_49219.png new file mode 100644 index 0000000000..aaf8b6ad51 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_49219.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_50289.png b/examples/generative/img/cyclegan/cyclegan_21_50289.png new file mode 100644 index 0000000000..e7755b9cd5 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_50289.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_51359.png b/examples/generative/img/cyclegan/cyclegan_21_51359.png new file mode 100644 index 0000000000..3be7ada979 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_51359.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_52429.png b/examples/generative/img/cyclegan/cyclegan_21_52429.png new file mode 100644 index 0000000000..d0875380ed Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_52429.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_5349.png b/examples/generative/img/cyclegan/cyclegan_21_5349.png new file mode 100644 index 0000000000..7aa34fe34c Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_5349.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_53499.png b/examples/generative/img/cyclegan/cyclegan_21_53499.png new file mode 100644 index 0000000000..f39210dc62 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_53499.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_54569.png b/examples/generative/img/cyclegan/cyclegan_21_54569.png new file mode 100644 index 0000000000..37ea4a9b6c Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_54569.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_55639.png b/examples/generative/img/cyclegan/cyclegan_21_55639.png new file mode 100644 index 0000000000..55b582ee87 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_55639.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_56709.png b/examples/generative/img/cyclegan/cyclegan_21_56709.png new file mode 100644 index 0000000000..267ab05977 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_56709.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_57779.png b/examples/generative/img/cyclegan/cyclegan_21_57779.png new file mode 100644 index 0000000000..44a34cc8c2 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_57779.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_58849.png b/examples/generative/img/cyclegan/cyclegan_21_58849.png new file mode 100644 index 0000000000..75846f78bb Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_58849.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_59919.png b/examples/generative/img/cyclegan/cyclegan_21_59919.png new file mode 100644 index 0000000000..00f1053fe5 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_59919.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_60989.png b/examples/generative/img/cyclegan/cyclegan_21_60989.png new file mode 100644 index 0000000000..1ddcb5ec7d Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_60989.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_62059.png b/examples/generative/img/cyclegan/cyclegan_21_62059.png new file mode 100644 index 0000000000..895be698e9 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_62059.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_63129.png b/examples/generative/img/cyclegan/cyclegan_21_63129.png new file mode 100644 index 0000000000..76afc7f569 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_63129.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_6419.png b/examples/generative/img/cyclegan/cyclegan_21_6419.png new file mode 100644 index 0000000000..4f7dcb19f7 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_6419.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_64199.png b/examples/generative/img/cyclegan/cyclegan_21_64199.png new file mode 100644 index 0000000000..626fd2efc8 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_64199.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_65269.png b/examples/generative/img/cyclegan/cyclegan_21_65269.png new file mode 100644 index 0000000000..1fd873391b Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_65269.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_66339.png b/examples/generative/img/cyclegan/cyclegan_21_66339.png new file mode 100644 index 0000000000..4d42bfba05 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_66339.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_67409.png b/examples/generative/img/cyclegan/cyclegan_21_67409.png new file mode 100644 index 0000000000..8f5b6cfafb Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_67409.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_68479.png b/examples/generative/img/cyclegan/cyclegan_21_68479.png new file mode 100644 index 0000000000..b8b60c62c4 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_68479.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_69549.png b/examples/generative/img/cyclegan/cyclegan_21_69549.png new file mode 100644 index 0000000000..109ce94e16 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_69549.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_70619.png b/examples/generative/img/cyclegan/cyclegan_21_70619.png new file mode 100644 index 0000000000..8ada414e32 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_70619.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_71689.png b/examples/generative/img/cyclegan/cyclegan_21_71689.png new file mode 100644 index 0000000000..6c4357b1c4 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_71689.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_72759.png b/examples/generative/img/cyclegan/cyclegan_21_72759.png new file mode 100644 index 0000000000..2ef23d0981 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_72759.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_73829.png b/examples/generative/img/cyclegan/cyclegan_21_73829.png new file mode 100644 index 0000000000..51e6f49dc9 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_73829.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_7489.png b/examples/generative/img/cyclegan/cyclegan_21_7489.png new file mode 100644 index 0000000000..7083a5f7c1 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_7489.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_74899.png b/examples/generative/img/cyclegan/cyclegan_21_74899.png new file mode 100644 index 0000000000..f1791950ca Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_74899.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_75969.png b/examples/generative/img/cyclegan/cyclegan_21_75969.png new file mode 100644 index 0000000000..161c1faed9 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_75969.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_77039.png b/examples/generative/img/cyclegan/cyclegan_21_77039.png new file mode 100644 index 0000000000..e83f908358 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_77039.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_78109.png b/examples/generative/img/cyclegan/cyclegan_21_78109.png new file mode 100644 index 0000000000..ffd1e45d0a Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_78109.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_79179.png b/examples/generative/img/cyclegan/cyclegan_21_79179.png new file mode 100644 index 0000000000..e86690ce1a Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_79179.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_80249.png b/examples/generative/img/cyclegan/cyclegan_21_80249.png new file mode 100644 index 0000000000..27203d37bb Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_80249.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_81319.png b/examples/generative/img/cyclegan/cyclegan_21_81319.png new file mode 100644 index 0000000000..dbcd08015c Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_81319.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_82389.png b/examples/generative/img/cyclegan/cyclegan_21_82389.png new file mode 100644 index 0000000000..f62b2ed717 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_82389.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_83459.png b/examples/generative/img/cyclegan/cyclegan_21_83459.png new file mode 100644 index 0000000000..074535c4ba Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_83459.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_84529.png b/examples/generative/img/cyclegan/cyclegan_21_84529.png new file mode 100644 index 0000000000..c82a10c8d5 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_84529.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_8559.png b/examples/generative/img/cyclegan/cyclegan_21_8559.png new file mode 100644 index 0000000000..bb60de0cc3 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_8559.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_85599.png b/examples/generative/img/cyclegan/cyclegan_21_85599.png new file mode 100644 index 0000000000..0ec3509c94 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_85599.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_86669.png b/examples/generative/img/cyclegan/cyclegan_21_86669.png new file mode 100644 index 0000000000..0a394b12bb Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_86669.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_87739.png b/examples/generative/img/cyclegan/cyclegan_21_87739.png new file mode 100644 index 0000000000..aeb540e795 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_87739.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_88809.png b/examples/generative/img/cyclegan/cyclegan_21_88809.png new file mode 100644 index 0000000000..dbb4a79ff6 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_88809.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_89879.png b/examples/generative/img/cyclegan/cyclegan_21_89879.png new file mode 100644 index 0000000000..705f371f96 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_89879.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_90949.png b/examples/generative/img/cyclegan/cyclegan_21_90949.png new file mode 100644 index 0000000000..1c538e5cd8 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_90949.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_92019.png b/examples/generative/img/cyclegan/cyclegan_21_92019.png new file mode 100644 index 0000000000..7d57a0a90c Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_92019.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_93089.png b/examples/generative/img/cyclegan/cyclegan_21_93089.png new file mode 100644 index 0000000000..f4199f24ff Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_93089.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_94159.png b/examples/generative/img/cyclegan/cyclegan_21_94159.png new file mode 100644 index 0000000000..a7a1477786 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_94159.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_95229.png b/examples/generative/img/cyclegan/cyclegan_21_95229.png new file mode 100644 index 0000000000..3ac98c2a6a Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_95229.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_9629.png b/examples/generative/img/cyclegan/cyclegan_21_9629.png new file mode 100644 index 0000000000..a15295f900 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_9629.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_21_96299.png b/examples/generative/img/cyclegan/cyclegan_21_96299.png new file mode 100644 index 0000000000..e2abf8a803 Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_96299.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_23_1.png b/examples/generative/img/cyclegan/cyclegan_23_1.png new file mode 100644 index 0000000000..52f4b7e32b Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_23_1.png differ diff --git a/examples/generative/img/cyclegan/cyclegan_9_0.png b/examples/generative/img/cyclegan/cyclegan_9_0.png index 8c03be0e98..886066b8e2 100644 Binary files a/examples/generative/img/cyclegan/cyclegan_9_0.png and b/examples/generative/img/cyclegan/cyclegan_9_0.png differ diff --git a/examples/generative/ipynb/cyclegan.ipynb b/examples/generative/ipynb/cyclegan.ipynb index 20841f8d5a..616f1521c1 100644 --- a/examples/generative/ipynb/cyclegan.ipynb +++ b/examples/generative/ipynb/cyclegan.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)
\n", "**Date created:** 2020/08/12
\n", - "**Last modified:** 2020/08/12
\n", + "**Last modified:** 2024/09/30
\n", "**Description:** Implementation of CycleGAN." ] }, @@ -29,7 +29,7 @@ "CycleGAN tries to learn this mapping without requiring paired input-output images,\n", "using cycle-consistent adversarial networks.\n", "\n", - "- [Paper](https://arxiv.org/pdf/1703.10593.pdf)\n", + "- [Paper](https://arxiv.org/abs/1703.10593)\n", "- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)" ] }, @@ -50,18 +50,18 @@ }, "outputs": [], "source": [ + "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "\n", "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", - "\n", - "import tensorflow_addons as tfa\n", + "import keras\n", + "from keras import layers, ops\n", "import tensorflow_datasets as tfds\n", "\n", "tfds.disable_progress_bar()\n", - "autotune = tf.data.AUTOTUNE\n" + "autotune = tf.data.AUTOTUNE\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" ] }, { @@ -86,7 +86,7 @@ "outputs": [], "source": [ "# Load the horse-zebra dataset using tensorflow-datasets.\n", - "dataset, _ = tfds.load(\"cycle_gan/horse2zebra\", with_info=True, as_supervised=True)\n", + "dataset, _ = tfds.load(name=\"cycle_gan/horse2zebra\", with_info=True, as_supervised=True)\n", "train_horses, train_zebras = dataset[\"trainA\"], dataset[\"trainB\"]\n", "test_horses, test_zebras = dataset[\"testA\"], dataset[\"testB\"]\n", "\n", @@ -104,7 +104,7 @@ "\n", "\n", "def normalize_img(img):\n", - " img = tf.cast(img, dtype=tf.float32)\n", + " img = ops.cast(img, dtype=tf.float32)\n", " # Map values in the range [-1, 1]\n", " return (img / 127.5) - 1.0\n", "\n", @@ -113,7 +113,7 @@ " # Random flip\n", " img = tf.image.random_flip_left_right(img)\n", " # Resize to the original size first\n", - " img = tf.image.resize(img, [*orig_img_size])\n", + " img = ops.image.resize(img, [*orig_img_size])\n", " # Random crop to 256X256\n", " img = tf.image.random_crop(img, size=[*input_img_size])\n", " # Normalize the pixel values in the range [-1, 1]\n", @@ -123,7 +123,7 @@ "\n", "def preprocess_test_image(img, label):\n", " # Only resizing and normalization for the test images.\n", - " img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])\n", + " img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])\n", " img = normalize_img(img)\n", " return img\n" ] @@ -243,7 +243,7 @@ " [padding_width, padding_width],\n", " [0, 0],\n", " ]\n", - " return tf.pad(input_tensor, padding_tensor, mode=\"REFLECT\")\n", + " return ops.pad(input_tensor, padding_tensor, mode=\"REFLECT\")\n", "\n", "\n", "def residual_block(\n", @@ -268,7 +268,9 @@ " padding=padding,\n", " use_bias=use_bias,\n", " )(x)\n", - " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", + " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n", + " x\n", + " )\n", " x = activation(x)\n", "\n", " x = ReflectionPadding2D()(x)\n", @@ -280,7 +282,9 @@ " padding=padding,\n", " use_bias=use_bias,\n", " )(x)\n", - " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", + " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n", + " x\n", + " )\n", " x = layers.add([input_tensor, x])\n", " return x\n", "\n", @@ -304,7 +308,9 @@ " padding=padding,\n", " use_bias=use_bias,\n", " )(x)\n", - " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", + " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n", + " x\n", + " )\n", " if activation:\n", " x = activation(x)\n", " return x\n", @@ -329,7 +335,9 @@ " kernel_initializer=kernel_initializer,\n", " use_bias=use_bias,\n", " )(x)\n", - " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", + " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n", + " x\n", + " )\n", " if activation:\n", " x = activation(x)\n", " return x\n" @@ -389,7 +397,9 @@ " x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(\n", " x\n", " )\n", - " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", + " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n", + " x\n", + " )\n", " x = layers.Activation(\"relu\")(x)\n", "\n", " # Downsampling\n", @@ -721,15 +731,17 @@ "adv_loss_fn = keras.losses.MeanSquaredError()\n", "\n", "# Define the loss function for the generators\n", + "\n", + "\n", "def generator_loss_fn(fake):\n", - " fake_loss = adv_loss_fn(tf.ones_like(fake), fake)\n", + " fake_loss = adv_loss_fn(ops.ones_like(fake), fake)\n", " return fake_loss\n", "\n", "\n", "# Define the loss function for the discriminators\n", "def discriminator_loss_fn(real, fake):\n", - " real_loss = adv_loss_fn(tf.ones_like(real), real)\n", - " fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)\n", + " real_loss = adv_loss_fn(ops.ones_like(real), real)\n", + " fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)\n", " return (real_loss + fake_loss) * 0.5\n", "\n", "\n", @@ -740,26 +752,25 @@ "\n", "# Compile the model\n", "cycle_gan_model.compile(\n", - " gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n", - " gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n", - " disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n", - " disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n", + " gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n", + " gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n", + " disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n", + " disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n", " gen_loss_fn=generator_loss_fn,\n", " disc_loss_fn=discriminator_loss_fn,\n", ")\n", "# Callbacks\n", "plotter = GANMonitor()\n", - "checkpoint_filepath = \"./model_checkpoints/cyclegan_checkpoints.{epoch:03d}\"\n", + "checkpoint_filepath = \"./model_checkpoints/cyclegan_checkpoints.weights.h5\"\n", "model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\n", - " filepath=checkpoint_filepath,\n", - " save_weights_only=True\n", + " filepath=checkpoint_filepath, save_weights_only=True\n", ")\n", "\n", "# Here we will train the model for just one epoch as each epoch takes around\n", "# 7 minutes on a single P100 backed machine.\n", "cycle_gan_model.fit(\n", " tf.data.Dataset.zip((train_horses, train_zebras)),\n", - " epochs=1,\n", + " epochs=90,\n", " callbacks=[plotter, model_checkpoint_callback],\n", ")" ] @@ -770,10 +781,7 @@ "colab_type": "text" }, "source": [ - "Test the performance of the model.\n", - "\n", - "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)\n", - "and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN)." + "Test the performance of the model." ] }, { @@ -785,35 +793,11 @@ "outputs": [], "source": [ "\n", - "# This model was trained for 90 epochs. We will be loading those weights\n", - "# here. Once the weights are loaded, we will take a few samples from the test\n", - "# data and check the model's performance." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "!curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip\n", - "!unzip -qq saved_checkpoints.zip" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ + "# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.\n", + "\n", "\n", "# Load the checkpoints\n", - "weight_file = \"./saved_checkpoints/cyclegan_checkpoints.090\"\n", - "cycle_gan_model.load_weights(weight_file).expect_partial()\n", + "cycle_gan_model.load_weights(checkpoint_filepath)\n", "print(\"Weights loaded successfully\")\n", "\n", "_, ax = plt.subplots(4, 2, figsize=(10, 15))\n", diff --git a/examples/generative/md/cyclegan.md b/examples/generative/md/cyclegan.md index a9857ed082..b27b7f0f54 100644 --- a/examples/generative/md/cyclegan.md +++ b/examples/generative/md/cyclegan.md @@ -2,7 +2,7 @@ **Author:** [A_K_Nain](https://twitter.com/A_K_Nain)
**Date created:** 2020/08/12
-**Last modified:** 2020/08/12
+**Last modified:** 2024/09/30
**Description:** Implementation of CycleGAN. @@ -20,7 +20,7 @@ aligned image pairs. However, obtaining paired examples isn't always feasible. CycleGAN tries to learn this mapping without requiring paired input-output images, using cycle-consistent adversarial networks. -- [Paper](https://arxiv.org/pdf/1703.10593.pdf) +- [Paper](https://arxiv.org/abs/1703.10593) - [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) --- @@ -28,19 +28,18 @@ using cycle-consistent adversarial networks. ```python +import os import numpy as np import matplotlib.pyplot as plt - import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers - -import tensorflow_addons as tfa +import keras +from keras import layers, ops import tensorflow_datasets as tfds tfds.disable_progress_bar() autotune = tf.data.AUTOTUNE +os.environ["KERAS_BACKEND"] = "tensorflow" ``` --- @@ -53,7 +52,7 @@ dataset. ```python # Load the horse-zebra dataset using tensorflow-datasets. -dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True) +dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True) train_horses, train_zebras = dataset["trainA"], dataset["trainB"] test_horses, test_zebras = dataset["testA"], dataset["testB"] @@ -71,7 +70,7 @@ batch_size = 1 def normalize_img(img): - img = tf.cast(img, dtype=tf.float32) + img = ops.cast(img, dtype=tf.float32) # Map values in the range [-1, 1] return (img / 127.5) - 1.0 @@ -80,7 +79,7 @@ def preprocess_train_image(img, label): # Random flip img = tf.image.random_flip_left_right(img) # Resize to the original size first - img = tf.image.resize(img, [*orig_img_size]) + img = ops.image.resize(img, [*orig_img_size]) # Random crop to 256X256 img = tf.image.random_crop(img, size=[*input_img_size]) # Normalize the pixel values in the range [-1, 1] @@ -90,7 +89,7 @@ def preprocess_train_image(img, label): def preprocess_test_image(img, label): # Only resizing and normalization for the test images. - img = tf.image.resize(img, [input_img_size[0], input_img_size[1]]) + img = ops.image.resize(img, [input_img_size[0], input_img_size[1]]) img = normalize_img(img) return img @@ -149,7 +148,9 @@ plt.show() ``` + ![png](/img/examples/generative/cyclegan/cyclegan_9_0.png) + --- @@ -181,7 +182,7 @@ class ReflectionPadding2D(layers.Layer): [padding_width, padding_width], [0, 0], ] - return tf.pad(input_tensor, padding_tensor, mode="REFLECT") + return ops.pad(input_tensor, padding_tensor, mode="REFLECT") def residual_block( @@ -206,7 +207,9 @@ def residual_block( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = activation(x) x = ReflectionPadding2D()(x) @@ -218,7 +221,9 @@ def residual_block( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = layers.add([input_tensor, x]) return x @@ -242,7 +247,9 @@ def downsample( padding=padding, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) if activation: x = activation(x) return x @@ -267,7 +274,9 @@ def upsample( kernel_initializer=kernel_initializer, use_bias=use_bias, )(x) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) if activation: x = activation(x) return x @@ -316,7 +325,9 @@ def get_resnet_generator( x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)( x ) - x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) + x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( + x + ) x = layers.Activation("relu")(x) # Downsampling @@ -430,6 +441,14 @@ class CycleGan(keras.Model): self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity + def call(self, inputs): + return ( + self.disc_X(inputs), + self.disc_Y(inputs), + self.gen_G(inputs), + self.gen_F(inputs), + ) + def compile( self, gen_G_optimizer, @@ -596,15 +615,17 @@ class GANMonitor(keras.callbacks.Callback): adv_loss_fn = keras.losses.MeanSquaredError() # Define the loss function for the generators + + def generator_loss_fn(fake): - fake_loss = adv_loss_fn(tf.ones_like(fake), fake) + fake_loss = adv_loss_fn(ops.ones_like(fake), fake) return fake_loss # Define the loss function for the discriminators def discriminator_loss_fn(real, fake): - real_loss = adv_loss_fn(tf.ones_like(real), real) - fake_loss = adv_loss_fn(tf.zeros_like(fake), fake) + real_loss = adv_loss_fn(ops.ones_like(real), real) + fake_loss = adv_loss_fn(ops.zeros_like(fake), fake) return (real_loss + fake_loss) * 0.5 @@ -615,16 +636,16 @@ cycle_gan_model = CycleGan( # Compile the model cycle_gan_model.compile( - gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), - disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5), + gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), + disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), gen_loss_fn=generator_loss_fn, disc_loss_fn=discriminator_loss_fn, ) # Callbacks plotter = GANMonitor() -checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}" +checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5" model_checkpoint_callback = keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True ) @@ -633,49 +654,21 @@ model_checkpoint_callback = keras.callbacks.ModelCheckpoint( # 7 minutes on a single P100 backed machine. cycle_gan_model.fit( tf.data.Dataset.zip((train_horses, train_zebras)), - epochs=1, + epochs=90, callbacks=[plotter, model_checkpoint_callback], ) ``` -
-``` -1067/1067 [==============================] - ETA: 0s - G_loss: 4.4794 - F_loss: 4.1048 - D_X_loss: 0.1584 - D_Y_loss: 0.1233 - -``` -
-![png](/img/examples/generative/cyclegan/cyclegan_21_1.png) - - -
-``` -1067/1067 [==============================] - 390s 366ms/step - G_loss: 4.4783 - F_loss: 4.1035 - D_X_loss: 0.1584 - D_Y_loss: 0.1232 - - - -``` -
Test the performance of the model. -You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN). -```python - -# This model was trained for 90 epochs. We will be loading those weights -# here. Once the weights are loaded, we will take a few samples from the test -# data and check the model's performance. -``` - ```python -!curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip -!unzip -qq saved_checkpoints.zip -``` -```python +# Once the weights are loaded, we will take a few samples from the test data and check the model's performance. + # Load the checkpoints -weight_file = "./saved_checkpoints/cyclegan_checkpoints.090" -cycle_gan_model.load_weights(weight_file).expect_partial() +cycle_gan_model.load_weights(checkpoint_filepath) print("Weights loaded successfully") _, ax = plt.subplots(4, 2, figsize=(10, 15)) @@ -697,15 +690,14 @@ for i, img in enumerate(test_horses.take(4)): plt.tight_layout() plt.show() ``` +
``` - % Total % Received % Xferd Average Speed Time Time Time Current - Dload Upload Total Spent Left Speed -100 634 100 634 0 0 2874 0 --:--:-- --:--:-- --:--:-- 2881 -100 273M 100 273M 0 0 1736k 0 0:02:41 0:02:41 --:--:-- 2049k - Weights loaded successfully ```
-![png](/img/examples/generative/cyclegan/cyclegan_25_1.png) + +![png](/img/examples/generative/cyclegan/cyclegan_23_1.png) + + diff --git a/scripts/examples_master.py b/scripts/examples_master.py index 6477e994a4..8c7c0de09e 100644 --- a/scripts/examples_master.py +++ b/scripts/examples_master.py @@ -759,6 +759,7 @@ "path": "cyclegan", "title": "CycleGAN", "subcategory": "Image generation", + "keras_3": True, }, { "path": "gan_ada",