diff --git a/guidance/if_utils.py b/guidance/if_utils.py index 0dcce221..b4ad016a 100644 --- a/guidance/if_utils.py +++ b/guidance/if_utils.py @@ -34,7 +34,7 @@ def seed_everything(seed): class IF(nn.Module): - def __init__(self, device, vram_O, t_range=[0.02, 0.98]): + def __init__(self, device, vram_O, t_range=[0.02, 0.98], fp16=True): super().__init__() self.device = device @@ -45,8 +45,10 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]): is_torch2 = torch.__version__[0] == '2' + self.precision_t = torch.float16 if fp16 else torch.float32 + # Create model - pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16) + pipe = IFPipeline.from_pretrained(model_key, variant="fp16" if fp16 else "fp32", torch_dtype=self.precision_t) if not is_torch2: pipe.enable_xformers_memory_efficient_attention() @@ -175,6 +177,50 @@ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num return imgs + def img_opt(self, text_embeddings, images, guidance_scale=40.0, sorted=True): + + if sorted: + timesteps = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device).sort()[0] + else: + timesteps = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device) + + with torch.no_grad(): + + # add noise to latents using the timesteps + noise = torch.randn_like(images) + images_noisy = self.scheduler.add_noise(images, noise, timesteps).to(self.device) + + # predict the noise residual + model_input = torch.cat([images_noisy] * 2) + model_input = self.scheduler.scale_model_input(model_input, timesteps) + tt = torch.cat([timesteps] * 2) + text_input = text_embeddings.repeat_interleave(len(images_noisy), 0) + noise_pred = [] + # To ensure batch_size=1 + for s, t, text in zip(model_input, tt, text_input): + noise_pred.append(self.unet(sample=s[None, ...], + timestep=t[None, ...], + encoder_hidden_states=text[None, ...]).sample) + + noise_pred = torch.cat(noise_pred) + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # w(t), sigma_t^2 + w = (1 - self.alphas[timesteps]) + grad = w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # since we omitted an item in grad, we need to use the custom function to specify the gradient + loss = SpecifyGradient.apply(images, grad) + + return images, loss + + if __name__ == '__main__': import argparse diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py index 3a00ab9f..7ead3ffe 100644 --- a/guidance/sd_utils.py +++ b/guidance/sd_utils.py @@ -153,11 +153,11 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa # see zero123_utils.py's version for a simpler implementation. alphas = self.scheduler.alphas.to(latents) total_timesteps = self.max_step - self.min_step + 1 - index = total_timesteps - t.to(latents.device) - 1 + index = total_timesteps - t.to(latents.device) - 1 b = len(noise_pred) a_t = alphas[index].reshape(b,1,1,1).to(self.device) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) - sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device) pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) @@ -242,6 +242,46 @@ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num return imgs + def img_opt(self, text_embeddings, latents, guidance_scale=40.0, sorted=True): + + with torch.no_grad(): + + if sorted: + timesteps = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device).sort()[0] + else: + timesteps = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn_like(latents) + noisy_latents = self.scheduler.add_noise(latents, noise, timesteps).to(self.device) + + # predict the noise residual + samples_unet = torch.cat([noisy_latents] * 2) + timesteps_unet = torch.cat([timesteps] * 2) + text_unet = text_embeddings.repeat_interleave(len(noisy_latents), 0) + noise_pred = [] + # To ensure batch_size=1 + for s, t, text in zip(samples_unet, timesteps_unet, text_unet): + noise_pred.append(self.unet(sample=s[None, ...], + timestep=t[None, ...], + encoder_hidden_states=text[None, ...]).sample) + + noise_pred = torch.cat(noise_pred) + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # w(t), sigma_t^2 + w = (1 - self.alphas[timesteps]) + grad = w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # since we omitted an item in grad, we need to use the custom function to specify the gradient + loss = SpecifyGradient.apply(latents, grad) + + return latents, loss + if __name__ == '__main__': diff --git a/main.py b/main.py index 6cccafce..c613ab0b 100644 --- a/main.py +++ b/main.py @@ -159,6 +159,9 @@ def __call__ (self, parser, namespace, values, option_string = None): parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level") parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level") + # optimize input image acc to prompt + parser.add_argument('--img_opt', action='store_true', help="optimize image acc to text prompt") + opt = parser.parse_args() if opt.O: @@ -196,7 +199,7 @@ def __call__ (self, parser, namespace, values, option_string = None): else: # use stable-diffusion when providing both text and image opt.guidance = ['SD', 'clip'] - + if not opt.dont_override_stuff: opt.guidance_scale = 10 opt.t_range = [0.2, 0.6] @@ -212,7 +215,7 @@ def __call__ (self, parser, namespace, values, option_string = None): opt.latent_iter_ratio = 0 if not opt.dont_override_stuff: opt.albedo_iter_ratio = 0 - + # make shape init more stable opt.progressive_view = True opt.progressive_level = True @@ -249,7 +252,7 @@ def __call__ (self, parser, namespace, values, option_string = None): opt.w = int(opt.w * opt.dmtet_reso_scale) opt.known_view_scale = 1 - if not opt.dont_override_stuff: + if not opt.dont_override_stuff: opt.t_range = [0.02, 0.50] # ref: magic3D if opt.images is not None: @@ -271,7 +274,7 @@ def __call__ (self, parser, namespace, values, option_string = None): if not opt.dont_override_stuff: # disable as they disturb progressive view opt.jitter_pose = False - + opt.uniform_sphere_rate = 0 # back up full range opt.full_radius_range = opt.radius_range @@ -334,6 +337,25 @@ def __call__ (self, parser, namespace, values, option_string = None): if opt.save_mesh: trainer.save_mesh() + elif opt.img_opt: + + if 'SD' in opt.guidance: + from guidance.sd_utils import StableDiffusion + guide = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range) + name = f"SD" + + elif 'IF' in opt.guidance: + from guidance.if_utils import IF + guide = IF(device, opt.vram_O, opt.t_range, opt.fp16) + name = f"DeepFloydIF" + + seed = opt.seed or 0 + + name += f"_fp{16 if opt.fp16 else 32}_iters{opt.iters}_lr{opt.lr}_seed{seed}_{opt.text.replace(' ', '_')}" + + img_opt = ImageOpt(opt, guide, name, seed) + img_opt.train() + elif opt.test: guidance = None # no need to load guidance model at test @@ -376,7 +398,7 @@ def __call__ (self, parser, namespace, values, option_string = None): if 'IF' in opt.guidance: from guidance.if_utils import IF - guidance['IF'] = IF(device, opt.vram_O, opt.t_range) + guidance['IF'] = IF(device, opt.vram_O, opt.t_range, fp16=opt.fp16) if 'zero123' in opt.guidance: from guidance.zero123_utils import Zero123 diff --git a/nerf/utils.py b/nerf/utils.py index 0fc29103..af505ac6 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -1257,3 +1257,106 @@ def get_GPU_mem(): mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000) mem += mems[-1] return mem, mems + + +class ImageOpt(): + + def __init__(self, opt, guide, name="SD_imgopt", seed=0): + + self.opt = opt + self.guide = guide + self.name = name + + for p in guide.parameters(): + p.requires_grad = False + + torch.manual_seed(seed) + if opt.IF: + images = torch.rand((opt.batch_size, 3, 64, 64), device=guide.device, dtype=guide.precision_t) * 2 - 1 + else: + images = torch.randn((opt.batch_size, 4, 64, 64), device=guide.device, dtype=guide.precision_t) + + self.images = images.requires_grad_(True) + + self.optimizer = optim.Adam([self.images], lr=opt.lr) # naive adam + + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.opt.fp16) + + self.sample_freq = max(1, self.opt.iters//20) + + uncond_embeddings = guide.get_text_embeds([""]) + text_embeddings = guide.get_text_embeds([self.opt.text]) + self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + def train(self): + + self.local_step = 0 + samples = [] + + for i in tqdm.tqdm(range(self.opt.iters)): + + self.local_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.opt.fp16): + self.images, loss = self.guide.img_opt(self.text_embeddings, self.images) + + # hooked grad clipping for RGB space + if self.opt.grad_clip_rgb >= 0: + def _hook(grad): + if self.opt.fp16: + # correctly handle the scale + grad_scale = self.scaler._get_scale_async() + return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb) + else: + return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb) + self.images.register_hook(_hook) + # self.images.retain_grad() + + self.scaler.scale(loss).backward() + + self.post_train_step() + self.scaler.step(self.optimizer) + self.scaler.update() + self.lr_scheduler.step() + + if i % self.sample_freq == 0: + if self.opt.IF: + samples.append(self.images.detach().cpu().permute(0, 2, 3, 1).numpy()) + else: + with torch.no_grad(): + images = [] + for latent in self.images: + images.append(self.guide.decode_latents(latent[None, ...].detach())) + image = torch.cat(images).cpu().permute(0, 2, 3, 1).numpy() + samples.append(image) + + gif = self.make_gif_from_imgs(samples, resize=(4 if "SD" in self.name else 1)) + os.makedirs(self.opt.workspace, exist_ok=True) + imageio.mimwrite(os.path.join(self.opt.workspace, f'{self.name}_imgopt.mp4'), gif, fps=10, quality=8, macro_block_size=1) + + def make_gif_from_imgs(self, frames, resize=1.0, upto=None, repeat_first=2, repeat_last=5, skip=1, + f=0, s=0.75, t=2): + imgs = [] + from PIL import Image + for i, img in tqdm.tqdm(enumerate(frames[:upto:skip]), total=len(frames[:upto:skip])): + img = np.moveaxis(img, 0, 1).reshape(img.shape[1], -1, 3) + img = np.array(Image.fromarray((img*255).astype(np.uint8)).resize((int(img.shape[1]/resize), int(img.shape[0]/resize)), Image.Resampling.LANCZOS)) + text = f"{i*self.sample_freq:05d}" + img = cv2.putText(img=img, text=text, org=(0, 20), fontFace=f, fontScale=s, color=(0,0,0), thickness=t) + imgs.append(img) + # Save gif + return [imgs[0]]*repeat_first + imgs + [imgs[-1]]*repeat_last + + def post_train_step(self): + + # unscale grad before modifying it! + # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping + self.scaler.unscale_(self.optimizer) + + # clip grad + if self.opt.grad_clip >= 0: + torch.nn.utils.clip_grad_value_(self.images, self.opt.grad_clip) diff --git a/scripts/run_img_opt.sh b/scripts/run_img_opt.sh new file mode 100644 index 00000000..4c2aa630 --- /dev/null +++ b/scripts/run_img_opt.sh @@ -0,0 +1,8 @@ +# SD fp32 +CUDA_VISIBLE_DEVICES=0 python main.py --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0 +# DeepFloydIF fp32 +CUDA_VISIBLE_DEVICES=0 python main.py --IF --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0 +# # SD fp16 +# CUDA_VISIBLE_DEVICES=0 python main.py -O --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0 +# # DeepFloydIF fp16 +# CUDA_VISIBLE_DEVICES=0 python main.py --IF -O --img_opt --text "A high quality 3D render of a strawberry" --workspace trial_imgopt --batch_size 10 --iters 100 --lr 0.1 --seed 0 \ No newline at end of file