22Title: Data-efficient GANs with Adaptive Discriminator Augmentation
33Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
44Date created: 2021/10/28
5- Last modified: 2021/10/28
5+ Last modified: 2025/01/23
66Description: Generating images from limited data using the Caltech Birds dataset.
77Accelerator: GPU
88"""
@@ -62,12 +62,17 @@ class of generative deep learning models, commonly used for image generation. Th
6262## Setup
6363"""
6464
65+ import os
66+
67+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
68+
6569import matplotlib .pyplot as plt
6670import tensorflow as tf
6771import tensorflow_datasets as tfds
6872
69- from tensorflow import keras
70- from tensorflow .keras import layers
73+ import keras
74+ from keras import ops
75+ from keras import layers
7176
7277"""
7378## Hyperparameterers
@@ -115,46 +120,47 @@ class of generative deep learning models, commonly used for image generation. Th
115120
116121
117122def round_to_int (float_value ):
118- return tf .cast (tf . math . round (float_value ), dtype = tf . int32 )
123+ return ops .cast (ops . round (float_value ), " int32" )
119124
120125
121126def preprocess_image (data ):
122127 # unnormalize bounding box coordinates
123- height = tf .cast (tf .shape (data ["image" ])[0 ], dtype = tf . float32 )
124- width = tf .cast (tf .shape (data ["image" ])[1 ], dtype = tf . float32 )
125- bounding_box = data ["bbox" ] * tf .stack ([height , width , height , width ])
128+ height = ops .cast (ops .shape (data ["image" ])[0 ], " float32" )
129+ width = ops .cast (ops .shape (data ["image" ])[1 ], " float32" )
130+ bounding_box = data ["bbox" ] * ops .stack ([height , width , height , width ])
126131
127132 # calculate center and length of longer side, add padding
128133 target_center_y = 0.5 * (bounding_box [0 ] + bounding_box [2 ])
129134 target_center_x = 0.5 * (bounding_box [1 ] + bounding_box [3 ])
130- target_size = tf .maximum (
135+ target_size = ops .maximum (
131136 (1.0 + padding ) * (bounding_box [2 ] - bounding_box [0 ]),
132137 (1.0 + padding ) * (bounding_box [3 ] - bounding_box [1 ]),
133138 )
134139
135140 # modify crop size to fit into image
136- target_height = tf . reduce_min (
141+ target_height = ops . min (
137142 [target_size , 2.0 * target_center_y , 2.0 * (height - target_center_y )]
138143 )
139- target_width = tf . reduce_min (
144+ target_width = ops . min (
140145 [target_size , 2.0 * target_center_x , 2.0 * (width - target_center_x )]
141146 )
142147
143- # crop image
144- image = tf . image . crop_to_bounding_box (
148+ # crop image, `ops.image.crop_images` only works with non-tensor croppings
149+ image = ops . slice (
145150 data ["image" ],
146- offset_height = round_to_int (target_center_y - 0.5 * target_height ),
147- offset_width = round_to_int (target_center_x - 0.5 * target_width ),
148- target_height = round_to_int (target_height ),
149- target_width = round_to_int (target_width ),
151+ start_indices = (
152+ round_to_int (target_center_y - 0.5 * target_height ),
153+ round_to_int (target_center_x - 0.5 * target_width ),
154+ 0 ,
155+ ),
156+ shape = (round_to_int (target_height ), round_to_int (target_width ), 3 ),
150157 )
151158
152159 # resize and clip
153- # for image downsampling, area interpolation is the preferred method
154- image = tf .image .resize (
155- image , size = [image_size , image_size ], method = tf .image .ResizeMethod .AREA
156- )
157- return tf .clip_by_value (image / 255.0 , 0.0 , 1.0 )
160+ image = ops .cast (image , "float32" )
161+ image = ops .image .resize (image , [image_size , image_size ])
162+
163+ return ops .clip (image / 255.0 , 0.0 , 1.0 )
158164
159165
160166def prepare_dataset (split ):
@@ -231,8 +237,10 @@ def __init__(self, name="kid", **kwargs):
231237 )
232238
233239 def polynomial_kernel (self , features_1 , features_2 ):
234- feature_dimensions = tf .cast (tf .shape (features_1 )[1 ], dtype = tf .float32 )
235- return (features_1 @ tf .transpose (features_2 ) / feature_dimensions + 1.0 ) ** 3.0
240+ feature_dimensions = ops .cast (ops .shape (features_1 )[1 ], "float32" )
241+ return (
242+ features_1 @ ops .transpose (features_2 ) / feature_dimensions + 1.0
243+ ) ** 3.0
236244
237245 def update_state (self , real_images , generated_images , sample_weight = None ):
238246 real_features = self .encoder (real_images , training = False )
@@ -246,15 +254,15 @@ def update_state(self, real_images, generated_images, sample_weight=None):
246254 kernel_cross = self .polynomial_kernel (real_features , generated_features )
247255
248256 # estimate the squared maximum mean discrepancy using the average kernel values
249- batch_size = tf .shape (real_features )[0 ]
250- batch_size_f = tf .cast (batch_size , dtype = tf . float32 )
251- mean_kernel_real = tf . reduce_sum (kernel_real * (1.0 - tf .eye (batch_size ))) / (
257+ batch_size = ops .shape (real_features )[0 ]
258+ batch_size_f = ops .cast (batch_size , " float32" )
259+ mean_kernel_real = ops . sum (kernel_real * (1.0 - ops .eye (batch_size ))) / (
252260 batch_size_f * (batch_size_f - 1.0 )
253261 )
254- mean_kernel_generated = tf . reduce_sum (
255- kernel_generated * (1.0 - tf .eye (batch_size ))
262+ mean_kernel_generated = ops . sum (
263+ kernel_generated * (1.0 - ops .eye (batch_size ))
256264 ) / (batch_size_f * (batch_size_f - 1.0 ))
257- mean_kernel_cross = tf . reduce_mean (kernel_cross )
265+ mean_kernel_cross = ops . mean (kernel_cross )
258266 kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
259267
260268 # update the average KID estimate
@@ -299,7 +307,7 @@ def reset_state(self):
299307# "hard sigmoid", useful for binary accuracy calculation from logits
300308def step (values ):
301309 # negative values -> 0.0, positive values -> 1.0
302- return 0.5 * (1.0 + tf .sign (values ))
310+ return 0.5 * (1.0 + ops .sign (values ))
303311
304312
305313# augments images with a probability that is dynamically updated during training
@@ -308,7 +316,8 @@ def __init__(self):
308316 super ().__init__ ()
309317
310318 # stores the current probability of an image being augmented
311- self .probability = tf .Variable (0.0 )
319+ self .probability = keras .Variable (0.0 )
320+ self .seed_generator = keras .random .SeedGenerator (42 )
312321
313322 # the corresponding augmentation names from the paper are shown above each layer
314323 # the authors show (see figure 4), that the blitting and geometric augmentations
@@ -336,28 +345,26 @@ def __init__(self):
336345
337346 def call (self , images , training ):
338347 if training :
339- augmented_images = self .augmenter (images , training )
348+ augmented_images = self .augmenter (images , training = training )
340349
341350 # during training either the original or the augmented images are selected
342351 # based on self.probability
343- augmentation_values = tf .random .uniform (
344- shape = (batch_size , 1 , 1 , 1 ), minval = 0.0 , maxval = 1.0
352+ augmentation_values = keras .random .uniform (
353+ shape = (batch_size , 1 , 1 , 1 ), seed = self . seed_generator
345354 )
346- augmentation_bools = tf . math .less (augmentation_values , self .probability )
355+ augmentation_bools = ops .less (augmentation_values , self .probability )
347356
348- images = tf .where (augmentation_bools , augmented_images , images )
357+ images = ops .where (augmentation_bools , augmented_images , images )
349358 return images
350359
351360 def update (self , real_logits ):
352- current_accuracy = tf . reduce_mean (step (real_logits ))
361+ current_accuracy = ops . mean (step (real_logits ))
353362
354363 # the augmentation probability is updated based on the discriminator's
355364 # accuracy on real images
356365 accuracy_error = current_accuracy - target_accuracy
357366 self .probability .assign (
358- tf .clip_by_value (
359- self .probability + accuracy_error / integration_steps , 0.0 , 1.0
360- )
367+ ops .clip (self .probability + accuracy_error / integration_steps , 0.0 , 1.0 )
361368 )
362369
363370
@@ -445,13 +452,17 @@ class GAN_ADA(keras.Model):
445452 def __init__ (self ):
446453 super ().__init__ ()
447454
455+ self .seed_generator = keras .random .SeedGenerator (seed = 42 )
448456 self .augmenter = AdaptiveAugmenter ()
449457 self .generator = get_generator ()
450458 self .ema_generator = keras .models .clone_model (self .generator )
451459 self .discriminator = get_discriminator ()
452460
453461 self .generator .summary ()
454462 self .discriminator .summary ()
463+ # we have created all layers at this point, so we can mark the model
464+ # as having been built
465+ self .built = True
455466
456467 def compile (self , generator_optimizer , discriminator_optimizer , ** kwargs ):
457468 super ().compile (** kwargs )
@@ -479,32 +490,34 @@ def metrics(self):
479490 ]
480491
481492 def generate (self , batch_size , training ):
482- latent_samples = tf .random .normal (shape = (batch_size , noise_size ))
493+ latent_samples = keras .random .normal (
494+ shape = (batch_size , noise_size ), seed = self .seed_generator
495+ )
483496 # use ema_generator during inference
484497 if training :
485- generated_images = self .generator (latent_samples , training )
498+ generated_images = self .generator (latent_samples , training = training )
486499 else :
487- generated_images = self .ema_generator (latent_samples , training )
500+ generated_images = self .ema_generator (latent_samples , training = training )
488501 return generated_images
489502
490503 def adversarial_loss (self , real_logits , generated_logits ):
491504 # this is usually called the non-saturating GAN loss
492505
493- real_labels = tf .ones (shape = (batch_size , 1 ))
494- generated_labels = tf .zeros (shape = (batch_size , 1 ))
506+ real_labels = ops .ones (shape = (batch_size , 1 ))
507+ generated_labels = ops .zeros (shape = (batch_size , 1 ))
495508
496509 # the generator tries to produce images that the discriminator considers as real
497510 generator_loss = keras .losses .binary_crossentropy (
498511 real_labels , generated_logits , from_logits = True
499512 )
500513 # the discriminator tries to determine if images are real or generated
501514 discriminator_loss = keras .losses .binary_crossentropy (
502- tf . concat ([real_labels , generated_labels ], axis = 0 ),
503- tf . concat ([real_logits , generated_logits ], axis = 0 ),
515+ ops . concatenate ([real_labels , generated_labels ], axis = 0 ),
516+ ops . concatenate ([real_logits , generated_logits ], axis = 0 ),
504517 from_logits = True ,
505518 )
506519
507- return tf . reduce_mean (generator_loss ), tf . reduce_mean (discriminator_loss )
520+ return ops . mean (generator_loss ), ops . mean (discriminator_loss )
508521
509522 def train_step (self , real_images ):
510523 real_images = self .augmenter (real_images , training = True )
@@ -604,8 +617,8 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5)
604617)
605618
606619# save the best model based on the validation KID metric
607- checkpoint_path = "gan_model"
608- checkpoint_callback = tf . keras .callbacks .ModelCheckpoint (
620+ checkpoint_path = "gan_model.weights.h5 "
621+ checkpoint_callback = keras .callbacks .ModelCheckpoint (
609622 filepath = checkpoint_path ,
610623 save_weights_only = True ,
611624 monitor = "val_kid" ,
0 commit comments