Skip to content

Commit 71c53a6

Browse files
authored
Upgrade Cyclegan to Keras 3 (with Tensorflow backend) (#1944)
* upgrade cyclegan.py to Keras 3 with tensorflow backend * enabling keras 3 * Increase training epochs * add ipynb, md files and images * remove training logs
1 parent 5c827bd commit 71c53a6

File tree

96 files changed

+137
-161
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+137
-161
lines changed

examples/generative/cyclegan.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: CycleGAN
33
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
44
Date created: 2020/08/12
5-
Last modified: 2020/08/12
5+
Last modified: 2024/09/30
66
Description: Implementation of CycleGAN.
77
Accelerator: GPU
88
"""
@@ -17,26 +17,26 @@
1717
CycleGAN tries to learn this mapping without requiring paired input-output images,
1818
using cycle-consistent adversarial networks.
1919
20-
- [Paper](https://arxiv.org/pdf/1703.10593.pdf)
20+
- [Paper](https://arxiv.org/abs/1703.10593)
2121
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
2222
"""
2323

2424
"""
2525
## Setup
2626
"""
2727

28-
28+
import os
2929
import numpy as np
3030
import matplotlib.pyplot as plt
3131
import tensorflow as tf
32-
from tensorflow import keras
33-
from tensorflow.keras import layers
34-
import tensorflow_addons as tfa
32+
import keras
33+
from keras import layers, ops
3534
import tensorflow_datasets as tfds
3635

3736
tfds.disable_progress_bar()
3837
autotune = tf.data.AUTOTUNE
3938

39+
os.environ["KERAS_BACKEND"] = "tensorflow"
4040

4141
"""
4242
## Prepare the dataset
@@ -47,7 +47,7 @@
4747
"""
4848

4949
# Load the horse-zebra dataset using tensorflow-datasets.
50-
dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
50+
dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
5151
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
5252
test_horses, test_zebras = dataset["testA"], dataset["testB"]
5353

@@ -65,7 +65,7 @@
6565

6666

6767
def normalize_img(img):
68-
img = tf.cast(img, dtype=tf.float32)
68+
img = ops.cast(img, dtype=tf.float32)
6969
# Map values in the range [-1, 1]
7070
return (img / 127.5) - 1.0
7171

@@ -74,7 +74,7 @@ def preprocess_train_image(img, label):
7474
# Random flip
7575
img = tf.image.random_flip_left_right(img)
7676
# Resize to the original size first
77-
img = tf.image.resize(img, [*orig_img_size])
77+
img = ops.image.resize(img, [*orig_img_size])
7878
# Random crop to 256X256
7979
img = tf.image.random_crop(img, size=[*input_img_size])
8080
# Normalize the pixel values in the range [-1, 1]
@@ -84,7 +84,7 @@ def preprocess_train_image(img, label):
8484

8585
def preprocess_test_image(img, label):
8686
# Only resizing and normalization for the test images.
87-
img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
87+
img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
8888
img = normalize_img(img)
8989
return img
9090

@@ -165,7 +165,7 @@ def call(self, input_tensor, mask=None):
165165
[padding_width, padding_width],
166166
[0, 0],
167167
]
168-
return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
168+
return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
169169

170170

171171
def residual_block(
@@ -190,7 +190,9 @@ def residual_block(
190190
padding=padding,
191191
use_bias=use_bias,
192192
)(x)
193-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
193+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
194+
x
195+
)
194196
x = activation(x)
195197

196198
x = ReflectionPadding2D()(x)
@@ -202,7 +204,9 @@ def residual_block(
202204
padding=padding,
203205
use_bias=use_bias,
204206
)(x)
205-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
207+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
208+
x
209+
)
206210
x = layers.add([input_tensor, x])
207211
return x
208212

@@ -226,7 +230,9 @@ def downsample(
226230
padding=padding,
227231
use_bias=use_bias,
228232
)(x)
229-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
233+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
234+
x
235+
)
230236
if activation:
231237
x = activation(x)
232238
return x
@@ -251,7 +257,9 @@ def upsample(
251257
kernel_initializer=kernel_initializer,
252258
use_bias=use_bias,
253259
)(x)
254-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
260+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
261+
x
262+
)
255263
if activation:
256264
x = activation(x)
257265
return x
@@ -298,7 +306,9 @@ def get_resnet_generator(
298306
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
299307
x
300308
)
301-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
309+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
310+
x
311+
)
302312
x = layers.Activation("relu")(x)
303313

304314
# Downsampling
@@ -581,14 +591,14 @@ def on_epoch_end(self, epoch, logs=None):
581591

582592

583593
def generator_loss_fn(fake):
584-
fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
594+
fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
585595
return fake_loss
586596

587597

588598
# Define the loss function for the discriminators
589599
def discriminator_loss_fn(real, fake):
590-
real_loss = adv_loss_fn(tf.ones_like(real), real)
591-
fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
600+
real_loss = adv_loss_fn(ops.ones_like(real), real)
601+
fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
592602
return (real_loss + fake_loss) * 0.5
593603

594604

@@ -599,16 +609,16 @@ def discriminator_loss_fn(real, fake):
599609

600610
# Compile the model
601611
cycle_gan_model.compile(
602-
gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
603-
gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
604-
disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
605-
disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
612+
gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
613+
gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
614+
disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
615+
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
606616
gen_loss_fn=generator_loss_fn,
607617
disc_loss_fn=discriminator_loss_fn,
608618
)
609619
# Callbacks
610620
plotter = GANMonitor()
611-
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
621+
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
612622
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
613623
filepath=checkpoint_filepath, save_weights_only=True
614624
)
@@ -617,31 +627,20 @@ def discriminator_loss_fn(real, fake):
617627
# 7 minutes on a single P100 backed machine.
618628
cycle_gan_model.fit(
619629
tf.data.Dataset.zip((train_horses, train_zebras)),
620-
epochs=1,
630+
epochs=90,
621631
callbacks=[plotter, model_checkpoint_callback],
622632
)
623633

624634
"""
625635
Test the performance of the model.
626-
627-
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
628-
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
629636
"""
630637

631638

632-
# This model was trained for 90 epochs. We will be loading those weights
633-
# here. Once the weights are loaded, we will take a few samples from the test
634-
# data and check the model's performance.
635-
636-
"""shell
637-
curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
638-
unzip -qq saved_checkpoints.zip
639-
"""
639+
# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
640640

641641

642642
# Load the checkpoints
643-
weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
644-
cycle_gan_model.load_weights(weight_file).expect_partial()
643+
cycle_gan_model.load_weights(checkpoint_filepath)
645644
print("Weights loaded successfully")
646645

647646
_, ax = plt.subplots(4, 2, figsize=(10, 15))
612 KB
658 KB
680 KB
731 KB
580 KB
723 KB
734 KB
670 KB
664 KB

0 commit comments

Comments
 (0)