Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 27 additions & 37 deletions examples/generative/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: CycleGAN
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
Date created: 2020/08/12
Last modified: 2020/08/12
Last modified: 2024/09/30
Description: Implementation of CycleGAN.
Accelerator: GPU
"""
Expand All @@ -25,18 +25,19 @@
## Setup
"""


import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import keras
from keras import layers, ops
#import tensorflow_addons as tfa
import tensorflow_datasets as tfds

tfds.disable_progress_bar()
autotune = tf.data.AUTOTUNE

os.environ["KERAS_BACKEND"] = "tensorflow"

"""
## Prepare the dataset
Expand All @@ -47,7 +48,7 @@
"""

# Load the horse-zebra dataset using tensorflow-datasets.
dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
test_horses, test_zebras = dataset["testA"], dataset["testB"]

Expand All @@ -65,7 +66,7 @@


def normalize_img(img):
img = tf.cast(img, dtype=tf.float32)
img = ops.cast(img, dtype=tf.float32)
# Map values in the range [-1, 1]
return (img / 127.5) - 1.0

Expand All @@ -74,7 +75,7 @@ def preprocess_train_image(img, label):
# Random flip
img = tf.image.random_flip_left_right(img)
# Resize to the original size first
img = tf.image.resize(img, [*orig_img_size])
img = ops.image.resize(img, [*orig_img_size])
# Random crop to 256X256
img = tf.image.random_crop(img, size=[*input_img_size])
# Normalize the pixel values in the range [-1, 1]
Expand All @@ -84,7 +85,7 @@ def preprocess_train_image(img, label):

def preprocess_test_image(img, label):
# Only resizing and normalization for the test images.
img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
img = normalize_img(img)
return img

Expand Down Expand Up @@ -165,7 +166,7 @@ def call(self, input_tensor, mask=None):
[padding_width, padding_width],
[0, 0],
]
return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
return ops.pad(input_tensor, padding_tensor, mode="REFLECT")


def residual_block(
Expand All @@ -190,7 +191,7 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
x = activation(x)

x = ReflectionPadding2D()(x)
Expand All @@ -202,7 +203,7 @@ def residual_block(
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
x = layers.add([input_tensor, x])
return x

Expand All @@ -226,7 +227,7 @@ def downsample(
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
Expand All @@ -251,7 +252,7 @@ def upsample(
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
Expand Down Expand Up @@ -298,7 +299,7 @@ def get_resnet_generator(
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
x = layers.Activation("relu")(x)

# Downsampling
Expand Down Expand Up @@ -581,14 +582,14 @@ def on_epoch_end(self, epoch, logs=None):


def generator_loss_fn(fake):
fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
return fake_loss


# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
real_loss = adv_loss_fn(tf.ones_like(real), real)
fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
real_loss = adv_loss_fn(ops.ones_like(real), real)
fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
return (real_loss + fake_loss) * 0.5


Expand All @@ -599,16 +600,16 @@ def discriminator_loss_fn(real, fake):

# Compile the model
cycle_gan_model.compile(
gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_loss_fn=generator_loss_fn,
disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath, save_weights_only=True
)
Expand All @@ -617,31 +618,20 @@ def discriminator_loss_fn(real, fake):
# 7 minutes on a single P100 backed machine.
cycle_gan_model.fit(
tf.data.Dataset.zip((train_horses, train_zebras)),
epochs=1,
epochs=90,
callbacks=[plotter, model_checkpoint_callback],
)

"""
Test the performance of the model.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
"""


# This model was trained for 90 epochs. We will be loading those weights
# here. Once the weights are loaded, we will take a few samples from the test
# data and check the model's performance.

"""shell
curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
unzip -qq saved_checkpoints.zip
"""
# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.


# Load the checkpoints
weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
cycle_gan_model.load_weights(weight_file).expect_partial()
cycle_gan_model.load_weights(checkpoint_filepath)
print("Weights loaded successfully")

_, ax = plt.subplots(4, 2, figsize=(10, 15))
Expand Down
1 change: 1 addition & 0 deletions scripts/examples_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@
"path": "cyclegan",
"title": "CycleGAN",
"subcategory": "Image generation",
"keras_3": True,
},
{
"path": "gan_ada",
Expand Down
Loading