Skip to content

Stylegan migration#2303

Open
kharshith-k wants to merge 4 commits intokeras-team:masterfrom
kharshith-k:stylegan-migration
Open

Stylegan migration#2303
kharshith-k wants to merge 4 commits intokeras-team:masterfrom
kharshith-k:stylegan-migration

Conversation

@kharshith-k
Copy link
Member

Please find my PR for migrating StyleGAN tutorial from keras2 to keras3. Full training of the model needs to be started after the review

Colab Notebook

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kharshith-k, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on migrating the StyleGAN tutorial to Keras 3, ensuring compatibility with the latest Keras API. The changes encompass updating the codebase to use keras.ops for tensor manipulations, refactoring custom components into Keras layers, and optimizing data handling for improved performance. The overall goal is to modernize the tutorial and align it with the current Keras ecosystem.

Highlights

  • Keras 3 Migration: Migrated the StyleGAN tutorial from tensorflow.keras to the new Keras 3 API, updating all relevant imports and API calls.
  • TensorFlow Operations Replaced: Replaced direct tf.* tensor operations with their backend-agnostic keras.ops.* equivalents across utility functions and custom layers.
  • Custom Layer Refactoring: Refactored the minibatch_std function into a dedicated MinibatchStd Keras layer and updated EqualizedConv, EqualizedDense, and AdaIN layers for Keras 3 compatibility, including using internal Conv2D layers and proper layer building patterns.
  • Data Loading and Performance Optimizations: Adjusted data loading and preprocessing, including increasing batch sizes for better utilization on A100 40GB GPUs and updating image_dataset_from_directory parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/generative/stylegan.py
    • Updated Keras imports and removed TensorFlow-specific imports.
    • Modified batch sizes and data loading parameters for improved GPU utilization.
    • Replaced TensorFlow operations with keras.ops equivalents in utility functions.
    • Converted minibatch_std function into a MinibatchStd Keras layer.
    • Refactored EqualizedConv to use an internal layers.Conv2D and keras.ops.pad.
    • Updated EqualizedDense and AdaIN layers for Keras 3 compatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request begins the migration of the StyleGAN example to Keras 3. While some TensorFlow operations have been correctly replaced with their Keras 3 keras.ops equivalents, the migration is incomplete and introduces several critical issues. Key problems include incorrect handling of imports which will lead to runtime errors, a performance degradation in the data pipeline, and incorrect refactoring of custom layers that will either fail or are not aligned with Keras 3's backend-agnostic principles. My review provides specific feedback and suggestions to address these critical issues and guide the completion of the migration.

Comment on lines +41 to +43
import keras
from keras import layers
from keras.models import Sequential
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The migration to Keras 3 is incomplete. import tensorflow as tf has been removed, but tf is still used extensively throughout the file (e.g., in the StyleGAN class), which will cause a NameError. Additionally, InstanceNormalization is used but it's no longer imported. To fix this, you should add back import tensorflow as tf and import InstanceNormalization from keras.layers.

Suggested change
import keras
from keras import layers
from keras.models import Sequential
import tensorflow as tf
import keras
from keras import layers
from keras.models import Sequential
from keras.layers import InstanceNormalization

Comment on lines +146 to +162
class MinibatchStd(layers.Layer):
def __init__(self, group_size=4, epsilon=1e-8, **kwargs):
super().__init__(**kwargs)
self.group_size = group_size
self.epsilon = epsilon

def call(self, input_tensor):
shape = keras.ops.shape(input_tensor)
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
group_size = keras.ops.minimum(self.group_size, n)
x = keras.ops.reshape(input_tensor, (group_size, -1, h, w, c))
group_mean = keras.ops.mean(x, axis=0, keepdims=False)
group_var = keras.ops.var(x, axis=0, keepdims=False)
group_std = keras.ops.sqrt(group_var + self.epsilon)
avg_std = keras.ops.mean(group_std, axis=[1, 2, 3], keepdims=True)
x_std = keras.ops.tile(avg_std, (group_size, h, w, 1))
return keras.ops.concatenate([input_tensor, x_std], axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Refactoring minibatch_std from a function to a keras.layers.Layer is a good architectural improvement. However, the usage of this layer in Discriminator.build_base (line 463) has not been updated. It's still being called as a function minibatch_std(input_tensor), which will raise an error. It should be instantiated as a layer: MinibatchStd()(input_tensor).

Comment on lines 199 to 207
def call(self, inputs):
if self.pad:
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
x = keras.ops.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="reflect")
else:
x = inputs
output = (
tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
)
# Set the kernel weights of the Conv2D layer to the scaled weights
self.conv2d.kernel.assign(self.scale * self.w)
output = self.conv2d(x) + self.b
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Assigning to self.conv2d.kernel inside the call method is problematic. It's a stateful operation that is not backend-agnostic (the .assign() method is specific to TensorFlow variables) and can behave unexpectedly with graph compilation. A better approach is to use the functional keras.ops.conv which is backend-agnostic and stateless in this context. This would also simplify the build method, as you would no longer need to create the self.conv2d layer.

Suggested change
def call(self, inputs):
if self.pad:
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
x = keras.ops.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="reflect")
else:
x = inputs
output = (
tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
)
# Set the kernel weights of the Conv2D layer to the scaled weights
self.conv2d.kernel.assign(self.scale * self.w)
output = self.conv2d(x) + self.b
return output
def call(self, inputs):
if self.pad:
x = keras.ops.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="reflect")
else:
x = inputs
# Use backend-agnostic functional convolution
output = keras.ops.conv(x, self.scale * self.w, bias=self.b, padding="valid")
return output

partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE
).unbatch()
dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
dl = ds_train.map(lambda x: resize_image(res, x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

By removing num_parallel_calls=tf.data.AUTOTUNE from the .map() operation, you are disabling parallel processing for image resizing. This can significantly slow down the data loading pipeline, creating a bottleneck during training. It's recommended to re-introduce this to leverage multi-core processing for the mapping function.

Suggested change
dl = ds_train.map(lambda x: resize_image(res, x))
dl = ds_train.map(lambda x: resize_image(res, x), num_parallel_calls=tf.data.AUTOTUNE)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants