-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Add Photon model and pipeline support #12456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit adds support for the Photon image generation model: - PhotonTransformer2DModel: Core transformer architecture - PhotonPipeline: Text-to-image generation pipeline - Attention processor updates for Photon-specific attention mechanism - Conversion script for loading Photon checkpoints - Documentation and tests
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
print(f"✓ Saved VAE to {vae_path}") | ||
|
||
|
||
def download_and_save_text_encoder(output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for the Text Encoder.
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
from einops import rearrange | ||
from einops.layers.torch import Rearrange |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to get rid of the einops
dependency and use native PyTorch ops here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it for native Pytorch. Out of curiosity why do you recommend avoiding using einops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We try to avoid additional dependencies especially when things can be done in native PyTorch.
return xq_out.reshape(*xq.shape).type_as(xq) | ||
|
||
|
||
class EmbedND(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this share similarity with Flux?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it comes from the BFL original implementation.
I tried to modify and use the logic from transformer_flux.py but I didn't manage to make it work without heavy changes and additional complexity.
I added a comment to explicitely say that it come from there. Is it OK for you or do you want me to continue trying to use the code from transformer_flux.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay. Then it's fine to keep it here. I would maybe rename it to PhotoEmbedND
and leave a note that it's inspired from Flux. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this wasn't addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.
Also, it would be great to see some samples of Photon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a couple more comments. Let's also add the pipeline-level tests.
<div class="flex flex-wrap space-x-1"> | ||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/> | ||
</div> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we're supporting it yet? If so, we can remove for now.
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/> | ||
</div> | ||
|
||
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @stevhliu for a review on the docs.
return xq_out.reshape(*xq.shape).type_as(xq) | ||
|
||
|
||
class PhotonAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we write it in a fashion similar to
class FluxAttnProcessor: |
return xq_out.reshape(*xq.shape).type_as(xq) | ||
|
||
|
||
class EmbedND(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this wasn't addressed.
gate: Tensor | ||
|
||
|
||
class Modulation(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For intermediate blocks like this, we avoid using a dataclass
to return outputs.
): | ||
"""Prepare initial latents for the diffusion process.""" | ||
if latents is None: | ||
spatial_compression = self.vae_spatial_compression_ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For image models (where there ate no separate spatial and temporal compression factors). we usually just refer to it as vae_scale_factor
:
def __call__( | ||
self, | ||
prompt: Union[str, List[str]] = None, | ||
height: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support passing prompt embeddings too in case users want to supply them precomputed:
prompt_embeds: Optional[torch.FloatTensor] = None, |
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) | ||
height = height or default_sample_size | ||
width = width or default_sample_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer this pattern:
height = height or self.default_sample_size * self.vae_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.
)[0] | ||
|
||
# Apply CFG | ||
if self.do_classifier_free_guidance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see negative_prompt
in the __call__()
of the pipeline. Is that expected?
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) | ||
ca_mask = None | ||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None: | ||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These can be moved out of the loop, right?
This commit adds support for the Photon image generation model:
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.