|
21 | 21 | from diffusion.models.pixel_diffusion import PixelDiffusion
|
22 | 22 | from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
|
23 | 23 | from diffusion.models.stable_diffusion import StableDiffusion
|
24 |
| -from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT |
| 24 | +from diffusion.models.t2i_transformer import ComposerPrecomputedTextLatentsToImageMMDiT, ComposerTextToImageMMDiT |
25 | 25 | from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
|
26 | 26 | from diffusion.models.transformer import DiffusionTransformer
|
27 | 27 | from diffusion.schedulers.schedulers import ContinuousTimeScheduler
|
@@ -1008,6 +1008,164 @@ def text_to_image_transformer(
|
1008 | 1008 | return model
|
1009 | 1009 |
|
1010 | 1010 |
|
| 1011 | +def precomputed_text_latents_to_image_transformer( |
| 1012 | + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', |
| 1013 | + autoencoder_path: Optional[str] = None, |
| 1014 | + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', |
| 1015 | + include_text_encoders: bool = False, |
| 1016 | + text_encoder_dtype: str = 'bfloat16', |
| 1017 | + cache_dir: str = '/tmp/hf_files', |
| 1018 | + num_layers: int = 28, |
| 1019 | + max_image_side: int = 1280, |
| 1020 | + conditioning_features: int = 768, |
| 1021 | + conditioning_max_sequence_length: int = 512 + 77, |
| 1022 | + num_register_tokens: int = 0, |
| 1023 | + patch_size: int = 2, |
| 1024 | + latent_mean: Union[float, Tuple, str] = 0.0, |
| 1025 | + latent_std: Union[float, Tuple, str] = 7.67754318618, |
| 1026 | + timestep_mean: float = 0.0, |
| 1027 | + timestep_std: float = 1.0, |
| 1028 | + timestep_shift: float = 1.0, |
| 1029 | + image_key: str = 'image', |
| 1030 | + t5_latent_key: str = 'T5_LATENTS', |
| 1031 | + t5_mask_key: str = 'T5_ATTENTION_MASK', |
| 1032 | + clip_latent_key: str = 'CLIP_LATENTS', |
| 1033 | + clip_mask_key: str = 'CLIP_ATTENTION_MASK', |
| 1034 | + clip_pooled_key: str = 'CLIP_POOLED', |
| 1035 | + pretrained: bool = False, |
| 1036 | +): |
| 1037 | + """Text to image transformer training setup. |
| 1038 | +
|
| 1039 | + Args: |
| 1040 | + vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'. |
| 1041 | + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, |
| 1042 | + will use the vae from `model_name`. Default `None`. |
| 1043 | + include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running |
| 1044 | + inference. Default: `False`. |
| 1045 | + text_encoder_dtype (str): The dtype to use for the text encoder. One of [`float32`, `float16`, `bfloat16`]. |
| 1046 | + Default: `bfloat16`. |
| 1047 | + cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`. |
| 1048 | + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. |
| 1049 | + num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by |
| 1050 | + this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. |
| 1051 | + max_image_side (int): Maximum side length of the image. Default: `1280`. |
| 1052 | + conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. |
| 1053 | + conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. |
| 1054 | + num_register_tokens (int): Number of additional register tokens to use. Default: `0`. |
| 1055 | + patch_size (int): Patch size for the transformer. Default: `2`. |
| 1056 | + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, |
| 1057 | + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder |
| 1058 | + checkpoint. Defaults to `0.0`. |
| 1059 | + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, |
| 1060 | + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder |
| 1061 | + checkpoint. Defaults to `1/0.13025`. |
| 1062 | + timestep_mean (float): The mean of the timesteps. Default: `0.0`. |
| 1063 | + timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. |
| 1064 | + timestep_shift (float): The shift of the timesteps. Default: `1.0`. |
| 1065 | + image_key (str): The key for the image in the batch. Default: `image`. |
| 1066 | + t5_latent_key (str): The key to use for the T5 latents in the precomputed latents. Default: `'T5_LATENTS'`. |
| 1067 | + t5_mask_key (str): The key to use for the T5 attention mask in the precomputed latents. Default: `'T5_ATTENTION_MASK'`. |
| 1068 | + clip_latent_key (str): The key to use for the CLIP latents in the precomputed latents. Default: `'CLIP_LATENTS'`. |
| 1069 | + clip_mask_key (str): The key to use for the CLIP attention mask in the precomputed latents. Default: `'CLIP_ATTENTION_MASK'`. |
| 1070 | + clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`. |
| 1071 | + pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. |
| 1072 | + """ |
| 1073 | + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) |
| 1074 | + |
| 1075 | + precision = torch.float16 |
| 1076 | + # Make the autoencoder |
| 1077 | + if autoencoder_path is None: |
| 1078 | + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': |
| 1079 | + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') |
| 1080 | + downsample_factor = 8 |
| 1081 | + autoencoder_channels = 4 |
| 1082 | + # Use the pretrained vae |
| 1083 | + try: |
| 1084 | + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) |
| 1085 | + except: # for handling SDXL vae fp16 fixed checkpoint |
| 1086 | + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) |
| 1087 | + else: |
| 1088 | + # Use a custom autoencoder |
| 1089 | + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) |
| 1090 | + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): |
| 1091 | + raise ValueError( |
| 1092 | + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') |
| 1093 | + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': |
| 1094 | + assert isinstance(latent_statistics, dict) |
| 1095 | + latent_mean = tuple(latent_statistics['latent_channel_means']) |
| 1096 | + if isinstance(latent_std, str) and latent_std == 'latent_statistics': |
| 1097 | + assert isinstance(latent_statistics, dict) |
| 1098 | + latent_std = tuple(latent_statistics['latent_channel_stds']) |
| 1099 | + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) |
| 1100 | + autoencoder_channels = vae.config['latent_channels'] |
| 1101 | + assert isinstance(vae, torch.nn.Module) |
| 1102 | + if isinstance(latent_mean, float): |
| 1103 | + latent_mean = (latent_mean,) * autoencoder_channels |
| 1104 | + if isinstance(latent_std, float): |
| 1105 | + latent_std = (latent_std,) * autoencoder_channels |
| 1106 | + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) |
| 1107 | + # Figure out the maximum input sequence length |
| 1108 | + input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) |
| 1109 | + # Make the transformer model |
| 1110 | + transformer = DiffusionTransformer(num_features=64 * num_layers, |
| 1111 | + num_heads=num_layers, |
| 1112 | + num_layers=num_layers, |
| 1113 | + input_features=autoencoder_channels * (patch_size**2), |
| 1114 | + input_max_sequence_length=input_max_sequence_length, |
| 1115 | + input_dimension=2, |
| 1116 | + conditioning_features=64 * num_layers, |
| 1117 | + conditioning_max_sequence_length=conditioning_max_sequence_length, |
| 1118 | + conditioning_dimension=1, |
| 1119 | + expansion_factor=4, |
| 1120 | + num_register_tokens=num_register_tokens) |
| 1121 | + |
| 1122 | + # Optionally load the tokenizers and text encoders |
| 1123 | + t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None |
| 1124 | + if include_text_encoders: |
| 1125 | + dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} |
| 1126 | + dtype = dtype_map[text_encoder_dtype] |
| 1127 | + t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) |
| 1128 | + clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', |
| 1129 | + subfolder='tokenizer', |
| 1130 | + cache_dir=cache_dir, |
| 1131 | + local_files_only=False) |
| 1132 | + t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl', |
| 1133 | + torch_dtype=dtype, |
| 1134 | + cache_dir=cache_dir, |
| 1135 | + local_files_only=False).encoder.eval() |
| 1136 | + clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', |
| 1137 | + subfolder='text_encoder', |
| 1138 | + torch_dtype=dtype, |
| 1139 | + cache_dir=cache_dir, |
| 1140 | + local_files_only=False).cuda().eval() |
| 1141 | + |
| 1142 | + # Make the composer model |
| 1143 | + model = ComposerPrecomputedTextLatentsToImageMMDiT(model=transformer, |
| 1144 | + autoencoder=vae, |
| 1145 | + t5_tokenizer=t5_tokenizer, |
| 1146 | + t5_encoder=t5_encoder, |
| 1147 | + clip_tokenizer=clip_tokenizer, |
| 1148 | + clip_encoder=clip_encoder, |
| 1149 | + latent_mean=latent_mean, |
| 1150 | + latent_std=latent_std, |
| 1151 | + patch_size=patch_size, |
| 1152 | + downsample_factor=downsample_factor, |
| 1153 | + latent_channels=autoencoder_channels, |
| 1154 | + timestep_mean=timestep_mean, |
| 1155 | + timestep_std=timestep_std, |
| 1156 | + timestep_shift=timestep_shift, |
| 1157 | + image_key=image_key, |
| 1158 | + t5_latent_key=t5_latent_key, |
| 1159 | + t5_mask_key=t5_mask_key, |
| 1160 | + clip_latent_key=clip_latent_key, |
| 1161 | + clip_mask_key=clip_mask_key, |
| 1162 | + clip_pooled_key=clip_pooled_key) |
| 1163 | + |
| 1164 | + if torch.cuda.is_available(): |
| 1165 | + model = DeviceGPU().module_to_device(model) |
| 1166 | + return model |
| 1167 | + |
| 1168 | + |
1011 | 1169 | def build_autoencoder(input_channels: int = 3,
|
1012 | 1170 | output_channels: int = 3,
|
1013 | 1171 | hidden_channels: int = 128,
|
|
0 commit comments