22Title: CycleGAN
33Author: [A_K_Nain](https://twitter.com/A_K_Nain)
44Date created: 2020/08/12
5- Last modified: 2020/08/12
5+ Last modified: 2024/09/30
66Description: Implementation of CycleGAN.
77Accelerator: GPU
88"""
1717CycleGAN tries to learn this mapping without requiring paired input-output images,
1818using cycle-consistent adversarial networks.
1919
20- - [Paper](https://arxiv.org/pdf /1703.10593.pdf )
20+ - [Paper](https://arxiv.org/abs /1703.10593)
2121- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
2222"""
2323
2424"""
2525## Setup
2626"""
2727
28-
28+ import os
2929import numpy as np
3030import matplotlib .pyplot as plt
3131import tensorflow as tf
32- from tensorflow import keras
33- from tensorflow .keras import layers
34- import tensorflow_addons as tfa
32+ import keras
33+ from keras import layers , ops
3534import tensorflow_datasets as tfds
3635
3736tfds .disable_progress_bar ()
3837autotune = tf .data .AUTOTUNE
3938
39+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
4040
4141"""
4242## Prepare the dataset
4747"""
4848
4949# Load the horse-zebra dataset using tensorflow-datasets.
50- dataset , _ = tfds .load ("cycle_gan/horse2zebra" , with_info = True , as_supervised = True )
50+ dataset , _ = tfds .load (name = "cycle_gan/horse2zebra" , with_info = True , as_supervised = True )
5151train_horses , train_zebras = dataset ["trainA" ], dataset ["trainB" ]
5252test_horses , test_zebras = dataset ["testA" ], dataset ["testB" ]
5353
6565
6666
6767def normalize_img (img ):
68- img = tf .cast (img , dtype = tf .float32 )
68+ img = ops .cast (img , dtype = tf .float32 )
6969 # Map values in the range [-1, 1]
7070 return (img / 127.5 ) - 1.0
7171
@@ -74,7 +74,7 @@ def preprocess_train_image(img, label):
7474 # Random flip
7575 img = tf .image .random_flip_left_right (img )
7676 # Resize to the original size first
77- img = tf .image .resize (img , [* orig_img_size ])
77+ img = ops .image .resize (img , [* orig_img_size ])
7878 # Random crop to 256X256
7979 img = tf .image .random_crop (img , size = [* input_img_size ])
8080 # Normalize the pixel values in the range [-1, 1]
@@ -84,7 +84,7 @@ def preprocess_train_image(img, label):
8484
8585def preprocess_test_image (img , label ):
8686 # Only resizing and normalization for the test images.
87- img = tf .image .resize (img , [input_img_size [0 ], input_img_size [1 ]])
87+ img = ops .image .resize (img , [input_img_size [0 ], input_img_size [1 ]])
8888 img = normalize_img (img )
8989 return img
9090
@@ -165,7 +165,7 @@ def call(self, input_tensor, mask=None):
165165 [padding_width , padding_width ],
166166 [0 , 0 ],
167167 ]
168- return tf .pad (input_tensor , padding_tensor , mode = "REFLECT" )
168+ return ops .pad (input_tensor , padding_tensor , mode = "REFLECT" )
169169
170170
171171def residual_block (
@@ -190,7 +190,9 @@ def residual_block(
190190 padding = padding ,
191191 use_bias = use_bias ,
192192 )(x )
193- x = tfa .layers .InstanceNormalization (gamma_initializer = gamma_initializer )(x )
193+ x = keras .layers .GroupNormalization (groups = 1 , gamma_initializer = gamma_initializer )(
194+ x
195+ )
194196 x = activation (x )
195197
196198 x = ReflectionPadding2D ()(x )
@@ -202,7 +204,9 @@ def residual_block(
202204 padding = padding ,
203205 use_bias = use_bias ,
204206 )(x )
205- x = tfa .layers .InstanceNormalization (gamma_initializer = gamma_initializer )(x )
207+ x = keras .layers .GroupNormalization (groups = 1 , gamma_initializer = gamma_initializer )(
208+ x
209+ )
206210 x = layers .add ([input_tensor , x ])
207211 return x
208212
@@ -226,7 +230,9 @@ def downsample(
226230 padding = padding ,
227231 use_bias = use_bias ,
228232 )(x )
229- x = tfa .layers .InstanceNormalization (gamma_initializer = gamma_initializer )(x )
233+ x = keras .layers .GroupNormalization (groups = 1 , gamma_initializer = gamma_initializer )(
234+ x
235+ )
230236 if activation :
231237 x = activation (x )
232238 return x
@@ -251,7 +257,9 @@ def upsample(
251257 kernel_initializer = kernel_initializer ,
252258 use_bias = use_bias ,
253259 )(x )
254- x = tfa .layers .InstanceNormalization (gamma_initializer = gamma_initializer )(x )
260+ x = keras .layers .GroupNormalization (groups = 1 , gamma_initializer = gamma_initializer )(
261+ x
262+ )
255263 if activation :
256264 x = activation (x )
257265 return x
@@ -298,7 +306,9 @@ def get_resnet_generator(
298306 x = layers .Conv2D (filters , (7 , 7 ), kernel_initializer = kernel_init , use_bias = False )(
299307 x
300308 )
301- x = tfa .layers .InstanceNormalization (gamma_initializer = gamma_initializer )(x )
309+ x = keras .layers .GroupNormalization (groups = 1 , gamma_initializer = gamma_initializer )(
310+ x
311+ )
302312 x = layers .Activation ("relu" )(x )
303313
304314 # Downsampling
@@ -581,14 +591,14 @@ def on_epoch_end(self, epoch, logs=None):
581591
582592
583593def generator_loss_fn (fake ):
584- fake_loss = adv_loss_fn (tf .ones_like (fake ), fake )
594+ fake_loss = adv_loss_fn (ops .ones_like (fake ), fake )
585595 return fake_loss
586596
587597
588598# Define the loss function for the discriminators
589599def discriminator_loss_fn (real , fake ):
590- real_loss = adv_loss_fn (tf .ones_like (real ), real )
591- fake_loss = adv_loss_fn (tf .zeros_like (fake ), fake )
600+ real_loss = adv_loss_fn (ops .ones_like (real ), real )
601+ fake_loss = adv_loss_fn (ops .zeros_like (fake ), fake )
592602 return (real_loss + fake_loss ) * 0.5
593603
594604
@@ -599,16 +609,16 @@ def discriminator_loss_fn(real, fake):
599609
600610# Compile the model
601611cycle_gan_model .compile (
602- gen_G_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
603- gen_F_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
604- disc_X_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
605- disc_Y_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
612+ gen_G_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
613+ gen_F_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
614+ disc_X_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
615+ disc_Y_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
606616 gen_loss_fn = generator_loss_fn ,
607617 disc_loss_fn = discriminator_loss_fn ,
608618)
609619# Callbacks
610620plotter = GANMonitor ()
611- checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d} "
621+ checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5 "
612622model_checkpoint_callback = keras .callbacks .ModelCheckpoint (
613623 filepath = checkpoint_filepath , save_weights_only = True
614624)
@@ -617,31 +627,20 @@ def discriminator_loss_fn(real, fake):
617627# 7 minutes on a single P100 backed machine.
618628cycle_gan_model .fit (
619629 tf .data .Dataset .zip ((train_horses , train_zebras )),
620- epochs = 1 ,
630+ epochs = 90 ,
621631 callbacks = [plotter , model_checkpoint_callback ],
622632)
623633
624634"""
625635Test the performance of the model.
626-
627- You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
628- and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
629636"""
630637
631638
632- # This model was trained for 90 epochs. We will be loading those weights
633- # here. Once the weights are loaded, we will take a few samples from the test
634- # data and check the model's performance.
635-
636- """shell
637- curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
638- unzip -qq saved_checkpoints.zip
639- """
639+ # Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
640640
641641
642642# Load the checkpoints
643- weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
644- cycle_gan_model .load_weights (weight_file ).expect_partial ()
643+ cycle_gan_model .load_weights (checkpoint_filepath )
645644print ("Weights loaded successfully" )
646645
647646_ , ax = plt .subplots (4 , 2 , figsize = (10 , 15 ))
0 commit comments