Skip to content

Commit 9008531

Browse files
committed
fix repaint
1 parent 417ff80 commit 9008531

File tree

2 files changed

+96
-56
lines changed

2 files changed

+96
-56
lines changed

dalle2_pytorch/dalle2_pytorch.py

+95-55
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,17 @@ def q_sample(self, x_start, t, noise=None):
516516
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
517517
)
518518

519+
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
520+
shape = x_from.shape
521+
noise = default(noise, lambda: torch.randn_like(x_from))
522+
523+
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
524+
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
525+
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
526+
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
527+
528+
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
529+
519530
def predict_start_from_noise(self, x_t, t, noise):
520531
return (
521532
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -2432,14 +2443,18 @@ def p_sample_loop_ddpm(
24322443
is_latent_diffusion = False,
24332444
lowres_noise_level = None,
24342445
inpaint_image = None,
2435-
inpaint_mask = None
2446+
inpaint_mask = None,
2447+
inpaint_resample_times = 5
24362448
):
24372449
device = self.device
24382450

24392451
b = shape[0]
24402452
img = torch.randn(shape, device = device)
24412453

2442-
if exists(inpaint_image):
2454+
is_inpaint = exists(inpaint_image)
2455+
resample_times = inpaint_resample_times if is_inpaint else 1
2456+
2457+
if is_inpaint:
24432458
inpaint_image = self.normalize_img(inpaint_image)
24442459
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
24452460
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2449,31 +2464,40 @@ def p_sample_loop_ddpm(
24492464
if not is_latent_diffusion:
24502465
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
24512466

2452-
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
2453-
times = torch.full((b,), i, device = device, dtype = torch.long)
2454-
2455-
if exists(inpaint_image):
2456-
# following the repaint paper
2457-
# https://arxiv.org/abs/2201.09865
2458-
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
2459-
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
2460-
2461-
img = self.p_sample(
2462-
unet,
2463-
img,
2464-
times,
2465-
image_embed = image_embed,
2466-
text_encodings = text_encodings,
2467-
cond_scale = cond_scale,
2468-
lowres_cond_img = lowres_cond_img,
2469-
lowres_noise_level = lowres_noise_level,
2470-
predict_x_start = predict_x_start,
2471-
noise_scheduler = noise_scheduler,
2472-
learned_variance = learned_variance,
2473-
clip_denoised = clip_denoised
2474-
)
2467+
for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
2468+
is_last_timestep = time == 0
24752469

2476-
if exists(inpaint_image):
2470+
for r in reversed(range(0, resample_times)):
2471+
is_last_resample_step = r == 0
2472+
2473+
times = torch.full((b,), time, device = device, dtype = torch.long)
2474+
2475+
if is_inpaint:
2476+
# following the repaint paper
2477+
# https://arxiv.org/abs/2201.09865
2478+
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
2479+
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
2480+
2481+
img = self.p_sample(
2482+
unet,
2483+
img,
2484+
times,
2485+
image_embed = image_embed,
2486+
text_encodings = text_encodings,
2487+
cond_scale = cond_scale,
2488+
lowres_cond_img = lowres_cond_img,
2489+
lowres_noise_level = lowres_noise_level,
2490+
predict_x_start = predict_x_start,
2491+
noise_scheduler = noise_scheduler,
2492+
learned_variance = learned_variance,
2493+
clip_denoised = clip_denoised
2494+
)
2495+
2496+
if is_inpaint and not (is_last_timestep or is_last_resample_step):
2497+
# in repaint, you renoise and resample up to 10 times every step
2498+
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
2499+
2500+
if is_inpaint:
24772501
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
24782502

24792503
unnormalize_img = self.unnormalize_img(img)
@@ -2497,7 +2521,8 @@ def p_sample_loop_ddim(
24972521
is_latent_diffusion = False,
24982522
lowres_noise_level = None,
24992523
inpaint_image = None,
2500-
inpaint_mask = None
2524+
inpaint_mask = None,
2525+
inpaint_resample_times = 5
25012526
):
25022527
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
25032528

@@ -2506,7 +2531,10 @@ def p_sample_loop_ddim(
25062531
times = list(reversed(times.int().tolist()))
25072532
time_pairs = list(zip(times[:-1], times[1:]))
25082533

2509-
if exists(inpaint_image):
2534+
is_inpaint = exists(inpaint_image)
2535+
resample_times = inpaint_resample_times if is_inpaint else 1
2536+
2537+
if is_inpaint:
25102538
inpaint_image = self.normalize_img(inpaint_image)
25112539
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
25122540
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2519,39 +2547,49 @@ def p_sample_loop_ddim(
25192547
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
25202548

25212549
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
2522-
alpha = alphas[time]
2523-
alpha_next = alphas[time_next]
2550+
is_last_timestep = time_next == 0
25242551

2525-
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
2552+
for r in reversed(range(0, resample_times)):
2553+
is_last_resample_step = r == 0
25262554

2527-
if exists(inpaint_image):
2528-
# following the repaint paper
2529-
# https://arxiv.org/abs/2201.09865
2530-
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
2531-
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
2555+
alpha = alphas[time]
2556+
alpha_next = alphas[time_next]
25322557

2533-
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
2558+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
25342559

2535-
if learned_variance:
2536-
pred, _ = pred.chunk(2, dim = 1)
2560+
if is_inpaint:
2561+
# following the repaint paper
2562+
# https://arxiv.org/abs/2201.09865
2563+
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
2564+
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
25372565

2538-
if predict_x_start:
2539-
x_start = pred
2540-
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
2541-
else:
2542-
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
2543-
pred_noise = pred
2566+
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
25442567

2545-
if clip_denoised:
2546-
x_start = self.dynamic_threshold(x_start)
2568+
if learned_variance:
2569+
pred, _ = pred.chunk(2, dim = 1)
25472570

2548-
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
2549-
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
2550-
noise = torch.randn_like(img) if time_next > 0 else 0.
2571+
if predict_x_start:
2572+
x_start = pred
2573+
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
2574+
else:
2575+
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
2576+
pred_noise = pred
2577+
2578+
if clip_denoised:
2579+
x_start = self.dynamic_threshold(x_start)
2580+
2581+
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
2582+
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
2583+
noise = torch.randn_like(img) if not is_last_timestep else 0.
2584+
2585+
img = x_start * alpha_next.sqrt() + \
2586+
c1 * noise + \
2587+
c2 * pred_noise
25512588

2552-
img = x_start * alpha_next.sqrt() + \
2553-
c1 * noise + \
2554-
c2 * pred_noise
2589+
if is_inpaint and not (is_last_timestep or is_last_resample_step):
2590+
# in repaint, you renoise and resample up to 10 times every step
2591+
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
2592+
img = noise_scheduler.q_sample_from_to(img, time_cond, time_next_cond)
25552593

25562594
if exists(inpaint_image):
25572595
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
@@ -2658,7 +2696,8 @@ def sample(
26582696
stop_at_unet_number = None,
26592697
distributed = False,
26602698
inpaint_image = None,
2661-
inpaint_mask = None
2699+
inpaint_mask = None,
2700+
inpaint_resample_times = 5
26622701
):
26632702
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
26642703

@@ -2730,7 +2769,8 @@ def sample(
27302769
noise_scheduler = noise_scheduler,
27312770
timesteps = sample_timesteps,
27322771
inpaint_image = inpaint_image,
2733-
inpaint_mask = inpaint_mask
2772+
inpaint_mask = inpaint_mask,
2773+
inpaint_resample_times = inpaint_resample_times
27342774
)
27352775

27362776
img = vae.decode(img)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.3'
1+
__version__ = '1.0.4'

0 commit comments

Comments
 (0)