diff --git a/examples/generative/cyclegan.py b/examples/generative/cyclegan.py
index ca208efa88..bef76f5d9b 100644
--- a/examples/generative/cyclegan.py
+++ b/examples/generative/cyclegan.py
@@ -2,7 +2,7 @@
Title: CycleGAN
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
Date created: 2020/08/12
-Last modified: 2020/08/12
+Last modified: 2024/09/30
Description: Implementation of CycleGAN.
Accelerator: GPU
"""
@@ -17,7 +17,7 @@
CycleGAN tries to learn this mapping without requiring paired input-output images,
using cycle-consistent adversarial networks.
-- [Paper](https://arxiv.org/pdf/1703.10593.pdf)
+- [Paper](https://arxiv.org/abs/1703.10593)
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
"""
@@ -25,18 +25,18 @@
## Setup
"""
-
+import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-import tensorflow_addons as tfa
+import keras
+from keras import layers, ops
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
autotune = tf.data.AUTOTUNE
+os.environ["KERAS_BACKEND"] = "tensorflow"
"""
## Prepare the dataset
@@ -47,7 +47,7 @@
"""
# Load the horse-zebra dataset using tensorflow-datasets.
-dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
+dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
test_horses, test_zebras = dataset["testA"], dataset["testB"]
@@ -65,7 +65,7 @@
def normalize_img(img):
- img = tf.cast(img, dtype=tf.float32)
+ img = ops.cast(img, dtype=tf.float32)
# Map values in the range [-1, 1]
return (img / 127.5) - 1.0
@@ -74,7 +74,7 @@ def preprocess_train_image(img, label):
# Random flip
img = tf.image.random_flip_left_right(img)
# Resize to the original size first
- img = tf.image.resize(img, [*orig_img_size])
+ img = ops.image.resize(img, [*orig_img_size])
# Random crop to 256X256
img = tf.image.random_crop(img, size=[*input_img_size])
# Normalize the pixel values in the range [-1, 1]
@@ -84,7 +84,7 @@ def preprocess_train_image(img, label):
def preprocess_test_image(img, label):
# Only resizing and normalization for the test images.
- img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
+ img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
img = normalize_img(img)
return img
@@ -165,7 +165,7 @@ def call(self, input_tensor, mask=None):
[padding_width, padding_width],
[0, 0],
]
- return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
+ return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
def residual_block(
@@ -190,7 +190,9 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = activation(x)
x = ReflectionPadding2D()(x)
@@ -202,7 +204,9 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = layers.add([input_tensor, x])
return x
@@ -226,7 +230,9 @@ def downsample(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
if activation:
x = activation(x)
return x
@@ -251,7 +257,9 @@ def upsample(
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
if activation:
x = activation(x)
return x
@@ -298,7 +306,9 @@ def get_resnet_generator(
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = layers.Activation("relu")(x)
# Downsampling
@@ -581,14 +591,14 @@ def on_epoch_end(self, epoch, logs=None):
def generator_loss_fn(fake):
- fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
+ fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
return fake_loss
# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
- real_loss = adv_loss_fn(tf.ones_like(real), real)
- fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
+ real_loss = adv_loss_fn(ops.ones_like(real), real)
+ fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
return (real_loss + fake_loss) * 0.5
@@ -599,16 +609,16 @@ def discriminator_loss_fn(real, fake):
# Compile the model
cycle_gan_model.compile(
- gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_loss_fn=generator_loss_fn,
disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()
-checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
+checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath, save_weights_only=True
)
@@ -617,31 +627,20 @@ def discriminator_loss_fn(real, fake):
# 7 minutes on a single P100 backed machine.
cycle_gan_model.fit(
tf.data.Dataset.zip((train_horses, train_zebras)),
- epochs=1,
+ epochs=90,
callbacks=[plotter, model_checkpoint_callback],
)
"""
Test the performance of the model.
-
-You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
-and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
"""
-# This model was trained for 90 epochs. We will be loading those weights
-# here. Once the weights are loaded, we will take a few samples from the test
-# data and check the model's performance.
-
-"""shell
-curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
-unzip -qq saved_checkpoints.zip
-"""
+# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
# Load the checkpoints
-weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
-cycle_gan_model.load_weights(weight_file).expect_partial()
+cycle_gan_model.load_weights(checkpoint_filepath)
print("Weights loaded successfully")
_, ax = plt.subplots(4, 2, figsize=(10, 15))
diff --git a/examples/generative/img/cyclegan/cyclegan_21_1069.png b/examples/generative/img/cyclegan/cyclegan_21_1069.png
new file mode 100644
index 0000000000..0f077a8568
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_1069.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_10699.png b/examples/generative/img/cyclegan/cyclegan_21_10699.png
new file mode 100644
index 0000000000..88f3bf2574
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_10699.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_11769.png b/examples/generative/img/cyclegan/cyclegan_21_11769.png
new file mode 100644
index 0000000000..9a92e1502f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_11769.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_12839.png b/examples/generative/img/cyclegan/cyclegan_21_12839.png
new file mode 100644
index 0000000000..a3568b2fc2
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_12839.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_13909.png b/examples/generative/img/cyclegan/cyclegan_21_13909.png
new file mode 100644
index 0000000000..f336b92899
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_13909.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_14979.png b/examples/generative/img/cyclegan/cyclegan_21_14979.png
new file mode 100644
index 0000000000..d3fbad6fd7
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_14979.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_16049.png b/examples/generative/img/cyclegan/cyclegan_21_16049.png
new file mode 100644
index 0000000000..2371b01b6f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_16049.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_17119.png b/examples/generative/img/cyclegan/cyclegan_21_17119.png
new file mode 100644
index 0000000000..5a5804c8be
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_17119.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_18189.png b/examples/generative/img/cyclegan/cyclegan_21_18189.png
new file mode 100644
index 0000000000..6a511b56d6
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_18189.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_19259.png b/examples/generative/img/cyclegan/cyclegan_21_19259.png
new file mode 100644
index 0000000000..595bddc345
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_19259.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_20329.png b/examples/generative/img/cyclegan/cyclegan_21_20329.png
new file mode 100644
index 0000000000..2f5879f971
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_20329.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_2139.png b/examples/generative/img/cyclegan/cyclegan_21_2139.png
new file mode 100644
index 0000000000..e214d8b0c3
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_2139.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_21399.png b/examples/generative/img/cyclegan/cyclegan_21_21399.png
new file mode 100644
index 0000000000..923ed3098c
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_21399.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_22469.png b/examples/generative/img/cyclegan/cyclegan_21_22469.png
new file mode 100644
index 0000000000..597a6dc34d
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_22469.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_23539.png b/examples/generative/img/cyclegan/cyclegan_21_23539.png
new file mode 100644
index 0000000000..107b7373f0
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_23539.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_24609.png b/examples/generative/img/cyclegan/cyclegan_21_24609.png
new file mode 100644
index 0000000000..22ce7f84b0
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_24609.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_25679.png b/examples/generative/img/cyclegan/cyclegan_21_25679.png
new file mode 100644
index 0000000000..544843541f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_25679.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_26749.png b/examples/generative/img/cyclegan/cyclegan_21_26749.png
new file mode 100644
index 0000000000..5b1fcbb77a
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_26749.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_27819.png b/examples/generative/img/cyclegan/cyclegan_21_27819.png
new file mode 100644
index 0000000000..6be4bd72d9
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_27819.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_28889.png b/examples/generative/img/cyclegan/cyclegan_21_28889.png
new file mode 100644
index 0000000000..fc94fb0243
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_28889.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_29959.png b/examples/generative/img/cyclegan/cyclegan_21_29959.png
new file mode 100644
index 0000000000..5c460e436e
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_29959.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_31029.png b/examples/generative/img/cyclegan/cyclegan_21_31029.png
new file mode 100644
index 0000000000..2c8d40037f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_31029.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_3209.png b/examples/generative/img/cyclegan/cyclegan_21_3209.png
new file mode 100644
index 0000000000..2bd7d0bf1d
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_3209.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_32099.png b/examples/generative/img/cyclegan/cyclegan_21_32099.png
new file mode 100644
index 0000000000..98d12e6469
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_32099.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_33169.png b/examples/generative/img/cyclegan/cyclegan_21_33169.png
new file mode 100644
index 0000000000..9837dcf2e4
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_33169.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_34239.png b/examples/generative/img/cyclegan/cyclegan_21_34239.png
new file mode 100644
index 0000000000..6c79d1ba3f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_34239.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_35309.png b/examples/generative/img/cyclegan/cyclegan_21_35309.png
new file mode 100644
index 0000000000..0c45d26f0a
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_35309.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_36379.png b/examples/generative/img/cyclegan/cyclegan_21_36379.png
new file mode 100644
index 0000000000..d114388196
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_36379.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_37449.png b/examples/generative/img/cyclegan/cyclegan_21_37449.png
new file mode 100644
index 0000000000..c766263b48
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_37449.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_38519.png b/examples/generative/img/cyclegan/cyclegan_21_38519.png
new file mode 100644
index 0000000000..88b535d5c4
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_38519.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_39589.png b/examples/generative/img/cyclegan/cyclegan_21_39589.png
new file mode 100644
index 0000000000..7667223689
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_39589.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_40659.png b/examples/generative/img/cyclegan/cyclegan_21_40659.png
new file mode 100644
index 0000000000..2ca1af93ec
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_40659.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_41729.png b/examples/generative/img/cyclegan/cyclegan_21_41729.png
new file mode 100644
index 0000000000..70a388f8de
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_41729.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_4279.png b/examples/generative/img/cyclegan/cyclegan_21_4279.png
new file mode 100644
index 0000000000..387aaba129
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_4279.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_42799.png b/examples/generative/img/cyclegan/cyclegan_21_42799.png
new file mode 100644
index 0000000000..c2d8a5431b
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_42799.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_43869.png b/examples/generative/img/cyclegan/cyclegan_21_43869.png
new file mode 100644
index 0000000000..85d642c066
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_43869.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_44939.png b/examples/generative/img/cyclegan/cyclegan_21_44939.png
new file mode 100644
index 0000000000..803e6d6d5f
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_44939.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_46009.png b/examples/generative/img/cyclegan/cyclegan_21_46009.png
new file mode 100644
index 0000000000..749c639f0e
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_46009.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_47079.png b/examples/generative/img/cyclegan/cyclegan_21_47079.png
new file mode 100644
index 0000000000..4c0a595269
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_47079.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_48149.png b/examples/generative/img/cyclegan/cyclegan_21_48149.png
new file mode 100644
index 0000000000..d703272a28
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_48149.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_49219.png b/examples/generative/img/cyclegan/cyclegan_21_49219.png
new file mode 100644
index 0000000000..aaf8b6ad51
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_49219.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_50289.png b/examples/generative/img/cyclegan/cyclegan_21_50289.png
new file mode 100644
index 0000000000..e7755b9cd5
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_50289.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_51359.png b/examples/generative/img/cyclegan/cyclegan_21_51359.png
new file mode 100644
index 0000000000..3be7ada979
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_51359.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_52429.png b/examples/generative/img/cyclegan/cyclegan_21_52429.png
new file mode 100644
index 0000000000..d0875380ed
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_52429.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_5349.png b/examples/generative/img/cyclegan/cyclegan_21_5349.png
new file mode 100644
index 0000000000..7aa34fe34c
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_5349.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_53499.png b/examples/generative/img/cyclegan/cyclegan_21_53499.png
new file mode 100644
index 0000000000..f39210dc62
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_53499.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_54569.png b/examples/generative/img/cyclegan/cyclegan_21_54569.png
new file mode 100644
index 0000000000..37ea4a9b6c
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_54569.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_55639.png b/examples/generative/img/cyclegan/cyclegan_21_55639.png
new file mode 100644
index 0000000000..55b582ee87
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_55639.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_56709.png b/examples/generative/img/cyclegan/cyclegan_21_56709.png
new file mode 100644
index 0000000000..267ab05977
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_56709.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_57779.png b/examples/generative/img/cyclegan/cyclegan_21_57779.png
new file mode 100644
index 0000000000..44a34cc8c2
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_57779.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_58849.png b/examples/generative/img/cyclegan/cyclegan_21_58849.png
new file mode 100644
index 0000000000..75846f78bb
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_58849.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_59919.png b/examples/generative/img/cyclegan/cyclegan_21_59919.png
new file mode 100644
index 0000000000..00f1053fe5
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_59919.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_60989.png b/examples/generative/img/cyclegan/cyclegan_21_60989.png
new file mode 100644
index 0000000000..1ddcb5ec7d
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_60989.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_62059.png b/examples/generative/img/cyclegan/cyclegan_21_62059.png
new file mode 100644
index 0000000000..895be698e9
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_62059.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_63129.png b/examples/generative/img/cyclegan/cyclegan_21_63129.png
new file mode 100644
index 0000000000..76afc7f569
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_63129.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_6419.png b/examples/generative/img/cyclegan/cyclegan_21_6419.png
new file mode 100644
index 0000000000..4f7dcb19f7
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_6419.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_64199.png b/examples/generative/img/cyclegan/cyclegan_21_64199.png
new file mode 100644
index 0000000000..626fd2efc8
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_64199.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_65269.png b/examples/generative/img/cyclegan/cyclegan_21_65269.png
new file mode 100644
index 0000000000..1fd873391b
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_65269.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_66339.png b/examples/generative/img/cyclegan/cyclegan_21_66339.png
new file mode 100644
index 0000000000..4d42bfba05
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_66339.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_67409.png b/examples/generative/img/cyclegan/cyclegan_21_67409.png
new file mode 100644
index 0000000000..8f5b6cfafb
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_67409.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_68479.png b/examples/generative/img/cyclegan/cyclegan_21_68479.png
new file mode 100644
index 0000000000..b8b60c62c4
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_68479.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_69549.png b/examples/generative/img/cyclegan/cyclegan_21_69549.png
new file mode 100644
index 0000000000..109ce94e16
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_69549.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_70619.png b/examples/generative/img/cyclegan/cyclegan_21_70619.png
new file mode 100644
index 0000000000..8ada414e32
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_70619.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_71689.png b/examples/generative/img/cyclegan/cyclegan_21_71689.png
new file mode 100644
index 0000000000..6c4357b1c4
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_71689.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_72759.png b/examples/generative/img/cyclegan/cyclegan_21_72759.png
new file mode 100644
index 0000000000..2ef23d0981
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_72759.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_73829.png b/examples/generative/img/cyclegan/cyclegan_21_73829.png
new file mode 100644
index 0000000000..51e6f49dc9
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_73829.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_7489.png b/examples/generative/img/cyclegan/cyclegan_21_7489.png
new file mode 100644
index 0000000000..7083a5f7c1
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_7489.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_74899.png b/examples/generative/img/cyclegan/cyclegan_21_74899.png
new file mode 100644
index 0000000000..f1791950ca
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_74899.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_75969.png b/examples/generative/img/cyclegan/cyclegan_21_75969.png
new file mode 100644
index 0000000000..161c1faed9
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_75969.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_77039.png b/examples/generative/img/cyclegan/cyclegan_21_77039.png
new file mode 100644
index 0000000000..e83f908358
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_77039.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_78109.png b/examples/generative/img/cyclegan/cyclegan_21_78109.png
new file mode 100644
index 0000000000..ffd1e45d0a
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_78109.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_79179.png b/examples/generative/img/cyclegan/cyclegan_21_79179.png
new file mode 100644
index 0000000000..e86690ce1a
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_79179.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_80249.png b/examples/generative/img/cyclegan/cyclegan_21_80249.png
new file mode 100644
index 0000000000..27203d37bb
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_80249.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_81319.png b/examples/generative/img/cyclegan/cyclegan_21_81319.png
new file mode 100644
index 0000000000..dbcd08015c
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_81319.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_82389.png b/examples/generative/img/cyclegan/cyclegan_21_82389.png
new file mode 100644
index 0000000000..f62b2ed717
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_82389.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_83459.png b/examples/generative/img/cyclegan/cyclegan_21_83459.png
new file mode 100644
index 0000000000..074535c4ba
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_83459.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_84529.png b/examples/generative/img/cyclegan/cyclegan_21_84529.png
new file mode 100644
index 0000000000..c82a10c8d5
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_84529.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_8559.png b/examples/generative/img/cyclegan/cyclegan_21_8559.png
new file mode 100644
index 0000000000..bb60de0cc3
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_8559.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_85599.png b/examples/generative/img/cyclegan/cyclegan_21_85599.png
new file mode 100644
index 0000000000..0ec3509c94
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_85599.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_86669.png b/examples/generative/img/cyclegan/cyclegan_21_86669.png
new file mode 100644
index 0000000000..0a394b12bb
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_86669.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_87739.png b/examples/generative/img/cyclegan/cyclegan_21_87739.png
new file mode 100644
index 0000000000..aeb540e795
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_87739.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_88809.png b/examples/generative/img/cyclegan/cyclegan_21_88809.png
new file mode 100644
index 0000000000..dbb4a79ff6
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_88809.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_89879.png b/examples/generative/img/cyclegan/cyclegan_21_89879.png
new file mode 100644
index 0000000000..705f371f96
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_89879.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_90949.png b/examples/generative/img/cyclegan/cyclegan_21_90949.png
new file mode 100644
index 0000000000..1c538e5cd8
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_90949.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_92019.png b/examples/generative/img/cyclegan/cyclegan_21_92019.png
new file mode 100644
index 0000000000..7d57a0a90c
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_92019.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_93089.png b/examples/generative/img/cyclegan/cyclegan_21_93089.png
new file mode 100644
index 0000000000..f4199f24ff
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_93089.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_94159.png b/examples/generative/img/cyclegan/cyclegan_21_94159.png
new file mode 100644
index 0000000000..a7a1477786
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_94159.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_95229.png b/examples/generative/img/cyclegan/cyclegan_21_95229.png
new file mode 100644
index 0000000000..3ac98c2a6a
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_95229.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_9629.png b/examples/generative/img/cyclegan/cyclegan_21_9629.png
new file mode 100644
index 0000000000..a15295f900
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_9629.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_21_96299.png b/examples/generative/img/cyclegan/cyclegan_21_96299.png
new file mode 100644
index 0000000000..e2abf8a803
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_21_96299.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_23_1.png b/examples/generative/img/cyclegan/cyclegan_23_1.png
new file mode 100644
index 0000000000..52f4b7e32b
Binary files /dev/null and b/examples/generative/img/cyclegan/cyclegan_23_1.png differ
diff --git a/examples/generative/img/cyclegan/cyclegan_9_0.png b/examples/generative/img/cyclegan/cyclegan_9_0.png
index 8c03be0e98..886066b8e2 100644
Binary files a/examples/generative/img/cyclegan/cyclegan_9_0.png and b/examples/generative/img/cyclegan/cyclegan_9_0.png differ
diff --git a/examples/generative/ipynb/cyclegan.ipynb b/examples/generative/ipynb/cyclegan.ipynb
index 20841f8d5a..616f1521c1 100644
--- a/examples/generative/ipynb/cyclegan.ipynb
+++ b/examples/generative/ipynb/cyclegan.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)
\n",
"**Date created:** 2020/08/12
\n",
- "**Last modified:** 2020/08/12
\n",
+ "**Last modified:** 2024/09/30
\n",
"**Description:** Implementation of CycleGAN."
]
},
@@ -29,7 +29,7 @@
"CycleGAN tries to learn this mapping without requiring paired input-output images,\n",
"using cycle-consistent adversarial networks.\n",
"\n",
- "- [Paper](https://arxiv.org/pdf/1703.10593.pdf)\n",
+ "- [Paper](https://arxiv.org/abs/1703.10593)\n",
"- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)"
]
},
@@ -50,18 +50,18 @@
},
"outputs": [],
"source": [
+ "import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
- "\n",
"import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "from tensorflow.keras import layers\n",
- "\n",
- "import tensorflow_addons as tfa\n",
+ "import keras\n",
+ "from keras import layers, ops\n",
"import tensorflow_datasets as tfds\n",
"\n",
"tfds.disable_progress_bar()\n",
- "autotune = tf.data.AUTOTUNE\n"
+ "autotune = tf.data.AUTOTUNE\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\""
]
},
{
@@ -86,7 +86,7 @@
"outputs": [],
"source": [
"# Load the horse-zebra dataset using tensorflow-datasets.\n",
- "dataset, _ = tfds.load(\"cycle_gan/horse2zebra\", with_info=True, as_supervised=True)\n",
+ "dataset, _ = tfds.load(name=\"cycle_gan/horse2zebra\", with_info=True, as_supervised=True)\n",
"train_horses, train_zebras = dataset[\"trainA\"], dataset[\"trainB\"]\n",
"test_horses, test_zebras = dataset[\"testA\"], dataset[\"testB\"]\n",
"\n",
@@ -104,7 +104,7 @@
"\n",
"\n",
"def normalize_img(img):\n",
- " img = tf.cast(img, dtype=tf.float32)\n",
+ " img = ops.cast(img, dtype=tf.float32)\n",
" # Map values in the range [-1, 1]\n",
" return (img / 127.5) - 1.0\n",
"\n",
@@ -113,7 +113,7 @@
" # Random flip\n",
" img = tf.image.random_flip_left_right(img)\n",
" # Resize to the original size first\n",
- " img = tf.image.resize(img, [*orig_img_size])\n",
+ " img = ops.image.resize(img, [*orig_img_size])\n",
" # Random crop to 256X256\n",
" img = tf.image.random_crop(img, size=[*input_img_size])\n",
" # Normalize the pixel values in the range [-1, 1]\n",
@@ -123,7 +123,7 @@
"\n",
"def preprocess_test_image(img, label):\n",
" # Only resizing and normalization for the test images.\n",
- " img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])\n",
+ " img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])\n",
" img = normalize_img(img)\n",
" return img\n"
]
@@ -243,7 +243,7 @@
" [padding_width, padding_width],\n",
" [0, 0],\n",
" ]\n",
- " return tf.pad(input_tensor, padding_tensor, mode=\"REFLECT\")\n",
+ " return ops.pad(input_tensor, padding_tensor, mode=\"REFLECT\")\n",
"\n",
"\n",
"def residual_block(\n",
@@ -268,7 +268,9 @@
" padding=padding,\n",
" use_bias=use_bias,\n",
" )(x)\n",
- " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
+ " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n",
+ " x\n",
+ " )\n",
" x = activation(x)\n",
"\n",
" x = ReflectionPadding2D()(x)\n",
@@ -280,7 +282,9 @@
" padding=padding,\n",
" use_bias=use_bias,\n",
" )(x)\n",
- " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
+ " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n",
+ " x\n",
+ " )\n",
" x = layers.add([input_tensor, x])\n",
" return x\n",
"\n",
@@ -304,7 +308,9 @@
" padding=padding,\n",
" use_bias=use_bias,\n",
" )(x)\n",
- " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
+ " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n",
+ " x\n",
+ " )\n",
" if activation:\n",
" x = activation(x)\n",
" return x\n",
@@ -329,7 +335,9 @@
" kernel_initializer=kernel_initializer,\n",
" use_bias=use_bias,\n",
" )(x)\n",
- " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
+ " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n",
+ " x\n",
+ " )\n",
" if activation:\n",
" x = activation(x)\n",
" return x\n"
@@ -389,7 +397,9 @@
" x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(\n",
" x\n",
" )\n",
- " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
+ " x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(\n",
+ " x\n",
+ " )\n",
" x = layers.Activation(\"relu\")(x)\n",
"\n",
" # Downsampling\n",
@@ -721,15 +731,17 @@
"adv_loss_fn = keras.losses.MeanSquaredError()\n",
"\n",
"# Define the loss function for the generators\n",
+ "\n",
+ "\n",
"def generator_loss_fn(fake):\n",
- " fake_loss = adv_loss_fn(tf.ones_like(fake), fake)\n",
+ " fake_loss = adv_loss_fn(ops.ones_like(fake), fake)\n",
" return fake_loss\n",
"\n",
"\n",
"# Define the loss function for the discriminators\n",
"def discriminator_loss_fn(real, fake):\n",
- " real_loss = adv_loss_fn(tf.ones_like(real), real)\n",
- " fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)\n",
+ " real_loss = adv_loss_fn(ops.ones_like(real), real)\n",
+ " fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)\n",
" return (real_loss + fake_loss) * 0.5\n",
"\n",
"\n",
@@ -740,26 +752,25 @@
"\n",
"# Compile the model\n",
"cycle_gan_model.compile(\n",
- " gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n",
- " gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n",
- " disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n",
- " disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),\n",
+ " gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n",
+ " gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n",
+ " disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n",
+ " disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),\n",
" gen_loss_fn=generator_loss_fn,\n",
" disc_loss_fn=discriminator_loss_fn,\n",
")\n",
"# Callbacks\n",
"plotter = GANMonitor()\n",
- "checkpoint_filepath = \"./model_checkpoints/cyclegan_checkpoints.{epoch:03d}\"\n",
+ "checkpoint_filepath = \"./model_checkpoints/cyclegan_checkpoints.weights.h5\"\n",
"model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
- " filepath=checkpoint_filepath,\n",
- " save_weights_only=True\n",
+ " filepath=checkpoint_filepath, save_weights_only=True\n",
")\n",
"\n",
"# Here we will train the model for just one epoch as each epoch takes around\n",
"# 7 minutes on a single P100 backed machine.\n",
"cycle_gan_model.fit(\n",
" tf.data.Dataset.zip((train_horses, train_zebras)),\n",
- " epochs=1,\n",
+ " epochs=90,\n",
" callbacks=[plotter, model_checkpoint_callback],\n",
")"
]
@@ -770,10 +781,7 @@
"colab_type": "text"
},
"source": [
- "Test the performance of the model.\n",
- "\n",
- "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)\n",
- "and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN)."
+ "Test the performance of the model."
]
},
{
@@ -785,35 +793,11 @@
"outputs": [],
"source": [
"\n",
- "# This model was trained for 90 epochs. We will be loading those weights\n",
- "# here. Once the weights are loaded, we will take a few samples from the test\n",
- "# data and check the model's performance."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "!curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip\n",
- "!unzip -qq saved_checkpoints.zip"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
+ "# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.\n",
+ "\n",
"\n",
"# Load the checkpoints\n",
- "weight_file = \"./saved_checkpoints/cyclegan_checkpoints.090\"\n",
- "cycle_gan_model.load_weights(weight_file).expect_partial()\n",
+ "cycle_gan_model.load_weights(checkpoint_filepath)\n",
"print(\"Weights loaded successfully\")\n",
"\n",
"_, ax = plt.subplots(4, 2, figsize=(10, 15))\n",
diff --git a/examples/generative/md/cyclegan.md b/examples/generative/md/cyclegan.md
index a9857ed082..b27b7f0f54 100644
--- a/examples/generative/md/cyclegan.md
+++ b/examples/generative/md/cyclegan.md
@@ -2,7 +2,7 @@
**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)
**Date created:** 2020/08/12
-**Last modified:** 2020/08/12
+**Last modified:** 2024/09/30
**Description:** Implementation of CycleGAN.
@@ -20,7 +20,7 @@ aligned image pairs. However, obtaining paired examples isn't always feasible.
CycleGAN tries to learn this mapping without requiring paired input-output images,
using cycle-consistent adversarial networks.
-- [Paper](https://arxiv.org/pdf/1703.10593.pdf)
+- [Paper](https://arxiv.org/abs/1703.10593)
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
---
@@ -28,19 +28,18 @@ using cycle-consistent adversarial networks.
```python
+import os
import numpy as np
import matplotlib.pyplot as plt
-
import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-
-import tensorflow_addons as tfa
+import keras
+from keras import layers, ops
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
autotune = tf.data.AUTOTUNE
+os.environ["KERAS_BACKEND"] = "tensorflow"
```
---
@@ -53,7 +52,7 @@ dataset.
```python
# Load the horse-zebra dataset using tensorflow-datasets.
-dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
+dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
test_horses, test_zebras = dataset["testA"], dataset["testB"]
@@ -71,7 +70,7 @@ batch_size = 1
def normalize_img(img):
- img = tf.cast(img, dtype=tf.float32)
+ img = ops.cast(img, dtype=tf.float32)
# Map values in the range [-1, 1]
return (img / 127.5) - 1.0
@@ -80,7 +79,7 @@ def preprocess_train_image(img, label):
# Random flip
img = tf.image.random_flip_left_right(img)
# Resize to the original size first
- img = tf.image.resize(img, [*orig_img_size])
+ img = ops.image.resize(img, [*orig_img_size])
# Random crop to 256X256
img = tf.image.random_crop(img, size=[*input_img_size])
# Normalize the pixel values in the range [-1, 1]
@@ -90,7 +89,7 @@ def preprocess_train_image(img, label):
def preprocess_test_image(img, label):
# Only resizing and normalization for the test images.
- img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
+ img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
img = normalize_img(img)
return img
@@ -149,7 +148,9 @@ plt.show()
```
+

+
---
@@ -181,7 +182,7 @@ class ReflectionPadding2D(layers.Layer):
[padding_width, padding_width],
[0, 0],
]
- return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
+ return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
def residual_block(
@@ -206,7 +207,9 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = activation(x)
x = ReflectionPadding2D()(x)
@@ -218,7 +221,9 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = layers.add([input_tensor, x])
return x
@@ -242,7 +247,9 @@ def downsample(
padding=padding,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
if activation:
x = activation(x)
return x
@@ -267,7 +274,9 @@ def upsample(
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
if activation:
x = activation(x)
return x
@@ -316,7 +325,9 @@ def get_resnet_generator(
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
- x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
x = layers.Activation("relu")(x)
# Downsampling
@@ -430,6 +441,14 @@ class CycleGan(keras.Model):
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity
+ def call(self, inputs):
+ return (
+ self.disc_X(inputs),
+ self.disc_Y(inputs),
+ self.gen_G(inputs),
+ self.gen_F(inputs),
+ )
+
def compile(
self,
gen_G_optimizer,
@@ -596,15 +615,17 @@ class GANMonitor(keras.callbacks.Callback):
adv_loss_fn = keras.losses.MeanSquaredError()
# Define the loss function for the generators
+
+
def generator_loss_fn(fake):
- fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
+ fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
return fake_loss
# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
- real_loss = adv_loss_fn(tf.ones_like(real), real)
- fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
+ real_loss = adv_loss_fn(ops.ones_like(real), real)
+ fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
return (real_loss + fake_loss) * 0.5
@@ -615,16 +636,16 @@ cycle_gan_model = CycleGan(
# Compile the model
cycle_gan_model.compile(
- gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
- disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_loss_fn=generator_loss_fn,
disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()
-checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
+checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath, save_weights_only=True
)
@@ -633,49 +654,21 @@ model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
# 7 minutes on a single P100 backed machine.
cycle_gan_model.fit(
tf.data.Dataset.zip((train_horses, train_zebras)),
- epochs=1,
+ epochs=90,
callbacks=[plotter, model_checkpoint_callback],
)
```
-