Skip to content

Image Editting via Inversion #255

Open
@KhoiDOO

Description

@KhoiDOO

Hi, please correct me if I'm wrong. I tried using the inverse function in DPM-Solver to invert the source latent to the noisy latent. After obtaining the noisy latent, I use the sample function to get the edited image. However, when I use the inverse function, the noisy latent I got is all of nan value. I leave the code below, please have a check.

import argparse
import torch
import sys
import os
import hashlib
import json
addpath = os.path.join('/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-1]), 'submodule/Sana')
sys.path.append(addpath)

from torch import Tensor

from app.sana_pipeline import SanaPipeline, classify_height_width_bin, guidance_type_select
from diffusion.data.datasets.utils import (
    ASPECT_RATIO_512_TEST,
    ASPECT_RATIO_1024_TEST,
    ASPECT_RATIO_2048_TEST,
    ASPECT_RATIO_4096_TEST,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_encode, vae_decode
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

from diffusion.model import gaussian_diffusion as gd
from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper


class CustomDPM_Solver(DPM_Solver):
    def __init__(
        self,
        model_fn,
        noise_schedule,
        algorithm_type="dpmsolver++",
        correcting_x0_fn=None,
        correcting_xt_fn=None,
        thresholding_max_val=1.0,
        dynamic_thresholding_ratio=0.995,
    ):
        super().__init__(
            model_fn,
            noise_schedule,
            algorithm_type=algorithm_type,
            correcting_x0_fn=correcting_x0_fn,
            correcting_xt_fn=correcting_xt_fn,
            thresholding_max_val=thresholding_max_val,
            dynamic_thresholding_ratio=dynamic_thresholding_ratio,
        )
    
    def inverse(
        self,
        x,
        steps=20,
        t_start=None,
        t_end=None,
        order=2,
        skip_type="time_uniform",
        method="multistep",
        lower_order_final=True,
        denoise_to_zero=False,
        solver_type="dpmsolver",
        atol=0.0078,
        rtol=0.05,
        return_intermediate=False,
        flow_shift=1.0,
    ):
        """
        Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
        For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
        """
        t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
        t_T = self.noise_schedule.T if t_end is None else t_end
        assert (
            t_0 > 0 and t_T > 0
        ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
        return self.sample(
            x,
            steps=steps,
            t_start=t_0,
            t_end=t_T,
            order=order,
            skip_type=skip_type,
            method=method,
            lower_order_final=lower_order_final,
            denoise_to_zero=denoise_to_zero,
            solver_type=solver_type,
            atol=atol,
            rtol=rtol,
            return_intermediate=return_intermediate,
            flow_shift=flow_shift,
        )

def DPMS(
    model,
    condition,
    uncondition,
    cfg_scale,
    pag_scale=1.0,
    pag_applied_layers=None,
    model_type="noise",  # or "x_start" or "v" or "score", "flow"
    noise_schedule="linear",
    guidance_type="classifier-free",
    model_kwargs=None,
    diffusion_steps=1000,
    schedule="VP",
    interval_guidance=None,
):
    if pag_applied_layers is None:
        pag_applied_layers = []
    if model_kwargs is None:
        model_kwargs = {}
    if interval_guidance is None:
        interval_guidance = [0, 1.0]
    betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))

    ## 1. Define the noise schedule.
    if schedule == "VP":
        noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
    elif schedule == "FLOW":
        noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")

    ## 2. Convert your discrete-time `model` to the continuous-time
    ## noise prediction model. Here is an example for a diffusion model
    ## `model` with the noise prediction type ("noise") .
    model_fn = model_wrapper(
        model,
        noise_schedule,
        model_type=model_type,
        model_kwargs=model_kwargs,
        guidance_type=guidance_type,
        pag_scale=pag_scale,
        pag_applied_layers=pag_applied_layers,
        condition=condition,
        unconditional_condition=uncondition,
        guidance_scale=cfg_scale,
        interval_guidance=interval_guidance,
    )
    ## 3. Define dpm-solver and sample by multistep DPM-Solver.
    return CustomDPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")


class DPMInversePipeline(SanaPipeline):
    def __init__(self, config_path):
        super().__init__(config_path)
    
    @torch.inference_mode()
    def prepare_prompt(self, prompts):
        if not self.config.text_encoder.chi_prompt:
            max_length_all = self.config.text_encoder.model_max_length
            prompts_all = prompts
        else:
            chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
            prompts_all = [chi_prompt + prompt for prompt in prompts]
            num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
            max_length_all = (
                num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
            )  # magic number 2: [bos], [_]

        caption_token = self.tokenizer(
            prompts_all,
            max_length=max_length_all,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).to(device=self.device)
        select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
        caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
            :, :, select_index
        ].to(self.weight_dtype)
        emb_masks = caption_token.attention_mask[:, select_index]

        return caption_embs, emb_masks
    
    @torch.inference_mode()
    def prepare_scheduler(self, caption_embs, null_y, guidance_scale, pag_guidance_scale, hw, ar, emb_masks):
        model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
        if self.vis_sampler == "flow_euler":
            raise NotImplementedError("Flow Euler is not supported for editing.")
        elif self.vis_sampler == "flow_dpm-solver":
            scheduler = DPMS(
                self.model,
                condition=caption_embs,
                uncondition=null_y,
                guidance_type=self.guidance_type,
                cfg_scale=guidance_scale,
                pag_scale=pag_guidance_scale,
                pag_applied_layers=self.config.model.pag_applied_layers,
                model_type="flow",
                model_kwargs=model_kwargs,
                schedule="FLOW",
            )
            scheduler.register_progress_bar(self.progress_fn)
            return scheduler
        else:
            raise ValueError(f"Unsupported sampler: {self.vis_sampler}")
        
    @torch.inference_mode()
    def edit(
        self,
        src_prompt: list | str = None,
        tgt_prompt: list | str =None,
        src_img: list[Tensor] = None,
        height=1024,
        width=1024,
        negative_prompt="",
        num_inversion_steps=5,
        num_inference_steps=20,
        guidance_scale=4.5,
        pag_guidance_scale=1.0,
        generator=torch.Generator().manual_seed(42),
        use_resolution_binning=True,
    ):
        self.ori_height, self.ori_width = height, width
        if use_resolution_binning:
            self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
        else:
            self.height, self.width = height, width
        self.latent_size_h, self.latent_size_w = (
            self.height // self.config.vae.vae_downsample_rate,
            self.width // self.config.vae.vae_downsample_rate,
        )
        self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)

        # 1. pre-compute negative embedding
        if negative_prompt != "":
            null_caption_token = self.tokenizer(
                negative_prompt,
                max_length=self.max_sequence_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(self.device)
            self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]

        if src_prompt is None or tgt_prompt is None or src_img is None:
            raise ValueError("src_prompt, tgt_prompt and src_img must be provided.")
        src_prompts = src_prompt if isinstance(src_prompt, list) else [src_prompt]
        tgt_prompts = tgt_prompt if isinstance(tgt_prompt, list) else [tgt_prompt]
        src_imgs = src_img if isinstance(src_img, list) else [src_img]
        samples = []

        for sprompt, tprompt, imgs in zip(src_prompts, tgt_prompts, src_imgs):
            # data prepare
            num_images_per_prompt = imgs.size(0)
            sprompts, tprompts, hw, ar = (
                [], [],
                torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(num_images_per_prompt, 1),
                torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
            )

            for _ in range(num_images_per_prompt):
                sprompts.append(prepare_prompt_ar(sprompt, self.base_ratios, device=self.device, show=False)[0].strip())
                tprompts.append(prepare_prompt_ar(tprompt, self.base_ratios, device=self.device, show=False)[0].strip())

            with torch.no_grad():
                # prepare text feature
                src_caption_embs, scr_emb_masks = self.prepare_prompt(sprompts)
                tgt_caption_embs, tgt_emb_masks = self.prepare_prompt(tprompts)
                
                null_y = self.null_caption_embs.repeat(len(sprompts), 1, 1)[:, None].to(self.weight_dtype)

                # inversion step
                scheduler = self.prepare_scheduler(src_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=scr_emb_masks)
                latent = vae_encode(self.config.vae.vae_type, self.vae, imgs, False, self.device)
                noisy_latent = scheduler.inverse(
                    x = latent, 
                    steps=num_inversion_steps, 
                    order=2, 
                    skip_type="time_uniform_flow", 
                    method="multistep",
                    flow_shift=self.flow_shift,
                )
                print(noisy_latent.max(), noisy_latent.min(), noisy_latent.mean(), noisy_latent.shape)
                
                # sampling
                scheduler = self.prepare_scheduler(tgt_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=tgt_emb_masks)
                sample = scheduler.sample(
                    noisy_latent, 
                    steps=num_inference_steps, 
                    order=2, 
                    skip_type="time_uniform_flow", 
                    method="multistep", 
                    flow_shift=self.flow_shift
                )
                    
            sample = sample.to(self.vae_dtype)
            with torch.no_grad():
                sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

            if use_resolution_binning:
                sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
            samples.append(sample)

        return samples
    
    @torch.inference_mode()
    def forward(
        self,
        prompt=None,
        height=1024,
        width=1024,
        negative_prompt="",
        num_inference_steps=20,
        guidance_scale=4.5,
        pag_guidance_scale=1.0,
        num_images_per_prompt=1,
        generator=torch.Generator().manual_seed(42),
        latents=None,
        use_resolution_binning=True,
    ):
        self.ori_height, self.ori_width = height, width
        if use_resolution_binning:
            self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
        else:
            self.height, self.width = height, width
        self.latent_size_h, self.latent_size_w = (
            self.height // self.config.vae.vae_downsample_rate,
            self.width // self.config.vae.vae_downsample_rate,
        )
        self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)

        # 1. pre-compute negative embedding
        if negative_prompt != "":
            null_caption_token = self.tokenizer(
                negative_prompt,
                max_length=self.max_sequence_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(self.device)
            self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
                0
            ]

        if prompt is None:
            prompt = [""]
        prompts = prompt if isinstance(prompt, list) else [prompt]
        samples = []

        for prompt in prompts:
            # data prepare
            prompts, hw, ar = (
                [],
                torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
                    num_images_per_prompt, 1
                ),
                torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
            )

            for _ in range(num_images_per_prompt):
                prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())

            with torch.no_grad():
                # prepare text feature
                if not self.config.text_encoder.chi_prompt:
                    max_length_all = self.config.text_encoder.model_max_length
                    prompts_all = prompts
                else:
                    chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
                    prompts_all = [chi_prompt + prompt for prompt in prompts]
                    num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
                    max_length_all = (
                        num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
                    )  # magic number 2: [bos], [_]

                caption_token = self.tokenizer(
                    prompts_all,
                    max_length=max_length_all,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                ).to(device=self.device)
                select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
                caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
                    :, :, select_index
                ].to(self.weight_dtype)
                emb_masks = caption_token.attention_mask[:, select_index]
                null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)

                n = len(prompts)
                if latents is None:
                    z = torch.randn(
                        n,
                        self.config.vae.vae_latent_dim,
                        self.latent_size_h,
                        self.latent_size_w,
                        generator=generator,
                        device=self.device,
                    )
                else:
                    z = latents.to(self.device)
                scheduler = self.prepare_scheduler(caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=emb_masks)
                sample = scheduler.sample(
                    z,
                    steps=num_inference_steps,
                    order=2,
                    skip_type="time_uniform_flow",
                    method="multistep",
                    flow_shift=self.flow_shift,
                )   

            sample = sample.to(self.vae_dtype)
            with torch.no_grad():
                sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

            if use_resolution_binning:
                sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
            samples.append(sample)

        return samples
        

if __name__ == '__main__':

    from torchvision.utils import save_image

    parser = argparse.ArgumentParser(description="Generate images using DPMInversePipeline.")
    parser.add_argument("--src_prompt", type=str, default="a yellow cat, frontal view, eye-level elevation, no tilt.", 
                        help="Source text prompt for image generation.")
    parser.add_argument("--tgt_prompt", type=str, default="a yellow cat, side view, eye-level elevation, no tilt.", 
                        help="Target text prompt for image editing.")
    parser.add_argument("--negative_prompt", type=str, default="", help="Negative text prompt for image generation.")
    parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")
    parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")
    parser.add_argument("--guidance_scale", type=float, default=4.5, help="Guidance scale for the pipeline.")
    parser.add_argument("--pag_guidance_scale", type=float, default=1.0, help="PAG guidance scale for the pipeline.")
    parser.add_argument("--num_inference_steps", type=int, default=20, help="Number of inference steps.")
    parser.add_argument("--num_images_per_prompt", type=int, default=2, help="Number of images to generate per prompt.")
    parser.add_argument("--num_inversion_steps", type=int, default=5, help="Number of inversion steps for image editing.")
    parser.add_argument("--config_path", type=str, 
                        default="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml", 
                        help="Path to the model configuration file.")
    parser.add_argument("--from_pretrained", type=str, 
                        default="hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth", 
                        help="Path to the pretrained model weights.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")  # Added seed argument

    args = parser.parse_args()

    # Replace spaces with underscores in the source prompt
    sanitized_prompt = args.src_prompt.replace(" ", "_")

    # Generate a unique folder name based on settings as a JSON string
    settings = {
        "src_prompt": args.src_prompt,
        "tgt_prompt": args.tgt_prompt,
        "negative_prompt": args.negative_prompt,
        "config_path": args.config_path,
        "from_pretrained": args.from_pretrained,
        "height": args.height,
        "width": args.width,
        "guidance_scale": args.guidance_scale,
        "pag_guidance_scale": args.pag_guidance_scale,
        "num_inference_steps": args.num_inference_steps,
        "num_images_per_prompt": args.num_images_per_prompt,
        "num_inversion_steps": args.num_inversion_steps,
        "seed": args.seed  # Added seed to settings
    }
    settings_str = json.dumps(settings, sort_keys=True)

    # Encode settings_str as a hash code
    settings_hash = hashlib.md5(settings_str.encode()).hexdigest()

    # Create output directory using settings_str as the folder name
    output_dir = os.path.join("editinv", sanitized_prompt, settings_hash)
    os.makedirs(output_dir, exist_ok=True)

    # Output file paths
    generated_file = os.path.join(output_dir, "sample.png")
    edited_file = os.path.join(output_dir, "edited.png")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    generator = torch.Generator(device=device).manual_seed(args.seed)  # Use user-configured seed

    config_path = os.path.join(addpath, args.config_path)

    sana = DPMInversePipeline(config_path)
    sana.from_pretrained(args.from_pretrained)

    # Generate images
    images = sana(
        prompt=args.src_prompt,
        height=args.height,
        width=args.width,
        negative_prompt=args.negative_prompt,
        guidance_scale=args.guidance_scale,
        pag_guidance_scale=args.pag_guidance_scale,
        num_inference_steps=args.num_inference_steps,
        generator=generator,
        num_images_per_prompt=args.num_images_per_prompt
    )

    print(f"Generated image shape: {images[0].shape}")
    save_image(images[0], generated_file, nrow=1, normalize=True, value_range=(-1, 1))
    print(f"Image saved to {generated_file}")

    # Edit images
    edited_images = sana.edit(
        src_prompt=args.src_prompt,
        tgt_prompt=args.tgt_prompt,
        src_img=images,
        height=args.height,
        width=args.width,
        negative_prompt=args.negative_prompt,
        num_inversion_steps=args.num_inversion_steps,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        pag_guidance_scale=args.pag_guidance_scale,
        generator=generator,
    )

    print(f"Edited image shape: {edited_images[0].shape}")
    save_image(edited_images[0], edited_file, nrow=1, normalize=True, value_range=(-1, 1))
    print(f"Edited image saved to {edited_file}")

The changes are:

  • I customized the pipeline of sana to add a function for editing the image
  • I customized the inverse function of the DPM-Solver to add flow_shift as an argument. It's worth noting that when I use flow_shift=1 or flow_shift=3, the noisy latents I got are the same (nan latent).

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions