Skip to content

Commit 3076a22

Browse files
committed
upgrade cyclegan.py to Keras 3 with tensorflow backend
1 parent 8436d18 commit 3076a22

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

examples/generative/cyclegan.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@
2525
## Setup
2626
"""
2727

28-
28+
import os
2929
import numpy as np
3030
import matplotlib.pyplot as plt
3131
import tensorflow as tf
32-
from tensorflow import keras
33-
from tensorflow.keras import layers
34-
import tensorflow_addons as tfa
32+
import keras
33+
from keras import layers, ops
34+
#import tensorflow_addons as tfa
3535
import tensorflow_datasets as tfds
3636

3737
tfds.disable_progress_bar()
3838
autotune = tf.data.AUTOTUNE
3939

40+
os.environ["KERAS_BACKEND"] = "tensorflow"
4041

4142
"""
4243
## Prepare the dataset
@@ -47,7 +48,7 @@
4748
"""
4849

4950
# Load the horse-zebra dataset using tensorflow-datasets.
50-
dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
51+
dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
5152
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
5253
test_horses, test_zebras = dataset["testA"], dataset["testB"]
5354

@@ -65,7 +66,7 @@
6566

6667

6768
def normalize_img(img):
68-
img = tf.cast(img, dtype=tf.float32)
69+
img = ops.cast(img, dtype=tf.float32)
6970
# Map values in the range [-1, 1]
7071
return (img / 127.5) - 1.0
7172

@@ -74,7 +75,7 @@ def preprocess_train_image(img, label):
7475
# Random flip
7576
img = tf.image.random_flip_left_right(img)
7677
# Resize to the original size first
77-
img = tf.image.resize(img, [*orig_img_size])
78+
img = ops.image.resize(img, [*orig_img_size])
7879
# Random crop to 256X256
7980
img = tf.image.random_crop(img, size=[*input_img_size])
8081
# Normalize the pixel values in the range [-1, 1]
@@ -84,7 +85,7 @@ def preprocess_train_image(img, label):
8485

8586
def preprocess_test_image(img, label):
8687
# Only resizing and normalization for the test images.
87-
img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
88+
img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
8889
img = normalize_img(img)
8990
return img
9091

@@ -165,7 +166,7 @@ def call(self, input_tensor, mask=None):
165166
[padding_width, padding_width],
166167
[0, 0],
167168
]
168-
return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
169+
return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
169170

170171

171172
def residual_block(
@@ -190,7 +191,7 @@ def residual_block(
190191
padding=padding,
191192
use_bias=use_bias,
192193
)(x)
193-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
194+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
194195
x = activation(x)
195196

196197
x = ReflectionPadding2D()(x)
@@ -202,7 +203,7 @@ def residual_block(
202203
padding=padding,
203204
use_bias=use_bias,
204205
)(x)
205-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
206+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
206207
x = layers.add([input_tensor, x])
207208
return x
208209

@@ -226,7 +227,7 @@ def downsample(
226227
padding=padding,
227228
use_bias=use_bias,
228229
)(x)
229-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
230+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
230231
if activation:
231232
x = activation(x)
232233
return x
@@ -251,7 +252,7 @@ def upsample(
251252
kernel_initializer=kernel_initializer,
252253
use_bias=use_bias,
253254
)(x)
254-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
255+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
255256
if activation:
256257
x = activation(x)
257258
return x
@@ -298,7 +299,7 @@ def get_resnet_generator(
298299
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
299300
x
300301
)
301-
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
302+
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(x)
302303
x = layers.Activation("relu")(x)
303304

304305
# Downsampling
@@ -581,14 +582,14 @@ def on_epoch_end(self, epoch, logs=None):
581582

582583

583584
def generator_loss_fn(fake):
584-
fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
585+
fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
585586
return fake_loss
586587

587588

588589
# Define the loss function for the discriminators
589590
def discriminator_loss_fn(real, fake):
590-
real_loss = adv_loss_fn(tf.ones_like(real), real)
591-
fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
591+
real_loss = adv_loss_fn(ops.ones_like(real), real)
592+
fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
592593
return (real_loss + fake_loss) * 0.5
593594

594595

@@ -599,16 +600,16 @@ def discriminator_loss_fn(real, fake):
599600

600601
# Compile the model
601602
cycle_gan_model.compile(
602-
gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
603-
gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
604-
disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
605-
disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
603+
gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
604+
gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
605+
disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
606+
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
606607
gen_loss_fn=generator_loss_fn,
607608
disc_loss_fn=discriminator_loss_fn,
608609
)
609610
# Callbacks
610611
plotter = GANMonitor()
611-
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
612+
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
612613
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
613614
filepath=checkpoint_filepath, save_weights_only=True
614615
)
@@ -623,25 +624,14 @@ def discriminator_loss_fn(real, fake):
623624

624625
"""
625626
Test the performance of the model.
626-
627-
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
628-
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
629627
"""
630628

631629

632-
# This model was trained for 90 epochs. We will be loading those weights
633-
# here. Once the weights are loaded, we will take a few samples from the test
634-
# data and check the model's performance.
635-
636-
"""shell
637-
curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
638-
unzip -qq saved_checkpoints.zip
639-
"""
630+
# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
640631

641632

642633
# Load the checkpoints
643-
weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
644-
cycle_gan_model.load_weights(weight_file).expect_partial()
634+
cycle_gan_model.load_weights(checkpoint_filepath)
645635
print("Weights loaded successfully")
646636

647637
_, ax = plt.subplots(4, 2, figsize=(10, 15))

0 commit comments

Comments
 (0)