Skip to content

Commit 9199eee

Browse files
authored
Data-efficient GANs with Adaptive Discriminator Augmentation to keras 3.0 (Tensorflow backend only) (#2035)
* keras3 migration * Keras3 migration - Add generated files
1 parent 54c7d73 commit 9199eee

File tree

6 files changed

+3801
-275
lines changed

6 files changed

+3801
-275
lines changed

examples/generative/gan_ada.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Data-efficient GANs with Adaptive Discriminator Augmentation
33
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
44
Date created: 2021/10/28
5-
Last modified: 2021/10/28
5+
Last modified: 2025/01/23
66
Description: Generating images from limited data using the Caltech Birds dataset.
77
Accelerator: GPU
88
"""
@@ -62,12 +62,17 @@ class of generative deep learning models, commonly used for image generation. Th
6262
## Setup
6363
"""
6464

65+
import os
66+
67+
os.environ["KERAS_BACKEND"] = "tensorflow"
68+
6569
import matplotlib.pyplot as plt
6670
import tensorflow as tf
6771
import tensorflow_datasets as tfds
6872

69-
from tensorflow import keras
70-
from tensorflow.keras import layers
73+
import keras
74+
from keras import ops
75+
from keras import layers
7176

7277
"""
7378
## Hyperparameterers
@@ -115,46 +120,47 @@ class of generative deep learning models, commonly used for image generation. Th
115120

116121

117122
def round_to_int(float_value):
118-
return tf.cast(tf.math.round(float_value), dtype=tf.int32)
123+
return ops.cast(ops.round(float_value), "int32")
119124

120125

121126
def preprocess_image(data):
122127
# unnormalize bounding box coordinates
123-
height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
124-
width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
125-
bounding_box = data["bbox"] * tf.stack([height, width, height, width])
128+
height = ops.cast(ops.shape(data["image"])[0], "float32")
129+
width = ops.cast(ops.shape(data["image"])[1], "float32")
130+
bounding_box = data["bbox"] * ops.stack([height, width, height, width])
126131

127132
# calculate center and length of longer side, add padding
128133
target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
129134
target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
130-
target_size = tf.maximum(
135+
target_size = ops.maximum(
131136
(1.0 + padding) * (bounding_box[2] - bounding_box[0]),
132137
(1.0 + padding) * (bounding_box[3] - bounding_box[1]),
133138
)
134139

135140
# modify crop size to fit into image
136-
target_height = tf.reduce_min(
141+
target_height = ops.min(
137142
[target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
138143
)
139-
target_width = tf.reduce_min(
144+
target_width = ops.min(
140145
[target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
141146
)
142147

143-
# crop image
144-
image = tf.image.crop_to_bounding_box(
148+
# crop image, `ops.image.crop_images` only works with non-tensor croppings
149+
image = ops.slice(
145150
data["image"],
146-
offset_height=round_to_int(target_center_y - 0.5 * target_height),
147-
offset_width=round_to_int(target_center_x - 0.5 * target_width),
148-
target_height=round_to_int(target_height),
149-
target_width=round_to_int(target_width),
151+
start_indices=(
152+
round_to_int(target_center_y - 0.5 * target_height),
153+
round_to_int(target_center_x - 0.5 * target_width),
154+
0,
155+
),
156+
shape=(round_to_int(target_height), round_to_int(target_width), 3),
150157
)
151158

152159
# resize and clip
153-
# for image downsampling, area interpolation is the preferred method
154-
image = tf.image.resize(
155-
image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
156-
)
157-
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
160+
image = ops.cast(image, "float32")
161+
image = ops.image.resize(image, [image_size, image_size])
162+
163+
return ops.clip(image / 255.0, 0.0, 1.0)
158164

159165

160166
def prepare_dataset(split):
@@ -231,8 +237,10 @@ def __init__(self, name="kid", **kwargs):
231237
)
232238

233239
def polynomial_kernel(self, features_1, features_2):
234-
feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
235-
return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0
240+
feature_dimensions = ops.cast(ops.shape(features_1)[1], "float32")
241+
return (
242+
features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
243+
) ** 3.0
236244

237245
def update_state(self, real_images, generated_images, sample_weight=None):
238246
real_features = self.encoder(real_images, training=False)
@@ -246,15 +254,15 @@ def update_state(self, real_images, generated_images, sample_weight=None):
246254
kernel_cross = self.polynomial_kernel(real_features, generated_features)
247255

248256
# estimate the squared maximum mean discrepancy using the average kernel values
249-
batch_size = tf.shape(real_features)[0]
250-
batch_size_f = tf.cast(batch_size, dtype=tf.float32)
251-
mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
257+
batch_size = ops.shape(real_features)[0]
258+
batch_size_f = ops.cast(batch_size, "float32")
259+
mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
252260
batch_size_f * (batch_size_f - 1.0)
253261
)
254-
mean_kernel_generated = tf.reduce_sum(
255-
kernel_generated * (1.0 - tf.eye(batch_size))
262+
mean_kernel_generated = ops.sum(
263+
kernel_generated * (1.0 - ops.eye(batch_size))
256264
) / (batch_size_f * (batch_size_f - 1.0))
257-
mean_kernel_cross = tf.reduce_mean(kernel_cross)
265+
mean_kernel_cross = ops.mean(kernel_cross)
258266
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
259267

260268
# update the average KID estimate
@@ -299,7 +307,7 @@ def reset_state(self):
299307
# "hard sigmoid", useful for binary accuracy calculation from logits
300308
def step(values):
301309
# negative values -> 0.0, positive values -> 1.0
302-
return 0.5 * (1.0 + tf.sign(values))
310+
return 0.5 * (1.0 + ops.sign(values))
303311

304312

305313
# augments images with a probability that is dynamically updated during training
@@ -308,7 +316,8 @@ def __init__(self):
308316
super().__init__()
309317

310318
# stores the current probability of an image being augmented
311-
self.probability = tf.Variable(0.0)
319+
self.probability = keras.Variable(0.0)
320+
self.seed_generator = keras.random.SeedGenerator(42)
312321

313322
# the corresponding augmentation names from the paper are shown above each layer
314323
# the authors show (see figure 4), that the blitting and geometric augmentations
@@ -336,28 +345,26 @@ def __init__(self):
336345

337346
def call(self, images, training):
338347
if training:
339-
augmented_images = self.augmenter(images, training)
348+
augmented_images = self.augmenter(images, training=training)
340349

341350
# during training either the original or the augmented images are selected
342351
# based on self.probability
343-
augmentation_values = tf.random.uniform(
344-
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
352+
augmentation_values = keras.random.uniform(
353+
shape=(batch_size, 1, 1, 1), seed=self.seed_generator
345354
)
346-
augmentation_bools = tf.math.less(augmentation_values, self.probability)
355+
augmentation_bools = ops.less(augmentation_values, self.probability)
347356

348-
images = tf.where(augmentation_bools, augmented_images, images)
357+
images = ops.where(augmentation_bools, augmented_images, images)
349358
return images
350359

351360
def update(self, real_logits):
352-
current_accuracy = tf.reduce_mean(step(real_logits))
361+
current_accuracy = ops.mean(step(real_logits))
353362

354363
# the augmentation probability is updated based on the discriminator's
355364
# accuracy on real images
356365
accuracy_error = current_accuracy - target_accuracy
357366
self.probability.assign(
358-
tf.clip_by_value(
359-
self.probability + accuracy_error / integration_steps, 0.0, 1.0
360-
)
367+
ops.clip(self.probability + accuracy_error / integration_steps, 0.0, 1.0)
361368
)
362369

363370

@@ -445,13 +452,17 @@ class GAN_ADA(keras.Model):
445452
def __init__(self):
446453
super().__init__()
447454

455+
self.seed_generator = keras.random.SeedGenerator(seed=42)
448456
self.augmenter = AdaptiveAugmenter()
449457
self.generator = get_generator()
450458
self.ema_generator = keras.models.clone_model(self.generator)
451459
self.discriminator = get_discriminator()
452460

453461
self.generator.summary()
454462
self.discriminator.summary()
463+
# we have created all layers at this point, so we can mark the model
464+
# as having been built
465+
self.built = True
455466

456467
def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
457468
super().compile(**kwargs)
@@ -479,32 +490,34 @@ def metrics(self):
479490
]
480491

481492
def generate(self, batch_size, training):
482-
latent_samples = tf.random.normal(shape=(batch_size, noise_size))
493+
latent_samples = keras.random.normal(
494+
shape=(batch_size, noise_size), seed=self.seed_generator
495+
)
483496
# use ema_generator during inference
484497
if training:
485-
generated_images = self.generator(latent_samples, training)
498+
generated_images = self.generator(latent_samples, training=training)
486499
else:
487-
generated_images = self.ema_generator(latent_samples, training)
500+
generated_images = self.ema_generator(latent_samples, training=training)
488501
return generated_images
489502

490503
def adversarial_loss(self, real_logits, generated_logits):
491504
# this is usually called the non-saturating GAN loss
492505

493-
real_labels = tf.ones(shape=(batch_size, 1))
494-
generated_labels = tf.zeros(shape=(batch_size, 1))
506+
real_labels = ops.ones(shape=(batch_size, 1))
507+
generated_labels = ops.zeros(shape=(batch_size, 1))
495508

496509
# the generator tries to produce images that the discriminator considers as real
497510
generator_loss = keras.losses.binary_crossentropy(
498511
real_labels, generated_logits, from_logits=True
499512
)
500513
# the discriminator tries to determine if images are real or generated
501514
discriminator_loss = keras.losses.binary_crossentropy(
502-
tf.concat([real_labels, generated_labels], axis=0),
503-
tf.concat([real_logits, generated_logits], axis=0),
515+
ops.concatenate([real_labels, generated_labels], axis=0),
516+
ops.concatenate([real_logits, generated_logits], axis=0),
504517
from_logits=True,
505518
)
506519

507-
return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)
520+
return ops.mean(generator_loss), ops.mean(discriminator_loss)
508521

509522
def train_step(self, real_images):
510523
real_images = self.augmenter(real_images, training=True)
@@ -604,8 +617,8 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5)
604617
)
605618

606619
# save the best model based on the validation KID metric
607-
checkpoint_path = "gan_model"
608-
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
620+
checkpoint_path = "gan_model.weights.h5"
621+
checkpoint_callback = keras.callbacks.ModelCheckpoint(
609622
filepath=checkpoint_path,
610623
save_weights_only=True,
611624
monitor="val_kid",
603 KB
Loading
910 KB
Loading
402 KB
Loading

0 commit comments

Comments
 (0)