2525## Setup
2626"""
2727
28-
28+ import os
2929import numpy as np
3030import matplotlib .pyplot as plt
3131import tensorflow as tf
32- from tensorflow import keras
33- from tensorflow . keras import layers
34- import tensorflow_addons as tfa
32+ import keras
33+ from keras import layers , ops
34+ # import tensorflow_addons as tfa
3535import tensorflow_datasets as tfds
3636
3737tfds .disable_progress_bar ()
3838autotune = tf .data .AUTOTUNE
3939
40+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
4041
4142"""
4243## Prepare the dataset
4748"""
4849
4950# Load the horse-zebra dataset using tensorflow-datasets.
50- dataset , _ = tfds .load ("cycle_gan/horse2zebra" , with_info = True , as_supervised = True )
51+ dataset , _ = tfds .load (name = "cycle_gan/horse2zebra" , with_info = True , as_supervised = True )
5152train_horses , train_zebras = dataset ["trainA" ], dataset ["trainB" ]
5253test_horses , test_zebras = dataset ["testA" ], dataset ["testB" ]
5354
6566
6667
6768def normalize_img (img ):
68- img = tf .cast (img , dtype = tf .float32 )
69+ img = ops .cast (img , dtype = tf .float32 )
6970 # Map values in the range [-1, 1]
7071 return (img / 127.5 ) - 1.0
7172
@@ -74,7 +75,7 @@ def preprocess_train_image(img, label):
7475 # Random flip
7576 img = tf .image .random_flip_left_right (img )
7677 # Resize to the original size first
77- img = tf .image .resize (img , [* orig_img_size ])
78+ img = ops .image .resize (img , [* orig_img_size ])
7879 # Random crop to 256X256
7980 img = tf .image .random_crop (img , size = [* input_img_size ])
8081 # Normalize the pixel values in the range [-1, 1]
@@ -84,7 +85,7 @@ def preprocess_train_image(img, label):
8485
8586def preprocess_test_image (img , label ):
8687 # Only resizing and normalization for the test images.
87- img = tf .image .resize (img , [input_img_size [0 ], input_img_size [1 ]])
88+ img = ops .image .resize (img , [input_img_size [0 ], input_img_size [1 ]])
8889 img = normalize_img (img )
8990 return img
9091
@@ -165,7 +166,7 @@ def call(self, input_tensor, mask=None):
165166 [padding_width , padding_width ],
166167 [0 , 0 ],
167168 ]
168- return tf .pad (input_tensor , padding_tensor , mode = "REFLECT" )
169+ return ops .pad (input_tensor , padding_tensor , mode = "REFLECT" )
169170
170171
171172def residual_block (
@@ -190,7 +191,7 @@ def residual_block(
190191 padding = padding ,
191192 use_bias = use_bias ,
192193 )(x )
193- x = tfa .layers .InstanceNormalization ( gamma_initializer = gamma_initializer )(x )
194+ x = keras .layers .GroupNormalization ( groups = 1 , gamma_initializer = gamma_initializer )(x )
194195 x = activation (x )
195196
196197 x = ReflectionPadding2D ()(x )
@@ -202,7 +203,7 @@ def residual_block(
202203 padding = padding ,
203204 use_bias = use_bias ,
204205 )(x )
205- x = tfa .layers .InstanceNormalization ( gamma_initializer = gamma_initializer )(x )
206+ x = keras .layers .GroupNormalization ( groups = 1 , gamma_initializer = gamma_initializer )(x )
206207 x = layers .add ([input_tensor , x ])
207208 return x
208209
@@ -226,7 +227,7 @@ def downsample(
226227 padding = padding ,
227228 use_bias = use_bias ,
228229 )(x )
229- x = tfa .layers .InstanceNormalization ( gamma_initializer = gamma_initializer )(x )
230+ x = keras .layers .GroupNormalization ( groups = 1 , gamma_initializer = gamma_initializer )(x )
230231 if activation :
231232 x = activation (x )
232233 return x
@@ -251,7 +252,7 @@ def upsample(
251252 kernel_initializer = kernel_initializer ,
252253 use_bias = use_bias ,
253254 )(x )
254- x = tfa .layers .InstanceNormalization ( gamma_initializer = gamma_initializer )(x )
255+ x = keras .layers .GroupNormalization ( groups = 1 , gamma_initializer = gamma_initializer )(x )
255256 if activation :
256257 x = activation (x )
257258 return x
@@ -298,7 +299,7 @@ def get_resnet_generator(
298299 x = layers .Conv2D (filters , (7 , 7 ), kernel_initializer = kernel_init , use_bias = False )(
299300 x
300301 )
301- x = tfa .layers .InstanceNormalization ( gamma_initializer = gamma_initializer )(x )
302+ x = keras .layers .GroupNormalization ( groups = 1 , gamma_initializer = gamma_initializer )(x )
302303 x = layers .Activation ("relu" )(x )
303304
304305 # Downsampling
@@ -581,14 +582,14 @@ def on_epoch_end(self, epoch, logs=None):
581582
582583
583584def generator_loss_fn (fake ):
584- fake_loss = adv_loss_fn (tf .ones_like (fake ), fake )
585+ fake_loss = adv_loss_fn (ops .ones_like (fake ), fake )
585586 return fake_loss
586587
587588
588589# Define the loss function for the discriminators
589590def discriminator_loss_fn (real , fake ):
590- real_loss = adv_loss_fn (tf .ones_like (real ), real )
591- fake_loss = adv_loss_fn (tf .zeros_like (fake ), fake )
591+ real_loss = adv_loss_fn (ops .ones_like (real ), real )
592+ fake_loss = adv_loss_fn (ops .zeros_like (fake ), fake )
592593 return (real_loss + fake_loss ) * 0.5
593594
594595
@@ -599,16 +600,16 @@ def discriminator_loss_fn(real, fake):
599600
600601# Compile the model
601602cycle_gan_model .compile (
602- gen_G_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
603- gen_F_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
604- disc_X_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
605- disc_Y_optimizer = keras .optimizers .legacy . Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
603+ gen_G_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
604+ gen_F_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
605+ disc_X_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
606+ disc_Y_optimizer = keras .optimizers .Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
606607 gen_loss_fn = generator_loss_fn ,
607608 disc_loss_fn = discriminator_loss_fn ,
608609)
609610# Callbacks
610611plotter = GANMonitor ()
611- checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d} "
612+ checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5 "
612613model_checkpoint_callback = keras .callbacks .ModelCheckpoint (
613614 filepath = checkpoint_filepath , save_weights_only = True
614615)
@@ -623,25 +624,14 @@ def discriminator_loss_fn(real, fake):
623624
624625"""
625626Test the performance of the model.
626-
627- You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)
628- and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN).
629627"""
630628
631629
632- # This model was trained for 90 epochs. We will be loading those weights
633- # here. Once the weights are loaded, we will take a few samples from the test
634- # data and check the model's performance.
635-
636- """shell
637- curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
638- unzip -qq saved_checkpoints.zip
639- """
630+ # Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
640631
641632
642633# Load the checkpoints
643- weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
644- cycle_gan_model .load_weights (weight_file ).expect_partial ()
634+ cycle_gan_model .load_weights (checkpoint_filepath )
645635print ("Weights loaded successfully" )
646636
647637_ , ax = plt .subplots (4 , 2 , figsize = (10 , 15 ))
0 commit comments