@@ -516,6 +516,17 @@ def q_sample(self, x_start, t, noise=None):
516
516
extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
517
517
)
518
518
519
+ def q_sample_from_to (self , x_from , from_t , to_t , noise = None ):
520
+ shape = x_from .shape
521
+ noise = default (noise , lambda : torch .randn_like (x_from ))
522
+
523
+ alpha = extract (self .sqrt_alphas_cumprod , from_t , shape )
524
+ sigma = extract (self .sqrt_one_minus_alphas_cumprod , from_t , shape )
525
+ alpha_next = extract (self .sqrt_alphas_cumprod , to_t , shape )
526
+ sigma_next = extract (self .sqrt_one_minus_alphas_cumprod , to_t , shape )
527
+
528
+ return x_from * (alpha_next / alpha ) + noise * (sigma_next * alpha - sigma * alpha_next ) / alpha
529
+
519
530
def predict_start_from_noise (self , x_t , t , noise ):
520
531
return (
521
532
extract (self .sqrt_recip_alphas_cumprod , t , x_t .shape ) * x_t -
@@ -2432,14 +2443,18 @@ def p_sample_loop_ddpm(
2432
2443
is_latent_diffusion = False ,
2433
2444
lowres_noise_level = None ,
2434
2445
inpaint_image = None ,
2435
- inpaint_mask = None
2446
+ inpaint_mask = None ,
2447
+ inpaint_resample_times = 5
2436
2448
):
2437
2449
device = self .device
2438
2450
2439
2451
b = shape [0 ]
2440
2452
img = torch .randn (shape , device = device )
2441
2453
2442
- if exists (inpaint_image ):
2454
+ is_inpaint = exists (inpaint_image )
2455
+ resample_times = inpaint_resample_times if is_inpaint else 1
2456
+
2457
+ if is_inpaint :
2443
2458
inpaint_image = self .normalize_img (inpaint_image )
2444
2459
inpaint_image = resize_image_to (inpaint_image , shape [- 1 ], nearest = True )
2445
2460
inpaint_mask = rearrange (inpaint_mask , 'b h w -> b 1 h w' ).float ()
@@ -2449,31 +2464,40 @@ def p_sample_loop_ddpm(
2449
2464
if not is_latent_diffusion :
2450
2465
lowres_cond_img = maybe (self .normalize_img )(lowres_cond_img )
2451
2466
2452
- for i in tqdm (reversed (range (0 , noise_scheduler .num_timesteps )), desc = 'sampling loop time step' , total = noise_scheduler .num_timesteps ):
2453
- times = torch .full ((b ,), i , device = device , dtype = torch .long )
2454
-
2455
- if exists (inpaint_image ):
2456
- # following the repaint paper
2457
- # https://arxiv.org/abs/2201.09865
2458
- noised_inpaint_image = noise_scheduler .q_sample (inpaint_image , t = times )
2459
- img = (img * ~ inpaint_mask ) + (noised_inpaint_image * inpaint_mask )
2460
-
2461
- img = self .p_sample (
2462
- unet ,
2463
- img ,
2464
- times ,
2465
- image_embed = image_embed ,
2466
- text_encodings = text_encodings ,
2467
- cond_scale = cond_scale ,
2468
- lowres_cond_img = lowres_cond_img ,
2469
- lowres_noise_level = lowres_noise_level ,
2470
- predict_x_start = predict_x_start ,
2471
- noise_scheduler = noise_scheduler ,
2472
- learned_variance = learned_variance ,
2473
- clip_denoised = clip_denoised
2474
- )
2467
+ for time in tqdm (reversed (range (0 , noise_scheduler .num_timesteps )), desc = 'sampling loop time step' , total = noise_scheduler .num_timesteps ):
2468
+ is_last_timestep = time == 0
2475
2469
2476
- if exists (inpaint_image ):
2470
+ for r in reversed (range (0 , resample_times )):
2471
+ is_last_resample_step = r == 0
2472
+
2473
+ times = torch .full ((b ,), time , device = device , dtype = torch .long )
2474
+
2475
+ if is_inpaint :
2476
+ # following the repaint paper
2477
+ # https://arxiv.org/abs/2201.09865
2478
+ noised_inpaint_image = noise_scheduler .q_sample (inpaint_image , t = times )
2479
+ img = (img * ~ inpaint_mask ) + (noised_inpaint_image * inpaint_mask )
2480
+
2481
+ img = self .p_sample (
2482
+ unet ,
2483
+ img ,
2484
+ times ,
2485
+ image_embed = image_embed ,
2486
+ text_encodings = text_encodings ,
2487
+ cond_scale = cond_scale ,
2488
+ lowres_cond_img = lowres_cond_img ,
2489
+ lowres_noise_level = lowres_noise_level ,
2490
+ predict_x_start = predict_x_start ,
2491
+ noise_scheduler = noise_scheduler ,
2492
+ learned_variance = learned_variance ,
2493
+ clip_denoised = clip_denoised
2494
+ )
2495
+
2496
+ if is_inpaint and not (is_last_timestep or is_last_resample_step ):
2497
+ # in repaint, you renoise and resample up to 10 times every step
2498
+ img = noise_scheduler .q_sample_from_to (img , times - 1 , times )
2499
+
2500
+ if is_inpaint :
2477
2501
img = (img * ~ inpaint_mask ) + (inpaint_image * inpaint_mask )
2478
2502
2479
2503
unnormalize_img = self .unnormalize_img (img )
@@ -2497,7 +2521,8 @@ def p_sample_loop_ddim(
2497
2521
is_latent_diffusion = False ,
2498
2522
lowres_noise_level = None ,
2499
2523
inpaint_image = None ,
2500
- inpaint_mask = None
2524
+ inpaint_mask = None ,
2525
+ inpaint_resample_times = 5
2501
2526
):
2502
2527
batch , device , total_timesteps , alphas , eta = shape [0 ], self .device , noise_scheduler .num_timesteps , noise_scheduler .alphas_cumprod_prev , self .ddim_sampling_eta
2503
2528
@@ -2506,7 +2531,10 @@ def p_sample_loop_ddim(
2506
2531
times = list (reversed (times .int ().tolist ()))
2507
2532
time_pairs = list (zip (times [:- 1 ], times [1 :]))
2508
2533
2509
- if exists (inpaint_image ):
2534
+ is_inpaint = exists (inpaint_image )
2535
+ resample_times = inpaint_resample_times if is_inpaint else 1
2536
+
2537
+ if is_inpaint :
2510
2538
inpaint_image = self .normalize_img (inpaint_image )
2511
2539
inpaint_image = resize_image_to (inpaint_image , shape [- 1 ], nearest = True )
2512
2540
inpaint_mask = rearrange (inpaint_mask , 'b h w -> b 1 h w' ).float ()
@@ -2519,39 +2547,49 @@ def p_sample_loop_ddim(
2519
2547
lowres_cond_img = maybe (self .normalize_img )(lowres_cond_img )
2520
2548
2521
2549
for time , time_next in tqdm (time_pairs , desc = 'sampling loop time step' ):
2522
- alpha = alphas [time ]
2523
- alpha_next = alphas [time_next ]
2550
+ is_last_timestep = time_next == 0
2524
2551
2525
- time_cond = torch .full ((batch ,), time , device = device , dtype = torch .long )
2552
+ for r in reversed (range (0 , resample_times )):
2553
+ is_last_resample_step = r == 0
2526
2554
2527
- if exists (inpaint_image ):
2528
- # following the repaint paper
2529
- # https://arxiv.org/abs/2201.09865
2530
- noised_inpaint_image = noise_scheduler .q_sample (inpaint_image , t = time_cond )
2531
- img = (img * ~ inpaint_mask ) + (noised_inpaint_image * inpaint_mask )
2555
+ alpha = alphas [time ]
2556
+ alpha_next = alphas [time_next ]
2532
2557
2533
- pred = unet . forward_with_cond_scale ( img , time_cond , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , lowres_cond_img = lowres_cond_img , lowres_noise_level = lowres_noise_level )
2558
+ time_cond = torch . full (( batch ,), time , device = device , dtype = torch . long )
2534
2559
2535
- if learned_variance :
2536
- pred , _ = pred .chunk (2 , dim = 1 )
2560
+ if is_inpaint :
2561
+ # following the repaint paper
2562
+ # https://arxiv.org/abs/2201.09865
2563
+ noised_inpaint_image = noise_scheduler .q_sample (inpaint_image , t = time_cond )
2564
+ img = (img * ~ inpaint_mask ) + (noised_inpaint_image * inpaint_mask )
2537
2565
2538
- if predict_x_start :
2539
- x_start = pred
2540
- pred_noise = noise_scheduler .predict_noise_from_start (img , t = time_cond , x0 = pred )
2541
- else :
2542
- x_start = noise_scheduler .predict_start_from_noise (img , t = time_cond , noise = pred )
2543
- pred_noise = pred
2566
+ pred = unet .forward_with_cond_scale (img , time_cond , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , lowres_cond_img = lowres_cond_img , lowres_noise_level = lowres_noise_level )
2544
2567
2545
- if clip_denoised :
2546
- x_start = self . dynamic_threshold ( x_start )
2568
+ if learned_variance :
2569
+ pred , _ = pred . chunk ( 2 , dim = 1 )
2547
2570
2548
- c1 = eta * ((1 - alpha / alpha_next ) * (1 - alpha_next ) / (1 - alpha )).sqrt ()
2549
- c2 = ((1 - alpha_next ) - torch .square (c1 )).sqrt ()
2550
- noise = torch .randn_like (img ) if time_next > 0 else 0.
2571
+ if predict_x_start :
2572
+ x_start = pred
2573
+ pred_noise = noise_scheduler .predict_noise_from_start (img , t = time_cond , x0 = pred )
2574
+ else :
2575
+ x_start = noise_scheduler .predict_start_from_noise (img , t = time_cond , noise = pred )
2576
+ pred_noise = pred
2577
+
2578
+ if clip_denoised :
2579
+ x_start = self .dynamic_threshold (x_start )
2580
+
2581
+ c1 = eta * ((1 - alpha / alpha_next ) * (1 - alpha_next ) / (1 - alpha )).sqrt ()
2582
+ c2 = ((1 - alpha_next ) - torch .square (c1 )).sqrt ()
2583
+ noise = torch .randn_like (img ) if not is_last_timestep else 0.
2584
+
2585
+ img = x_start * alpha_next .sqrt () + \
2586
+ c1 * noise + \
2587
+ c2 * pred_noise
2551
2588
2552
- img = x_start * alpha_next .sqrt () + \
2553
- c1 * noise + \
2554
- c2 * pred_noise
2589
+ if is_inpaint and not (is_last_timestep or is_last_resample_step ):
2590
+ # in repaint, you renoise and resample up to 10 times every step
2591
+ time_next_cond = torch .full ((batch ,), time_next , device = device , dtype = torch .long )
2592
+ img = noise_scheduler .q_sample_from_to (img , time_next_cond , time_cond )
2555
2593
2556
2594
if exists (inpaint_image ):
2557
2595
img = (img * ~ inpaint_mask ) + (inpaint_image * inpaint_mask )
@@ -2658,7 +2696,8 @@ def sample(
2658
2696
stop_at_unet_number = None ,
2659
2697
distributed = False ,
2660
2698
inpaint_image = None ,
2661
- inpaint_mask = None
2699
+ inpaint_mask = None ,
2700
+ inpaint_resample_times = 5
2662
2701
):
2663
2702
assert self .unconditional or exists (image_embed ), 'image embed must be present on sampling from decoder unless if trained unconditionally'
2664
2703
@@ -2730,7 +2769,8 @@ def sample(
2730
2769
noise_scheduler = noise_scheduler ,
2731
2770
timesteps = sample_timesteps ,
2732
2771
inpaint_image = inpaint_image ,
2733
- inpaint_mask = inpaint_mask
2772
+ inpaint_mask = inpaint_mask ,
2773
+ inpaint_resample_times = inpaint_resample_times
2734
2774
)
2735
2775
2736
2776
img = vae .decode (img )
0 commit comments