diff --git a/training/stable_diffusion/README.md b/training/stable_diffusion/README.md index bdadff29b..112478b70 100644 --- a/training/stable_diffusion/README.md +++ b/training/stable_diffusion/README.md @@ -32,7 +32,7 @@ Make sure to customize the training parameters in the script to suit your specif ## Inference -For inference, you can use the `inf-loop.py` Python code. Follow these steps: +For inference, you can use the `inf_txt2img_loop.py` Python code. Follow these steps: 1. Provide your desired prompts as input in the script. 2. Run the `inf_txt2img_loop.py` script. @@ -40,5 +40,5 @@ For inference, you can use the `inf-loop.py` Python code. Follow these steps: Here's an example command to run the inference script:
-deepspeed inf_txt2img_loop.py
+deepspeed inf_txt2img_loop.py --out_dir out_images/
 
\ No newline at end of file diff --git a/training/stable_diffusion/inf_txt2img_loop.py b/training/stable_diffusion/inf_txt2img_loop.py index 20482bff4..1d0203d5f 100644 --- a/training/stable_diffusion/inf_txt2img_loop.py +++ b/training/stable_diffusion/inf_txt2img_loop.py @@ -2,13 +2,15 @@ import torch import os from local_pipeline_stable_diffusion import StableDiffusionPipeline +from diffusers import DPMSolverMultistepScheduler from diffusers import StableDiffusionPipeline as StableDiffusionPipelineBaseline import argparse + seed = 123450011 parser = argparse.ArgumentParser() -parser.add_argument("--ft_model", default="new_sd-distill-v21-10k-1e", type=str, help="Path to the fine-tuned model") -parser.add_argument("--b_model", default="stabilityai/stable-diffusion-2-1-base", type=str, help="Path to the baseline model") +parser.add_argument("--finetuned_model", default="./sd-distill-lora-multi-50k-50", type=str, help="Path to the fine-tuned model") +parser.add_argument("--base_model", default="stabilityai/stable-diffusion-2-1-base", type=str, help="Path to the baseline model") parser.add_argument("--out_dir", default="image_out/", type=str, help="Path to the generated images") parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance Scale') parser.add_argument("--use_local_pipe", action='store_true', help="Use local SD pipeline") @@ -40,17 +42,27 @@ "A person holding a cat"] +# Load the pipelines +pipe_new = StableDiffusionPipeline.from_pretrained(args.base_model, torch_dtype=torch.float16).to("cuda") +pipe_baseline = StableDiffusionPipelineBaseline.from_pretrained(args.base_model, torch_dtype=torch.float16).to("cuda") + +pipe_new.scheduler = DPMSolverMultistepScheduler.from_config(pipe_new.scheduler.config) +pipe_baseline.scheduler = DPMSolverMultistepScheduler.from_config(pipe_baseline.scheduler.config) + +# Load the Lora weights +pipe_new.unet.load_attn_procs(args.finetuned_model) + +pipe_new = deepspeed.init_inference(pipe_new, mp_size=world_size, dtype=torch.half) +pipe_baseline = deepspeed.init_inference(pipe_baseline, mp_size=world_size, dtype=torch.half) + +# Generate the images for prompt in prompts: - #--- new image - pipe_new = StableDiffusionPipeline.from_pretrained(args.ft_model, torch_dtype=torch.float16).to("cuda") - generator = torch.Generator("cuda").manual_seed(seed) - pipe_new = deepspeed.init_inference(pipe_new, mp_size=world_size, dtype=torch.half) - image_new = pipe_new(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0] - image_new.save(args.out_dir+"/NEW__seed_"+str(seed)+"_"+prompt[0:100]+".png") - - #--- baseline image - pipe_baseline = StableDiffusionPipelineBaseline.from_pretrained(args.b_model, torch_dtype=torch.float16).to("cuda") - generator = torch.Generator("cuda").manual_seed(seed) - pipe_baseline = deepspeed.init_inference(pipe_baseline, mp_size=world_size, dtype=torch.half) - image_baseline = pipe_baseline(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0] - image_baseline.save(args.out_dir+"/BASELINE_seed_"+str(seed)+"_"+prompt[0:100]+".png") + #--- baseline image + generator = torch.Generator("cuda").manual_seed(seed) + image_baseline = pipe_baseline(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + image_baseline.save(args.out_dir+"BASELINE_seed_"+str(seed)+"_"+prompt[0:100]+".png") + + #--- new image + generator = torch.Generator("cuda").manual_seed(seed) + image_new = pipe_new(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + image_new.save(args.out_dir+"NEW_seed_"+str(seed)+"_"+prompt[0:100]+".png") diff --git a/training/stable_diffusion/local_pipeline_stable_diffusion.py b/training/stable_diffusion/local_pipeline_stable_diffusion.py index 64abf7d7f..ebffd521b 100644 --- a/training/stable_diffusion/local_pipeline_stable_diffusion.py +++ b/training/stable_diffusion/local_pipeline_stable_diffusion.py @@ -29,11 +29,11 @@ is_accelerate_available, is_accelerate_version, logging, - randn_tensor, replace_example_docstring, ) -from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker diff --git a/training/stable_diffusion/mytrainbash.sh b/training/stable_diffusion/mytrainbash.sh index fcd2f8508..0dce85887 100644 --- a/training/stable_diffusion/mytrainbash.sh +++ b/training/stable_diffusion/mytrainbash.sh @@ -1,5 +1,5 @@ export MODEL_NAME="stabilityai/stable-diffusion-2-1-base" -export OUTPUT_DIR="./sd-distill-v21" +export OUTPUT_DIR="./sd-distill-lora-multi-50k-50" if [ ! -d "$OUTPUT_DIR" ]; then mkdir "$OUTPUT_DIR" @@ -7,15 +7,15 @@ if [ ! -d "$OUTPUT_DIR" ]; then else echo "Folder '$OUTPUT_DIR' already exists" fi - -accelerate launch train_sd_distil_lora.py \ + +accelerate launch train_sd_distill_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --output_dir=$OUTPUT_DIR \ - --default_prompt="A man dancing" \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=1 \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ - --lr_warmup_steps=0 + --lr_warmup_steps=0 \ + --default_prompt="A man dancing" diff --git a/training/stable_diffusion/requirements.txt b/training/stable_diffusion/requirements.txt index 7a612982f..1223b5130 100644 --- a/training/stable_diffusion/requirements.txt +++ b/training/stable_diffusion/requirements.txt @@ -1,4 +1,5 @@ accelerate>=0.16.0 +diffusers==0.23.1 torchvision transformers>=4.25.1 ftfy diff --git a/training/stable_diffusion/train_sd_distil_lora.py b/training/stable_diffusion/train_sd_distill_lora.py similarity index 68% rename from training/stable_diffusion/train_sd_distil_lora.py rename to training/stable_diffusion/train_sd_distill_lora.py index 012cb0e0f..7544b76dd 100644 --- a/training/stable_diffusion/train_sd_distil_lora.py +++ b/training/stable_diffusion/train_sd_distill_lora.py @@ -20,9 +20,10 @@ import logging import math import os +import shutil import warnings from pathlib import Path - +from typing import Dict import numpy as np import torch import torch.nn.functional as F @@ -31,13 +32,15 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, model_info, upload_folder +from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image +from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from torch.autograd import profiler import diffusers from diffusers import ( @@ -45,23 +48,41 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.loaders import ( + LoraLoaderMixin, + text_encoder_lora_state_dict, +) +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -if is_wandb_available(): - import wandb - # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.17.0.dev0") +check_min_version("0.18.0.dev0") logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -73,103 +94,26 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode base_model: {base_model} instance_prompt: {prompt} tags: -- stable-diffusion -- stable-diffusion-diffusers +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} - text-to-image - diffusers -- dreambooth +- lora inference: true --- """ model_card = f""" -# DreamBooth - {repo_id} +# LoRA DreamBooth - {repo_id} -This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). -You can find some example images in the following. \n +These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n {img_str} -DreamBooth for the text encoder was enabled: {train_text_encoder}. +LoRA for the text encoder was enabled: {train_text_encoder}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) -def log_validation( - text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds -): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - - pipeline_args = {} - - if text_encoder is not None: - pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) - - if vae is not None: - pipeline_args["vae"] = vae - - # create pipeline (note: unet and vae are loaded again in float32) - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - tokenizer=tokenizer, - unet=accelerator.unwrap_model(unet), - revision=args.revision, - torch_dtype=weight_dtype, - **pipeline_args, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - if args.pre_compute_text_embeddings: - pipeline_args = { - "prompt_embeds": prompt_embeds, - "negative_prompt_embeds": negative_prompt_embeds, - } - else: - pipeline_args = {"prompt": args.validation_prompt} - - # run inference - generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() - - return images - - def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -208,10 +152,7 @@ def parse_args(input_args=None): type=str, default=None, required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), + help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", @@ -239,6 +180,27 @@ def parse_args(input_args=None): default=None, help="The prompt to specify images in the same class as provided instance images.", ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) parser.add_argument( "--with_prior_preservation", default=False, @@ -258,7 +220,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="text-inversion-model", + default="lora-dreambooth-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -291,7 +253,7 @@ def parse_args(input_args=None): parser.add_argument( "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) - parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--num_train_epochs", type=int, default=50) parser.add_argument( "--max_train_steps", type=int, @@ -303,22 +265,16 @@ def parse_args(input_args=None): type=int, default=500, help=( - "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " - "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." - "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." - "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" - "instructions." + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, - help=( - "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." - " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" - " for more details" - ), + help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", @@ -343,7 +299,7 @@ def parse_args(input_args=None): parser.add_argument( "--learning_rate", type=float, - default=5e-6, + default=5e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -371,9 +327,6 @@ def parse_args(input_args=None): help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -382,6 +335,9 @@ def parse_args(input_args=None): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -421,28 +377,6 @@ def parse_args(input_args=None): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="A prompt that is used during validation to verify that the model is learning.", - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=4, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=100, - help=( - "Run validation every X steps. Validation consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`" - " and logging the images." - ), - ) parser.add_argument( "--mixed_precision", type=str, @@ -468,25 +402,6 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument( - "--set_grads_to_none", - action="store_true", - help=( - "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" - " behaviors, so disable this argument if it causes any problems. More info:" - " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" - ), - ) - - parser.add_argument( - "--offset_noise", - action="store_true", - default=False, - help=( - "Fine-tuning against a modified noise" - " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." - ), - ) parser.add_argument( "--pre_compute_text_embeddings", action="store_true", @@ -506,7 +421,23 @@ def parse_args(input_args=None): help="Whether to use attention mask for the text encoder", ) parser.add_argument( - "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), ) if input_args is not None: @@ -545,7 +476,7 @@ class DreamBoothDataset(Dataset): def __init__( self, instance_prompts, - instance_images, + instance_images, tokenizer, class_data_root=None, class_prompt=None, @@ -563,9 +494,10 @@ def __init__( self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states self.tokenizer_max_length = tokenizer_max_length self.num_instance_images = len(instance_prompts) - self.instance_images = instance_images self.instance_prompts = instance_prompts + self.instance_images = instance_images self._length = self.num_instance_images + self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), @@ -580,9 +512,11 @@ def __len__(self): def __getitem__(self, index): example = {} + #instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) instance_image = self.instance_images[index % self.num_instance_images] + instance_image = exif_transpose(instance_image) uncond_tokens = [""] * args.train_batch_size - + if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) @@ -593,11 +527,11 @@ def __getitem__(self, index): example["instance_prompt_ids"] = text_inputs.input_ids example["instance_attention_mask"] = text_inputs.attention_mask - # Compute the unconditional prompt - uncond_inputs = tokenize_prompt( + # Compute the unconditional prompt + uncond_inputs = tokenize_prompt( self.tokenizer, uncond_tokens, tokenizer_max_length=self.tokenizer_max_length - ) - example["uncond_prompt_ids"] = uncond_inputs.input_ids + ) + example["uncond_prompt_ids"] = uncond_inputs.input_ids example["uncond_attention_mask"] = uncond_inputs.attention_mask return example @@ -606,9 +540,9 @@ def collate_fn(examples, with_prior_preservation=False): has_attention_mask = "instance_attention_mask" in examples[0] input_ids = [example["instance_prompt_ids"] for example in examples] - uncond_ids = [example["uncond_prompt_ids"] for example in examples] + uncond_ids = [example["uncond_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - + if has_attention_mask: attention_mask = [example["instance_attention_mask"] for example in examples] uncond_attention_mask = [example["uncond_attention_mask"] for example in examples] @@ -629,7 +563,7 @@ def collate_fn(examples, with_prior_preservation=False): input_ids = torch.cat(input_ids, dim=0) uncond_ids = torch.cat(uncond_ids, dim=0) - + batch = { "input_ids": input_ids, "uncond_ids": uncond_ids, @@ -660,16 +594,6 @@ def __getitem__(self, index): return example -def model_has_vae(args): - config_file_name = os.path.join("vae", AutoencoderKL.config_name) - if os.path.isdir(args.pretrained_model_name_or_path): - config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) - return os.path.isfile(config_file_name) - else: - files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings - return any(file.rfilename == config_file_name for file in files_in_repo) - - def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): if tokenizer_max_length is not None: max_length = tokenizer_max_length @@ -704,25 +628,42 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte return prompt_embeds +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + r""" + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter + + return attn_processors_state_dict + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, - project_dir=logging_dir, project_config=accelerator_project_config, ) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " @@ -792,7 +733,7 @@ def main(args): del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() - + # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: @@ -822,59 +763,38 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - - if model_has_vae(args): + try: vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision ) - else: + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here vae = None - unet = UNet2DConditionModel.from_pretrained( + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision - ) - teacher_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - #Turn off gradients for the teacher - for param in teacher_unet.parameters(): - param.requires_grad = False - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for model in models: - sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" - model.save_pretrained(os.path.join(output_dir, sub_dir)) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - while len(models) > 0: - # pop models so that they are not loaded again - model = models.pop() - - if isinstance(model, type(accelerator.unwrap_model(text_encoder))): - # load transformers style into model - load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") - model.config = load_model.config - else: - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - + # We only train the additional adapter LoRA layers if vae is not None: vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - if not args.train_text_encoder: - text_encoder.requires_grad_(False) + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -889,28 +809,100 @@ def load_model_hook(models, input_dir): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + unet_lora_attn_procs = {} + unet_lora_parameters = [] + + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) - # Check that all trainable models are in full precision - low_precision_error_string = ( - "Please make sure to always have all model weights in full float32 precision when starting training - even if" - " doing mixed precision training. copy of the weights should still be float32." - ) + module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) - if accelerator.unwrap_model(unet).dtype != torch.float32: - raise ValueError( - f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + unet.set_attn_processor(unet_lora_attn_procs) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_attn_processors_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) - if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: - raise ValueError( - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." - f" {low_precision_error_string}" + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ ) + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -936,7 +928,9 @@ def load_model_hook(models, input_dir): # Optimizer creation params_to_optimize = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + itertools.chain(unet_lora_parameters, text_lora_parameters) + if args.train_text_encoder + else unet_lora_parameters ) optimizer = optimizer_class( params_to_optimize, @@ -983,23 +977,23 @@ def compute_text_embeddings(prompt): validation_prompt_encoder_hidden_states = None validation_prompt_negative_prompt_embeds = None pre_computed_instance_prompt_encoder_hidden_states = None - - from datasets import load_dataset - dataset_hf = load_dataset('poloclub/diffusiondb', '2m_first_10k') - raw_train_dataset = dataset_hf['train'] - - #Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_prompts=raw_train_dataset['prompt'], - instance_images=raw_train_dataset['image'], - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - encoder_hidden_states=pre_computed_encoder_hidden_states, - instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, - tokenizer_max_length=args.tokenizer_max_length, - ) + from datasets import load_dataset + dataset_hf = load_dataset('poloclub/diffusiondb', '2m_first_50k') + raw_train_dataset = dataset_hf['train'] + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_prompts=raw_train_dataset['prompt'], + instance_images=raw_train_dataset['image'], + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, @@ -1031,26 +1025,9 @@ def compute_text_embeddings(prompt): ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler ) - - teacher_unet.to(accelerator.device) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move vae and text_encoder to device and cast to weight_dtype - if vae is not None: - vae.to(accelerator.device, dtype=weight_dtype) - - if not args.train_text_encoder and text_encoder is not None: - text_encoder.to(accelerator.device, dtype=weight_dtype) - + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -1061,7 +1038,7 @@ def compute_text_embeddings(prompt): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("dreambooth", config=vars(args)) + accelerator.init_trackers("dreambooth-lora", config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1105,17 +1082,16 @@ def compute_text_embeddings(prompt): # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") - for epoch in range(first_epoch, args.num_train_epochs): - + print("epoch", epoch) unet.train() - print("epoch:", epoch) if args.train_text_encoder: text_encoder.train() - - # For each prompt* for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step + #print ("step", step) + #if step == 2: + # break if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % args.gradient_accumulation_steps == 0: progress_bar.update(1) @@ -1126,19 +1102,14 @@ def compute_text_embeddings(prompt): if vae is not None: # Convert images to latent space - model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor else: model_input = pixel_values - # Sample noise that we'll add to the model input - if args.offset_noise: - noise = torch.randn_like(model_input) + 0.1 * torch.randn( - model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device - ) - else: - noise = torch.randn_like(model_input) - bsz = model_input.shape[0] + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device @@ -1158,23 +1129,37 @@ def compute_text_embeddings(prompt): batch["input_ids"], batch["attention_mask"], text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) + ) encoder_hidden_states_uncond = encode_prompt( - text_encoder, - batch["uncond_ids"], - batch["uncond_attention_mask"], - text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, - ) + text_encoder, + batch["uncond_ids"], + batch["uncond_attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None # Predict the student noise residual - model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample #student_noise_pred - # The teacher noise residual is based on the inference pipeline: uncond_noise +gc * (cond_noise - uncond_noise) + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, cross_attention_kwargs={"scale":1}, class_labels=class_labels + ).sample + gc = 7.5 - teacher_cond_noise = teacher_unet(noisy_model_input, timesteps, encoder_hidden_states).sample - teacher_uncond_noise = teacher_unet(noisy_model_input, timesteps, encoder_hidden_states_uncond).sample - teacher_noise_pred = teacher_uncond_noise + gc * (teacher_cond_noise - teacher_uncond_noise) + with torch.no_grad(): + teacher_cond_noise = unet(noisy_model_input, timesteps, encoder_hidden_states, cross_attention_kwargs={"scale":0}, class_labels=class_labels).sample + teacher_uncond_noise = unet(noisy_model_input, timesteps, encoder_hidden_states_uncond, cross_attention_kwargs={"scale":0}, class_labels=class_labels).sample + teacher_noise_pred = teacher_uncond_noise + gc * (teacher_cond_noise - teacher_uncond_noise) + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. if model_pred.shape[1] == 6: model_pred, _ = torch.chunk(model_pred, 2, dim=1) @@ -1185,19 +1170,22 @@ def compute_text_embeddings(prompt): target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = F.mse_loss(model_pred.float(), teacher_noise_pred.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), teacher_noise_pred.float(), reduction="mean") + accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) + itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder - else unet.parameters() + else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=args.set_grads_to_none) + optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1205,26 +1193,31 @@ def compute_text_embeddings(prompt): global_step += 1 if accelerator.is_main_process: - images = [] if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - images = log_validation( - text_encoder, - tokenizer, - unet, - vae, - args, - accelerator, - weight_dtype, - epoch, - validation_prompt_encoder_hidden_states, - validation_prompt_negative_prompt_embeds, - ) - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) @@ -1232,22 +1225,104 @@ def compute_text_embeddings(prompt): if global_step >= args.max_train_steps: break - # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + if args.validation_images is None: + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - pipeline_args = {} - - if text_encoder is not None: - pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = unet_attn_processors_state_dict(unet) + + if text_encoder is not None and args.train_text_encoder: + text_encoder = accelerator.unwrap_model(text_encoder) + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) + else: + text_encoder_lora_layers = None - if args.skip_save_text_encoder: - pipeline_args["text_encoder"] = None + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + # Final inference + # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - revision=args.revision, - **pipeline_args, + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it @@ -1261,9 +1336,35 @@ def compute_text_embeddings(prompt): scheduler_args["variance_type"] = variance_type - pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) - pipeline.save_pretrained(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) if args.push_to_hub: save_model_card( @@ -1273,6 +1374,7 @@ def compute_text_embeddings(prompt): train_text_encoder=args.train_text_encoder, prompt=args.default_prompt, repo_folder=args.output_dir, + pipeline=pipeline, ) upload_folder( repo_id=repo_id, @@ -1283,6 +1385,7 @@ def compute_text_embeddings(prompt): accelerator.end_training() + if __name__ == "__main__": args = parse_args() main(args)