Skip to content

Commit 723bf0a

Browse files
committed
complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder
1 parent d88c7ba commit 723bf0a

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,8 @@ Once built, images will be saved to the same directory the command is invoked
10491049
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
10501050
- [x] allow for unet to be able to condition non-cross attention style as well
10511051
- [x] speed up inference, read up on papers (ddim)
1052-
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
1053-
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
1052+
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
1053+
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
10541054
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
10551055

10561056
## Citations

dalle2_pytorch/dalle2_pytorch.py

+84-4
Original file line numberDiff line numberDiff line change
@@ -2415,20 +2415,51 @@ def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = No
24152415
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
24162416

24172417
@torch.no_grad()
2418-
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
2418+
def p_sample_loop_ddpm(
2419+
self,
2420+
unet,
2421+
shape,
2422+
image_embed,
2423+
noise_scheduler,
2424+
predict_x_start = False,
2425+
learned_variance = False,
2426+
clip_denoised = True,
2427+
lowres_cond_img = None,
2428+
text_encodings = None,
2429+
cond_scale = 1,
2430+
is_latent_diffusion = False,
2431+
lowres_noise_level = None,
2432+
inpaint_image = None,
2433+
inpaint_mask = None
2434+
):
24192435
device = self.device
24202436

24212437
b = shape[0]
24222438
img = torch.randn(shape, device = device)
24232439

2440+
if exists(inpaint_image):
2441+
inpaint_image = self.normalize_img(inpaint_image)
2442+
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
2443+
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
2444+
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
2445+
inpaint_mask = inpaint_mask.bool()
2446+
24242447
if not is_latent_diffusion:
24252448
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
24262449

24272450
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
2451+
times = torch.full((b,), i, device = device, dtype = torch.long)
2452+
2453+
if exists(inpaint_image):
2454+
# following the repaint paper
2455+
# https://arxiv.org/abs/2201.09865
2456+
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
2457+
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
2458+
24282459
img = self.p_sample(
24292460
unet,
24302461
img,
2431-
torch.full((b,), i, device = device, dtype = torch.long),
2462+
times,
24322463
image_embed = image_embed,
24332464
text_encodings = text_encodings,
24342465
cond_scale = cond_scale,
@@ -2440,18 +2471,46 @@ def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_
24402471
clip_denoised = clip_denoised
24412472
)
24422473

2474+
if exists(inpaint_image):
2475+
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
2476+
24432477
unnormalize_img = self.unnormalize_img(img)
24442478
return unnormalize_img
24452479

24462480
@torch.no_grad()
2447-
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
2481+
def p_sample_loop_ddim(
2482+
self,
2483+
unet,
2484+
shape,
2485+
image_embed,
2486+
noise_scheduler,
2487+
timesteps,
2488+
eta = 1.,
2489+
predict_x_start = False,
2490+
learned_variance = False,
2491+
clip_denoised = True,
2492+
lowres_cond_img = None,
2493+
text_encodings = None,
2494+
cond_scale = 1,
2495+
is_latent_diffusion = False,
2496+
lowres_noise_level = None,
2497+
inpaint_image = None,
2498+
inpaint_mask = None
2499+
):
24482500
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
24492501

24502502
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
24512503

24522504
times = list(reversed(times.int().tolist()))
24532505
time_pairs = list(zip(times[:-1], times[1:]))
24542506

2507+
if exists(inpaint_image):
2508+
inpaint_image = self.normalize_img(inpaint_image)
2509+
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
2510+
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
2511+
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
2512+
inpaint_mask = inpaint_mask.bool()
2513+
24552514
img = torch.randn(shape, device = device)
24562515

24572516
if not is_latent_diffusion:
@@ -2463,6 +2522,12 @@ def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timestep
24632522

24642523
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
24652524

2525+
if exists(inpaint_image):
2526+
# following the repaint paper
2527+
# https://arxiv.org/abs/2201.09865
2528+
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
2529+
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
2530+
24662531
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
24672532

24682533
if learned_variance:
@@ -2486,6 +2551,9 @@ def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timestep
24862551
c1 * noise + \
24872552
c2 * pred_noise
24882553

2554+
if exists(inpaint_image):
2555+
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
2556+
24892557
img = self.unnormalize_img(img)
24902558
return img
24912559

@@ -2585,6 +2653,8 @@ def sample(
25852653
cond_scale = 1.,
25862654
stop_at_unet_number = None,
25872655
distributed = False,
2656+
inpaint_image = None,
2657+
inpaint_mask = None
25882658
):
25892659
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
25902660

@@ -2598,6 +2668,8 @@ def sample(
25982668
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
25992669
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
26002670

2671+
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
2672+
26012673
img = None
26022674
is_cuda = next(self.parameters()).is_cuda
26032675

@@ -2609,6 +2681,8 @@ def sample(
26092681
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
26102682

26112683
with context:
2684+
# prepare low resolution conditioning for upsamplers
2685+
26122686
lowres_cond_img = lowres_noise_level = None
26132687
shape = (batch_size, channel, image_size, image_size)
26142688

@@ -2619,12 +2693,16 @@ def sample(
26192693
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
26202694
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
26212695

2696+
# latent diffusion
2697+
26222698
is_latent_diffusion = isinstance(vae, VQGanVAE)
26232699
image_size = vae.get_encoded_fmap_size(image_size)
26242700
shape = (batch_size, vae.encoded_dim, image_size, image_size)
26252701

26262702
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
26272703

2704+
# denoising loop for image
2705+
26282706
img = self.p_sample_loop(
26292707
unet,
26302708
shape,
@@ -2638,7 +2716,9 @@ def sample(
26382716
lowres_noise_level = lowres_noise_level,
26392717
is_latent_diffusion = is_latent_diffusion,
26402718
noise_scheduler = noise_scheduler,
2641-
timesteps = sample_timesteps
2719+
timesteps = sample_timesteps,
2720+
inpaint_image = inpaint_image,
2721+
inpaint_mask = inpaint_mask
26422722
)
26432723

26442724
img = vae.decode(img)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.25.2'
1+
__version__ = '0.26.0'

0 commit comments

Comments
 (0)