-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add Stable Diffusion 3 guide for KerasHub #1955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ae39887
91ce532
c920e17
0fb72f8
e7323bf
95c6b2f
024e411
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,332 @@ | ||
| """ | ||
| Title: Stable Diffusion 3 in KerasHub! | ||
| Author: [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta) | ||
| Date created: 2024/10/09 | ||
| Last modified: 2024/10/09 | ||
| Description: Generate new images using KerasHub's Stable Diffusion 3 model. | ||
| Accelerator: GPU | ||
| """ | ||
|
|
||
| """ | ||
| ## Overview | ||
|
|
||
| Stable Diffusion 3 is a powerful, open-source latent diffusion model (LDM) | ||
| designed to generate high-quality novel images based on text prompts. Released | ||
| by [Stability AI](https://stability.ai/), it was pre-trained on 1 billion | ||
| images and fine-tuned on 33 million high-quality aesthetic and preference images | ||
| , resulting in a greatly improved performance compared to previous version of | ||
| Stable Diffusion models. | ||
|
|
||
| In this guide, we will explore KerasHub's implementation of the | ||
| [Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) | ||
| including text-to-image, image-to-image and inpaint tasks. | ||
|
|
||
| To get started, let's install a few dependencies and sort out some imports: | ||
| """ | ||
|
|
||
| """shell | ||
| !pip install -Uq keras | ||
| !pip install -Uq git+https://github.com/keras-team/keras-hub.git | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| os.environ["KERAS_BACKEND"] = "jax" | ||
|
|
||
| import time | ||
|
|
||
| import keras | ||
| import keras_hub | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| from PIL import Image | ||
|
|
||
| """ | ||
| ## Introduction | ||
|
|
||
| Before diving into how latent diffusion models work, let's start by generating | ||
| some images using KerasHub's APIs. | ||
|
|
||
| To avoid reinitializing variables for different tasks, we'll instantiate and | ||
| load the trained `backbone` and `preprocessor` using KerasHub's `from_preset` | ||
| factory method. If you only want to perform one task at a time, you can use a | ||
| simpler API like this: | ||
|
|
||
| ```python | ||
| text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( | ||
| "stable_diffusion_3_medium", dtype="float16" | ||
| ) | ||
| ``` | ||
|
|
||
| That will automatically load and configure trained `backbone` and `preprocessor` | ||
| for you. | ||
|
|
||
| Note that in this guide, we'll use `height=512` and `width=512` for faster | ||
| image generation. For higher-quality output, it's recommended to use the default | ||
| size of `1024`. Since the entire backbone has about 3 billion parameters, which | ||
| can be challenging to fit into a consumer-level GPU, we set `dtype="float16"` to | ||
| reduce the usage of GPU memory -- the officially released weights are also in | ||
| float16. | ||
|
|
||
| It is also worth noting that the preset "stable_diffusion_3_medium" excludes the | ||
| T5XXL text encoder, as it requires significantly more GPU memory. The performace | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also add content on if users wanted to try it with T5XXL text encoder, how can they do that.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added this line: |
||
| degradation is negligible in most cases. | ||
| """ | ||
|
|
||
| backbone = keras_hub.models.StableDiffusion3Backbone.from_preset( | ||
| "stable_diffusion_3_medium", height=512, width=512, dtype="float16" | ||
| ) | ||
| preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset( | ||
| "stable_diffusion_3_medium" | ||
| ) | ||
| text_to_image = keras_hub.models.StableDiffusion3TextToImage(backbone, preprocessor) | ||
|
|
||
| """ | ||
| Next, we give it a prompt: | ||
| """ | ||
|
|
||
| prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" | ||
|
|
||
| # When using JAX or TensorFlow backends, you might experience a significant | ||
| # compilation time during the first `generate()` call. It will be much faster | ||
| # after that. | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| generated_image = text_to_image.generate(prompt) | ||
| generated_image = Image.fromarray(generated_image) | ||
| plt.axis("off") | ||
| plt.imshow(generated_image) | ||
|
|
||
| """ | ||
| Pretty impressive! But how does this work? | ||
|
|
||
| Let's dig into what "latent diffusion model" means. | ||
|
|
||
| Consider the concept of "super-resolution," where a deep learning model | ||
| "denoises" an input image, turning it into a higher-resolution version. The | ||
| model uses its training data distribution to hallucinate the visual details that | ||
| are most likely given the input. To learn more about super-resolution, you can | ||
| check out the following Keras.io tutorials: | ||
|
|
||
| - [Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/) | ||
| - [Enhanced Deep Residual Networks for single-image super-resolution](https://keras.io/examples/vision/edsr/) | ||
|
|
||
|  | ||
|
|
||
| When we push this idea to the limit, we may start asking -- what if we just run | ||
| such a model on pure noise? The model would then "denoise the noise" and start | ||
| hallucinating a brand new image. By repeating the process multiple times, we | ||
| can get turn a small patch of noise into an increasingly clear and | ||
| high-resolution artificial picture. | ||
|
|
||
| This is the key idea of latent diffusion, proposed in | ||
| [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752). | ||
| To understand diffusion in depth, you can check the Keras.io tutorial | ||
| [Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/). | ||
|
|
||
|  | ||
|
|
||
| To transition from latent diffusion to a text-to-image system, one key feature | ||
| must be added: the ability to control the generated visual content using prompt | ||
| keywords. In Stable Diffusion 3, the text encoders from the CLIP and T5XXL | ||
| models are used to obtain text embeddings, which are then fed into the diffusion | ||
| model to condition the diffusion process. This approach is based on the concept | ||
| of "classifier-free guidance", proposed in | ||
| [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
|
|
||
| When we combine these ideas, we get a high-level overview of the architecture of | ||
| Stable Diffusion 3: | ||
|
|
||
| - Text encoders: Convert the text prompt into text embeddings. | ||
| - Diffusion model: Repeatedly "denoises" a smaller latent image patch. | ||
| - Decoder: Transforms the final latent patch into a higher-resolution image. | ||
|
|
||
| First, the text prompt is projected into the latent space by multiple text | ||
| encoders, which are pretrained and frozen language models. Next, the text | ||
| embeddings, along with a randomly generated noise patch (typically from a | ||
| Gaussian distribution), are then fed into the diffusion model. The diffusion | ||
| model repeatly "denoises" the noise patch over a series of steps (the more | ||
| steps, the clearer and more refined the image becomes -- the default value is | ||
| 28 steps). Finally, the latent patch is passed through the decoder from the VAE | ||
| model to render the image in high resolution. | ||
|
|
||
| The overview of the Stable Diffusion 3 architecture: | ||
|  | ||
|
|
||
| This relatively simple system starts looking like magic once we train on | ||
| billions of pictures and their captions. As Feynman said about the universe: | ||
| _"It's not complicated, it's just a lot of it!"_ | ||
| """ | ||
|
|
||
|
|
||
| """ | ||
| ## Text-to-image task | ||
|
|
||
| Now we know the basis of the Stable Diffusion 3 and the text-to-image task. | ||
| Let's explore it more by KerasHub APIs. | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| To enable batch processing, we can feed a list of prompts into the model: | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a line explaining what the below helper function is doing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to: def display_generated_images(images):
"""Helper function to display the images from the inputs.
This function accepts the following input formats:
- 3D numpy array.
- 4D numpy array: concatenated horizontally.
- List of 3D numpy arrays: concatenated horizontally.
"""
display_image = None
if isinstance(images, np.ndarray):
if images.ndim == 3:
display_image = Image.fromarray(images)
elif images.ndim == 4:
concated_images = np.concatenate(list(images), axis=1)
display_image = Image.fromarray(concated_images)
elif isinstance(images, list):
concated_images = np.concatenate(images, axis=1)
display_image = Image.fromarray(concated_images)
if display_image is None:
raise ValueError("Unsupported input format.")
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(display_image)
plt.show()
plt.close() |
||
|
|
||
|
|
||
| def concate_images(images): | ||
| if isinstance(images, list): | ||
| concated_images = np.concatenate(list(images), axis=1) | ||
| return Image.fromarray(concated_images) | ||
| elif len(images.shape) < 4: | ||
| return Image.fromarray(images) | ||
| else: | ||
| concated_images = np.concatenate(list(images), axis=1) | ||
| return Image.fromarray(concated_images) | ||
|
|
||
|
|
||
| generated_images = text_to_image.generate([prompt] * 3) | ||
| generated_image = concate_images(generated_images) | ||
| plt.axis("off") | ||
| plt.imshow(generated_image) | ||
|
|
||
| """ | ||
| `num_steps` controls the number of denoising steps. More denoising steps | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| typically produce higher quality images, but it'll take longer to generate. In | ||
| Stable Diffusion 3, it defaults to `28`. | ||
| """ | ||
|
|
||
| num_steps = [10, 28, 50] | ||
| generated_images = [] | ||
| for n in num_steps: | ||
| st = time.time() | ||
| generated_images.append(text_to_image.generate(prompt, num_steps=n)) | ||
| print(f"Cost time (`num_steps={n}`): {time.time() - st:.2f}s") | ||
|
|
||
| generated_image = concate_images(generated_images) | ||
| plt.axis("off") | ||
| plt.imshow(generated_image) | ||
|
|
||
| """ | ||
| We can use `"negative_prompts"` to guide the model away from generating specific | ||
| styles and elements. The input format becomes a dict with the keys `"prompts"` | ||
| and `"negative_prompts"`. | ||
|
|
||
| If `"negative_prompts"` is not provided, it will be interpreted as an | ||
| unconditioned prompt with the default value of `""`. | ||
| """ | ||
|
|
||
| inputs = {"prompts": [prompt] * 3, "negative_prompts": ["Green color"] * 3} | ||
| generated_images = text_to_image.generate(inputs) | ||
| generated_image = concate_images(generated_images) | ||
|
||
| plt.axis("off") | ||
| plt.imshow(generated_image) | ||
|
|
||
| """ | ||
| `guidance_scale` affects how much the `"prompts"` influences image generation. | ||
| A lower value gives the model creativity to generate images that are more | ||
| loosely related to the prompt. Higher values push the model to follow the prompt | ||
| more closely. If this value is too high, you may observe some artifacts in the | ||
| generated image. In Stable Diffusion 3, it defaults to `7.0`. | ||
| """ | ||
|
|
||
| generated_images = [ | ||
| text_to_image.generate(prompt, guidance_scale=2.5), | ||
| text_to_image.generate(prompt, guidance_scale=7.0), | ||
| text_to_image.generate(prompt, guidance_scale=10.5), | ||
| ] | ||
| generated_image = concate_images(generated_images) | ||
| plt.axis("off") | ||
| plt.imshow(generated_image) | ||
|
|
||
| """ | ||
| Note that `negative_prompts` and `guidance_scale` are related. The formula in | ||
| the implementation can be represented as follows: | ||
| `predicted_noise = negative_noise + guidance_scale * (positive_noise - negative_noise)`. | ||
| """ | ||
|
|
||
| """ | ||
| ## Image-to-image task | ||
|
|
||
| It is possible to use a referece image as the starting point for the diffusion | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| process. This requires an additional module in the pipeline -- the encoder of | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| the VAE model. | ||
|
|
||
| The reference image is encoded by the VAE encoder into the latent space, where | ||
| noise is then added. The subsequent steps follow the same procedure as the | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| text-to-image task. | ||
|
|
||
| The input format becomes a dict with the keys `"images"`, `"prompts"` and | ||
| optionally `"negative_prompts"`. | ||
| """ | ||
|
|
||
| image_to_image = keras_hub.models.StableDiffusion3ImageToImage(backbone, preprocessor) | ||
|
|
||
| image = keras.utils.get_file( | ||
| "cat.png", | ||
| origin="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png", | ||
|
||
| ) | ||
| image = Image.open(image).convert("RGB") | ||
| width, height = image.size | ||
|
|
||
| # Crop the image to fit the height and width of the backbone. | ||
| image = image.crop( | ||
| (width // 2 - 256, height // 2 - 256, width // 2 + 256, height // 2 + 256) | ||
| ) | ||
|
|
||
| # Note that the values of the image must be in the range of [-1.0, 1.0]. | ||
| image_array = np.array(image).astype("float32") | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| image_array = image_array / 127.5 - 1.0 | ||
| prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, " | ||
| prompt += "adorable, Pixar, Disney, 8k" | ||
|
|
||
| generated_image = image_to_image.generate({"images": image_array, "prompts": prompt}) | ||
|
|
||
| display_image = concate_images([np.array(image), generated_image]) | ||
| plt.axis("off") | ||
| plt.imshow(display_image) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. explain the output that is seen here with 1 or 2 lines
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added: |
||
| """ | ||
| ## Inpaint task | ||
|
|
||
| To extent the image-to-image task, we can also control the generated area using | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| a mask. This process is called inpainting, where specific areas of an image are | ||
| replaced or edited. | ||
|
|
||
| Inpainting relies on a mask to determine which regions of the image to modify. | ||
| The areas to inpaint are represented by white pixels (`True`), while the areas | ||
| to preserve are represented by black pixels (`False`). | ||
|
|
||
| The input format becomes a dict with the keys `"images"`, `"masks"`, `"prompts"` | ||
james77777778 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| and optionally `"negative_prompts"`. | ||
| """ | ||
|
|
||
| inpaint = keras_hub.models.StableDiffusion3Inpaint(backbone, preprocessor) | ||
|
|
||
| image = keras.utils.get_file( | ||
| "inpaint.png", | ||
| origin="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png", | ||
| ) | ||
| image = Image.open(image).convert("RGB") | ||
| image_array = np.array(image).astype("float32") | ||
| image_array = image_array / 127.5 - 1.0 | ||
| mask = keras.utils.get_file( | ||
| "inpaint_mask.png", | ||
| origin="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png", | ||
| ) | ||
| mask = Image.open(mask).convert("L") | ||
|
|
||
| # Note that the mask values are of boolean dtype. | ||
| mask_array = np.array(mask).astype("bool") | ||
| prompt = "concept art digital painting of an elven castle, " | ||
| prompt += "inspired by lord of the rings, highly detailed, 8k" | ||
|
|
||
| generated_image = inpaint.generate( | ||
| {"images": image_array, "masks": mask_array, "prompts": prompt} | ||
| ) | ||
|
|
||
| display_image = concate_images( | ||
| [np.array(image), np.array(mask.convert("RGB")), generated_image] | ||
| ) | ||
| plt.axis("off") | ||
| plt.imshow(display_image) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add few lines here explaining the output.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added: |
||
| """ | ||
| ## Conclusion | ||
|
|
||
| KerasHub's `StableDiffusion3` supports a variety of applications and, with the | ||
| help of Keras 3, enables running the model on TensorFlow, JAX, and PyTorch! | ||
| """ | ||
Uh oh!
There was an error while loading. Please reload this page.