Skip to content

Add model class for running transformer with precomputed text latents #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c52e1f6
Update versions
corystephenson-db Sep 17, 2024
15e4a4d
Include mlflow
corystephenson-db Sep 18, 2024
a6ced94
Spelling is important
corystephenson-db Sep 18, 2024
fd7cffd
Torchmetrics :(
corystephenson-db Sep 18, 2024
75bfd30
Triton
corystephenson-db Sep 18, 2024
cc90fff
New xformers
corystephenson-db Oct 3, 2024
6cb12ee
Automatically add per-device batch size as a streaming kwarg
corystephenson-db Oct 3, 2024
755e5cc
Bump composer
corystephenson-db Oct 4, 2024
192dc17
Update streaming
corystephenson-db Oct 4, 2024
a004a9d
Fix huggingface warning
corystephenson-db Oct 4, 2024
1e2c20c
Update to new torch autocast
corystephenson-db Oct 4, 2024
37bbd4c
Update workflows
corystephenson-db Oct 4, 2024
0910514
numpy :(
corystephenson-db Oct 4, 2024
9825134
Update batch size autosetting
corystephenson-db Oct 6, 2024
b5aa661
Don't autoset per-device batch size
corystephenson-db Oct 6, 2024
829465c
Pass batch size to dataset class
corystephenson-db Oct 7, 2024
9d639de
Include a precomputed text latent transformer
corystephenson-db Oct 5, 2024
8b469df
Add model for precomputed text latent transformer
corystephenson-db Oct 5, 2024
c84f0d1
Conditioning features are now same as model features
corystephenson-db Oct 5, 2024
d7073fe
Simple projection to common dimensionality
corystephenson-db Oct 6, 2024
13274df
No per-sequence embeddings or post pooled layernorm
corystephenson-db Oct 7, 2024
16f3a49
Generate should match forward
corystephenson-db Oct 7, 2024
6939d02
Fix bug in generated latents size
corystephenson-db Oct 7, 2024
03fbe97
Cleanup
corystephenson-db Oct 11, 2024
1d9cdd1
Merge branch 'main' of https://github.com/coryMosaicML/diffusion into…
corystephenson-db Oct 11, 2024
4144a3d
Correctly calc seq len in flops calculation
corystephenson-db Oct 11, 2024
c302f68
Add optional register tokens
corystephenson-db Nov 14, 2024
ad4f117
Mask on dim 1
corystephenson-db Nov 14, 2024
639ae6d
Versions
corystephenson-db Nov 14, 2024
318d727
Salt and pepper
corystephenson-db Nov 16, 2024
614f1a9
Cleanup
corystephenson-db Nov 22, 2024
efb9c72
Fix versions
corystephenson-db Nov 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
from diffusion.models.t2i_transformer import ComposerPrecomputedTextLatentsToImageMMDiT, ComposerTextToImageMMDiT
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.models.transformer import DiffusionTransformer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
Expand Down Expand Up @@ -1010,6 +1010,164 @@ def text_to_image_transformer(
return model


def precomputed_text_latents_to_image_transformer(
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
autoencoder_path: Optional[str] = None,
autoencoder_local_path: str = '/tmp/autoencoder_weights.pt',
include_text_encoders: bool = False,
text_encoder_dtype: str = 'bfloat16',
cache_dir: str = '/tmp/hf_files',
num_layers: int = 28,
max_image_side: int = 1280,
conditioning_features: int = 768,
conditioning_max_sequence_length: int = 512 + 77,
num_register_tokens: int = 0,
patch_size: int = 2,
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 7.67754318618,
timestep_mean: float = 0.0,
timestep_std: float = 1.0,
timestep_shift: float = 1.0,
image_key: str = 'image',
t5_latent_key: str = 'T5_LATENTS',
t5_mask_key: str = 'T5_ATTENTION_MASK',
clip_latent_key: str = 'CLIP_LATENTS',
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
clip_pooled_key: str = 'CLIP_POOLED',
pretrained: bool = False,
):
"""Text to image transformer training setup.

Args:
vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'.
autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified,
will use the vae from `model_name`. Default `None`.
include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running
inference. Default: `False`.
text_encoder_dtype (str): The dtype to use for the text encoder. One of [`float32`, `float16`, `bfloat16`].
Default: `bfloat16`.
cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`.
autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`.
num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by
this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`.
max_image_side (int): Maximum side length of the image. Default: `1280`.
conditioning_features (int): Number of features in the conditioning transformer. Default: `768`.
conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`.
num_register_tokens (int): Number of additional register tokens to use. Default: `0`.
patch_size (int): Patch size for the transformer. Default: `2`.
latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value,
a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `0.0`.
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.13025`.
timestep_mean (float): The mean of the timesteps. Default: `0.0`.
timestep_std (float): The std. dev. of the timesteps. Default: `1.0`.
timestep_shift (float): The shift of the timesteps. Default: `1.0`.
image_key (str): The key for the image in the batch. Default: `image`.
t5_latent_key (str): The key to use for the T5 latents in the precomputed latents. Default: `'T5_LATENTS'`.
t5_mask_key (str): The key to use for the T5 attention mask in the precomputed latents. Default: `'T5_ATTENTION_MASK'`.
clip_latent_key (str): The key to use for the CLIP latents in the precomputed latents. Default: `'CLIP_LATENTS'`.
clip_mask_key (str): The key to use for the CLIP attention mask in the precomputed latents. Default: `'CLIP_ATTENTION_MASK'`.
clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`.
pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

precision = torch.float16
# Make the autoencoder
if autoencoder_path is None:
if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics':
raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.')
downsample_factor = 8
autoencoder_channels = 4
# Use the pretrained vae
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision)
except: # for handling SDXL vae fp16 fixed checkpoint
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision)
else:
# Use a custom autoencoder
vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision)
if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'):
raise ValueError(
'Must specify latent scale when using a custom autoencoder without tracking latent statistics.')
if isinstance(latent_mean, str) and latent_mean == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_mean = tuple(latent_statistics['latent_channel_means'])
if isinstance(latent_std, str) and latent_std == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_std = tuple(latent_statistics['latent_channel_stds'])
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
autoencoder_channels = vae.config['latent_channels']
assert isinstance(vae, torch.nn.Module)
if isinstance(latent_mean, float):
latent_mean = (latent_mean,) * autoencoder_channels
if isinstance(latent_std, float):
latent_std = (latent_std,) * autoencoder_channels
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)
# Figure out the maximum input sequence length
input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size))
# Make the transformer model
transformer = DiffusionTransformer(num_features=64 * num_layers,
num_heads=num_layers,
num_layers=num_layers,
input_features=autoencoder_channels * (patch_size**2),
input_max_sequence_length=input_max_sequence_length,
input_dimension=2,
conditioning_features=64 * num_layers,
conditioning_max_sequence_length=conditioning_max_sequence_length,
conditioning_dimension=1,
expansion_factor=4,
num_register_tokens=num_register_tokens)

# Optionally load the tokenizers and text encoders
t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None
if include_text_encoders:
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
dtype = dtype_map[text_encoder_dtype]
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='tokenizer',
cache_dir=cache_dir,
local_files_only=False)
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).encoder.eval()
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='text_encoder',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).cuda().eval()

# Make the composer model
model = ComposerPrecomputedTextLatentsToImageMMDiT(model=transformer,
autoencoder=vae,
t5_tokenizer=t5_tokenizer,
t5_encoder=t5_encoder,
clip_tokenizer=clip_tokenizer,
clip_encoder=clip_encoder,
latent_mean=latent_mean,
latent_std=latent_std,
patch_size=patch_size,
downsample_factor=downsample_factor,
latent_channels=autoencoder_channels,
timestep_mean=timestep_mean,
timestep_std=timestep_std,
timestep_shift=timestep_shift,
image_key=image_key,
t5_latent_key=t5_latent_key,
t5_mask_key=t5_mask_key,
clip_latent_key=clip_latent_key,
clip_mask_key=clip_mask_key,
clip_pooled_key=clip_pooled_key)

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
return model


def build_autoencoder(input_channels: int = 3,
output_channels: int = 3,
hidden_channels: int = 128,
Expand Down
Loading
Loading