Skip to content

Conversation

DavidBert
Copy link

@DavidBert DavidBert commented Oct 9, 2025

This commit adds support for the Photon image generation model:

  • PhotonTransformer2DModel: Core transformer architecture
  • PhotonPipeline: Text-to-image generation pipeline
  • Attention processor updates for Photon-specific attention mechanism
  • Conversion script for loading Photon checkpoints
  • Documentation and tests
image_10 image_4 image_0 image_1

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests
print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

print(f"✓ Saved VAE to {vae_path}")


def download_and_save_text_encoder(output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the Text Encoder.

print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

Comment on lines 19 to 20
from einops import rearrange
from einops.layers.torch import Rearrange
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to get rid of the einops dependency and use native PyTorch ops here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it for native Pytorch. Out of curiosity why do you recommend avoiding using einops?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to avoid additional dependencies especially when things can be done in native PyTorch.

return xq_out.reshape(*xq.shape).type_as(xq)


class EmbedND(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this share similarity with Flux?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it comes from the BFL original implementation.
I tried to modify and use the logic from transformer_flux.py but I didn't manage to make it work without heavy changes and additional complexity.
I added a comment to explicitely say that it come from there. Is it OK for you or do you want me to continue trying to use the code from transformer_flux.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay. Then it's fine to keep it here. I would maybe rename it to PhotoEmbedND and leave a note that it's inspired from Flux. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this wasn't addressed.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.

Also, it would be great to see some samples of Photon!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left a couple more comments. Let's also add the pipeline-level tests.

Comment on lines 17 to 19
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we're supporting it yet? If so, we can remove for now.

<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>

Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @stevhliu for a review on the docs.

return xq_out.reshape(*xq.shape).type_as(xq)


class PhotonAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we write it in a fashion similar to

?

return xq_out.reshape(*xq.shape).type_as(xq)


class EmbedND(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this wasn't addressed.

gate: Tensor


class Modulation(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For intermediate blocks like this, we avoid using a dataclass to return outputs.

):
"""Prepare initial latents for the diffusion process."""
if latents is None:
spatial_compression = self.vae_spatial_compression_ratio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For image models (where there ate no separate spatial and temporal compression factors). we usually just refer to it as vae_scale_factor:

https://github.com/huggingface/diffusers/blob/8abc7aeb715c0149ee0a9982b2d608ce97f55215/src/diffusers/pipelines/flux/pipeline_flux.py#L209C14-L209C34

def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support passing prompt embeddings too in case users want to supply them precomputed:

prompt_embeds: Optional[torch.FloatTensor] = None,

Comment on lines 484 to 486
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
height = height or default_sample_size
width = width or default_sample_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer this pattern:

height = height or self.default_sample_size * self.vae_scale_factor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.

)[0]

# Apply CFG
if self.do_classifier_free_guidance:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see negative_prompt in the __call__() of the pipeline. Is that expected?

Comment on lines 561 to 564
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be moved out of the loop, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants