Open
Description
Hi, please correct me if I'm wrong. I tried using the inverse function in DPM-Solver to invert the source latent to the noisy latent. After obtaining the noisy latent, I use the sample function to get the edited image. However, when I use the inverse function, the noisy latent I got is all of nan value. I leave the code below, please have a check.
import argparse
import torch
import sys
import os
import hashlib
import json
addpath = os.path.join('/'.join(os.path.dirname(os.path.abspath(__file__)).split('/')[:-1]), 'submodule/Sana')
sys.path.append(addpath)
from torch import Tensor
from app.sana_pipeline import SanaPipeline, classify_height_width_bin, guidance_type_select
from diffusion.data.datasets.utils import (
ASPECT_RATIO_512_TEST,
ASPECT_RATIO_1024_TEST,
ASPECT_RATIO_2048_TEST,
ASPECT_RATIO_4096_TEST,
)
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_encode, vae_decode
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger
from diffusion.model import gaussian_diffusion as gd
from diffusion.model.dpm_solver import DPM_Solver, NoiseScheduleFlow, NoiseScheduleVP, model_wrapper
class CustomDPM_Solver(DPM_Solver):
def __init__(
self,
model_fn,
noise_schedule,
algorithm_type="dpmsolver++",
correcting_x0_fn=None,
correcting_xt_fn=None,
thresholding_max_val=1.0,
dynamic_thresholding_ratio=0.995,
):
super().__init__(
model_fn,
noise_schedule,
algorithm_type=algorithm_type,
correcting_x0_fn=correcting_x0_fn,
correcting_xt_fn=correcting_xt_fn,
thresholding_max_val=thresholding_max_val,
dynamic_thresholding_ratio=dynamic_thresholding_ratio,
)
def inverse(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
flow_shift=1.0,
):
"""
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
return self.sample(
x,
steps=steps,
t_start=t_0,
t_end=t_T,
order=order,
skip_type=skip_type,
method=method,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero,
solver_type=solver_type,
atol=atol,
rtol=rtol,
return_intermediate=return_intermediate,
flow_shift=flow_shift,
)
def DPMS(
model,
condition,
uncondition,
cfg_scale,
pag_scale=1.0,
pag_applied_layers=None,
model_type="noise", # or "x_start" or "v" or "score", "flow"
noise_schedule="linear",
guidance_type="classifier-free",
model_kwargs=None,
diffusion_steps=1000,
schedule="VP",
interval_guidance=None,
):
if pag_applied_layers is None:
pag_applied_layers = []
if model_kwargs is None:
model_kwargs = {}
if interval_guidance is None:
interval_guidance = [0, 1.0]
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
## 1. Define the noise schedule.
if schedule == "VP":
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
elif schedule == "FLOW":
noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type,
model_kwargs=model_kwargs,
guidance_type=guidance_type,
pag_scale=pag_scale,
pag_applied_layers=pag_applied_layers,
condition=condition,
unconditional_condition=uncondition,
guidance_scale=cfg_scale,
interval_guidance=interval_guidance,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
return CustomDPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
class DPMInversePipeline(SanaPipeline):
def __init__(self, config_path):
super().__init__(config_path)
@torch.inference_mode()
def prepare_prompt(self, prompts):
if not self.config.text_encoder.chi_prompt:
max_length_all = self.config.text_encoder.model_max_length
prompts_all = prompts
else:
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
prompts_all = [chi_prompt + prompt for prompt in prompts]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = (
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
) # magic number 2: [bos], [_]
caption_token = self.tokenizer(
prompts_all,
max_length=max_length_all,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(device=self.device)
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
:, :, select_index
].to(self.weight_dtype)
emb_masks = caption_token.attention_mask[:, select_index]
return caption_embs, emb_masks
@torch.inference_mode()
def prepare_scheduler(self, caption_embs, null_y, guidance_scale, pag_guidance_scale, hw, ar, emb_masks):
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
if self.vis_sampler == "flow_euler":
raise NotImplementedError("Flow Euler is not supported for editing.")
elif self.vis_sampler == "flow_dpm-solver":
scheduler = DPMS(
self.model,
condition=caption_embs,
uncondition=null_y,
guidance_type=self.guidance_type,
cfg_scale=guidance_scale,
pag_scale=pag_guidance_scale,
pag_applied_layers=self.config.model.pag_applied_layers,
model_type="flow",
model_kwargs=model_kwargs,
schedule="FLOW",
)
scheduler.register_progress_bar(self.progress_fn)
return scheduler
else:
raise ValueError(f"Unsupported sampler: {self.vis_sampler}")
@torch.inference_mode()
def edit(
self,
src_prompt: list | str = None,
tgt_prompt: list | str =None,
src_img: list[Tensor] = None,
height=1024,
width=1024,
negative_prompt="",
num_inversion_steps=5,
num_inference_steps=20,
guidance_scale=4.5,
pag_guidance_scale=1.0,
generator=torch.Generator().manual_seed(42),
use_resolution_binning=True,
):
self.ori_height, self.ori_width = height, width
if use_resolution_binning:
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
else:
self.height, self.width = height, width
self.latent_size_h, self.latent_size_w = (
self.height // self.config.vae.vae_downsample_rate,
self.width // self.config.vae.vae_downsample_rate,
)
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
# 1. pre-compute negative embedding
if negative_prompt != "":
null_caption_token = self.tokenizer(
negative_prompt,
max_length=self.max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]
if src_prompt is None or tgt_prompt is None or src_img is None:
raise ValueError("src_prompt, tgt_prompt and src_img must be provided.")
src_prompts = src_prompt if isinstance(src_prompt, list) else [src_prompt]
tgt_prompts = tgt_prompt if isinstance(tgt_prompt, list) else [tgt_prompt]
src_imgs = src_img if isinstance(src_img, list) else [src_img]
samples = []
for sprompt, tprompt, imgs in zip(src_prompts, tgt_prompts, src_imgs):
# data prepare
num_images_per_prompt = imgs.size(0)
sprompts, tprompts, hw, ar = (
[], [],
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(num_images_per_prompt, 1),
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
)
for _ in range(num_images_per_prompt):
sprompts.append(prepare_prompt_ar(sprompt, self.base_ratios, device=self.device, show=False)[0].strip())
tprompts.append(prepare_prompt_ar(tprompt, self.base_ratios, device=self.device, show=False)[0].strip())
with torch.no_grad():
# prepare text feature
src_caption_embs, scr_emb_masks = self.prepare_prompt(sprompts)
tgt_caption_embs, tgt_emb_masks = self.prepare_prompt(tprompts)
null_y = self.null_caption_embs.repeat(len(sprompts), 1, 1)[:, None].to(self.weight_dtype)
# inversion step
scheduler = self.prepare_scheduler(src_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=scr_emb_masks)
latent = vae_encode(self.config.vae.vae_type, self.vae, imgs, False, self.device)
noisy_latent = scheduler.inverse(
x = latent,
steps=num_inversion_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift,
)
print(noisy_latent.max(), noisy_latent.min(), noisy_latent.mean(), noisy_latent.shape)
# sampling
scheduler = self.prepare_scheduler(tgt_caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=tgt_emb_masks)
sample = scheduler.sample(
noisy_latent,
steps=num_inference_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift
)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
if use_resolution_binning:
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
samples.append(sample)
return samples
@torch.inference_mode()
def forward(
self,
prompt=None,
height=1024,
width=1024,
negative_prompt="",
num_inference_steps=20,
guidance_scale=4.5,
pag_guidance_scale=1.0,
num_images_per_prompt=1,
generator=torch.Generator().manual_seed(42),
latents=None,
use_resolution_binning=True,
):
self.ori_height, self.ori_width = height, width
if use_resolution_binning:
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
else:
self.height, self.width = height, width
self.latent_size_h, self.latent_size_w = (
self.height // self.config.vae.vae_downsample_rate,
self.width // self.config.vae.vae_downsample_rate,
)
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
# 1. pre-compute negative embedding
if negative_prompt != "":
null_caption_token = self.tokenizer(
negative_prompt,
max_length=self.max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
0
]
if prompt is None:
prompt = [""]
prompts = prompt if isinstance(prompt, list) else [prompt]
samples = []
for prompt in prompts:
# data prepare
prompts, hw, ar = (
[],
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
num_images_per_prompt, 1
),
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
)
for _ in range(num_images_per_prompt):
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
with torch.no_grad():
# prepare text feature
if not self.config.text_encoder.chi_prompt:
max_length_all = self.config.text_encoder.model_max_length
prompts_all = prompts
else:
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
prompts_all = [chi_prompt + prompt for prompt in prompts]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = (
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
) # magic number 2: [bos], [_]
caption_token = self.tokenizer(
prompts_all,
max_length=max_length_all,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(device=self.device)
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
:, :, select_index
].to(self.weight_dtype)
emb_masks = caption_token.attention_mask[:, select_index]
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
n = len(prompts)
if latents is None:
z = torch.randn(
n,
self.config.vae.vae_latent_dim,
self.latent_size_h,
self.latent_size_w,
generator=generator,
device=self.device,
)
else:
z = latents.to(self.device)
scheduler = self.prepare_scheduler(caption_embs, null_y, guidance_scale, pag_guidance_scale, hw=hw, ar=ar, emb_masks=emb_masks)
sample = scheduler.sample(
z,
steps=num_inference_steps,
order=2,
skip_type="time_uniform_flow",
method="multistep",
flow_shift=self.flow_shift,
)
sample = sample.to(self.vae_dtype)
with torch.no_grad():
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
if use_resolution_binning:
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
samples.append(sample)
return samples
if __name__ == '__main__':
from torchvision.utils import save_image
parser = argparse.ArgumentParser(description="Generate images using DPMInversePipeline.")
parser.add_argument("--src_prompt", type=str, default="a yellow cat, frontal view, eye-level elevation, no tilt.",
help="Source text prompt for image generation.")
parser.add_argument("--tgt_prompt", type=str, default="a yellow cat, side view, eye-level elevation, no tilt.",
help="Target text prompt for image editing.")
parser.add_argument("--negative_prompt", type=str, default="", help="Negative text prompt for image generation.")
parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")
parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")
parser.add_argument("--guidance_scale", type=float, default=4.5, help="Guidance scale for the pipeline.")
parser.add_argument("--pag_guidance_scale", type=float, default=1.0, help="PAG guidance scale for the pipeline.")
parser.add_argument("--num_inference_steps", type=int, default=20, help="Number of inference steps.")
parser.add_argument("--num_images_per_prompt", type=int, default=2, help="Number of images to generate per prompt.")
parser.add_argument("--num_inversion_steps", type=int, default=5, help="Number of inversion steps for image editing.")
parser.add_argument("--config_path", type=str,
default="configs/sana1-5_config/1024ms/Sana_1600M_1024px_allqknorm_bf16_lr2e5.yaml",
help="Path to the model configuration file.")
parser.add_argument("--from_pretrained", type=str,
default="hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
help="Path to the pretrained model weights.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") # Added seed argument
args = parser.parse_args()
# Replace spaces with underscores in the source prompt
sanitized_prompt = args.src_prompt.replace(" ", "_")
# Generate a unique folder name based on settings as a JSON string
settings = {
"src_prompt": args.src_prompt,
"tgt_prompt": args.tgt_prompt,
"negative_prompt": args.negative_prompt,
"config_path": args.config_path,
"from_pretrained": args.from_pretrained,
"height": args.height,
"width": args.width,
"guidance_scale": args.guidance_scale,
"pag_guidance_scale": args.pag_guidance_scale,
"num_inference_steps": args.num_inference_steps,
"num_images_per_prompt": args.num_images_per_prompt,
"num_inversion_steps": args.num_inversion_steps,
"seed": args.seed # Added seed to settings
}
settings_str = json.dumps(settings, sort_keys=True)
# Encode settings_str as a hash code
settings_hash = hashlib.md5(settings_str.encode()).hexdigest()
# Create output directory using settings_str as the folder name
output_dir = os.path.join("editinv", sanitized_prompt, settings_hash)
os.makedirs(output_dir, exist_ok=True)
# Output file paths
generated_file = os.path.join(output_dir, "sample.png")
edited_file = os.path.join(output_dir, "edited.png")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(args.seed) # Use user-configured seed
config_path = os.path.join(addpath, args.config_path)
sana = DPMInversePipeline(config_path)
sana.from_pretrained(args.from_pretrained)
# Generate images
images = sana(
prompt=args.src_prompt,
height=args.height,
width=args.width,
negative_prompt=args.negative_prompt,
guidance_scale=args.guidance_scale,
pag_guidance_scale=args.pag_guidance_scale,
num_inference_steps=args.num_inference_steps,
generator=generator,
num_images_per_prompt=args.num_images_per_prompt
)
print(f"Generated image shape: {images[0].shape}")
save_image(images[0], generated_file, nrow=1, normalize=True, value_range=(-1, 1))
print(f"Image saved to {generated_file}")
# Edit images
edited_images = sana.edit(
src_prompt=args.src_prompt,
tgt_prompt=args.tgt_prompt,
src_img=images,
height=args.height,
width=args.width,
negative_prompt=args.negative_prompt,
num_inversion_steps=args.num_inversion_steps,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
pag_guidance_scale=args.pag_guidance_scale,
generator=generator,
)
print(f"Edited image shape: {edited_images[0].shape}")
save_image(edited_images[0], edited_file, nrow=1, normalize=True, value_range=(-1, 1))
print(f"Edited image saved to {edited_file}")
The changes are:
- I customized the pipeline of sana to add a function for editing the image
- I customized the inverse function of the DPM-Solver to add flow_shift as an argument. It's worth noting that when I use
flow_shift=1
orflow_shift=3
, the noisy latents I got are the same (nan latent).
Thank you!
Metadata
Metadata
Assignees
Labels
No labels