Skip to content

Commit cb14024

Browse files
authored
Add model class for running transformer with precomputed text latents (#180)
1 parent ba8ca02 commit cb14024

File tree

4 files changed

+647
-5
lines changed

4 files changed

+647
-5
lines changed

diffusion/models/models.py

+159-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusion.models.pixel_diffusion import PixelDiffusion
2222
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
2323
from diffusion.models.stable_diffusion import StableDiffusion
24-
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
24+
from diffusion.models.t2i_transformer import ComposerPrecomputedTextLatentsToImageMMDiT, ComposerTextToImageMMDiT
2525
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
2626
from diffusion.models.transformer import DiffusionTransformer
2727
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
@@ -1008,6 +1008,164 @@ def text_to_image_transformer(
10081008
return model
10091009

10101010

1011+
def precomputed_text_latents_to_image_transformer(
1012+
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
1013+
autoencoder_path: Optional[str] = None,
1014+
autoencoder_local_path: str = '/tmp/autoencoder_weights.pt',
1015+
include_text_encoders: bool = False,
1016+
text_encoder_dtype: str = 'bfloat16',
1017+
cache_dir: str = '/tmp/hf_files',
1018+
num_layers: int = 28,
1019+
max_image_side: int = 1280,
1020+
conditioning_features: int = 768,
1021+
conditioning_max_sequence_length: int = 512 + 77,
1022+
num_register_tokens: int = 0,
1023+
patch_size: int = 2,
1024+
latent_mean: Union[float, Tuple, str] = 0.0,
1025+
latent_std: Union[float, Tuple, str] = 7.67754318618,
1026+
timestep_mean: float = 0.0,
1027+
timestep_std: float = 1.0,
1028+
timestep_shift: float = 1.0,
1029+
image_key: str = 'image',
1030+
t5_latent_key: str = 'T5_LATENTS',
1031+
t5_mask_key: str = 'T5_ATTENTION_MASK',
1032+
clip_latent_key: str = 'CLIP_LATENTS',
1033+
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
1034+
clip_pooled_key: str = 'CLIP_POOLED',
1035+
pretrained: bool = False,
1036+
):
1037+
"""Text to image transformer training setup.
1038+
1039+
Args:
1040+
vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'.
1041+
autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified,
1042+
will use the vae from `model_name`. Default `None`.
1043+
include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running
1044+
inference. Default: `False`.
1045+
text_encoder_dtype (str): The dtype to use for the text encoder. One of [`float32`, `float16`, `bfloat16`].
1046+
Default: `bfloat16`.
1047+
cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`.
1048+
autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`.
1049+
num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by
1050+
this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`.
1051+
max_image_side (int): Maximum side length of the image. Default: `1280`.
1052+
conditioning_features (int): Number of features in the conditioning transformer. Default: `768`.
1053+
conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`.
1054+
num_register_tokens (int): Number of additional register tokens to use. Default: `0`.
1055+
patch_size (int): Patch size for the transformer. Default: `2`.
1056+
latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value,
1057+
a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder
1058+
checkpoint. Defaults to `0.0`.
1059+
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
1060+
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
1061+
checkpoint. Defaults to `1/0.13025`.
1062+
timestep_mean (float): The mean of the timesteps. Default: `0.0`.
1063+
timestep_std (float): The std. dev. of the timesteps. Default: `1.0`.
1064+
timestep_shift (float): The shift of the timesteps. Default: `1.0`.
1065+
image_key (str): The key for the image in the batch. Default: `image`.
1066+
t5_latent_key (str): The key to use for the T5 latents in the precomputed latents. Default: `'T5_LATENTS'`.
1067+
t5_mask_key (str): The key to use for the T5 attention mask in the precomputed latents. Default: `'T5_ATTENTION_MASK'`.
1068+
clip_latent_key (str): The key to use for the CLIP latents in the precomputed latents. Default: `'CLIP_LATENTS'`.
1069+
clip_mask_key (str): The key to use for the CLIP attention mask in the precomputed latents. Default: `'CLIP_ATTENTION_MASK'`.
1070+
clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`.
1071+
pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False.
1072+
"""
1073+
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)
1074+
1075+
precision = torch.float16
1076+
# Make the autoencoder
1077+
if autoencoder_path is None:
1078+
if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics':
1079+
raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.')
1080+
downsample_factor = 8
1081+
autoencoder_channels = 4
1082+
# Use the pretrained vae
1083+
try:
1084+
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision)
1085+
except: # for handling SDXL vae fp16 fixed checkpoint
1086+
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision)
1087+
else:
1088+
# Use a custom autoencoder
1089+
vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision)
1090+
if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'):
1091+
raise ValueError(
1092+
'Must specify latent scale when using a custom autoencoder without tracking latent statistics.')
1093+
if isinstance(latent_mean, str) and latent_mean == 'latent_statistics':
1094+
assert isinstance(latent_statistics, dict)
1095+
latent_mean = tuple(latent_statistics['latent_channel_means'])
1096+
if isinstance(latent_std, str) and latent_std == 'latent_statistics':
1097+
assert isinstance(latent_statistics, dict)
1098+
latent_std = tuple(latent_statistics['latent_channel_stds'])
1099+
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
1100+
autoencoder_channels = vae.config['latent_channels']
1101+
assert isinstance(vae, torch.nn.Module)
1102+
if isinstance(latent_mean, float):
1103+
latent_mean = (latent_mean,) * autoencoder_channels
1104+
if isinstance(latent_std, float):
1105+
latent_std = (latent_std,) * autoencoder_channels
1106+
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)
1107+
# Figure out the maximum input sequence length
1108+
input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size))
1109+
# Make the transformer model
1110+
transformer = DiffusionTransformer(num_features=64 * num_layers,
1111+
num_heads=num_layers,
1112+
num_layers=num_layers,
1113+
input_features=autoencoder_channels * (patch_size**2),
1114+
input_max_sequence_length=input_max_sequence_length,
1115+
input_dimension=2,
1116+
conditioning_features=64 * num_layers,
1117+
conditioning_max_sequence_length=conditioning_max_sequence_length,
1118+
conditioning_dimension=1,
1119+
expansion_factor=4,
1120+
num_register_tokens=num_register_tokens)
1121+
1122+
# Optionally load the tokenizers and text encoders
1123+
t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None
1124+
if include_text_encoders:
1125+
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
1126+
dtype = dtype_map[text_encoder_dtype]
1127+
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
1128+
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
1129+
subfolder='tokenizer',
1130+
cache_dir=cache_dir,
1131+
local_files_only=False)
1132+
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
1133+
torch_dtype=dtype,
1134+
cache_dir=cache_dir,
1135+
local_files_only=False).encoder.eval()
1136+
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
1137+
subfolder='text_encoder',
1138+
torch_dtype=dtype,
1139+
cache_dir=cache_dir,
1140+
local_files_only=False).cuda().eval()
1141+
1142+
# Make the composer model
1143+
model = ComposerPrecomputedTextLatentsToImageMMDiT(model=transformer,
1144+
autoencoder=vae,
1145+
t5_tokenizer=t5_tokenizer,
1146+
t5_encoder=t5_encoder,
1147+
clip_tokenizer=clip_tokenizer,
1148+
clip_encoder=clip_encoder,
1149+
latent_mean=latent_mean,
1150+
latent_std=latent_std,
1151+
patch_size=patch_size,
1152+
downsample_factor=downsample_factor,
1153+
latent_channels=autoencoder_channels,
1154+
timestep_mean=timestep_mean,
1155+
timestep_std=timestep_std,
1156+
timestep_shift=timestep_shift,
1157+
image_key=image_key,
1158+
t5_latent_key=t5_latent_key,
1159+
t5_mask_key=t5_mask_key,
1160+
clip_latent_key=clip_latent_key,
1161+
clip_mask_key=clip_mask_key,
1162+
clip_pooled_key=clip_pooled_key)
1163+
1164+
if torch.cuda.is_available():
1165+
model = DeviceGPU().module_to_device(model)
1166+
return model
1167+
1168+
10111169
def build_autoencoder(input_channels: int = 3,
10121170
output_channels: int = 3,
10131171
hidden_channels: int = 128,

0 commit comments

Comments
 (0)